├── .flake8
├── .github
└── workflows
│ ├── code-scan.yml
│ ├── notebook-linting.yml
│ ├── notebook-tests.yml
│ ├── python-linting.yml
│ ├── python-package-conda.yml
│ ├── python-package.yml
│ └── python-publish.yml
├── .gitignore
├── CODEOWNERS
├── LICENSE
├── MANIFEST.in
├── README.rst
├── dice_ml
├── __init__.py
├── constants.py
├── counterfactual_explanations.py
├── data.py
├── data_interfaces
│ ├── __init__.py
│ ├── base_data_interface.py
│ ├── private_data_interface.py
│ └── public_data_interface.py
├── dice.py
├── diverse_counterfactuals.py
├── explainer_interfaces
│ ├── __init__.py
│ ├── dice_KD.py
│ ├── dice_genetic.py
│ ├── dice_pytorch.py
│ ├── dice_random.py
│ ├── dice_tensorflow1.py
│ ├── dice_tensorflow2.py
│ ├── dice_xgboost.py
│ ├── explainer_base.py
│ ├── feasible_base_vae.py
│ └── feasible_model_approx.py
├── model.py
├── model_interfaces
│ ├── __init__.py
│ ├── base_model.py
│ ├── keras_tensorflow_model.py
│ ├── pytorch_model.py
│ └── xgboost_model.py
├── schema
│ ├── __init__.py
│ ├── counterfactual_explanations_v1.0.json
│ └── counterfactual_explanations_v2.0.json
└── utils
│ ├── __init__.py
│ ├── exception.py
│ ├── helpers.py
│ ├── neuralnetworks.py
│ ├── sample_architecture
│ ├── __init__.py
│ └── vae_model.py
│ ├── sample_trained_models
│ ├── adult-margin-0.165-validity_reg-42.0-epoch-25-base-gen.pth
│ ├── adult-margin-0.344-validity_reg-76.0-epoch-25-ae-gen.pth
│ ├── adult.h5
│ ├── adult.pkl
│ ├── adult.pth
│ ├── adult_2nodes.pth
│ ├── custom.sav
│ ├── custom_binary.sav
│ ├── custom_multiclass.sav
│ ├── custom_regression.sav
│ └── custom_vars.sav
│ └── serialize.py
├── docs
├── .buildinfo
├── .nojekyll
├── Makefile
├── _images
│ └── dice_getting_started_api.png
├── _modules
│ ├── dice_ml
│ │ ├── constants.html
│ │ ├── counterfactual_explanations.html
│ │ ├── data.html
│ │ ├── data_interfaces
│ │ │ ├── private_data_interface.html
│ │ │ └── public_data_interface.html
│ │ ├── dice.html
│ │ ├── dice_interfaces
│ │ │ ├── dice_base.html
│ │ │ ├── dice_tensorflow1.html
│ │ │ └── dice_tensorflow2.html
│ │ ├── diverse_counterfactuals.html
│ │ ├── explainer_interfaces
│ │ │ ├── dice_KD.html
│ │ │ ├── dice_genetic.html
│ │ │ ├── dice_pytorch.html
│ │ │ ├── dice_random.html
│ │ │ ├── dice_tensorflow1.html
│ │ │ ├── dice_tensorflow2.html
│ │ │ ├── explainer_base.html
│ │ │ ├── feasible_base_vae.html
│ │ │ └── feasible_model_approx.html
│ │ ├── model.html
│ │ ├── model_interfaces
│ │ │ ├── base_model.html
│ │ │ ├── keras_tensorflow_model.html
│ │ │ └── pytorch_model.html
│ │ └── utils
│ │ │ ├── exception.html
│ │ │ ├── helpers.html
│ │ │ ├── neuralnetworks.html
│ │ │ ├── sample_architecture
│ │ │ └── vae_model.html
│ │ │ └── serialize.html
│ └── index.html
├── _static
│ ├── _sphinx_javascript_frameworks_compat.js
│ ├── basic.css
│ ├── css
│ │ ├── badge_only.css
│ │ ├── fonts
│ │ │ ├── Roboto-Slab-Bold.woff
│ │ │ ├── Roboto-Slab-Bold.woff2
│ │ │ ├── Roboto-Slab-Regular.woff
│ │ │ ├── Roboto-Slab-Regular.woff2
│ │ │ ├── fontawesome-webfont.eot
│ │ │ ├── fontawesome-webfont.svg
│ │ │ ├── fontawesome-webfont.ttf
│ │ │ ├── fontawesome-webfont.woff
│ │ │ ├── fontawesome-webfont.woff2
│ │ │ ├── lato-bold-italic.woff
│ │ │ ├── lato-bold-italic.woff2
│ │ │ ├── lato-bold.woff
│ │ │ ├── lato-bold.woff2
│ │ │ ├── lato-normal-italic.woff
│ │ │ ├── lato-normal-italic.woff2
│ │ │ ├── lato-normal.woff
│ │ │ └── lato-normal.woff2
│ │ └── theme.css
│ ├── doctools.js
│ ├── documentation_options.js
│ ├── file.png
│ ├── fonts
│ │ ├── FontAwesome.otf
│ │ ├── Inconsolata-Bold.ttf
│ │ ├── Inconsolata-Regular.ttf
│ │ ├── Inconsolata.ttf
│ │ ├── Lato-Bold.ttf
│ │ ├── Lato-Regular.ttf
│ │ ├── Lato
│ │ │ ├── lato-bold.eot
│ │ │ ├── lato-bold.ttf
│ │ │ ├── lato-bold.woff
│ │ │ ├── lato-bold.woff2
│ │ │ ├── lato-bolditalic.eot
│ │ │ ├── lato-bolditalic.ttf
│ │ │ ├── lato-bolditalic.woff
│ │ │ ├── lato-bolditalic.woff2
│ │ │ ├── lato-italic.eot
│ │ │ ├── lato-italic.ttf
│ │ │ ├── lato-italic.woff
│ │ │ ├── lato-italic.woff2
│ │ │ ├── lato-regular.eot
│ │ │ ├── lato-regular.ttf
│ │ │ ├── lato-regular.woff
│ │ │ └── lato-regular.woff2
│ │ ├── Roboto-Slab-Bold.woff
│ │ ├── Roboto-Slab-Bold.woff2
│ │ ├── Roboto-Slab-Light.woff
│ │ ├── Roboto-Slab-Light.woff2
│ │ ├── Roboto-Slab-Regular.woff
│ │ ├── Roboto-Slab-Regular.woff2
│ │ ├── Roboto-Slab-Thin.woff
│ │ ├── Roboto-Slab-Thin.woff2
│ │ ├── RobotoSlab-Bold.ttf
│ │ ├── RobotoSlab-Regular.ttf
│ │ ├── RobotoSlab
│ │ │ ├── roboto-slab-v7-bold.eot
│ │ │ ├── roboto-slab-v7-bold.ttf
│ │ │ ├── roboto-slab-v7-bold.woff
│ │ │ ├── roboto-slab-v7-bold.woff2
│ │ │ ├── roboto-slab-v7-regular.eot
│ │ │ ├── roboto-slab-v7-regular.ttf
│ │ │ ├── roboto-slab-v7-regular.woff
│ │ │ └── roboto-slab-v7-regular.woff2
│ │ ├── fontawesome-webfont.eot
│ │ ├── fontawesome-webfont.svg
│ │ ├── fontawesome-webfont.ttf
│ │ ├── fontawesome-webfont.woff
│ │ ├── fontawesome-webfont.woff2
│ │ ├── lato-bold-italic.woff
│ │ ├── lato-bold-italic.woff2
│ │ ├── lato-bold.woff
│ │ ├── lato-bold.woff2
│ │ ├── lato-normal-italic.woff
│ │ ├── lato-normal-italic.woff2
│ │ ├── lato-normal.woff
│ │ └── lato-normal.woff2
│ ├── getting_started_output.png
│ ├── getting_started_updated.png
│ ├── jquery-3.2.1.js
│ ├── jquery-3.4.1.js
│ ├── jquery-3.5.1.js
│ ├── jquery-3.6.0.js
│ ├── jquery.js
│ ├── js
│ │ ├── badge_only.js
│ │ ├── html5shiv-printshiv.min.js
│ │ ├── html5shiv.min.js
│ │ ├── modernizr.min.js
│ │ └── theme.js
│ ├── language_data.js
│ ├── minus.png
│ ├── plus.png
│ ├── pygments.css
│ ├── searchtools.js
│ ├── sphinx_highlight.js
│ ├── underscore-1.13.1.js
│ ├── underscore-1.3.1.js
│ └── underscore.js
├── dice_ml.data_interfaces.html
├── dice_ml.dice_interfaces.html
├── dice_ml.explainer_interfaces.html
├── dice_ml.html
├── dice_ml.model_interfaces.html
├── dice_ml.schema.html
├── dice_ml.utils.html
├── dice_ml.utils.sample_architecture.html
├── genindex.html
├── index.html
├── make.bat
├── modules.html
├── notebooks
│ ├── Benchmarking_different_CF_explanation_methods.html
│ ├── DiCE_feature_importances.html
│ ├── DiCE_getting_started.html
│ ├── DiCE_getting_started_feasible.html
│ ├── DiCE_model_agnostic_CFs.html
│ ├── DiCE_multiclass_classification_and_regression.html
│ ├── DiCE_with_advanced_options.html
│ ├── DiCE_with_private_data.html
│ └── nb_index.html
├── objects.inv
├── py-modindex.html
├── readme.html
├── search.html
├── searchindex.js
├── source
│ ├── conf.py
│ ├── dice_ml.data_interfaces.rst
│ ├── dice_ml.explainer_interfaces.rst
│ ├── dice_ml.model_interfaces.rst
│ ├── dice_ml.rst
│ ├── dice_ml.schema.rst
│ ├── dice_ml.utils.rst
│ ├── dice_ml.utils.sample_architecture.rst
│ ├── index.rst
│ ├── modules.rst
│ ├── notebooks
│ │ ├── Benchmarking_different_CF_explanation_methods.ipynb
│ │ ├── DiCE_feature_importances.ipynb
│ │ ├── DiCE_getting_started.ipynb
│ │ ├── DiCE_getting_started_feasible.ipynb
│ │ ├── DiCE_model_agnostic_CFs.ipynb
│ │ ├── DiCE_multiclass_classification_and_regression.ipynb
│ │ ├── DiCE_with_advanced_options.ipynb
│ │ ├── DiCE_with_private_data.ipynb
│ │ ├── images
│ │ │ └── dice_getting_started_api.png
│ │ └── nb_index.rst
│ └── readme.rst
└── update_docs.sh
├── environment-deeplearning.yml
├── environment.yml
├── requirements-deeplearning.txt
├── requirements-linting.txt
├── requirements-test.txt
├── requirements.txt
├── setup.cfg
├── setup.py
└── tests
├── __init__.py
├── conftest.py
├── test_counterfactual_explanations.py
├── test_data.py
├── test_data_interface
├── __init__.py
├── test_base_data_interface.py
├── test_private_data_interface.py
└── test_public_data_interface.py
├── test_dice.py
├── test_dice_interface
├── __init__.py
├── test_dice_KD.py
├── test_dice_genetic.py
├── test_dice_pytorch.py
├── test_dice_random.py
├── test_dice_tensorflow.py
└── test_explainer_base.py
├── test_helpers.py
├── test_model.py
├── test_model_interface
├── __init__.py
├── test_base_model.py
├── test_keras_tensorflow_model.py
└── test_pytorch_model.py
└── test_notebooks.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 127
3 |
--------------------------------------------------------------------------------
/.github/workflows/code-scan.yml:
--------------------------------------------------------------------------------
1 | name: code scan
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 | schedule:
9 | - cron: '30 5 * * *'
10 |
11 | jobs:
12 | analyze:
13 | name: Analyze
14 | runs-on: ubuntu-latest
15 | permissions:
16 | actions: read
17 | contents: read
18 | security-events: write
19 |
20 | strategy:
21 | fail-fast: false
22 | matrix:
23 | language: ["python"]
24 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
25 | # Learn more:
26 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed
27 |
28 | steps:
29 | - name: Checkout repository
30 | uses: actions/checkout@v4
31 |
32 | # Initializes the CodeQL tools for scanning.
33 | - name: Initialize CodeQL
34 | uses: github/codeql-action/init@v1
35 | with:
36 | languages: ${{ matrix.language }}
37 | # If you wish to specify custom queries, you can do so here or in a config file.
38 | # By default, queries listed here will override any specified in a config file.
39 | # Prefix the list here with "+" to use these queries and those in the config file.
40 | # queries: ./path/to/local/query, your-org/your-repo/queries@main
41 |
42 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
43 | # If this step fails, then you should remove it and run the build manually (see below)
44 | - name: Autobuild
45 | uses: github/codeql-action/autobuild@v1
46 |
47 | # ℹ️ Command-line programs to run using the OS shell.
48 | # 📚 https://git.io/JvXDl
49 |
50 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
51 | # and modify them (or add more) to build your code if your project
52 | # uses a compiled language
53 |
54 | #- run: |
55 | # make bootstrap
56 | # make release
57 |
58 | - name: Perform CodeQL Analysis
59 | uses: github/codeql-action/analyze@v1
60 |
--------------------------------------------------------------------------------
/.github/workflows/notebook-linting.yml:
--------------------------------------------------------------------------------
1 | # This workflow will lint jupyter notebooks with flake8-nb.
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Notebook linting
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 | schedule:
12 | - cron: '30 5 * * *'
13 |
14 | jobs:
15 | build:
16 | runs-on: ubuntu-latest
17 |
18 | steps:
19 | - uses: actions/checkout@v4
20 | - name: Set up Python 3.11
21 | uses: actions/setup-python@v5
22 | with:
23 | python-version: '3.11'
24 | - name: Install dependencies
25 | run: |
26 | python -m pip install --upgrade pip
27 | pip install flake8-nb==0.4.0
28 | - name: Lint notebooks with flake8_nb
29 | run: |
30 | # stop the build if there are flake8 errors in notebooks
31 | flake8_nb docs/source/notebooks/ --statistics --max-line-length=127
32 |
--------------------------------------------------------------------------------
/.github/workflows/notebook-tests.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Notebook tests
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 | schedule:
12 | - cron: '30 5 * * *'
13 |
14 | jobs:
15 | build:
16 |
17 | runs-on: ${{ matrix.os }}
18 | strategy:
19 | matrix:
20 | python-version: ["3.9", "3.10", "3.11", "3.12"]
21 | os: [ubuntu-latest, macos-latest]
22 |
23 | steps:
24 | - uses: actions/checkout@v4
25 | - name: Set up Python ${{ matrix.python-version }} ${{ matrix.os }}
26 | uses: actions/setup-python@v5
27 | with:
28 | python-version: ${{ matrix.python-version }}
29 | - name: Upgrade pip
30 | run: |
31 | python -m pip install --upgrade pip
32 | - name: Install core dependencies
33 | run: |
34 | pip install -r requirements.txt
35 | - name: Install deep learning dependencies
36 | run: |
37 | pip install -r requirements-deeplearning.txt
38 | - name: Install test dependencies
39 | run: |
40 | pip install -r requirements-test.txt
41 | - name: Test with pytest
42 | run: |
43 | # pytest
44 | pytest tests/ -m "notebook_tests" --durations=10 --doctest-modules --junitxml=junit/test-results.xml --cov=dice_ml --cov-report=xml --cov-report=html
45 |
--------------------------------------------------------------------------------
/.github/workflows/python-linting.yml:
--------------------------------------------------------------------------------
1 | # This workflow will lint python code with flake8.
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python linting
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 | schedule:
12 | - cron: '30 5 * * *'
13 |
14 | jobs:
15 | build:
16 | runs-on: ubuntu-latest
17 |
18 | steps:
19 | - uses: actions/checkout@v4
20 | - name: Set up Python 3.11
21 | uses: actions/setup-python@v5
22 | with:
23 | python-version: '3.11'
24 | - name: Install dependencies
25 | run: |
26 | python -m pip install --upgrade pip
27 | pip install -r requirements-linting.txt
28 | - name: Check sorted python imports using isort
29 | run: |
30 | isort . -c
31 | - name: Lint code with flake8
32 | run: |
33 | # stop the build if there are Python syntax errors or undefined names
34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
35 | # The GitHub editor is 127 chars wide.
36 | flake8 . --count --max-complexity=30 --max-line-length=127 --statistics
37 | # Check for cyclometric complexity for specific files where this metric has been
38 | # reduced to ten and below
39 | flake8 dice_ml/data_interfaces/ --count --max-complexity=10 --max-line-length=127
40 |
41 |
--------------------------------------------------------------------------------
/.github/workflows/python-package-conda.yml:
--------------------------------------------------------------------------------
1 | name: Python Package using Conda
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 | schedule:
9 | - cron: '30 5 * * *'
10 |
11 | jobs:
12 | build-linux:
13 | runs-on: ubuntu-latest
14 | strategy:
15 | max-parallel: 5
16 |
17 | steps:
18 | - uses: actions/checkout@v4
19 | - name: Set up Python 3.8
20 | uses: actions/setup-python@v5
21 | with:
22 | python-version: 3.8
23 | - name: Add conda to system path
24 | run: |
25 | # $CONDA is an environment variable pointing to the root of the miniconda directory
26 | echo $CONDA/bin >> $GITHUB_PATH
27 | - name: Install core dependencies
28 | run: |
29 | conda env update --file environment.yml --name base
30 | - name: Install deep learning dependencies
31 | run: |
32 | conda env update --file environment-deeplearning.yml --name base
33 | - name: Test with pytest
34 | run: |
35 | conda install pytest ipython jupyter nbformat pytest-mock
36 |
37 | pytest
38 |
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python package test
5 |
6 | on:
7 | push:
8 | branches: [ main ]
9 | pull_request:
10 | branches: [ main ]
11 | schedule:
12 | - cron: '30 5 * * *'
13 |
14 | jobs:
15 | build:
16 |
17 | runs-on: ${{ matrix.os }}
18 | strategy:
19 | matrix:
20 | python-version: ["3.9", "3.10", "3.11", "3.12"]
21 | os: [ubuntu-latest, macos-latest, windows-latest]
22 |
23 | steps:
24 | - uses: actions/checkout@v4
25 | - name: Set up Python ${{ matrix.python-version }} ${{ matrix.os }}
26 | uses: actions/setup-python@v5
27 | with:
28 | python-version: ${{ matrix.python-version }}
29 | - name: Upgrade pip
30 | run: |
31 | python -m pip install --upgrade pip
32 | - name: Install core dependencies
33 | run: |
34 | pip install -r requirements.txt
35 | - name: Install deep learning dependencies
36 | run: |
37 | pip install -r requirements-deeplearning.txt
38 | - name: Install test dependencies
39 | run: |
40 | pip install -r requirements-test.txt
41 | - name: Test with pytest
42 | run: |
43 | # pytest
44 | pytest tests/ -m "not notebook_tests" --durations=10 --doctest-modules --junitxml=junit/test-results.xml --cov=dice_ml --cov-report=xml --cov-report=html
45 | - name: Publish Unit Test Results
46 | uses: EnricoMi/publish-unit-test-result-action/composite@v1
47 | if: ${{ (matrix.python-version == '3.9') && (matrix.os == 'ubuntu-latest') }}
48 | with:
49 | files: junit/test-results.xml
50 | # - name: Upload coverage to Codecov
51 | # uses: codecov/codecov-action@v3
52 | # if: ${{ (matrix.python-version == '3.9') && (matrix.os == 'ubuntu-latest') }}
53 | # with:
54 | # token: ${{ secrets.CODECOV_TOKEN }}
55 | # directory: .
56 | # env_vars: OS,PYTHON
57 | # fail_ci_if_error: true
58 | # files: ./coverage.xml
59 | # flags: unittests
60 | # name: codecov-umbrella
61 | # path_to_write_report: ./coverage/codecov_report.txt
62 | # verbose: true
63 | - name: Check package consistency with twine
64 | run: |
65 | python setup.py check sdist bdist_wheel
66 | twine check dist/*
67 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | on:
2 | release:
3 | types: [created]
4 | workflow_dispatch:
5 |
6 | jobs:
7 | pypi-publish:
8 | name: upload release to PyPI
9 | runs-on: ubuntu-latest
10 | # Specifying a GitHub environment is optional, but strongly encouraged
11 | environment: release
12 | permissions:
13 | # IMPORTANT: this permission is mandatory for trusted publishing
14 | id-token: write
15 | steps:
16 | - uses: actions/checkout@v4
17 | - name: Set up Python
18 | uses: actions/setup-python@v5
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | run: |
27 | python setup.py sdist bdist_wheel
28 | - name: Publish package distributions to PyPI
29 | uses: pypa/gh-action-pypi-publish@release/v1
30 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 | docs/build/
69 | docs/_sources/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 | *.data
108 |
109 | # notebook to solve issues
110 | notebooks/DiCE_issues.ipynb
111 | # to avoid auto copied notebooks to be stored in repo
112 | docs/notebooks/DiCE_getting_started.ipynb
113 | docs/notebooks/DiCE_getting_started_feasible.ipynb
114 | docs/notebooks/DiCE_with_advanced_options.ipynb
115 | docs/notebooks/DiCE_with_private_data.ipynb
116 | docs/notebooks/*.ipynb
117 |
118 |
119 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # dice-ml package
2 | /dice_ml @gaugup @amit-sharma
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation. All rights reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 | include requirements-deeplearning.txt
3 | include requirements-test.txt
4 | include requirements-linting.txt
5 | include environment.yml
6 | include environment-deeplearning.yml
7 | include LICENSE
8 | include CODEOWNERS
9 | recursive-include docs *
10 | recursive-include tests *.py
11 | include dice_ml/utils/sample_trained_models/*
12 |
--------------------------------------------------------------------------------
/dice_ml/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import Data
2 | from .dice import Dice
3 | from .model import Model
4 |
5 | __all__ = ["Data",
6 | "Model",
7 | "Dice"]
8 |
--------------------------------------------------------------------------------
/dice_ml/constants.py:
--------------------------------------------------------------------------------
1 | """Constants for dice-ml package."""
2 |
3 |
4 | class BackEndTypes:
5 | Sklearn = 'sklearn'
6 | Tensorflow1 = 'TF1'
7 | Tensorflow2 = 'TF2'
8 | Pytorch = 'PYT'
9 |
10 | ALL = [Sklearn, Tensorflow1, Tensorflow2, Pytorch]
11 |
12 |
13 | class SamplingStrategy:
14 | Random = 'random'
15 | Genetic = 'genetic'
16 | KdTree = 'kdtree'
17 | Gradient = 'gradient'
18 |
19 |
20 | class ModelTypes:
21 | Classifier = 'classifier'
22 | Regressor = 'regressor'
23 |
24 | ALL = [Classifier, Regressor]
25 |
26 |
27 | class _SchemaVersions:
28 | V1 = '1.0'
29 | V2 = '2.0'
30 | CURRENT_VERSION = V2
31 |
32 | ALL_VERSIONS = [V1, V2]
33 |
34 |
35 | class _PostHocSparsityTypes:
36 | LINEAR = 'linear'
37 | BINARY = 'binary'
38 |
39 | ALL = [LINEAR, BINARY]
40 |
--------------------------------------------------------------------------------
/dice_ml/data.py:
--------------------------------------------------------------------------------
1 | """Module pointing to different implementations of Data class
2 |
3 | DiCE requires only few parameters about the data such as the range of continuous
4 | features and the levels of categorical features. Hence, DiCE can be used for a
5 | private data whose meta data are only available (such as the feature names and
6 | range/levels of different features) by specifying appropriate parameters.
7 | """
8 |
9 | from dice_ml.data_interfaces.base_data_interface import _BaseData
10 |
11 |
12 | class Data(_BaseData):
13 | """Class containing all required information about the data for DiCE."""
14 |
15 | def __init__(self, **params):
16 | """Init method
17 |
18 | :param **params: a dictionary of required parameters.
19 | """
20 | self.decide_implementation_type(params)
21 |
22 | def decide_implementation_type(self, params):
23 | """Decides if the Data class is for public or private data."""
24 | self.__class__ = decide(params)
25 | self.__init__(params)
26 |
27 |
28 | def decide(params):
29 | """Decides if the Data class is for public or private data.
30 |
31 | To add new implementations of Data, add the class in data_interfaces
32 | subpackage and import-and-return the class in an elif loop as shown
33 | in the below method.
34 | """
35 | if 'dataframe' in params:
36 | # if params contain a Pandas dataframe, then use PublicData class
37 | from dice_ml.data_interfaces.public_data_interface import PublicData
38 | return PublicData
39 | else:
40 | # use PrivateData if only meta data is provided
41 | from dice_ml.data_interfaces.private_data_interface import PrivateData
42 | return PrivateData
43 |
--------------------------------------------------------------------------------
/dice_ml/data_interfaces/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/data_interfaces/__init__.py
--------------------------------------------------------------------------------
/dice_ml/data_interfaces/base_data_interface.py:
--------------------------------------------------------------------------------
1 | """Module containing base class for data interfaces for dice-ml."""
2 |
3 | from abc import ABC, abstractmethod
4 |
5 | import pandas as pd
6 | from raiutils.exceptions import UserConfigValidationException
7 |
8 | from dice_ml.utils.exception import SystemException
9 |
10 |
11 | class _BaseData(ABC):
12 |
13 | def _validate_and_set_data_name(self, params):
14 | """Validate and set the data name."""
15 | if 'data_name' in params:
16 | self.data_name = params['data_name']
17 | else:
18 | self.data_name = 'mydata'
19 |
20 | def _validate_and_set_outcome_name(self, params):
21 | """Validate and set the outcome name."""
22 | if 'outcome_name' not in params:
23 | raise ValueError("should provide the name of outcome feature")
24 |
25 | if isinstance(params['outcome_name'], str):
26 | self.outcome_name = params['outcome_name']
27 | else:
28 | raise ValueError("should provide the name of outcome feature as a string")
29 |
30 | def set_continuous_feature_indexes(self, query_instance):
31 | """Remaps continuous feature indices based on the query instance"""
32 | self.continuous_feature_indexes = [query_instance.columns.get_loc(name) for name in
33 | self.continuous_feature_names]
34 |
35 | def check_features_to_vary(self, features_to_vary):
36 | if features_to_vary is not None and features_to_vary != 'all':
37 | not_training_features = set(features_to_vary) - set(self.feature_names)
38 | if len(not_training_features) > 0:
39 | raise UserConfigValidationException("Got features {0} which are not present in training data".format(
40 | not_training_features))
41 |
42 | def check_permitted_range(self, permitted_range):
43 | if permitted_range is not None:
44 | permitted_range_features = list(permitted_range)
45 | not_training_features = set(permitted_range_features) - set(self.feature_names)
46 | if len(not_training_features) > 0:
47 | raise UserConfigValidationException("Got features {0} which are not present in training data".format(
48 | not_training_features))
49 |
50 | for feature in permitted_range_features:
51 | if feature in self.categorical_feature_names:
52 | train_categories = self.permitted_range[feature]
53 | for test_category in permitted_range[feature]:
54 | if test_category not in train_categories:
55 | raise UserConfigValidationException(
56 | 'The category {0} does not occur in the training data for feature {1}.'
57 | ' Allowed categories are {2}'.format(test_category, feature, train_categories))
58 |
59 | def _validate_and_set_permitted_range(self, params, features_dict=None):
60 | """Validate and set the dictionary of permitted ranges for continuous features."""
61 | input_permitted_range = None
62 | if 'permitted_range' in params:
63 | input_permitted_range = params['permitted_range']
64 |
65 | if not hasattr(self, 'feature_names'):
66 | raise SystemException('Feature names not correctly set in public data interface')
67 |
68 | for input_permitted_range_feature_name in input_permitted_range:
69 | if input_permitted_range_feature_name not in self.feature_names:
70 | raise UserConfigValidationException(
71 | "permitted_range contains some feature names which are not part of columns in dataframe"
72 | )
73 | self.permitted_range, _ = self.get_features_range(input_permitted_range, features_dict)
74 |
75 | def ensure_consistent_type(self, output_df, query_instance):
76 | qdf = self.query_instance_to_df(query_instance)
77 | output_df = output_df.astype(qdf.dtypes.to_dict())
78 | return output_df
79 |
80 | def query_instance_to_df(self, query_instance):
81 | if isinstance(query_instance, list):
82 | if isinstance(query_instance[0], dict): # prepare a list of query instances
83 | test = pd.DataFrame(query_instance, columns=self.feature_names)
84 |
85 | else: # prepare a single query instance in list
86 | query_instance = {'row1': query_instance}
87 | test = pd.DataFrame.from_dict(
88 | query_instance, orient='index', columns=self.feature_names)
89 |
90 | elif isinstance(query_instance, dict):
91 | test = pd.DataFrame({k: [v] for k, v in query_instance.items()}, columns=self.feature_names)
92 |
93 | elif isinstance(query_instance, pd.DataFrame):
94 | test = query_instance.copy()
95 |
96 | else:
97 | raise ValueError("Query instance should be a dict, a pandas dataframe, a list, or a list of dicts")
98 | return test
99 |
100 | @abstractmethod
101 | def __init__(self, params):
102 | """The init method needs to be implemented by the inherting classes."""
103 | pass
104 |
--------------------------------------------------------------------------------
/dice_ml/dice.py:
--------------------------------------------------------------------------------
1 | """Module pointing to different implementations of DiCE based on different
2 | frameworks such as Tensorflow or PyTorch or sklearn, and different methods
3 | such as RandomSampling, DiCEKD or DiCEGenetic"""
4 |
5 | from raiutils.exceptions import UserConfigValidationException
6 |
7 | from dice_ml.constants import BackEndTypes, SamplingStrategy
8 | from dice_ml.data_interfaces.private_data_interface import PrivateData
9 | from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
10 |
11 |
12 | class Dice(ExplainerBase):
13 | """An interface class to different DiCE implementations."""
14 |
15 | def __init__(self, data_interface, model_interface, method="random", **kwargs):
16 | """Init method
17 |
18 | :param data_interface: an interface to access data related params.
19 | :param model_interface: an interface to access the output or gradients of a trained ML model.
20 | :param method: Name of the method to use for generating counterfactuals
21 | """
22 | self.decide_implementation_type(data_interface, model_interface, method, **kwargs)
23 |
24 | def decide_implementation_type(self, data_interface, model_interface, method, **kwargs):
25 | """Decides DiCE implementation type."""
26 | if model_interface.backend == BackEndTypes.Sklearn:
27 | if method == SamplingStrategy.KdTree and isinstance(data_interface, PrivateData):
28 | raise UserConfigValidationException(
29 | 'Private data interface is not supported with kdtree explainer'
30 | ' since kdtree explainer needs access to entire training data')
31 | self.__class__ = decide(model_interface, method)
32 | self.__init__(data_interface, model_interface, **kwargs)
33 |
34 | def _generate_counterfactuals(self, query_instance, total_CFs,
35 | desired_class="opposite", desired_range=None,
36 | permitted_range=None, features_to_vary="all",
37 | stopping_threshold=0.5, posthoc_sparsity_param=0.1,
38 | posthoc_sparsity_algorithm="linear", verbose=False, **kwargs):
39 | raise NotImplementedError("This method should be implemented by the concrete classes "
40 | "that inherit from ExplainerBase")
41 |
42 |
43 | def decide(model_interface, method):
44 | """Decides DiCE implementation type.
45 |
46 | To add new implementations of DiCE, add the class in explainer_interfaces
47 | subpackage and import-and-return the class in an elif loop as shown in
48 | the below method.
49 | """
50 | if method == SamplingStrategy.Random:
51 | # random sampling of CFs
52 | from dice_ml.explainer_interfaces.dice_random import DiceRandom
53 | return DiceRandom
54 | elif method == SamplingStrategy.Genetic:
55 | from dice_ml.explainer_interfaces.dice_genetic import DiceGenetic
56 | return DiceGenetic
57 | elif method == SamplingStrategy.KdTree:
58 | from dice_ml.explainer_interfaces.dice_KD import DiceKD
59 | return DiceKD
60 | elif method == SamplingStrategy.Gradient:
61 | if model_interface.backend == BackEndTypes.Tensorflow1:
62 | # pretrained Keras Sequential model with Tensorflow 1.x backend
63 | from dice_ml.explainer_interfaces.dice_tensorflow1 import \
64 | DiceTensorFlow1
65 | return DiceTensorFlow1
66 |
67 | elif model_interface.backend == BackEndTypes.Tensorflow2:
68 | # pretrained Keras Sequential model with Tensorflow 2.x backend
69 | from dice_ml.explainer_interfaces.dice_tensorflow2 import \
70 | DiceTensorFlow2
71 | return DiceTensorFlow2
72 |
73 | elif model_interface.backend == BackEndTypes.Pytorch:
74 | # PyTorch backend
75 | from dice_ml.explainer_interfaces.dice_pytorch import DicePyTorch
76 | return DicePyTorch
77 | else:
78 | raise UserConfigValidationException(
79 | "{0} is only supported for differentiable neural network models. "
80 | "Please choose one of {1}, {2} or {3}".format(
81 | method, SamplingStrategy.Random,
82 | SamplingStrategy.Genetic,
83 | SamplingStrategy.KdTree
84 | ))
85 | elif method is None:
86 | # all other backends
87 | backend_dice = model_interface.backend['explainer']
88 | module_name, class_name = backend_dice.split('.')
89 | module = __import__("dice_ml.explainer_interfaces." + module_name, fromlist=[class_name])
90 | return getattr(module, class_name)
91 | else:
92 | raise UserConfigValidationException("Unsupported sample strategy {0} provided. "
93 | "Please choose one of {1}, {2} or {3}".format(
94 | method, SamplingStrategy.Random,
95 | SamplingStrategy.Genetic,
96 | SamplingStrategy.KdTree
97 | ))
98 |
--------------------------------------------------------------------------------
/dice_ml/explainer_interfaces/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/explainer_interfaces/__init__.py
--------------------------------------------------------------------------------
/dice_ml/explainer_interfaces/dice_xgboost.py:
--------------------------------------------------------------------------------
1 | from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
2 |
3 |
4 | class DiceXGBoost(ExplainerBase):
5 | def __init__(self, data_interface, model_interface):
6 | """Initialize with data and model interfaces"""
7 | super().__init__(data_interface, model_interface)
8 |
9 | def generate_counterfactuals(self, query_instance, total_CFs=5):
10 | """Generate counterfactuals"""
11 | # Implement your logic to generate counterfactuals
12 | raise NotImplementedError("Counterfactual generation for XGBoost is not implemented yet.")
13 |
--------------------------------------------------------------------------------
/dice_ml/explainer_interfaces/feasible_model_approx.py:
--------------------------------------------------------------------------------
1 | # Dice Imports
2 | # Pytorch
3 | import torch
4 | import torch.utils.data
5 | from torch.nn import functional as F
6 |
7 | from dice_ml.explainer_interfaces.explainer_base import ExplainerBase
8 | from dice_ml.explainer_interfaces.feasible_base_vae import FeasibleBaseVAE
9 | from dice_ml.utils.helpers import get_base_gen_cf_initialization
10 |
11 |
12 | class FeasibleModelApprox(FeasibleBaseVAE, ExplainerBase):
13 |
14 | def __init__(self, data_interface, model_interface, **kwargs):
15 | """
16 | :param data_interface: an interface class to data related params
17 | :param model_interface: an interface class to access trained ML model
18 | """
19 |
20 | # initiating data related parameters
21 | ExplainerBase.__init__(self, data_interface)
22 |
23 | # Black Box ML Model to be explained
24 | self.pred_model = model_interface.model
25 |
26 | self.minx, self.maxx, self.encoded_categorical_feature_indexes, \
27 | self.encoded_continuous_feature_indexes, self.cont_minx, self.cont_maxx, self.cont_precisions = \
28 | self.data_interface.get_data_params_for_gradient_dice()
29 | self.data_interface.one_hot_encoded_data = self.data_interface.one_hot_encode_data(self.data_interface.data_df)
30 | # Hyperparam
31 | self.encoded_size = kwargs['encoded_size']
32 | self.learning_rate = kwargs['lr']
33 | self.batch_size = kwargs['batch_size']
34 | self.validity_reg = kwargs['validity_reg']
35 | self.margin = kwargs['margin']
36 | self.epochs = kwargs['epochs']
37 | self.wm1 = kwargs['wm1']
38 | self.wm2 = kwargs['wm2']
39 | self.wm3 = kwargs['wm3']
40 |
41 | # Initializing parameters for the DiceModelApproxGenCF
42 | self.vae_train_dataset, self.vae_val_dataset, self.vae_test_dataset, self.normalise_weights, \
43 | self.cf_vae, self.cf_vae_optimizer = \
44 | get_base_gen_cf_initialization(
45 | self.data_interface, self.encoded_size, self.cont_minx,
46 | self.cont_maxx, self.margin, self.validity_reg, self.epochs,
47 | self.wm1, self.wm2, self.wm3, self.learning_rate)
48 |
49 | # Data paths
50 | self.base_model_dir = '../../../dice_ml/utils/sample_trained_models/'
51 | self.save_path = self.base_model_dir + self.data_interface.data_name + \
52 | '-margin-' + str(self.margin) + '-validity_reg-' + str(self.validity_reg) + \
53 | '-epoch-' + str(self.epochs) + '-' + 'ae-gen' + '.pth'
54 |
55 | def train(self, constraint_type, constraint_variables, constraint_direction, constraint_reg, pre_trained=False):
56 | '''
57 | :param pre_trained: Bool Variable to check whether pre trained model exists to avoid training again
58 | :param constraint_type: Binary Variable currently: (1) unary / (0) monotonic
59 | :param constraint_variables: List of List: [[Effect, Cause1, Cause2, .... ]]
60 | :param constraint_direction: -1: Negative, 1: Positive ( By default has to be one for monotonic constraints )
61 | :param constraint_reg: Tunable Hyperparamter
62 |
63 | :return None
64 | '''
65 | if pre_trained:
66 | self.cf_vae.load_state_dict(torch.load(self.save_path))
67 | self.cf_vae.eval()
68 | return
69 |
70 | # TODO: Handling such dataset specific constraints in a more general way
71 | # CF Generation for only low to high income data points
72 | self.vae_train_dataset = self.vae_train_dataset[self.vae_train_dataset[:, -1] == 0, :]
73 | self.vae_val_dataset = self.vae_val_dataset[self.vae_val_dataset[:, -1] == 0, :]
74 |
75 | # Removing the outcome variable from the datasets
76 | self.vae_train_feat = self.vae_train_dataset[:, :-1]
77 | self.vae_val_feat = self.vae_val_dataset[:, :-1]
78 |
79 | for epoch in range(self.epochs):
80 | batch_num = 0
81 | train_loss = 0.0
82 | train_size = 0
83 |
84 | train_dataset = torch.tensor(self.vae_train_feat).float()
85 | train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
86 | for train in enumerate(train_dataset):
87 | self.cf_vae_optimizer.zero_grad()
88 |
89 | train_x = train[1]
90 | train_y = 1.0-torch.argmax(self.pred_model(train_x), dim=1)
91 | train_size += train_x.shape[0]
92 |
93 | out = self.cf_vae(train_x, train_y)
94 | loss = self.compute_loss(out, train_x, train_y)
95 |
96 | # Unary Case
97 | if constraint_type:
98 | for const in constraint_variables:
99 | # Get the index from the feature name
100 | # Handle the categorical variable case here too
101 | const_idx = const[0]
102 | dm = out['x_pred']
103 | mc_samples = out['mc_samples']
104 | x_pred = dm[0]
105 |
106 | constraint_loss = F.hinge_embedding_loss(
107 | constraint_direction*(x_pred[:, const_idx] - train_x[:, const_idx]), torch.tensor(-1), 0)
108 |
109 | for j in range(1, mc_samples):
110 | x_pred = dm[j]
111 | constraint_loss += F.hinge_embedding_loss(
112 | constraint_direction*(x_pred[:, const_idx] - train_x[:, const_idx]), torch.tensor(-1), 0)
113 |
114 | constraint_loss = constraint_loss/mc_samples
115 | constraint_loss = constraint_reg*constraint_loss
116 | loss += constraint_loss
117 | print('Constraint: ', constraint_loss, torch.mean(constraint_loss))
118 | else:
119 | # Train the regression model
120 | raise NotImplementedError(
121 | "This has not been implemented yet. If you'd like this to be implemented in the next version, "
122 | "please raise an issue at https://github.com/interpretml/DiCE/issues")
123 |
124 | loss.backward()
125 | train_loss += loss.item()
126 | self.cf_vae_optimizer.step()
127 |
128 | batch_num += 1
129 |
130 | ret = loss/batch_num
131 | print('Train Avg Loss: ', ret, train_size)
132 |
133 | # Save the model after training every 10 epochs and at the last epoch
134 | if (epoch != 0 and (epoch % 10) == 0) or epoch == self.epochs-1:
135 | torch.save(self.cf_vae.state_dict(), self.save_path)
136 |
--------------------------------------------------------------------------------
/dice_ml/model.py:
--------------------------------------------------------------------------------
1 | """Module pointing to different implementations of Model class
2 |
3 | The implementations contain methods to access the output or gradients of ML models trained based on different
4 | frameworks such as Tensorflow or PyTorch.
5 | """
6 | import warnings
7 |
8 | from raiutils.exceptions import UserConfigValidationException
9 |
10 | from dice_ml.constants import BackEndTypes, ModelTypes
11 |
12 |
13 | class Model:
14 | """An interface class to different ML Model implementations."""
15 | def __init__(self, model=None, model_path='', backend=BackEndTypes.Tensorflow1, model_type=ModelTypes.Classifier,
16 | func=None, kw_args=None):
17 | """Init method
18 |
19 | :param model: trained ML model.
20 | :param model_path: path to trained ML model.
21 | :param backend: "TF1" ("TF2") for TensorFLow 1.0 (2.0), "PYT" for PyTorch implementations,
22 | "sklearn" for Scikit-Learn implementations of standard
23 | DiCE (https://arxiv.org/pdf/1905.07697.pdf). For all other frameworks and
24 | implementations, provide a dictionary with "model" and "explainer" as keys,
25 | and include module and class names as values in the form module_name.class_name.
26 | For instance, if there is a model interface class "XGBoostModel" in module "xgboost_model.py"
27 | inside the subpackage dice_ml.model_interfaces, and dice interface class "DiceXGBoost"
28 | in module "dice_xgboost" inside dice_ml.explainer_interfaces, then backend parameter
29 | should be {"model": "xgboost_model.XGBoostModel", "explainer": dice_xgboost.DiceXGBoost}.
30 | :param func: function transformation required for ML model. If func is None, then func will be the identity function.
31 | :param kw_args: Dictionary of additional keyword arguments to pass to func. DiCE's data_interface is appended
32 | to the dictionary of kw_args, by default.
33 | """
34 | if backend not in BackEndTypes.ALL:
35 | warnings.warn('{0} backend not in supported backends {1}'.format(
36 | backend, ','.join(BackEndTypes.ALL)), stacklevel=2)
37 |
38 | if model_type not in ModelTypes.ALL:
39 | raise UserConfigValidationException('{0} model type not in supported model types {1}'.format(
40 | model_type, ','.join(ModelTypes.ALL))
41 | )
42 |
43 | self.model_type = model_type
44 | if model is None and model_path == '':
45 | raise ValueError("should provide either a trained model or the path to a model")
46 | else:
47 | self.decide_implementation_type(model, model_path, backend, func, kw_args)
48 |
49 | def decide_implementation_type(self, model, model_path, backend, func, kw_args):
50 | """Decides the Model implementation type."""
51 |
52 | self.__class__ = decide(backend)
53 | self.__init__(model, model_path, backend, func, kw_args)
54 |
55 |
56 | def decide(backend):
57 | """Decides the Model implementation type.
58 |
59 | To add new implementations of Model, add the class in model_interfaces subpackage and
60 | import-and-return the class in an elif loop as shown in the below method.
61 | """
62 | if backend == BackEndTypes.Sklearn:
63 | # random sampling of CFs
64 | from dice_ml.model_interfaces.base_model import BaseModel
65 | return BaseModel
66 |
67 | elif backend == BackEndTypes.Tensorflow1 or backend == BackEndTypes.Tensorflow2:
68 | # Tensorflow 1 or 2 backend
69 | try:
70 | import tensorflow # noqa: F401
71 | except ImportError:
72 | raise UserConfigValidationException("Unable to import tensorflow. Please install tensorflow")
73 | from dice_ml.model_interfaces.keras_tensorflow_model import \
74 | KerasTensorFlowModel
75 | return KerasTensorFlowModel
76 |
77 | elif backend == BackEndTypes.Pytorch:
78 | # PyTorch backend
79 | try:
80 | import torch # noqa: F401
81 | except ImportError:
82 | raise UserConfigValidationException("Unable to import torch. Please install torch from https://pytorch.org/")
83 | from dice_ml.model_interfaces.pytorch_model import PyTorchModel
84 | return PyTorchModel
85 |
86 | else:
87 | # all other implementations and frameworks
88 | backend_model = backend['model']
89 | module_name, class_name = backend_model.split('.')
90 | module = __import__("dice_ml.model_interfaces." + module_name, fromlist=[class_name])
91 | return getattr(module, class_name)
92 |
--------------------------------------------------------------------------------
/dice_ml/model_interfaces/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/model_interfaces/__init__.py
--------------------------------------------------------------------------------
/dice_ml/model_interfaces/base_model.py:
--------------------------------------------------------------------------------
1 | """Module containing a template class as an interface to ML model.
2 | Subclasses implement model interfaces for different ML frameworks such as TensorFlow, PyTorch OR Sklearn.
3 | All model interface methods are in dice_ml.model_interfaces"""
4 |
5 | import pickle
6 |
7 | import numpy as np
8 |
9 | from dice_ml.constants import ModelTypes
10 | from dice_ml.utils.exception import SystemException
11 | from dice_ml.utils.helpers import DataTransfomer
12 |
13 |
14 | class BaseModel:
15 |
16 | def __init__(self, model=None, model_path='', backend='', func=None, kw_args=None):
17 | """Init method
18 |
19 | :param model: trained ML Model.
20 | :param model_path: path to trained model.
21 | :param backend: ML framework. For frameworks other than TensorFlow or PyTorch,
22 | or for implementations other than standard DiCE
23 | (https://arxiv.org/pdf/1905.07697.pdf),
24 | provide both the module and class names as module_name.class_name.
25 | For instance, if there is a model interface class "SklearnModel"
26 | in module "sklearn_model.py" inside the subpackage dice_ml.model_interfaces,
27 | then backend parameter should be "sklearn_model.SklearnModel".
28 | :param func: function transformation required for ML model. If func is None, then func will be the identity function.
29 | :param kw_args: Dictionary of additional keyword arguments to pass to func. DiCE's data_interface is appended to the
30 | dictionary of kw_args, by default.
31 |
32 | """
33 | self.model = model
34 | self.model_path = model_path
35 | self.backend = backend
36 | # calls FunctionTransformer of scikit-learn internally
37 | # (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.FunctionTransformer.html)
38 | self.transformer = DataTransfomer(func, kw_args)
39 |
40 | def load_model(self):
41 | if self.model_path != '':
42 | with open(self.model_path, 'rb') as filehandle:
43 | self.model = pickle.load(filehandle)
44 |
45 | def get_output(self, input_instance, model_score=True):
46 | """returns prediction probabilities for a classifier and the predicted output for a regressor.
47 |
48 | :returns: an array of output scores for a classifier, and a singleton
49 | array of predicted value for a regressor.
50 | """
51 | input_instance = self.transformer.transform(input_instance)
52 | if model_score:
53 | if self.model_type == ModelTypes.Classifier:
54 | return self.model.predict_proba(input_instance)
55 | else:
56 | return self.model.predict(input_instance)
57 | else:
58 | return self.model.predict(input_instance)
59 |
60 | def get_gradient(self):
61 | raise NotImplementedError
62 |
63 | def get_num_output_nodes(self, inp_size):
64 | temp_input = np.transpose(np.array([np.random.uniform(0, 1) for i in range(inp_size)]).reshape(-1, 1))
65 | return self.get_output(temp_input).shape[1]
66 |
67 | def get_num_output_nodes2(self, input_instance):
68 | if self.model_type == ModelTypes.Regressor:
69 | raise SystemException('Number of output nodes not supported for regression')
70 | return self.get_output(input_instance).shape[1]
71 |
--------------------------------------------------------------------------------
/dice_ml/model_interfaces/keras_tensorflow_model.py:
--------------------------------------------------------------------------------
1 | """Module containing an interface to trained Keras Tensorflow model."""
2 |
3 | import tensorflow as tf
4 | from tensorflow import keras
5 |
6 | from dice_ml.model_interfaces.base_model import BaseModel
7 |
8 |
9 | class KerasTensorFlowModel(BaseModel):
10 |
11 | def __init__(self, model=None, model_path='', backend='TF1', func=None, kw_args=None):
12 | """Init method
13 |
14 | :param model: trained Keras Sequential Model.
15 | :param model_path: path to trained model.
16 | :param backend: "TF1" for TensorFlow 1 and "TF2" for TensorFlow 2.
17 | :param func: function transformation required for ML model. If func is None, then func will be the identity function.
18 | :param kw_args: Dictionary of additional keyword arguments to pass to func. DiCE's data_interface is appended to the
19 | dictionary of kw_args, by default.
20 | """
21 |
22 | super().__init__(model, model_path, backend, func, kw_args)
23 |
24 | def load_model(self):
25 | if self.model_path != '':
26 | self.model = keras.models.load_model(self.model_path)
27 |
28 | def get_output(self, input_tensor, training=False, transform_data=False):
29 | """returns prediction probabilities
30 |
31 | :param input_tensor: test input.
32 | :param training: to determine training mode in TF2.
33 | :param transform_data: boolean to indicate if data transformation is required.
34 | """
35 | if transform_data or not tf.is_tensor(input_tensor):
36 | input_tensor = tf.constant(self.transformer.transform(input_tensor).to_numpy(), dtype=tf.float32)
37 | if self.backend == 'TF2':
38 | return self.model(input_tensor, training=training)
39 | else:
40 | return self.model(input_tensor)
41 |
42 | def get_gradient(self, input_instance):
43 | # Future Support
44 | raise NotImplementedError("Future Support")
45 |
46 | def get_num_output_nodes(self, inp_size):
47 | temp_input = tf.convert_to_tensor([tf.random.uniform([inp_size])], dtype=tf.float32)
48 | return self.get_output(temp_input)
49 |
--------------------------------------------------------------------------------
/dice_ml/model_interfaces/pytorch_model.py:
--------------------------------------------------------------------------------
1 | """Module containing an interface to trained PyTorch model."""
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from dice_ml.constants import ModelTypes
7 | from dice_ml.model_interfaces.base_model import BaseModel
8 |
9 |
10 | class PyTorchModel(BaseModel):
11 |
12 | def __init__(self, model=None, model_path='', backend='PYT', func=None, kw_args=None):
13 | """Init method
14 |
15 | :param model: trained PyTorch Model.
16 | :param model_path: path to trained model.
17 | :param backend: "PYT" for PyTorch framework.
18 | :param func: function transformation required for ML model. If func is None, then func will be the identity function.
19 | :param kw_args: Dictionary of additional keyword arguments to pass to func. DiCE's data_interface is appended to the
20 | dictionary of kw_args, by default.
21 | """
22 |
23 | super().__init__(model, model_path, backend, func, kw_args)
24 |
25 | def load_model(self, weights_only=False):
26 | if self.model_path != '':
27 | self.model = torch.load(self.model_path, weights_only=weights_only)
28 |
29 | def get_output(self, input_instance, model_score=True,
30 | transform_data=False, out_tensor=False):
31 | """returns prediction probabilities
32 |
33 | :param input_tensor: test input.
34 | :param transform_data: boolean to indicate if data transformation is required.
35 | """
36 | input_tensor = input_instance
37 | if transform_data:
38 | input_tensor = torch.tensor(self.transformer.transform(input_instance).to_numpy(dtype=np.float64)).float()
39 | if not torch.is_tensor(input_instance):
40 | input_tensor = torch.tensor(self.transformer.transform(input_instance).to_numpy(dtype=np.float64)).float()
41 | out = self.model(input_tensor).float()
42 | if not out_tensor:
43 | out = out.data.numpy()
44 | if model_score is False and self.model_type == ModelTypes.Classifier:
45 | out = np.round(out) # TODO need to generalize for n-class classifier
46 | return out
47 |
48 | def set_eval_mode(self):
49 | self.model.eval()
50 |
51 | def get_gradient(self, input_instance):
52 | # Future Support
53 | raise NotImplementedError("Future Support")
54 |
55 | def get_num_output_nodes(self, inp_size):
56 | temp_input = torch.rand(1, inp_size).float()
57 | return self.get_output(temp_input).data
58 |
--------------------------------------------------------------------------------
/dice_ml/model_interfaces/xgboost_model.py:
--------------------------------------------------------------------------------
1 | import xgboost as xgb
2 |
3 | from dice_ml.constants import ModelTypes
4 | from dice_ml.model_interfaces.base_model import BaseModel
5 |
6 |
7 | class XGBoostModel(BaseModel):
8 |
9 | def __init__(self, model=None, model_path='', backend='', func=None, kw_args=None):
10 | super().__init__(model=model, model_path=model_path, backend='xgboost', func=func, kw_args=kw_args)
11 | if model is None and model_path:
12 | self.load_model()
13 |
14 | def load_model(self):
15 | if self.model_path != '':
16 | self.model = xgb.Booster()
17 | self.model.load_model(self.model_path)
18 |
19 | def get_output(self, input_instance, model_score=True):
20 | input_instance = self.transformer.transform(input_instance)
21 | for col in input_instance.columns:
22 | input_instance[col] = input_instance[col].astype('int64')
23 | if model_score:
24 | if self.model_type == ModelTypes.Classifier:
25 | return self.model.predict_proba(input_instance)
26 | else:
27 | return self.model.predict(input_instance)
28 | else:
29 | return self.model.predict(input_instance)
30 |
31 | def get_gradient(self):
32 | raise NotImplementedError("XGBoost does not support gradient calculation in this context")
33 |
--------------------------------------------------------------------------------
/dice_ml/schema/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/schema/__init__.py
--------------------------------------------------------------------------------
/dice_ml/schema/counterfactual_explanations_v1.0.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "http://json-schema.org/draft-07/schema#",
3 | "title": "Dashboard Dictionary for Counterfactual outputs",
4 | "description": "The original JSON format for counterfactual examples",
5 | "type": "object",
6 | "properties": {
7 | "cf_examples_list": {
8 | "description": "The list of the computed counterfactual examples.",
9 | "type": "array",
10 | "items": {
11 | "type": "string"
12 | },
13 | "uniqueItems": true
14 | },
15 | "local_importance": {
16 | "description": "The list of counterfactual local importance for the features in input data.",
17 | "type": ["array", "null"],
18 | "items": {
19 | "type": "object"
20 | }
21 | },
22 | "summary_importance": {
23 | "description": "The list of counterfactual summary importance for the features in input data.",
24 | "type": ["object", "null"]
25 | },
26 | "metadata": {
27 | "description": "The metadata about the generated counterfactuals.",
28 | "type": "object"
29 | }
30 | },
31 | "required": [
32 | "cf_examples_list",
33 | "local_importance",
34 | "summary_importance",
35 | "metadata"
36 | ]
37 | }
--------------------------------------------------------------------------------
/dice_ml/schema/counterfactual_explanations_v2.0.json:
--------------------------------------------------------------------------------
1 | {
2 | "$schema": "http://json-schema.org/draft-07/schema#",
3 | "title": "Output contract for Counterfactual outputs",
4 | "description": "The original JSON format for counterfactual examples",
5 | "type": "object",
6 | "properties": {
7 | "cfs_list": {
8 | "description": "The list of the computed counterfactual examples.",
9 | "type": "array",
10 | "items": {
11 | "type": ["array", "null"]
12 | }
13 | },
14 | "test_data": {
15 | "description": "The list of the input test samples for which counterfactual examples need to be computed.",
16 | "type": "array",
17 | "items": {
18 | "type": "array"
19 | }
20 | },
21 | "local_importance": {
22 | "description": "The list of counterfactual local importance for the features in input data.",
23 | "type": ["array", "null"],
24 | "items": {
25 | "type": "array"
26 | }
27 | },
28 | "summary_importance": {
29 | "description": "The list of counterfactual summary importance for the features in input data.",
30 | "type": ["array", "null"],
31 | "items": {
32 | "type": "number"
33 | }
34 | },
35 | "feature_names": {
36 | "description": "The list of features in the input data.",
37 | "type": ["array", "null"],
38 | "items": {
39 | "type": "string"
40 | }
41 | },
42 | "feature_names_including_target": {
43 | "description": "The list of features including the target in input data.",
44 | "type": ["array", "null"],
45 | "items": {
46 | "type": "string"
47 | }
48 | },
49 | "model_type": {
50 | "description": "The type of model is either a classifier/regressor",
51 | "type": ["string", "null"]
52 | },
53 | "desired_class": {
54 | "description": "The target class for the generated counterfactual examples",
55 | "type": ["string", "integer", "null"]
56 | },
57 | "desired_range": {
58 | "description": "The target range for the generated counterfactual examples",
59 | "type": ["array", "null"],
60 | "items": {
61 | "type": "number"
62 | }
63 | },
64 | "data_interface": {
65 | "description": "The data interface details including outcome name.",
66 | "type": ["object", "null"]
67 | },
68 | "metadata": {
69 | "description": "The metadata about the generated counterfactuals.",
70 | "type": "object"
71 | }
72 | },
73 | "required": [
74 | "cfs_list",
75 | "test_data",
76 | "local_importance",
77 | "summary_importance",
78 | "feature_names",
79 | "feature_names_including_target",
80 | "model_type",
81 | "desired_class",
82 | "desired_range",
83 | "data_interface",
84 | "metadata"
85 | ]
86 | }
--------------------------------------------------------------------------------
/dice_ml/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/__init__.py
--------------------------------------------------------------------------------
/dice_ml/utils/exception.py:
--------------------------------------------------------------------------------
1 | """Exceptions for the package."""
2 |
3 |
4 | class SystemException(Exception):
5 | """An exception indicating that some system exception happened during execution.
6 |
7 | :param exception_message: A message describing the error.
8 | :type exception_message: str
9 | """
10 | _error_code = "System Error"
11 |
--------------------------------------------------------------------------------
/dice_ml/utils/neuralnetworks.py:
--------------------------------------------------------------------------------
1 | from torch import nn, sigmoid
2 |
3 |
4 | class FFNetwork(nn.Module):
5 | def __init__(self, input_size, is_classifier=True):
6 | super(FFNetwork, self).__init__()
7 | self.is_classifier = is_classifier
8 | self.flatten = nn.Flatten()
9 | self.linear_relu_stack = nn.Sequential(
10 | nn.Linear(input_size, 16),
11 | nn.ReLU(),
12 | nn.Linear(16, 1),
13 | )
14 |
15 | def forward(self, x):
16 | x = self.flatten(x)
17 | out = self.linear_relu_stack(x)
18 | out = sigmoid(out)
19 | if not self.is_classifier:
20 | out = 3 * out # output between 0 and 3
21 | return out
22 |
23 |
24 | class MulticlassNetwork(nn.Module):
25 | def __init__(self, input_size: int, num_class: int):
26 | super(MulticlassNetwork, self).__init__()
27 |
28 | self.linear_relu_stack = nn.Sequential(
29 | nn.Linear(input_size, 16),
30 | nn.ReLU(),
31 | nn.Linear(16, num_class)
32 | )
33 | self.softmax = nn.Softmax(dim=1)
34 |
35 | def forward(self, x):
36 | x = self.linear_relu_stack(x)
37 | out = self.softmax(x)
38 |
39 | return out
40 |
--------------------------------------------------------------------------------
/dice_ml/utils/sample_architecture/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_architecture/__init__.py
--------------------------------------------------------------------------------
/dice_ml/utils/sample_trained_models/adult-margin-0.165-validity_reg-42.0-epoch-25-base-gen.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult-margin-0.165-validity_reg-42.0-epoch-25-base-gen.pth
--------------------------------------------------------------------------------
/dice_ml/utils/sample_trained_models/adult-margin-0.344-validity_reg-76.0-epoch-25-ae-gen.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult-margin-0.344-validity_reg-76.0-epoch-25-ae-gen.pth
--------------------------------------------------------------------------------
/dice_ml/utils/sample_trained_models/adult.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult.h5
--------------------------------------------------------------------------------
/dice_ml/utils/sample_trained_models/adult.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult.pkl
--------------------------------------------------------------------------------
/dice_ml/utils/sample_trained_models/adult.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult.pth
--------------------------------------------------------------------------------
/dice_ml/utils/sample_trained_models/adult_2nodes.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult_2nodes.pth
--------------------------------------------------------------------------------
/dice_ml/utils/sample_trained_models/custom.sav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/custom.sav
--------------------------------------------------------------------------------
/dice_ml/utils/sample_trained_models/custom_binary.sav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/custom_binary.sav
--------------------------------------------------------------------------------
/dice_ml/utils/sample_trained_models/custom_multiclass.sav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/custom_multiclass.sav
--------------------------------------------------------------------------------
/dice_ml/utils/sample_trained_models/custom_regression.sav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/custom_regression.sav
--------------------------------------------------------------------------------
/dice_ml/utils/sample_trained_models/custom_vars.sav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/custom_vars.sav
--------------------------------------------------------------------------------
/dice_ml/utils/serialize.py:
--------------------------------------------------------------------------------
1 | class DummyDataInterface:
2 | def __init__(self, outcome_name, data_df=None):
3 | self.outcome_name = outcome_name
4 | self.data_df = None
5 | if data_df is not None:
6 | self.data_df = data_df
7 |
8 | def to_json(self):
9 | return {
10 | 'outcome_name': self.outcome_name,
11 | 'data_df': self.data_df
12 | }
13 |
--------------------------------------------------------------------------------
/docs/.buildinfo:
--------------------------------------------------------------------------------
1 | # Sphinx build info version 1
2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done.
3 | config: 7648e4eb332244e06501f513e1affc13
4 | tags: 645f666f9bcd5a90fca523b33c5a78b7
5 |
--------------------------------------------------------------------------------
/docs/.nojekyll:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/.nojekyll
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 |
18 | movehtmlfiles: html
19 | cp -r $(BUILDDIR)/html/* .
20 | rm -r _sources
21 |
22 | # Catch-all target: route all unknown targets to Sphinx using the new
23 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
24 | %: Makefile
25 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
26 |
--------------------------------------------------------------------------------
/docs/_images/dice_getting_started_api.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_images/dice_getting_started_api.png
--------------------------------------------------------------------------------
/docs/_modules/dice_ml/utils/exception.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | dice_ml.utils.exception — DiCE 0.11 documentation
7 |
8 |
9 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
62 |
63 |
67 |
68 |
69 |
70 |
71 |
72 | - »
73 | - Module code »
74 | - dice_ml.utils.exception
75 | -
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
Source code for dice_ml.utils.exception
84 | """Exceptions for the package."""
85 |
86 |
87 | [docs]class SystemException(Exception):
88 |
"""An exception indicating that some system exception happened during execution.
89 |
90 |
:param exception_message: A message describing the error.
91 |
:type exception_message: str
92 |
"""
93 |
_error_code = "System Error"
94 |
95 |
96 |
97 |
98 |
112 |
113 |
114 |
115 |
116 |
121 |
122 |
123 |
--------------------------------------------------------------------------------
/docs/_modules/dice_ml/utils/serialize.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | dice_ml.utils.serialize — DiCE 0.11 documentation
7 |
8 |
9 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
62 |
63 |
67 |
68 |
69 |
70 |
71 |
72 | - »
73 | - Module code »
74 | - dice_ml.utils.serialize
75 | -
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
Source code for dice_ml.utils.serialize
84 | [docs]class DummyDataInterface:
85 |
def __init__(self, outcome_name, data_df=None):
86 |
self.outcome_name = outcome_name
87 |
self.data_df = None
88 |
if data_df is not None:
89 |
self.data_df = data_df
90 |
91 |
[docs] def to_json(self):
92 |
return {
93 |
'outcome_name': self.outcome_name,
94 |
'data_df': self.data_df
95 |
}
96 |
97 |
98 |
99 |
100 |
114 |
115 |
116 |
117 |
118 |
123 |
124 |
125 |
--------------------------------------------------------------------------------
/docs/_modules/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Overview: module code — DiCE 0.11 documentation
7 |
8 |
9 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
62 |
63 |
67 |
68 |
69 |
70 |
71 |
72 | - »
73 | - Overview: module code
74 | -
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
All modules for which code is available
83 |
109 |
110 |
111 |
112 |
126 |
127 |
128 |
129 |
130 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/docs/_static/_sphinx_javascript_frameworks_compat.js:
--------------------------------------------------------------------------------
1 | /*
2 | * _sphinx_javascript_frameworks_compat.js
3 | * ~~~~~~~~~~
4 | *
5 | * Compatability shim for jQuery and underscores.js.
6 | *
7 | * WILL BE REMOVED IN Sphinx 6.0
8 | * xref RemovedInSphinx60Warning
9 | *
10 | */
11 |
12 | /**
13 | * select a different prefix for underscore
14 | */
15 | $u = _.noConflict();
16 |
17 |
18 | /**
19 | * small helper function to urldecode strings
20 | *
21 | * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/decodeURIComponent#Decoding_query_parameters_from_a_URL
22 | */
23 | jQuery.urldecode = function(x) {
24 | if (!x) {
25 | return x
26 | }
27 | return decodeURIComponent(x.replace(/\+/g, ' '));
28 | };
29 |
30 | /**
31 | * small helper function to urlencode strings
32 | */
33 | jQuery.urlencode = encodeURIComponent;
34 |
35 | /**
36 | * This function returns the parsed url parameters of the
37 | * current request. Multiple values per key are supported,
38 | * it will always return arrays of strings for the value parts.
39 | */
40 | jQuery.getQueryParameters = function(s) {
41 | if (typeof s === 'undefined')
42 | s = document.location.search;
43 | var parts = s.substr(s.indexOf('?') + 1).split('&');
44 | var result = {};
45 | for (var i = 0; i < parts.length; i++) {
46 | var tmp = parts[i].split('=', 2);
47 | var key = jQuery.urldecode(tmp[0]);
48 | var value = jQuery.urldecode(tmp[1]);
49 | if (key in result)
50 | result[key].push(value);
51 | else
52 | result[key] = [value];
53 | }
54 | return result;
55 | };
56 |
57 | /**
58 | * highlight a given string on a jquery object by wrapping it in
59 | * span elements with the given class name.
60 | */
61 | jQuery.fn.highlightText = function(text, className) {
62 | function highlight(node, addItems) {
63 | if (node.nodeType === 3) {
64 | var val = node.nodeValue;
65 | var pos = val.toLowerCase().indexOf(text);
66 | if (pos >= 0 &&
67 | !jQuery(node.parentNode).hasClass(className) &&
68 | !jQuery(node.parentNode).hasClass("nohighlight")) {
69 | var span;
70 | var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg");
71 | if (isInSVG) {
72 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan");
73 | } else {
74 | span = document.createElement("span");
75 | span.className = className;
76 | }
77 | span.appendChild(document.createTextNode(val.substr(pos, text.length)));
78 | node.parentNode.insertBefore(span, node.parentNode.insertBefore(
79 | document.createTextNode(val.substr(pos + text.length)),
80 | node.nextSibling));
81 | node.nodeValue = val.substr(0, pos);
82 | if (isInSVG) {
83 | var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect");
84 | var bbox = node.parentElement.getBBox();
85 | rect.x.baseVal.value = bbox.x;
86 | rect.y.baseVal.value = bbox.y;
87 | rect.width.baseVal.value = bbox.width;
88 | rect.height.baseVal.value = bbox.height;
89 | rect.setAttribute('class', className);
90 | addItems.push({
91 | "parent": node.parentNode,
92 | "target": rect});
93 | }
94 | }
95 | }
96 | else if (!jQuery(node).is("button, select, textarea")) {
97 | jQuery.each(node.childNodes, function() {
98 | highlight(this, addItems);
99 | });
100 | }
101 | }
102 | var addItems = [];
103 | var result = this.each(function() {
104 | highlight(this, addItems);
105 | });
106 | for (var i = 0; i < addItems.length; ++i) {
107 | jQuery(addItems[i].parent).before(addItems[i].target);
108 | }
109 | return result;
110 | };
111 |
112 | /*
113 | * backward compatibility for jQuery.browser
114 | * This will be supported until firefox bug is fixed.
115 | */
116 | if (!jQuery.browser) {
117 | jQuery.uaMatch = function(ua) {
118 | ua = ua.toLowerCase();
119 |
120 | var match = /(chrome)[ \/]([\w.]+)/.exec(ua) ||
121 | /(webkit)[ \/]([\w.]+)/.exec(ua) ||
122 | /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) ||
123 | /(msie) ([\w.]+)/.exec(ua) ||
124 | ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) ||
125 | [];
126 |
127 | return {
128 | browser: match[ 1 ] || "",
129 | version: match[ 2 ] || "0"
130 | };
131 | };
132 | jQuery.browser = {};
133 | jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true;
134 | }
135 |
--------------------------------------------------------------------------------
/docs/_static/css/badge_only.css:
--------------------------------------------------------------------------------
1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}}
--------------------------------------------------------------------------------
/docs/_static/css/fonts/Roboto-Slab-Bold.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/Roboto-Slab-Bold.woff
--------------------------------------------------------------------------------
/docs/_static/css/fonts/Roboto-Slab-Bold.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/Roboto-Slab-Bold.woff2
--------------------------------------------------------------------------------
/docs/_static/css/fonts/Roboto-Slab-Regular.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/Roboto-Slab-Regular.woff
--------------------------------------------------------------------------------
/docs/_static/css/fonts/Roboto-Slab-Regular.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/Roboto-Slab-Regular.woff2
--------------------------------------------------------------------------------
/docs/_static/css/fonts/fontawesome-webfont.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/fontawesome-webfont.eot
--------------------------------------------------------------------------------
/docs/_static/css/fonts/fontawesome-webfont.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/fontawesome-webfont.ttf
--------------------------------------------------------------------------------
/docs/_static/css/fonts/fontawesome-webfont.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/fontawesome-webfont.woff
--------------------------------------------------------------------------------
/docs/_static/css/fonts/fontawesome-webfont.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/fontawesome-webfont.woff2
--------------------------------------------------------------------------------
/docs/_static/css/fonts/lato-bold-italic.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-bold-italic.woff
--------------------------------------------------------------------------------
/docs/_static/css/fonts/lato-bold-italic.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-bold-italic.woff2
--------------------------------------------------------------------------------
/docs/_static/css/fonts/lato-bold.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-bold.woff
--------------------------------------------------------------------------------
/docs/_static/css/fonts/lato-bold.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-bold.woff2
--------------------------------------------------------------------------------
/docs/_static/css/fonts/lato-normal-italic.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-normal-italic.woff
--------------------------------------------------------------------------------
/docs/_static/css/fonts/lato-normal-italic.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-normal-italic.woff2
--------------------------------------------------------------------------------
/docs/_static/css/fonts/lato-normal.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-normal.woff
--------------------------------------------------------------------------------
/docs/_static/css/fonts/lato-normal.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-normal.woff2
--------------------------------------------------------------------------------
/docs/_static/doctools.js:
--------------------------------------------------------------------------------
1 | /*
2 | * doctools.js
3 | * ~~~~~~~~~~~
4 | *
5 | * Base JavaScript utilities for all Sphinx HTML documentation.
6 | *
7 | * :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS.
8 | * :license: BSD, see LICENSE for details.
9 | *
10 | */
11 | "use strict";
12 |
13 | const BLACKLISTED_KEY_CONTROL_ELEMENTS = new Set([
14 | "TEXTAREA",
15 | "INPUT",
16 | "SELECT",
17 | "BUTTON",
18 | ]);
19 |
20 | const _ready = (callback) => {
21 | if (document.readyState !== "loading") {
22 | callback();
23 | } else {
24 | document.addEventListener("DOMContentLoaded", callback);
25 | }
26 | };
27 |
28 | /**
29 | * Small JavaScript module for the documentation.
30 | */
31 | const Documentation = {
32 | init: () => {
33 | Documentation.initDomainIndexTable();
34 | Documentation.initOnKeyListeners();
35 | },
36 |
37 | /**
38 | * i18n support
39 | */
40 | TRANSLATIONS: {},
41 | PLURAL_EXPR: (n) => (n === 1 ? 0 : 1),
42 | LOCALE: "unknown",
43 |
44 | // gettext and ngettext don't access this so that the functions
45 | // can safely bound to a different name (_ = Documentation.gettext)
46 | gettext: (string) => {
47 | const translated = Documentation.TRANSLATIONS[string];
48 | switch (typeof translated) {
49 | case "undefined":
50 | return string; // no translation
51 | case "string":
52 | return translated; // translation exists
53 | default:
54 | return translated[0]; // (singular, plural) translation tuple exists
55 | }
56 | },
57 |
58 | ngettext: (singular, plural, n) => {
59 | const translated = Documentation.TRANSLATIONS[singular];
60 | if (typeof translated !== "undefined")
61 | return translated[Documentation.PLURAL_EXPR(n)];
62 | return n === 1 ? singular : plural;
63 | },
64 |
65 | addTranslations: (catalog) => {
66 | Object.assign(Documentation.TRANSLATIONS, catalog.messages);
67 | Documentation.PLURAL_EXPR = new Function(
68 | "n",
69 | `return (${catalog.plural_expr})`
70 | );
71 | Documentation.LOCALE = catalog.locale;
72 | },
73 |
74 | /**
75 | * helper function to focus on search bar
76 | */
77 | focusSearchBar: () => {
78 | document.querySelectorAll("input[name=q]")[0]?.focus();
79 | },
80 |
81 | /**
82 | * Initialise the domain index toggle buttons
83 | */
84 | initDomainIndexTable: () => {
85 | const toggler = (el) => {
86 | const idNumber = el.id.substr(7);
87 | const toggledRows = document.querySelectorAll(`tr.cg-${idNumber}`);
88 | if (el.src.substr(-9) === "minus.png") {
89 | el.src = `${el.src.substr(0, el.src.length - 9)}plus.png`;
90 | toggledRows.forEach((el) => (el.style.display = "none"));
91 | } else {
92 | el.src = `${el.src.substr(0, el.src.length - 8)}minus.png`;
93 | toggledRows.forEach((el) => (el.style.display = ""));
94 | }
95 | };
96 |
97 | const togglerElements = document.querySelectorAll("img.toggler");
98 | togglerElements.forEach((el) =>
99 | el.addEventListener("click", (event) => toggler(event.currentTarget))
100 | );
101 | togglerElements.forEach((el) => (el.style.display = ""));
102 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) togglerElements.forEach(toggler);
103 | },
104 |
105 | initOnKeyListeners: () => {
106 | // only install a listener if it is really needed
107 | if (
108 | !DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS &&
109 | !DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS
110 | )
111 | return;
112 |
113 | document.addEventListener("keydown", (event) => {
114 | // bail for input elements
115 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return;
116 | // bail with special keys
117 | if (event.altKey || event.ctrlKey || event.metaKey) return;
118 |
119 | if (!event.shiftKey) {
120 | switch (event.key) {
121 | case "ArrowLeft":
122 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break;
123 |
124 | const prevLink = document.querySelector('link[rel="prev"]');
125 | if (prevLink && prevLink.href) {
126 | window.location.href = prevLink.href;
127 | event.preventDefault();
128 | }
129 | break;
130 | case "ArrowRight":
131 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break;
132 |
133 | const nextLink = document.querySelector('link[rel="next"]');
134 | if (nextLink && nextLink.href) {
135 | window.location.href = nextLink.href;
136 | event.preventDefault();
137 | }
138 | break;
139 | }
140 | }
141 |
142 | // some keyboard layouts may need Shift to get /
143 | switch (event.key) {
144 | case "/":
145 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) break;
146 | Documentation.focusSearchBar();
147 | event.preventDefault();
148 | }
149 | });
150 | },
151 | };
152 |
153 | // quick alias for translations
154 | const _ = Documentation.gettext;
155 |
156 | _ready(Documentation.init);
157 |
--------------------------------------------------------------------------------
/docs/_static/documentation_options.js:
--------------------------------------------------------------------------------
1 | var DOCUMENTATION_OPTIONS = {
2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
3 | VERSION: '0.11',
4 | LANGUAGE: 'en',
5 | COLLAPSE_INDEX: false,
6 | BUILDER: 'html',
7 | FILE_SUFFIX: '.html',
8 | LINK_SUFFIX: '.html',
9 | HAS_SOURCE: true,
10 | SOURCELINK_SUFFIX: '.txt',
11 | NAVIGATION_WITH_KEYS: false,
12 | SHOW_SEARCH_SUMMARY: true,
13 | ENABLE_SEARCH_SHORTCUTS: true,
14 | };
--------------------------------------------------------------------------------
/docs/_static/file.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/file.png
--------------------------------------------------------------------------------
/docs/_static/fonts/FontAwesome.otf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/FontAwesome.otf
--------------------------------------------------------------------------------
/docs/_static/fonts/Inconsolata-Bold.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Inconsolata-Bold.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/Inconsolata-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Inconsolata-Regular.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/Inconsolata.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Inconsolata.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato-Bold.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato-Bold.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato-Regular.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-bold.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bold.eot
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-bold.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bold.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-bold.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bold.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-bold.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bold.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-bolditalic.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bolditalic.eot
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-bolditalic.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bolditalic.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-bolditalic.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bolditalic.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-bolditalic.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bolditalic.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-italic.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-italic.eot
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-italic.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-italic.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-italic.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-italic.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-italic.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-italic.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-regular.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-regular.eot
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-regular.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-regular.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-regular.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/Lato/lato-regular.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-regular.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/Roboto-Slab-Bold.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Bold.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/Roboto-Slab-Bold.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Bold.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/Roboto-Slab-Light.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Light.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/Roboto-Slab-Light.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Light.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/Roboto-Slab-Regular.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Regular.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/Roboto-Slab-Regular.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Regular.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/Roboto-Slab-Thin.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Thin.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/Roboto-Slab-Thin.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Thin.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/RobotoSlab-Bold.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab-Bold.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/RobotoSlab-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab-Regular.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot
--------------------------------------------------------------------------------
/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot
--------------------------------------------------------------------------------
/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/fontawesome-webfont.eot:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/fontawesome-webfont.eot
--------------------------------------------------------------------------------
/docs/_static/fonts/fontawesome-webfont.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/fontawesome-webfont.ttf
--------------------------------------------------------------------------------
/docs/_static/fonts/fontawesome-webfont.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/fontawesome-webfont.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/fontawesome-webfont.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/fontawesome-webfont.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/lato-bold-italic.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-bold-italic.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/lato-bold-italic.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-bold-italic.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/lato-bold.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-bold.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/lato-bold.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-bold.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/lato-normal-italic.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-normal-italic.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/lato-normal-italic.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-normal-italic.woff2
--------------------------------------------------------------------------------
/docs/_static/fonts/lato-normal.woff:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-normal.woff
--------------------------------------------------------------------------------
/docs/_static/fonts/lato-normal.woff2:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-normal.woff2
--------------------------------------------------------------------------------
/docs/_static/getting_started_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/getting_started_output.png
--------------------------------------------------------------------------------
/docs/_static/getting_started_updated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/getting_started_updated.png
--------------------------------------------------------------------------------
/docs/_static/js/badge_only.js:
--------------------------------------------------------------------------------
1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}});
--------------------------------------------------------------------------------
/docs/_static/js/html5shiv-printshiv.min.js:
--------------------------------------------------------------------------------
1 | /**
2 | * @preserve HTML5 Shiv 3.7.3-pre | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed
3 | */
4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=y.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=y.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),y.elements=c+" "+a,j(b)}function f(a){var b=x[a[v]];return b||(b={},w++,a[v]=w,x[w]=b),b}function g(a,c,d){if(c||(c=b),q)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():u.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||t.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),q)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return y.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(y,b.frag)}function j(a){a||(a=b);var d=f(a);return!y.shivCSS||p||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),q||i(a,d),a}function k(a){for(var b,c=a.getElementsByTagName("*"),e=c.length,f=RegExp("^(?:"+d().join("|")+")$","i"),g=[];e--;)b=c[e],f.test(b.nodeName)&&g.push(b.applyElement(l(b)));return g}function l(a){for(var b,c=a.attributes,d=c.length,e=a.ownerDocument.createElement(A+":"+a.nodeName);d--;)b=c[d],b.specified&&e.setAttribute(b.nodeName,b.nodeValue);return e.style.cssText=a.style.cssText,e}function m(a){for(var b,c=a.split("{"),e=c.length,f=RegExp("(^|[\\s,>+~])("+d().join("|")+")(?=[[\\s,>+~#.:]|$)","gi"),g="$1"+A+"\\:$2";e--;)b=c[e]=c[e].split("}"),b[b.length-1]=b[b.length-1].replace(f,g),c[e]=b.join("}");return c.join("{")}function n(a){for(var b=a.length;b--;)a[b].removeNode()}function o(a){function b(){clearTimeout(g._removeSheetTimer),d&&d.removeNode(!0),d=null}var d,e,g=f(a),h=a.namespaces,i=a.parentWindow;return!B||a.printShived?a:("undefined"==typeof h[A]&&h.add(A),i.attachEvent("onbeforeprint",function(){b();for(var f,g,h,i=a.styleSheets,j=[],l=i.length,n=Array(l);l--;)n[l]=i[l];for(;h=n.pop();)if(!h.disabled&&z.test(h.media)){try{f=h.imports,g=f.length}catch(o){g=0}for(l=0;g>l;l++)n.push(f[l]);try{j.push(h.cssText)}catch(o){}}j=m(j.reverse().join("")),e=k(a),d=c(a,j)}),i.attachEvent("onafterprint",function(){n(e),clearTimeout(g._removeSheetTimer),g._removeSheetTimer=setTimeout(b,500)}),a.printShived=!0,a)}var p,q,r="3.7.3",s=a.html5||{},t=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,u=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,v="_html5shiv",w=0,x={};!function(){try{var a=b.createElement("a");a.innerHTML="",p="hidden"in a,q=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){p=!0,q=!0}}();var y={elements:s.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:r,shivCSS:s.shivCSS!==!1,supportsUnknownElements:q,shivMethods:s.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=y,j(b);var z=/^$|\b(?:all|print)\b/,A="html5shiv",B=!q&&function(){var c=b.documentElement;return!("undefined"==typeof b.namespaces||"undefined"==typeof b.parentWindow||"undefined"==typeof c.applyElement||"undefined"==typeof c.removeNode||"undefined"==typeof a.attachEvent)}();y.type+=" print",y.shivPrint=o,o(b),"object"==typeof module&&module.exports&&(module.exports=y)}("undefined"!=typeof window?window:this,document);
--------------------------------------------------------------------------------
/docs/_static/js/html5shiv.min.js:
--------------------------------------------------------------------------------
1 | /**
2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed
3 | */
4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document);
--------------------------------------------------------------------------------
/docs/_static/js/theme.js:
--------------------------------------------------------------------------------
1 | !function(n){var e={};function t(i){if(e[i])return e[i].exports;var o=e[i]={i:i,l:!1,exports:{}};return n[i].call(o.exports,o,o.exports,t),o.l=!0,o.exports}t.m=n,t.c=e,t.d=function(n,e,i){t.o(n,e)||Object.defineProperty(n,e,{enumerable:!0,get:i})},t.r=function(n){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(n,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(n,"__esModule",{value:!0})},t.t=function(n,e){if(1&e&&(n=t(n)),8&e)return n;if(4&e&&"object"==typeof n&&n&&n.__esModule)return n;var i=Object.create(null);if(t.r(i),Object.defineProperty(i,"default",{enumerable:!0,value:n}),2&e&&"string"!=typeof n)for(var o in n)t.d(i,o,function(e){return n[e]}.bind(null,o));return i},t.n=function(n){var e=n&&n.__esModule?function(){return n.default}:function(){return n};return t.d(e,"a",e),e},t.o=function(n,e){return Object.prototype.hasOwnProperty.call(n,e)},t.p="",t(t.s=0)}([function(n,e,t){t(1),n.exports=t(3)},function(n,e,t){(function(){var e="undefined"!=typeof window?window.jQuery:t(2);n.exports.ThemeNav={navBar:null,win:null,winScroll:!1,winResize:!1,linkScroll:!1,winPosition:0,winHeight:null,docHeight:null,isRunning:!1,enable:function(n){var t=this;void 0===n&&(n=!0),t.isRunning||(t.isRunning=!0,e((function(e){t.init(e),t.reset(),t.win.on("hashchange",t.reset),n&&t.win.on("scroll",(function(){t.linkScroll||t.winScroll||(t.winScroll=!0,requestAnimationFrame((function(){t.onScroll()})))})),t.win.on("resize",(function(){t.winResize||(t.winResize=!0,requestAnimationFrame((function(){t.onResize()})))})),t.onResize()})))},enableSticky:function(){this.enable(!0)},init:function(n){n(document);var e=this;this.navBar=n("div.wy-side-scroll:first"),this.win=n(window),n(document).on("click","[data-toggle='wy-nav-top']",(function(){n("[data-toggle='wy-nav-shift']").toggleClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift")})).on("click",".wy-menu-vertical .current ul li a",(function(){var t=n(this);n("[data-toggle='wy-nav-shift']").removeClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift"),e.toggleCurrent(t),e.hashChange()})).on("click","[data-toggle='rst-current-version']",(function(){n("[data-toggle='rst-versions']").toggleClass("shift-up")})),n("table.docutils:not(.field-list,.footnote,.citation)").wrap(""),n("table.docutils.footnote").wrap(""),n("table.docutils.citation").wrap(""),n(".wy-menu-vertical ul").not(".simple").siblings("a").each((function(){var t=n(this);expand=n(''),expand.on("click",(function(n){return e.toggleCurrent(t),n.stopPropagation(),!1})),t.prepend(expand)}))},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),t=e.find('[href="'+n+'"]');if(0===t.length){var i=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(t=e.find('[href="#'+i.attr("id")+'"]')).length&&(t=e.find('[href="#"]'))}if(t.length>0){$(".wy-menu-vertical .current").removeClass("current").attr("aria-expanded","false"),t.addClass("current").attr("aria-expanded","true"),t.closest("li.toctree-l1").parent().addClass("current").attr("aria-expanded","true");for(let n=1;n<=10;n++)t.closest("li.toctree-l"+n).addClass("current").attr("aria-expanded","true");t[0].scrollIntoView()}}catch(n){console.log("Error expanding nav for anchor",n)}},onScroll:function(){this.winScroll=!1;var n=this.win.scrollTop(),e=n+this.winHeight,t=this.navBar.scrollTop()+(n-this.winPosition);n<0||e>this.docHeight||(this.navBar.scrollTop(t),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",(function(){this.linkScroll=!1}))},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current").attr("aria-expanded","false"),e.siblings().find("li.current").removeClass("current").attr("aria-expanded","false");var t=e.find("> ul li");t.length&&(t.removeClass("current").attr("aria-expanded","false"),e.toggleClass("current").attr("aria-expanded",(function(n,e){return"true"==e?"false":"true"})))}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:n.exports.ThemeNav,StickyNav:n.exports.ThemeNav}),function(){for(var n=0,e=["ms","moz","webkit","o"],t=0;t0
63 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1
64 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1
65 | var s_v = "^(" + C + ")?" + v; // vowel in stem
66 |
67 | this.stemWord = function (w) {
68 | var stem;
69 | var suffix;
70 | var firstch;
71 | var origword = w;
72 |
73 | if (w.length < 3)
74 | return w;
75 |
76 | var re;
77 | var re2;
78 | var re3;
79 | var re4;
80 |
81 | firstch = w.substr(0,1);
82 | if (firstch == "y")
83 | w = firstch.toUpperCase() + w.substr(1);
84 |
85 | // Step 1a
86 | re = /^(.+?)(ss|i)es$/;
87 | re2 = /^(.+?)([^s])s$/;
88 |
89 | if (re.test(w))
90 | w = w.replace(re,"$1$2");
91 | else if (re2.test(w))
92 | w = w.replace(re2,"$1$2");
93 |
94 | // Step 1b
95 | re = /^(.+?)eed$/;
96 | re2 = /^(.+?)(ed|ing)$/;
97 | if (re.test(w)) {
98 | var fp = re.exec(w);
99 | re = new RegExp(mgr0);
100 | if (re.test(fp[1])) {
101 | re = /.$/;
102 | w = w.replace(re,"");
103 | }
104 | }
105 | else if (re2.test(w)) {
106 | var fp = re2.exec(w);
107 | stem = fp[1];
108 | re2 = new RegExp(s_v);
109 | if (re2.test(stem)) {
110 | w = stem;
111 | re2 = /(at|bl|iz)$/;
112 | re3 = new RegExp("([^aeiouylsz])\\1$");
113 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$");
114 | if (re2.test(w))
115 | w = w + "e";
116 | else if (re3.test(w)) {
117 | re = /.$/;
118 | w = w.replace(re,"");
119 | }
120 | else if (re4.test(w))
121 | w = w + "e";
122 | }
123 | }
124 |
125 | // Step 1c
126 | re = /^(.+?)y$/;
127 | if (re.test(w)) {
128 | var fp = re.exec(w);
129 | stem = fp[1];
130 | re = new RegExp(s_v);
131 | if (re.test(stem))
132 | w = stem + "i";
133 | }
134 |
135 | // Step 2
136 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/;
137 | if (re.test(w)) {
138 | var fp = re.exec(w);
139 | stem = fp[1];
140 | suffix = fp[2];
141 | re = new RegExp(mgr0);
142 | if (re.test(stem))
143 | w = stem + step2list[suffix];
144 | }
145 |
146 | // Step 3
147 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/;
148 | if (re.test(w)) {
149 | var fp = re.exec(w);
150 | stem = fp[1];
151 | suffix = fp[2];
152 | re = new RegExp(mgr0);
153 | if (re.test(stem))
154 | w = stem + step3list[suffix];
155 | }
156 |
157 | // Step 4
158 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/;
159 | re2 = /^(.+?)(s|t)(ion)$/;
160 | if (re.test(w)) {
161 | var fp = re.exec(w);
162 | stem = fp[1];
163 | re = new RegExp(mgr1);
164 | if (re.test(stem))
165 | w = stem;
166 | }
167 | else if (re2.test(w)) {
168 | var fp = re2.exec(w);
169 | stem = fp[1] + fp[2];
170 | re2 = new RegExp(mgr1);
171 | if (re2.test(stem))
172 | w = stem;
173 | }
174 |
175 | // Step 5
176 | re = /^(.+?)e$/;
177 | if (re.test(w)) {
178 | var fp = re.exec(w);
179 | stem = fp[1];
180 | re = new RegExp(mgr1);
181 | re2 = new RegExp(meq1);
182 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$");
183 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem))))
184 | w = stem;
185 | }
186 | re = /ll$/;
187 | re2 = new RegExp(mgr1);
188 | if (re.test(w) && re2.test(w)) {
189 | re = /.$/;
190 | w = w.replace(re,"");
191 | }
192 |
193 | // and turn initial Y back to y
194 | if (firstch == "y")
195 | w = firstch.toLowerCase() + w.substr(1);
196 | return w;
197 | }
198 | }
199 |
200 |
--------------------------------------------------------------------------------
/docs/_static/minus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/minus.png
--------------------------------------------------------------------------------
/docs/_static/plus.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/plus.png
--------------------------------------------------------------------------------
/docs/_static/pygments.css:
--------------------------------------------------------------------------------
1 | pre { line-height: 125%; }
2 | td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
3 | span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; }
4 | td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
5 | span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; }
6 | .highlight .hll { background-color: #ffffcc }
7 | .highlight { background: #f8f8f8; }
8 | .highlight .c { color: #3D7B7B; font-style: italic } /* Comment */
9 | .highlight .err { border: 1px solid #FF0000 } /* Error */
10 | .highlight .k { color: #008000; font-weight: bold } /* Keyword */
11 | .highlight .o { color: #666666 } /* Operator */
12 | .highlight .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */
13 | .highlight .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */
14 | .highlight .cp { color: #9C6500 } /* Comment.Preproc */
15 | .highlight .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */
16 | .highlight .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */
17 | .highlight .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */
18 | .highlight .gd { color: #A00000 } /* Generic.Deleted */
19 | .highlight .ge { font-style: italic } /* Generic.Emph */
20 | .highlight .gr { color: #E40000 } /* Generic.Error */
21 | .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */
22 | .highlight .gi { color: #008400 } /* Generic.Inserted */
23 | .highlight .go { color: #717171 } /* Generic.Output */
24 | .highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */
25 | .highlight .gs { font-weight: bold } /* Generic.Strong */
26 | .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */
27 | .highlight .gt { color: #0044DD } /* Generic.Traceback */
28 | .highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */
29 | .highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */
30 | .highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */
31 | .highlight .kp { color: #008000 } /* Keyword.Pseudo */
32 | .highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */
33 | .highlight .kt { color: #B00040 } /* Keyword.Type */
34 | .highlight .m { color: #666666 } /* Literal.Number */
35 | .highlight .s { color: #BA2121 } /* Literal.String */
36 | .highlight .na { color: #687822 } /* Name.Attribute */
37 | .highlight .nb { color: #008000 } /* Name.Builtin */
38 | .highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */
39 | .highlight .no { color: #880000 } /* Name.Constant */
40 | .highlight .nd { color: #AA22FF } /* Name.Decorator */
41 | .highlight .ni { color: #717171; font-weight: bold } /* Name.Entity */
42 | .highlight .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */
43 | .highlight .nf { color: #0000FF } /* Name.Function */
44 | .highlight .nl { color: #767600 } /* Name.Label */
45 | .highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */
46 | .highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */
47 | .highlight .nv { color: #19177C } /* Name.Variable */
48 | .highlight .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */
49 | .highlight .w { color: #bbbbbb } /* Text.Whitespace */
50 | .highlight .mb { color: #666666 } /* Literal.Number.Bin */
51 | .highlight .mf { color: #666666 } /* Literal.Number.Float */
52 | .highlight .mh { color: #666666 } /* Literal.Number.Hex */
53 | .highlight .mi { color: #666666 } /* Literal.Number.Integer */
54 | .highlight .mo { color: #666666 } /* Literal.Number.Oct */
55 | .highlight .sa { color: #BA2121 } /* Literal.String.Affix */
56 | .highlight .sb { color: #BA2121 } /* Literal.String.Backtick */
57 | .highlight .sc { color: #BA2121 } /* Literal.String.Char */
58 | .highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */
59 | .highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */
60 | .highlight .s2 { color: #BA2121 } /* Literal.String.Double */
61 | .highlight .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */
62 | .highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */
63 | .highlight .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */
64 | .highlight .sx { color: #008000 } /* Literal.String.Other */
65 | .highlight .sr { color: #A45A77 } /* Literal.String.Regex */
66 | .highlight .s1 { color: #BA2121 } /* Literal.String.Single */
67 | .highlight .ss { color: #19177C } /* Literal.String.Symbol */
68 | .highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */
69 | .highlight .fm { color: #0000FF } /* Name.Function.Magic */
70 | .highlight .vc { color: #19177C } /* Name.Variable.Class */
71 | .highlight .vg { color: #19177C } /* Name.Variable.Global */
72 | .highlight .vi { color: #19177C } /* Name.Variable.Instance */
73 | .highlight .vm { color: #19177C } /* Name.Variable.Magic */
74 | .highlight .il { color: #666666 } /* Literal.Number.Integer.Long */
--------------------------------------------------------------------------------
/docs/_static/sphinx_highlight.js:
--------------------------------------------------------------------------------
1 | /* Highlighting utilities for Sphinx HTML documentation. */
2 | "use strict";
3 |
4 | const SPHINX_HIGHLIGHT_ENABLED = true
5 |
6 | /**
7 | * highlight a given string on a node by wrapping it in
8 | * span elements with the given class name.
9 | */
10 | const _highlight = (node, addItems, text, className) => {
11 | if (node.nodeType === Node.TEXT_NODE) {
12 | const val = node.nodeValue;
13 | const parent = node.parentNode;
14 | const pos = val.toLowerCase().indexOf(text);
15 | if (
16 | pos >= 0 &&
17 | !parent.classList.contains(className) &&
18 | !parent.classList.contains("nohighlight")
19 | ) {
20 | let span;
21 |
22 | const closestNode = parent.closest("body, svg, foreignObject");
23 | const isInSVG = closestNode && closestNode.matches("svg");
24 | if (isInSVG) {
25 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan");
26 | } else {
27 | span = document.createElement("span");
28 | span.classList.add(className);
29 | }
30 |
31 | span.appendChild(document.createTextNode(val.substr(pos, text.length)));
32 | parent.insertBefore(
33 | span,
34 | parent.insertBefore(
35 | document.createTextNode(val.substr(pos + text.length)),
36 | node.nextSibling
37 | )
38 | );
39 | node.nodeValue = val.substr(0, pos);
40 |
41 | if (isInSVG) {
42 | const rect = document.createElementNS(
43 | "http://www.w3.org/2000/svg",
44 | "rect"
45 | );
46 | const bbox = parent.getBBox();
47 | rect.x.baseVal.value = bbox.x;
48 | rect.y.baseVal.value = bbox.y;
49 | rect.width.baseVal.value = bbox.width;
50 | rect.height.baseVal.value = bbox.height;
51 | rect.setAttribute("class", className);
52 | addItems.push({ parent: parent, target: rect });
53 | }
54 | }
55 | } else if (node.matches && !node.matches("button, select, textarea")) {
56 | node.childNodes.forEach((el) => _highlight(el, addItems, text, className));
57 | }
58 | };
59 | const _highlightText = (thisNode, text, className) => {
60 | let addItems = [];
61 | _highlight(thisNode, addItems, text, className);
62 | addItems.forEach((obj) =>
63 | obj.parent.insertAdjacentElement("beforebegin", obj.target)
64 | );
65 | };
66 |
67 | /**
68 | * Small JavaScript module for the documentation.
69 | */
70 | const SphinxHighlight = {
71 |
72 | /**
73 | * highlight the search words provided in localstorage in the text
74 | */
75 | highlightSearchWords: () => {
76 | if (!SPHINX_HIGHLIGHT_ENABLED) return; // bail if no highlight
77 |
78 | // get and clear terms from localstorage
79 | const url = new URL(window.location);
80 | const highlight =
81 | localStorage.getItem("sphinx_highlight_terms")
82 | || url.searchParams.get("highlight")
83 | || "";
84 | localStorage.removeItem("sphinx_highlight_terms")
85 | url.searchParams.delete("highlight");
86 | window.history.replaceState({}, "", url);
87 |
88 | // get individual terms from highlight string
89 | const terms = highlight.toLowerCase().split(/\s+/).filter(x => x);
90 | if (terms.length === 0) return; // nothing to do
91 |
92 | // There should never be more than one element matching "div.body"
93 | const divBody = document.querySelectorAll("div.body");
94 | const body = divBody.length ? divBody[0] : document.querySelector("body");
95 | window.setTimeout(() => {
96 | terms.forEach((term) => _highlightText(body, term, "highlighted"));
97 | }, 10);
98 |
99 | const searchBox = document.getElementById("searchbox");
100 | if (searchBox === null) return;
101 | searchBox.appendChild(
102 | document
103 | .createRange()
104 | .createContextualFragment(
105 | '' +
106 | '' +
107 | _("Hide Search Matches") +
108 | "
"
109 | )
110 | );
111 | },
112 |
113 | /**
114 | * helper function to hide the search marks again
115 | */
116 | hideSearchWords: () => {
117 | document
118 | .querySelectorAll("#searchbox .highlight-link")
119 | .forEach((el) => el.remove());
120 | document
121 | .querySelectorAll("span.highlighted")
122 | .forEach((el) => el.classList.remove("highlighted"));
123 | localStorage.removeItem("sphinx_highlight_terms")
124 | },
125 |
126 | initEscapeListener: () => {
127 | // only install a listener if it is really needed
128 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return;
129 |
130 | document.addEventListener("keydown", (event) => {
131 | // bail for input elements
132 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return;
133 | // bail with special keys
134 | if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return;
135 | if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) {
136 | SphinxHighlight.hideSearchWords();
137 | event.preventDefault();
138 | }
139 | });
140 | },
141 | };
142 |
143 | _ready(SphinxHighlight.highlightSearchWords);
144 | _ready(SphinxHighlight.initEscapeListener);
145 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/objects.inv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/objects.inv
--------------------------------------------------------------------------------
/docs/search.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Search — DiCE 0.11 documentation
7 |
8 |
9 |
10 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
65 |
66 |
70 |
71 |
72 |
73 |
74 |
75 | - »
76 | - Search
77 | -
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
114 |
115 |
116 |
117 |
118 |
123 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 |
16 | sys.path.insert(0, os.path.abspath("../../"))
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'DiCE'
21 | copyright = '2020, Ramaravind, Amit, Chenhao' # noqa: A001
22 | author = 'Ramaravind, Amit, Chenhao'
23 |
24 | # The full version, including alpha/beta/rc tags
25 | release = '0.11'
26 |
27 |
28 | # -- General configuration ---------------------------------------------------
29 |
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = [
34 | 'sphinx.ext.autodoc',
35 | 'sphinx.ext.viewcode',
36 | 'sphinx.ext.todo',
37 | 'nbsphinx',
38 | 'sphinx_rtd_theme'
39 | ]
40 |
41 | autodoc_mock_imports = ['numpy', 'pandas', 'matplotlib', 'os', 'tensorflow', 'random', 'collections',
42 | 'timeit', 'tensorflow.keras', 'sklearn', 'sklearn.model_selection.train_test_split',
43 | 'copy', 'IPython', 'IPython.display.display', 'collections', 'collections.OrderedDict',
44 | 'logging', 'torch', 'torchvision']
45 |
46 | # Add any paths that contain templates here, relative to this directory.
47 | templates_path = ['_templates']
48 |
49 | # source_suffix = ['.rst', '.md']
50 | source_suffix = '.rst'
51 |
52 | # The master toctree document.
53 | master_doc = 'index'
54 |
55 | language = 'en'
56 |
57 | # List of patterns, relative to source directory, that match files and
58 | # directories to ignore when looking for source files.
59 | # This pattern also affects html_static_path and html_extra_path.
60 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '.ipynb_checkpoints']
61 |
62 |
63 | # -- Options for HTML output -------------------------------------------------
64 |
65 | # The theme to use for HTML and HTML Help pages. See the documentation for
66 | # a list of builtin themes.
67 | #
68 | html_theme = 'sphinx_rtd_theme'
69 |
70 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True'
71 | if not on_rtd:
72 | # only import and set the theme if we're building docs locally
73 | import sphinx_rtd_theme
74 | html_theme = 'sphinx_rtd_theme'
75 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
76 |
77 | # Add any paths that contain custom static files (such as style sheets) here,
78 | # relative to this directory. They are copied after the builtin static files,
79 | # so a file named "default.css" will overwrite the builtin "default.css".
80 | html_static_path = []
81 |
82 | html_sidebars = {'**': ['globaltoc.html', 'relations.html', 'sourcelink.html', 'searchbox.html']}
83 |
--------------------------------------------------------------------------------
/docs/source/dice_ml.data_interfaces.rst:
--------------------------------------------------------------------------------
1 | dice\_ml.data\_interfaces package
2 | =================================
3 |
4 | Submodules
5 | ----------
6 |
7 | dice\_ml.data\_interfaces.base\_data\_interface module
8 | ------------------------------------------------------
9 |
10 | .. automodule:: dice_ml.data_interfaces.base_data_interface
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | dice\_ml.data\_interfaces.private\_data\_interface module
16 | ---------------------------------------------------------
17 |
18 | .. automodule:: dice_ml.data_interfaces.private_data_interface
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | dice\_ml.data\_interfaces.public\_data\_interface module
24 | --------------------------------------------------------
25 |
26 | .. automodule:: dice_ml.data_interfaces.public_data_interface
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | Module contents
32 | ---------------
33 |
34 | .. automodule:: dice_ml.data_interfaces
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
--------------------------------------------------------------------------------
/docs/source/dice_ml.explainer_interfaces.rst:
--------------------------------------------------------------------------------
1 | dice\_ml.explainer\_interfaces package
2 | ======================================
3 |
4 | Submodules
5 | ----------
6 |
7 | dice\_ml.explainer\_interfaces.dice\_KD module
8 | ----------------------------------------------
9 |
10 | .. automodule:: dice_ml.explainer_interfaces.dice_KD
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | dice\_ml.explainer\_interfaces.dice\_genetic module
16 | ---------------------------------------------------
17 |
18 | .. automodule:: dice_ml.explainer_interfaces.dice_genetic
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | dice\_ml.explainer\_interfaces.dice\_pytorch module
24 | ---------------------------------------------------
25 |
26 | .. automodule:: dice_ml.explainer_interfaces.dice_pytorch
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | dice\_ml.explainer\_interfaces.dice\_random module
32 | --------------------------------------------------
33 |
34 | .. automodule:: dice_ml.explainer_interfaces.dice_random
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | dice\_ml.explainer\_interfaces.dice\_tensorflow1 module
40 | -------------------------------------------------------
41 |
42 | .. automodule:: dice_ml.explainer_interfaces.dice_tensorflow1
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | dice\_ml.explainer\_interfaces.dice\_tensorflow2 module
48 | -------------------------------------------------------
49 |
50 | .. automodule:: dice_ml.explainer_interfaces.dice_tensorflow2
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | dice\_ml.explainer\_interfaces.explainer\_base module
56 | -----------------------------------------------------
57 |
58 | .. automodule:: dice_ml.explainer_interfaces.explainer_base
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 | dice\_ml.explainer\_interfaces.feasible\_base\_vae module
64 | ---------------------------------------------------------
65 |
66 | .. automodule:: dice_ml.explainer_interfaces.feasible_base_vae
67 | :members:
68 | :undoc-members:
69 | :show-inheritance:
70 |
71 | dice\_ml.explainer\_interfaces.feasible\_model\_approx module
72 | -------------------------------------------------------------
73 |
74 | .. automodule:: dice_ml.explainer_interfaces.feasible_model_approx
75 | :members:
76 | :undoc-members:
77 | :show-inheritance:
78 |
79 | Module contents
80 | ---------------
81 |
82 | .. automodule:: dice_ml.explainer_interfaces
83 | :members:
84 | :undoc-members:
85 | :show-inheritance:
86 |
--------------------------------------------------------------------------------
/docs/source/dice_ml.model_interfaces.rst:
--------------------------------------------------------------------------------
1 | dice\_ml.model\_interfaces package
2 | ==================================
3 |
4 | Submodules
5 | ----------
6 |
7 | dice\_ml.model\_interfaces.base\_model module
8 | ---------------------------------------------
9 |
10 | .. automodule:: dice_ml.model_interfaces.base_model
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | dice\_ml.model\_interfaces.keras\_tensorflow\_model module
16 | ----------------------------------------------------------
17 |
18 | .. automodule:: dice_ml.model_interfaces.keras_tensorflow_model
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | dice\_ml.model\_interfaces.pytorch\_model module
24 | ------------------------------------------------
25 |
26 | .. automodule:: dice_ml.model_interfaces.pytorch_model
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | Module contents
32 | ---------------
33 |
34 | .. automodule:: dice_ml.model_interfaces
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
--------------------------------------------------------------------------------
/docs/source/dice_ml.rst:
--------------------------------------------------------------------------------
1 | dice\_ml package
2 | ================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | dice_ml.data_interfaces
11 | dice_ml.explainer_interfaces
12 | dice_ml.model_interfaces
13 | dice_ml.schema
14 | dice_ml.utils
15 |
16 | Submodules
17 | ----------
18 |
19 | dice\_ml.constants module
20 | -------------------------
21 |
22 | .. automodule:: dice_ml.constants
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
27 | dice\_ml.counterfactual\_explanations module
28 | --------------------------------------------
29 |
30 | .. automodule:: dice_ml.counterfactual_explanations
31 | :members:
32 | :undoc-members:
33 | :show-inheritance:
34 |
35 | dice\_ml.data module
36 | --------------------
37 |
38 | .. automodule:: dice_ml.data
39 | :members:
40 | :undoc-members:
41 | :show-inheritance:
42 |
43 | dice\_ml.dice module
44 | --------------------
45 |
46 | .. automodule:: dice_ml.dice
47 | :members:
48 | :undoc-members:
49 | :show-inheritance:
50 |
51 | dice\_ml.diverse\_counterfactuals module
52 | ----------------------------------------
53 |
54 | .. automodule:: dice_ml.diverse_counterfactuals
55 | :members:
56 | :undoc-members:
57 | :show-inheritance:
58 |
59 | dice\_ml.model module
60 | ---------------------
61 |
62 | .. automodule:: dice_ml.model
63 | :members:
64 | :undoc-members:
65 | :show-inheritance:
66 |
67 | Module contents
68 | ---------------
69 |
70 | .. automodule:: dice_ml
71 | :members:
72 | :undoc-members:
73 | :show-inheritance:
74 |
--------------------------------------------------------------------------------
/docs/source/dice_ml.schema.rst:
--------------------------------------------------------------------------------
1 | dice\_ml.schema package
2 | =======================
3 |
4 | Module contents
5 | ---------------
6 |
7 | .. automodule:: dice_ml.schema
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
11 |
--------------------------------------------------------------------------------
/docs/source/dice_ml.utils.rst:
--------------------------------------------------------------------------------
1 | dice\_ml.utils package
2 | ======================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | dice_ml.utils.sample_architecture
11 |
12 | Submodules
13 | ----------
14 |
15 | dice\_ml.utils.exception module
16 | -------------------------------
17 |
18 | .. automodule:: dice_ml.utils.exception
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
23 | dice\_ml.utils.helpers module
24 | -----------------------------
25 |
26 | .. automodule:: dice_ml.utils.helpers
27 | :members:
28 | :undoc-members:
29 | :show-inheritance:
30 |
31 | dice\_ml.utils.neuralnetworks module
32 | ------------------------------------
33 |
34 | .. automodule:: dice_ml.utils.neuralnetworks
35 | :members:
36 | :undoc-members:
37 | :show-inheritance:
38 |
39 | dice\_ml.utils.serialize module
40 | -------------------------------
41 |
42 | .. automodule:: dice_ml.utils.serialize
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | Module contents
48 | ---------------
49 |
50 | .. automodule:: dice_ml.utils
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
--------------------------------------------------------------------------------
/docs/source/dice_ml.utils.sample_architecture.rst:
--------------------------------------------------------------------------------
1 | dice\_ml.utils.sample\_architecture package
2 | ===========================================
3 |
4 | Submodules
5 | ----------
6 |
7 | dice\_ml.utils.sample\_architecture.vae\_model module
8 | -----------------------------------------------------
9 |
10 | .. automodule:: dice_ml.utils.sample_architecture.vae_model
11 | :members:
12 | :undoc-members:
13 | :show-inheritance:
14 |
15 | Module contents
16 | ---------------
17 |
18 | .. automodule:: dice_ml.utils.sample_architecture
19 | :members:
20 | :undoc-members:
21 | :show-inheritance:
22 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. DiCE documentation master file, created by
2 | sphinx-quickstart on Tue May 19 13:15:18 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | .. include:: ../../README.rst
7 |
8 | .. toctree::
9 | :maxdepth: 2
10 | :caption: Getting Started:
11 |
12 | readme
13 |
14 | .. toctree::
15 | :maxdepth: 2
16 | :caption: Notebooks:
17 |
18 | notebooks/nb_index
19 |
20 | .. toctree::
21 | :maxdepth: 2
22 | :caption: Package:
23 |
24 | dice_ml
25 |
26 | Indices and tables
27 | ==================
28 |
29 | * :ref:`genindex`
30 | * :ref:`modindex`
31 | * :ref:`search`
32 |
--------------------------------------------------------------------------------
/docs/source/modules.rst:
--------------------------------------------------------------------------------
1 | dice_ml
2 | =======
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | dice_ml
8 |
--------------------------------------------------------------------------------
/docs/source/notebooks/DiCE_feature_importances.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Estimating local and global feature importance scores using DiCE\n",
8 | "\n",
9 | "Summaries of counterfactual examples can be used to estimate importance of features. Intuitively, a feature that is changed more often to generate a proximal counterfactual is an important feature. We use this intuition to build a feature importance score. \n",
10 | "\n",
11 | "This score can be interpreted as a measure of the **necessity** of a feature to cause a particular model output. That is, if the feature's value changes, then it is likely that the model's output class will also change (or the model's output will significantly change in case of regression model). \n",
12 | "\n",
13 | "Below we show how counterfactuals can be used to provide local feature importance scores for any input, and how those scores can be combined to yield a global importance score for each feature."
14 | ]
15 | },
16 | {
17 | "cell_type": "code",
18 | "execution_count": null,
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "from sklearn.compose import ColumnTransformer\n",
23 | "from sklearn.model_selection import train_test_split\n",
24 | "from sklearn.pipeline import Pipeline\n",
25 | "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
26 | "from sklearn.ensemble import RandomForestClassifier\n",
27 | "\n",
28 | "import dice_ml\n",
29 | "from dice_ml import Dice\n",
30 | "from dice_ml.utils import helpers # helper functions"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "%load_ext autoreload\n",
40 | "%autoreload 2"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "metadata": {},
46 | "source": [
47 | "## Preliminaries: Loading the data and ML model"
48 | ]
49 | },
50 | {
51 | "cell_type": "code",
52 | "execution_count": null,
53 | "metadata": {},
54 | "outputs": [],
55 | "source": [
56 | "dataset = helpers.load_adult_income_dataset().sample(5000) # downsampling to reduce ML model fitting time\n",
57 | "helpers.get_adult_data_info()"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": null,
63 | "metadata": {},
64 | "outputs": [],
65 | "source": [
66 | "target = dataset[\"income\"]\n",
67 | "\n",
68 | "# Split data into train and test\n",
69 | "datasetX = dataset.drop(\"income\", axis=1)\n",
70 | "x_train, x_test, y_train, y_test = train_test_split(datasetX,\n",
71 | " target,\n",
72 | " test_size=0.2,\n",
73 | " random_state=0,\n",
74 | " stratify=target)\n",
75 | "\n",
76 | "numerical = [\"age\", \"hours_per_week\"]\n",
77 | "categorical = x_train.columns.difference(numerical)\n",
78 | "\n",
79 | "# We create the preprocessing pipelines for both numeric and categorical data.\n",
80 | "numeric_transformer = Pipeline(steps=[\n",
81 | " ('scaler', StandardScaler())])\n",
82 | "\n",
83 | "categorical_transformer = Pipeline(steps=[\n",
84 | " ('onehot', OneHotEncoder(handle_unknown='ignore'))])\n",
85 | "\n",
86 | "transformations = ColumnTransformer(\n",
87 | " transformers=[\n",
88 | " ('num', numeric_transformer, numerical),\n",
89 | " ('cat', categorical_transformer, categorical)])\n",
90 | "\n",
91 | "# Append classifier to preprocessing pipeline.\n",
92 | "# Now we have a full prediction pipeline.\n",
93 | "clf = Pipeline(steps=[('preprocessor', transformations),\n",
94 | " ('classifier', RandomForestClassifier())])\n",
95 | "model = clf.fit(x_train, y_train)"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')\n",
105 | "m = dice_ml.Model(model=model, backend=\"sklearn\")"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "metadata": {},
111 | "source": [
112 | "## Local feature importance\n",
113 | "\n",
114 | "We first generate counterfactuals for a given input point. "
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "exp = Dice(d, m, method=\"random\")\n",
124 | "query_instance = x_train[1:2]\n",
125 | "e1 = exp.generate_counterfactuals(query_instance, total_CFs=10, desired_range=None,\n",
126 | " desired_class=\"opposite\",\n",
127 | " permitted_range=None, features_to_vary=\"all\")\n",
128 | "e1.visualize_as_dataframe(show_only_changes=True)"
129 | ]
130 | },
131 | {
132 | "cell_type": "markdown",
133 | "metadata": {},
134 | "source": [
135 | "These can now be used to calculate the feature importance scores. "
136 | ]
137 | },
138 | {
139 | "cell_type": "code",
140 | "execution_count": null,
141 | "metadata": {},
142 | "outputs": [],
143 | "source": [
144 | "imp = exp.local_feature_importance(query_instance, cf_examples_list=e1.cf_examples_list)\n",
145 | "print(imp.local_importance)"
146 | ]
147 | },
148 | {
149 | "cell_type": "markdown",
150 | "metadata": {},
151 | "source": [
152 | "Feature importance can also be estimated directly, by leaving the `cf_examples_list` argument blank."
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "metadata": {},
159 | "outputs": [],
160 | "source": [
161 | "imp = exp.local_feature_importance(query_instance, posthoc_sparsity_param=None)\n",
162 | "print(imp.local_importance)"
163 | ]
164 | },
165 | {
166 | "cell_type": "markdown",
167 | "metadata": {},
168 | "source": [
169 | "## Global importance\n",
170 | "\n",
171 | "For global importance, we need to generate counterfactuals for a representative sample of the dataset. "
172 | ]
173 | },
174 | {
175 | "cell_type": "code",
176 | "execution_count": null,
177 | "metadata": {},
178 | "outputs": [],
179 | "source": [
180 | "cobj = exp.global_feature_importance(x_train[0:10], total_CFs=10, posthoc_sparsity_param=None)\n",
181 | "print(cobj.summary_importance)"
182 | ]
183 | },
184 | {
185 | "cell_type": "markdown",
186 | "metadata": {},
187 | "source": [
188 | "## Convert the counterfactual output to json"
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "execution_count": null,
194 | "metadata": {},
195 | "outputs": [],
196 | "source": [
197 | "json_str = cobj.to_json()\n",
198 | "print(json_str)"
199 | ]
200 | },
201 | {
202 | "cell_type": "markdown",
203 | "metadata": {},
204 | "source": [
205 | "## Convert the json output to a counterfactual object"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": null,
211 | "metadata": {},
212 | "outputs": [],
213 | "source": [
214 | "imp_r = imp.from_json(json_str)\n",
215 | "print([o.visualize_as_dataframe(show_only_changes=True) for o in imp_r.cf_examples_list])\n",
216 | "print(imp_r.local_importance)\n",
217 | "print(imp_r.summary_importance)"
218 | ]
219 | }
220 | ],
221 | "metadata": {
222 | "kernelspec": {
223 | "display_name": "Python 3",
224 | "language": "python",
225 | "name": "python3"
226 | },
227 | "language_info": {
228 | "codemirror_mode": {
229 | "name": "ipython",
230 | "version": 3
231 | },
232 | "file_extension": ".py",
233 | "mimetype": "text/x-python",
234 | "name": "python",
235 | "nbconvert_exporter": "python",
236 | "pygments_lexer": "ipython3",
237 | "version": "3.6.12"
238 | }
239 | },
240 | "nbformat": 4,
241 | "nbformat_minor": 4
242 | }
243 |
--------------------------------------------------------------------------------
/docs/source/notebooks/images/dice_getting_started_api.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/source/notebooks/images/dice_getting_started_api.png
--------------------------------------------------------------------------------
/docs/source/notebooks/nb_index.rst:
--------------------------------------------------------------------------------
1 | .. notebooks to be included in dice docs
2 |
3 | Example notebooks
4 |
5 | .. toctree::
6 | :maxdepth: 2
7 | :caption: Notebooks:
8 |
9 | DiCE_getting_started
10 | DiCE_feature_importances.ipynb
11 | DiCE_multiclass_classification_and_regression.ipynb
12 | DiCE_model_agnostic_CFs.ipynb
13 | DiCE_with_private_data
14 | DiCE_with_advanced_options
15 | DiCE_getting_started_feasible
16 |
17 |
--------------------------------------------------------------------------------
/docs/source/readme.rst:
--------------------------------------------------------------------------------
1 | .. include:: ../../README.rst
2 |
3 |
--------------------------------------------------------------------------------
/docs/update_docs.sh:
--------------------------------------------------------------------------------
1 | sphinx-apidoc -f -o source ../dice_ml
2 |
--------------------------------------------------------------------------------
/environment-deeplearning.yml:
--------------------------------------------------------------------------------
1 | name: example-environment
2 |
3 | dependencies:
4 | - tensorflow>=1.13.0-rc1
5 | - pytorch
6 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: example-environment
2 |
3 | dependencies:
4 | - numpy
5 | - scikit-learn
6 | - pandas
7 | - jsonschema
8 | - tqdm
9 | - pip:
10 | - raiutils
11 |
--------------------------------------------------------------------------------
/requirements-deeplearning.txt:
--------------------------------------------------------------------------------
1 | tensorflow>=1.13.1
2 | torch
3 |
--------------------------------------------------------------------------------
/requirements-linting.txt:
--------------------------------------------------------------------------------
1 | flake8
2 | flake8-bugbear
3 | flake8-blind-except
4 | flake8-breakpoint
5 | flake8-builtins
6 | flake8-logging-format
7 | flake8-pytest-style
8 | flake8-all-not-strings
9 | isort
10 | packaging
11 |
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | ipython
2 | jupyter
3 | pytest
4 | pytest-cov
5 | twine
6 | pytest-mock
7 | rai_test_utils
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | jsonschema
2 | numpy # if you are using tensorflow 1.x, it requires numpy<=1.16
3 | pandas>=2.0.0
4 | scikit-learn
5 | tqdm
6 | raiutils>=0.4.0
7 | xgboost # if you are using xgboost
8 | lightgbm # if you are using lightgbm
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.rst
3 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | VERSION_STR = "0.11"
4 |
5 | with open("README.rst", "r") as fh:
6 | long_description = fh.read()
7 |
8 | # Get the required packages
9 | with open('requirements.txt', encoding='utf-8') as f:
10 | install_requires = f.read().splitlines()
11 |
12 | # Deep learning packages are optional to install
13 | extras = ["deeplearning"]
14 | extras_require = dict()
15 | for e in extras:
16 | req_file = "requirements-{0}.txt".format(e)
17 | with open(req_file) as f:
18 | extras_require[e] = [line.strip() for line in f]
19 |
20 | setuptools.setup(
21 | name="dice_ml",
22 | version=VERSION_STR,
23 | license="MIT",
24 | author="Ramaravind Mothilal, Amit Sharma, Chenhao Tan",
25 | author_email="raam.arvind93@gmail.com",
26 | description="Generate Diverse Counterfactual Explanations for any machine learning model.",
27 | long_description=long_description,
28 | long_description_content_type="text/x-rst",
29 | url="https://github.com/interpretml/DiCE",
30 | download_url="https://github.com/interpretml/DiCE/archive/v"+VERSION_STR+".tar.gz",
31 | python_requires='>=3.9',
32 | packages=setuptools.find_packages(exclude=['tests*']),
33 | classifiers=[
34 | "Intended Audience :: Developers",
35 | "Intended Audience :: Science/Research",
36 | "Programming Language :: Python :: 3",
37 | "Programming Language :: Python :: 3.9",
38 | "Programming Language :: Python :: 3.10",
39 | "Programming Language :: Python :: 3.11",
40 | "License :: OSI Approved :: MIT License",
41 | "Operating System :: OS Independent",
42 | ],
43 | keywords='machine-learning explanation interpretability counterfactual',
44 | install_requires=install_requires,
45 | extras_require=extras_require,
46 | include_package_data=True,
47 | package_data={
48 | # If any package contains *.h5 files, include them:
49 | '': ['*.h5',
50 | 'counterfactual_explanations_v1.0.json',
51 | 'counterfactual_explanations_v2.0.json']
52 | }
53 | )
54 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_data.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import dice_ml
4 |
5 |
6 | def test_data_initiation(public_data_object, private_data_object):
7 | assert isinstance(public_data_object, dice_ml.data_interfaces.public_data_interface.PublicData), \
8 | "the given parameters should instantiate PublicData class"
9 |
10 | assert isinstance(private_data_object, dice_ml.data_interfaces.private_data_interface.PrivateData), \
11 | "the given parameters should instantiate PrivateData class"
12 |
13 |
14 | class TestCommonDataMethods:
15 | """
16 | Test methods common to Public and Private data interfaces modules.
17 | """
18 | @pytest.fixture(autouse=True)
19 | def _get_data_object(self, public_data_object, private_data_object):
20 | self.d = [public_data_object, private_data_object]
21 |
22 | def test_get_valid_mads(self):
23 | # public data
24 | for normalized in [True, False]:
25 | mads = self.d[0].get_valid_mads(normalized=normalized, display_warnings=False, return_mads=True)
26 | # mads denotes variability in features and should be positive for DiCE.
27 | assert all(mads[feature] > 0 for feature in mads)
28 |
29 | if normalized:
30 | min_value = 0
31 | max_value = 1
32 |
33 | errors = 0
34 | for feature in mads:
35 | if not normalized:
36 | min_value = self.d[0].data_df[feature].min()
37 | max_value = self.d[0].data_df[feature].max()
38 |
39 | if mads[feature] > max_value - min_value:
40 | errors += 1
41 | assert errors == 0 # mads can't be larger than the feature range
42 |
43 | # private data
44 | for normalized in [True, False]:
45 | mads = self.d[1].get_valid_mads(normalized=normalized, display_warnings=False, return_mads=True)
46 | # no mad is provided for private data by default, so a practical alternative is keeping all value at 1.
47 | # Check get_valid_mads() in data interface classes for more info.
48 | assert all(mads[feature] == 1 for feature in mads)
49 |
50 | # @pytest.mark.parametrize(
51 | # "encode_categorical, output_query",
52 | # [('label', [0.068, (2/3), (3/7), (3/4), (4/5), 1.0, 0.0, 0.449]),
53 | # ('one-hot', [0.068, 0.449, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0,
54 | # 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0])]
55 | # )
56 | # def test_prepare_query_instance(self, sample_adultincome_query, encode_categorical, output_query):
57 | # """
58 | # Tests prepare_query_instance method that covnerts continuous features into [0,1] range and
59 | # one-hot encodes categorical features.
60 | # """
61 | # for d in self.d:
62 | # prepared_query = d.prepare_query_instance(
63 | # query_instance=sample_adultincome_query, encoding=encode_categorical).iloc[0].tolist()
64 | # assert output_query == pytest.approx(prepared_query, abs=1e-3)
65 |
66 | # TODO: add separate test method for label-encoded data
67 | def test_ohe_min_max_transformed_query_instance(self, sample_adultincome_query):
68 | """
69 | Tests method that converts continuous features into [0,1] range and one-hot encodes categorical features.
70 | """
71 | output_query = [0.068, 0.449, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,
72 | 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0]
73 | d = self.d[0]
74 | prepared_query = d.get_ohe_min_max_normalized_data(query_instance=sample_adultincome_query).iloc[0].tolist()
75 | assert output_query == pytest.approx(prepared_query, abs=1e-3)
76 |
77 | def test_encoded_categorical_features(self):
78 | """
79 | Tests if correct encoded categorical feature indexes are returned. Should work even when feature names
80 | are starting with common names.
81 | TODO also test for private data interface
82 | """
83 | res = []
84 | d = self.d[0]
85 | # for d in self.d:
86 | # d.categorical_feature_names = ['cat1', 'cat2']
87 | # d.continuous_feature_names = ['cat2_cont1', 'cont2']
88 | # d.encoded_feature_names = ['cat2_cont1', 'cont2', 'cat1_val1', 'cat1_val2', 'cat2_val1', 'cat2_val2']
89 | print(d.data_df)
90 | temp_ohe_data = d.get_ohe_min_max_normalized_data(d.data_df.iloc[[0]])
91 | d.create_ohe_params(temp_ohe_data)
92 | res.append(d.get_encoded_categorical_feature_indexes())
93 | assert [2, 3, 4, 5] == res[0][0] # there are 4 types of workclass
94 | assert len(res[0][1]) == 8 # eight types of education
95 | assert len(res[0][-1]) == 2 # two types of gender in the data
96 | # 2,3,4,5 are correct indexes of encoded categorical features and the data object's method should not
97 | # return the first continuous feature that starts with the same name. Returned value should be a list of lists.
98 |
99 | def test_features_to_vary(self):
100 | """
101 | Tests if correct indexes of features are returned. Should work even when feature names are starting with common names.
102 |
103 | TODO: also make it work for private_data_interface
104 | """
105 | res = []
106 | d = self.d[0]
107 | temp_ohe_data = d.get_ohe_min_max_normalized_data(d.data_df.iloc[[0]])
108 | d.create_ohe_params(temp_ohe_data)
109 | # d.create_ohe_params()
110 | # d.categorical_feature_names = ['cat1', 'cat2']
111 | # d.encoded_feature_names = ['cat2_cont1', 'cont2', 'cat1_val1', 'cat1_val2', 'cat2_val1', 'cat2_val2']
112 | # d.continuous_feature_names = ['cat2_cont1', 'cont2']
113 | res.append(d.get_indexes_of_features_to_vary(features_to_vary=['workclass']))
114 | # 4,5 are correct indexes of features that can be varied and the data object's method should not return
115 | # the first continuous feature that starts with the same name.
116 | assert [2, 3, 4, 5] == res[0]
117 |
--------------------------------------------------------------------------------
/tests/test_data_interface/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/tests/test_data_interface/__init__.py
--------------------------------------------------------------------------------
/tests/test_data_interface/test_base_data_interface.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from dice_ml.data_interfaces.base_data_interface import _BaseData
4 |
5 |
6 | class TestBaseData:
7 | def test_base_data_initialization(self):
8 | with pytest.raises(TypeError):
9 | _BaseData({})
10 |
--------------------------------------------------------------------------------
/tests/test_data_interface/test_private_data_interface.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import pytest
4 |
5 | import dice_ml
6 |
7 |
8 | @pytest.fixture(scope='session')
9 | def data_object():
10 | features_dict = OrderedDict(
11 | [('age', [17, 90]),
12 | ('workclass', ['Government', 'Other/Unknown', 'Private', 'Self-Employed']),
13 | ('education', ['Assoc', 'Bachelors', 'Doctorate', 'HS-grad', 'Masters', 'Prof-school', 'School', 'Some-college']),
14 | ('marital_status', ['Divorced', 'Married', 'Separated', 'Single', 'Widowed']),
15 | ('occupation', ['Blue-Collar', 'Other/Unknown', 'Professional', 'Sales', 'Service', 'White-Collar']),
16 | ('race', ['Other', 'White']),
17 | ('gender', ['Female', 'Male']),
18 | ('hours_per_week', [1, 99])]
19 | ) # providing an OrderedDict to make it work for Python<3.6
20 | return dice_ml.Data(features=features_dict, outcome_name='income',
21 | type_and_precision={'hours_per_week': ['float', 2]}, mad={'age': 10})
22 |
23 |
24 | class TestPrivateDataMethods:
25 | @pytest.fixture(autouse=True)
26 | def _get_data_object(self, data_object):
27 | self.d = data_object
28 |
29 | def test_mads(self):
30 | # normalized=True is already tested in test_data.py
31 | mads = self.d.get_valid_mads(normalized=False)
32 | # 10 is given as the mad of feature 'age' while initiating private Data object; 1.0 is the default value.
33 | # Check get_valid_mads() in private_data_interface for more info.
34 | assert list(mads.values()) == [10.0, 1.0]
35 |
36 | def test_feature_precision(self):
37 | # feature precision decides the least change that can be made to the feature in optimization,
38 | # given as 2-decimal place for 'hours_per_week' feature while initiating private Data object.
39 | assert self.d.get_decimal_precisions()[1] == 2
40 |
--------------------------------------------------------------------------------
/tests/test_dice.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from raiutils.exceptions import UserConfigValidationException
3 |
4 | import dice_ml
5 | from dice_ml.utils import helpers
6 |
7 |
8 | class TestBaseExplainerLoader:
9 | def _get_exp(self, backend, method="random", is_public_data_interface=True):
10 | if is_public_data_interface:
11 | dataset = helpers.load_adult_income_dataset()
12 | d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')
13 | else:
14 | d = dice_ml.Data(features={
15 | 'age': [17, 90],
16 | 'workclass': ['Government', 'Other/Unknown', 'Private', 'Self-Employed'],
17 | 'education': ['Assoc', 'Bachelors', 'Doctorate', 'HS-grad', 'Masters',
18 | 'Prof-school', 'School', 'Some-college'],
19 | 'marital_status': ['Divorced', 'Married', 'Separated', 'Single', 'Widowed'],
20 | 'occupation': ['Blue-Collar', 'Other/Unknown', 'Professional', 'Sales', 'Service', 'White-Collar'],
21 | 'race': ['Other', 'White'],
22 | 'gender': ['Female', 'Male'],
23 | 'hours_per_week': [1, 99]},
24 | outcome_name='income')
25 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
26 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func="ohe-min-max")
27 | exp = dice_ml.Dice(d, m, method=method)
28 | return exp
29 |
30 | def test_tf(self):
31 | tf = pytest.importorskip("tensorflow")
32 | backend = 'TF'+tf.__version__[0]
33 | exp = self._get_exp(backend, method="gradient")
34 | assert issubclass(type(exp), dice_ml.explainer_interfaces.explainer_base.ExplainerBase)
35 | assert isinstance(exp, dice_ml.explainer_interfaces.dice_tensorflow2.DiceTensorFlow2) or \
36 | isinstance(exp, dice_ml.explainer_interfaces.dice_tensorflow1.DiceTensorFlow1)
37 |
38 | def test_pyt(self):
39 | pytest.importorskip("torch")
40 | backend = 'PYT'
41 | exp = self._get_exp(backend, method="gradient")
42 | assert issubclass(type(exp), dice_ml.explainer_interfaces.explainer_base.ExplainerBase)
43 | assert isinstance(exp, dice_ml.explainer_interfaces.dice_pytorch.DicePyTorch)
44 |
45 | @pytest.mark.skip(reason="Need to fix this test")
46 | @pytest.mark.parametrize('method', ['random'])
47 | def test_sklearn(self, method):
48 | pytest.importorskip("sklearn")
49 | backend = 'sklearn'
50 | exp = self._get_exp(backend, method=method)
51 | assert issubclass(type(exp), dice_ml.explainer_interfaces.explainer_base.ExplainerBase)
52 | assert isinstance(exp, dice_ml.explainer_interfaces.dice_random.DiceRandom)
53 |
54 | @pytest.mark.skip(reason="Need to fix this test")
55 | def test_minimum_query_instances(self):
56 | pytest.importorskip('sklearn')
57 | backend = 'sklearn'
58 | exp = self._get_exp(backend)
59 | query_instances = helpers.load_adult_income_dataset().drop("income", axis=1)[0:1]
60 | with pytest.raises(UserConfigValidationException):
61 | exp.global_feature_importance(query_instances)
62 |
63 | def test_unsupported_sampling_strategy(self):
64 | pytest.importorskip('sklearn')
65 | backend = 'sklearn'
66 | with pytest.raises(UserConfigValidationException):
67 | self._get_exp(backend, method="unsupported")
68 |
69 | def test_private_data_interface_dice_kdtree(self):
70 | pytest.importorskip("sklearn")
71 | backend = 'sklearn'
72 | with pytest.raises(UserConfigValidationException) as ucve:
73 | self._get_exp(backend, method='kdtree', is_public_data_interface=False)
74 |
75 | assert 'Private data interface is not supported with kdtree explainer' + \
76 | ' since kdtree explainer needs access to entire training data' in str(ucve)
77 |
--------------------------------------------------------------------------------
/tests/test_dice_interface/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/tests/test_dice_interface/__init__.py
--------------------------------------------------------------------------------
/tests/test_dice_interface/test_dice_pytorch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import dice_ml
5 | from dice_ml.counterfactual_explanations import CounterfactualExplanations
6 | from dice_ml.utils import helpers
7 |
8 | torch = pytest.importorskip("torch")
9 |
10 |
11 | @pytest.fixture(scope='session')
12 | def pyt_exp_object():
13 | backend = 'PYT'
14 | dataset = helpers.load_adult_income_dataset()
15 | d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')
16 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
17 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func="ohe-min-max")
18 | exp = dice_ml.Dice(d, m, method="gradient")
19 | return exp
20 |
21 |
22 | class TestDiceTorchMethods:
23 | @pytest.fixture(autouse=True)
24 | def _initiate_exp_object(self, pyt_exp_object, sample_adultincome_query):
25 | self.exp = pyt_exp_object # explainer object
26 | # initialize required params for CF computations
27 | self.exp.do_cf_initializations(total_CFs=4, algorithm="DiverseCF", features_to_vary="all")
28 |
29 | # prepare query isntance for CF optimization
30 | # query_instance = self.exp.data_interface.prepare_query_instance(
31 | # query_instance=sample_adultincome_query, encoding='one-hot')
32 | # self.query_instance = query_instance.iloc[0].values
33 | self.query_instance = self.exp.data_interface.get_ohe_min_max_normalized_data(
34 | sample_adultincome_query).iloc[0].to_numpy(dtype=np.float64)
35 |
36 | self.exp.initialize_CFs(self.query_instance, init_near_query_instance=True) # initialize CFs
37 | self.exp.target_cf_class = torch.tensor(1).float() # set desired class to 1
38 |
39 | # setting random feature weights
40 | np.random.seed(42)
41 | weights = np.random.rand(len(self.exp.data_interface.ohe_encoded_feature_names))
42 | self.exp.feature_weights_list = torch.tensor(weights)
43 |
44 | @pytest.mark.parametrize(("yloss", "output"), [("hinge_loss", 10.8443), ("l2_loss", 0.9999), ("log_loss", 9.8443)])
45 | def test_yloss(self, yloss, output):
46 | self.exp.yloss_type = yloss
47 | loss1 = self.exp.compute_yloss()
48 | assert pytest.approx(loss1.data.detach().numpy(), abs=1e-4) == output
49 |
50 | def test_proximity_loss(self):
51 | self.exp.x1 = torch.tensor(self.query_instance)
52 | loss2 = self.exp.compute_proximity_loss()
53 | # proximity loss computed for given query instance and feature weights.
54 | assert pytest.approx(loss2.data.detach().numpy(), abs=1e-4) == 0.0068
55 |
56 | @pytest.mark.parametrize(("diversity_loss", "output"), [("dpp_style:inverse_dist", 0.0104), ("avg_dist", 0.1743)])
57 | def test_diversity_loss(self, diversity_loss, output):
58 | self.exp.diversity_loss_type = diversity_loss
59 | loss3 = self.exp.compute_diversity_loss()
60 | assert pytest.approx(loss3.data.detach().numpy(), abs=1e-4) == output
61 |
62 | def test_regularization_loss(self):
63 | loss4 = self.exp.compute_regularization_loss()
64 | # regularization loss computed for given query instance and feature weights.
65 | assert pytest.approx(loss4.data.detach().numpy(), abs=1e-4) == 0.2086
66 |
67 | def test_final_cfs_and_preds(self, sample_adultincome_query):
68 | """
69 | Tets correctness of final CFs and their predictions for sample query instance.
70 | """
71 | counterfactual_explanations = self.exp.generate_counterfactuals(
72 | sample_adultincome_query, total_CFs=4, desired_class="opposite")
73 | assert isinstance(counterfactual_explanations, CounterfactualExplanations)
74 | # test_cfs = [[72.0, 'Private', 'HS-grad', 'Married', 'White-Collar', 'White', 'Female', 45.0, 0.691],
75 | # [29.0, 'Private', 'Prof-school', 'Married', 'Service', 'White', 'Male', 45.0, 0.954],
76 | # [52.0, 'Private', 'Doctorate', 'Married', 'Service', 'White', 'Female', 45.0, 0.971],
77 | # [47.0, 'Private', 'Masters', 'Married', 'Service', 'White', 'Female', 73.0, 0.971]]
78 | # TODO The model predictions changed after update to posthoc sparsity. Need to investigate.
79 | # assert dice_exp.final_cfs_df_sparse.values.tolist() == test_cfs
80 |
--------------------------------------------------------------------------------
/tests/test_dice_interface/test_dice_tensorflow.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import dice_ml
5 | from dice_ml.counterfactual_explanations import CounterfactualExplanations
6 | from dice_ml.utils import helpers
7 |
8 | tf = pytest.importorskip("tensorflow")
9 |
10 |
11 | @pytest.fixture(scope='session')
12 | def tf_exp_object():
13 | backend = 'TF'+tf.__version__[0]
14 | dataset = helpers.load_adult_income_dataset()
15 | d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')
16 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
17 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func="ohe-min-max")
18 | exp = dice_ml.Dice(d, m, method="gradient")
19 | return exp
20 |
21 |
22 | class TestDiceTensorFlowMethods:
23 | @pytest.fixture(autouse=True)
24 | def _initiate_exp_object(self, tf_exp_object, sample_adultincome_query):
25 | self.exp = tf_exp_object # explainer object
26 | # initialize required params for CF computations
27 | self.exp.do_cf_initializations(total_CFs=4, algorithm="DiverseCF", features_to_vary="all")
28 |
29 | # prepare query isntance for CF optimization
30 | # query_instance = self.exp.data_interface.prepare_query_instance(
31 | # query_instance=sample_adultincome_query, encoding='one-hot')
32 | # self.query_instance = np.array([query_instance.iloc[0].values], dtype=np.float32)
33 | self.query_instance = self.exp.data_interface.get_ohe_min_max_normalized_data(sample_adultincome_query).values
34 |
35 | init_arrs = self.exp.initialize_CFs(self.query_instance, init_near_query_instance=True) # initialize CFs
36 | self.desired_class = 1 # desired class is 1
37 |
38 | # setting random feature weights
39 | np.random.seed(42)
40 | weights = np.random.rand(len(self.exp.data_interface.ohe_encoded_feature_names))
41 | weights = np.array([weights], dtype=np.float32)
42 | if tf.__version__[0] == '1':
43 | for i in range(4):
44 | self.exp.dice_sess.run(self.exp.cf_assign[i], feed_dict={self.exp.cf_init: init_arrs[i]})
45 | self.exp.feature_weights = tf.Variable(self.exp.minx, dtype=tf.float32)
46 | self.exp.dice_sess.run(tf.assign(self.exp.feature_weights, weights))
47 | else:
48 | self.exp.feature_weights_list = tf.constant([weights], dtype=tf.float32)
49 |
50 | @pytest.mark.parametrize(("yloss", "output"), [("hinge_loss", 4.6711), ("l2_loss", 0.9501), ("log_loss", 3.6968)])
51 | def test_yloss(self, yloss, output):
52 | if tf.__version__[0] == '1':
53 | loss1 = self.exp.compute_yloss(method=yloss)
54 | loss1 = self.exp.dice_sess.run(loss1, feed_dict={self.exp.target_cf: np.array([[1]])})
55 | else:
56 | self.exp.target_cf_class = np.array([[self.desired_class]], dtype=np.float32)
57 | self.exp.yloss_type = yloss
58 | loss1 = self.exp.compute_yloss().numpy()
59 | assert pytest.approx(loss1, abs=1e-4) == output
60 |
61 | def test_proximity_loss(self):
62 | if tf.__version__[0] == '1':
63 | loss2 = self.exp.compute_proximity_loss()
64 | loss2 = self.exp.dice_sess.run(loss2, feed_dict={self.exp.x1: self.query_instance})
65 | else:
66 | self.exp.x1 = tf.constant(self.query_instance, dtype=tf.float32)
67 | loss2 = self.exp.compute_proximity_loss().numpy()
68 | # proximity loss computed for given query instance and feature weights.
69 | assert pytest.approx(loss2, abs=1e-4) == 0.0068
70 |
71 | @pytest.mark.parametrize(("diversity_loss", "output"), [("dpp_style:inverse_dist", 0.0104), ("avg_dist", 0.1743)])
72 | def test_diversity_loss(self, diversity_loss, output):
73 | if tf.__version__[0] == '1':
74 | loss3 = self.exp.compute_diversity_loss(diversity_loss)
75 | loss3 = self.exp.dice_sess.run(loss3)
76 | else:
77 | self.exp.diversity_loss_type = diversity_loss
78 | loss3 = self.exp.compute_diversity_loss().numpy()
79 | assert pytest.approx(loss3, abs=1e-4) == output
80 |
81 | def test_regularization_loss(self):
82 | loss4 = self.exp.compute_regularization_loss()
83 | if tf.__version__[0] == '1':
84 | loss4 = self.exp.dice_sess.run(loss4)
85 | else:
86 | loss4 = loss4.numpy()
87 | # regularization loss computed for given query instance and feature weights.
88 | assert pytest.approx(loss4, abs=1e-4) == 0.2086
89 |
90 | def test_final_cfs_and_preds(self, sample_adultincome_query):
91 | """
92 | Tets correctness of final CFs and their predictions for sample query instance.
93 | """
94 | counterfactual_explanations = self.exp.generate_counterfactuals(
95 | sample_adultincome_query, total_CFs=4, desired_class="opposite")
96 | assert isinstance(counterfactual_explanations, CounterfactualExplanations)
97 | # test_cfs = [[70.0, 'Private', 'Masters', 'Single', 'White-Collar', 'White', 'Female', 51.0, 0.534],
98 | # [22.0, 'Self-Employed', 'Doctorate', 'Married', 'Service', 'White', 'Female', 45.0, 0.861],
99 | # [47.0, 'Private', 'HS-grad', 'Married', 'Service', 'White', 'Female', 45.0, 0.589],
100 | # [36.0, 'Private', 'Prof-school', 'Married', 'Service', 'White', 'Female', 62.0, 0.937]]
101 | # TODO The model predictions changed after update to posthoc sparsity. Need to investigate.
102 | # assert dice_exp.final_cfs_df_sparse.values.tolist() == test_cfs
103 |
--------------------------------------------------------------------------------
/tests/test_helpers.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | from dice_ml.utils.helpers import load_adult_income_dataset
4 |
5 |
6 | class TestHelpers:
7 | def test_load_adult_income_dataset(self):
8 | adult_data = load_adult_income_dataset()
9 | assert adult_data is not None
10 | assert isinstance(adult_data, pd.DataFrame)
11 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from rai_test_utils.datasets.tabular import create_iris_data
3 | from raiutils.exceptions import UserConfigValidationException
4 | from sklearn.ensemble import RandomForestClassifier
5 |
6 | import dice_ml
7 | from dice_ml.utils import helpers
8 |
9 |
10 | class TestBaseModelLoader:
11 | def _get_model(self, backend):
12 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
13 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
14 | return m
15 |
16 | def test_tf(self):
17 | tf = pytest.importorskip("tensorflow")
18 | backend = 'TF'+tf.__version__[0]
19 | m = self._get_model(backend)
20 | assert issubclass(type(m), dice_ml.model_interfaces.base_model.BaseModel)
21 | assert isinstance(m, dice_ml.model_interfaces.keras_tensorflow_model.KerasTensorFlowModel)
22 |
23 | def test_pyt(self):
24 | pytest.importorskip("torch")
25 | backend = 'PYT'
26 | m = self._get_model(backend)
27 | assert issubclass(type(m), dice_ml.model_interfaces.base_model.BaseModel)
28 | assert isinstance(m, dice_ml.model_interfaces.pytorch_model.PyTorchModel)
29 |
30 | def test_sklearn(self):
31 | pytest.importorskip("sklearn")
32 | backend = 'sklearn'
33 | m = self._get_model(backend)
34 | assert isinstance(m, dice_ml.model_interfaces.base_model.BaseModel)
35 |
36 |
37 | class TestModelUserValidations:
38 |
39 | def create_sklearn_random_forest_classifier(self, X, y):
40 | rfc = RandomForestClassifier(n_estimators=10, max_depth=4,
41 | random_state=777)
42 | model = rfc.fit(X, y)
43 | return model
44 |
45 | def test_model_user_validation_model_type(self):
46 | x_train, x_test, y_train, y_test, feature_names, classes = \
47 | create_iris_data()
48 | trained_model = self.create_sklearn_random_forest_classifier(x_train, y_train)
49 |
50 | assert dice_ml.Model(model=trained_model, backend='sklearn', model_type='classifier') is not None
51 | assert dice_ml.Model(model=trained_model, backend='sklearn', model_type='regressor') is not None
52 |
53 | with pytest.raises(UserConfigValidationException):
54 | dice_ml.Model(model=trained_model, backend='sklearn', model_type='random')
55 |
56 | def test_model_user_validation_no_valid_model(self):
57 | with pytest.raises(
58 | ValueError,
59 | match="should provide either a trained model or the path to a model"):
60 | dice_ml.Model(backend='sklearn')
61 |
--------------------------------------------------------------------------------
/tests/test_model_interface/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/tests/test_model_interface/__init__.py
--------------------------------------------------------------------------------
/tests/test_model_interface/test_base_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from rai_test_utils.datasets.tabular import (create_housing_data,
4 | create_iris_data)
5 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
6 |
7 | import dice_ml
8 | from dice_ml.utils.exception import SystemException
9 |
10 |
11 | class TestModelClassification:
12 |
13 | def create_sklearn_random_forest_classifier(self, X, y):
14 | rfc = RandomForestClassifier(n_estimators=10, max_depth=4,
15 | random_state=777)
16 | model = rfc.fit(X, y)
17 | return model
18 |
19 | def test_base_model_classification(self):
20 | x_train, x_test, y_train, y_test, feature_names, classes = \
21 | create_iris_data()
22 | trained_model = self.create_sklearn_random_forest_classifier(x_train, y_train)
23 |
24 | diceml_model = dice_ml.Model(model=trained_model, backend='sklearn')
25 | diceml_model.transformer.initialize_transform_func()
26 |
27 | assert diceml_model is not None
28 |
29 | prediction_probabilities = diceml_model.get_output(x_test)
30 | assert prediction_probabilities.shape[0] == x_test.shape[0]
31 | assert prediction_probabilities.shape[1] == len(classes)
32 |
33 | predictions = diceml_model.get_output(x_test, model_score=False).reshape(-1, 1)
34 | assert predictions.shape[0] == x_test.shape[0]
35 | assert predictions.shape[1] == 1
36 | assert np.all(np.unique(predictions) == np.unique(y_test))
37 |
38 | with pytest.raises(NotImplementedError):
39 | diceml_model.get_gradient()
40 |
41 | assert diceml_model.get_num_output_nodes2(x_test) == len(classes)
42 |
43 |
44 | class TestModelRegression:
45 |
46 | def create_sklearn_random_forest_regressor(self, X, y):
47 | rfc = RandomForestRegressor(n_estimators=10, max_depth=4,
48 | random_state=777)
49 | model = rfc.fit(X, y)
50 | return model
51 |
52 | def test_base_model_regression(self):
53 | x_train, x_test, y_train, y_test, feature_names = \
54 | create_housing_data()
55 | trained_model = self.create_sklearn_random_forest_regressor(x_train, y_train)
56 |
57 | diceml_model = dice_ml.Model(model=trained_model, model_type='regressor', backend='sklearn')
58 | diceml_model.transformer.initialize_transform_func()
59 |
60 | assert diceml_model is not None
61 |
62 | prediction_probabilities = diceml_model.get_output(x_test).reshape(-1, 1)
63 | assert prediction_probabilities.shape[0] == x_test.shape[0]
64 | assert prediction_probabilities.shape[1] == 1
65 |
66 | predictions = diceml_model.get_output(x_test, model_score=False).reshape(-1, 1)
67 | assert predictions.shape[0] == x_test.shape[0]
68 | assert predictions.shape[1] == 1
69 |
70 | with pytest.raises(NotImplementedError):
71 | diceml_model.get_gradient()
72 |
73 | with pytest.raises(SystemException):
74 | diceml_model.get_num_output_nodes2(x_test)
75 |
--------------------------------------------------------------------------------
/tests/test_model_interface/test_keras_tensorflow_model.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import dice_ml
4 | from dice_ml.utils import helpers
5 | from dice_ml.utils.helpers import DataTransfomer
6 |
7 | tf = pytest.importorskip("tensorflow")
8 |
9 |
10 | @pytest.fixture(scope='session')
11 | def tf_session():
12 | if tf.__version__[0] == '1':
13 | sess = tf.InteractiveSession()
14 | return sess
15 |
16 |
17 | @pytest.fixture(scope='session')
18 | def tf_model_object():
19 | backend = 'TF'+tf.__version__[0]
20 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
21 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func='ohe-min-max')
22 | return m
23 |
24 |
25 | def test_model_initiation(tf_model_object):
26 | assert isinstance(tf_model_object, dice_ml.model_interfaces.keras_tensorflow_model.KerasTensorFlowModel)
27 |
28 |
29 | def test_model_initiation_fullpath():
30 | """
31 | Tests if model is initiated when full path to a model and explainer class is given to backend parameter.
32 | """
33 | tf_version = tf.__version__[0]
34 | backend = {'model': 'keras_tensorflow_model.KerasTensorFlowModel',
35 | 'explainer': 'dice_tensorflow'+tf_version+'.DiceTensorFlow'+tf_version}
36 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
37 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
38 | assert isinstance(m, dice_ml.model_interfaces.keras_tensorflow_model.KerasTensorFlowModel)
39 |
40 |
41 | class TestKerasModelMethods:
42 | @pytest.fixture(autouse=True)
43 | def _get_model_object(self, tf_model_object, tf_session):
44 | self.m = tf_model_object
45 | self.sess = tf_session
46 |
47 | def test_load_model(self):
48 | self.m.load_model()
49 | assert self.m.model is not None
50 |
51 | # @pytest.mark.parametrize("input_instance, prediction",[(np.array([[0.5]*29], dtype=np.float32), 0.747)])
52 | # def test_model_output(self, input_instance, prediction):
53 | # self.m.load_model()
54 | # if tf.__version__[0] == '1':
55 | # input_instance_tf = tf.Variable(input_instance, dtype=tf.float32)
56 | # output_instance = self.m.get_output(input_instance_tf)
57 | # prediction = self.sess.run(output_instance, feed_dict={input_instance_tf:input_instance})[0][0]
58 | # else:
59 | # prediction = self.m.get_output(input_instance).numpy()[0][0]
60 | # pytest.approx(prediction, abs=1e-3) == prediction
61 |
62 | @pytest.mark.parametrize("prediction", [0.747])
63 | def test_model_output(self, sample_adultincome_query, public_data_object, prediction):
64 | # Initializing data and model objects
65 | self.m.load_model()
66 | # initializing data transformation required for ML model
67 | self.m.transformer = DataTransfomer(func='ohe-min-max', kw_args=None)
68 | self.m.transformer.feed_data_params(public_data_object)
69 | self.m.transformer.initialize_transform_func()
70 | output_instance = self.m.get_output(sample_adultincome_query, transform_data=True)
71 |
72 | if tf.__version__[0] == '1':
73 | predictval = self.sess.run(output_instance)[0][0]
74 | else:
75 | predictval = output_instance.numpy()[0][0]
76 | assert predictval is not None
77 | # TODO: The assert below fails.
78 | # assert pytest.approx(predictval, abs=1e-3) == prediction
79 |
--------------------------------------------------------------------------------
/tests/test_model_interface/test_pytorch_model.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import dice_ml
4 | from dice_ml.utils import helpers
5 | from dice_ml.utils.helpers import DataTransfomer
6 |
7 | pyt = pytest.importorskip("torch")
8 |
9 |
10 | @pytest.fixture(scope='session')
11 | def pyt_model_object():
12 | backend = 'PYT'
13 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
14 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func='ohe-min-max')
15 | return m
16 |
17 |
18 | def test_model_initiation(pyt_model_object):
19 | assert isinstance(pyt_model_object, dice_ml.model_interfaces.pytorch_model.PyTorchModel)
20 |
21 |
22 | def test_model_initiation_fullpath():
23 | """
24 | Tests if model is initiated when full path to a model and explainer class is given to backend parameter.
25 | """
26 | pytest.importorskip("torch")
27 | backend = {'model': 'pytorch_model.PyTorchModel',
28 | 'explainer': 'dice_pytorch.DicePyTorch'}
29 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend)
30 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend)
31 | assert isinstance(m, dice_ml.model_interfaces.pytorch_model.PyTorchModel)
32 |
33 |
34 | class TestPyTorchModelMethods:
35 | @pytest.fixture(autouse=True)
36 | def _get_model_object(self, pyt_model_object):
37 | self.m = pyt_model_object
38 |
39 | def test_load_model(self):
40 | self.m.load_model()
41 | assert self.m.model is not None
42 |
43 | # @pytest.mark.parametrize("input_instance, prediction",[(np.array([[0.5]*29], dtype=np.float32), 0.0957)])
44 | # def test_model_output(self, input_instance, prediction):
45 | # self.m.load_model()
46 | # input_instance_pyt = pyt.tensor(input_instance)
47 | # prediction = self.m.get_output(input_instance_pyt).detach().numpy()[0][0]
48 | # pytest.approx(prediction, abs=1e-3) == prediction
49 |
50 | @pytest.mark.parametrize("prediction", [0.0957])
51 | def test_model_output(self, sample_adultincome_query, public_data_object, prediction):
52 | # initializing data transormation required for ML model
53 | self.m.load_model()
54 | self.m.transformer = DataTransfomer(func='ohe-min-max', kw_args=None)
55 | self.m.transformer.feed_data_params(public_data_object)
56 | self.m.transformer.initialize_transform_func()
57 | output_instance = self.m.get_output(sample_adultincome_query, transform_data=True)
58 | predictval = output_instance[0][0]
59 | assert predictval is not None
60 | # TODO: The assert below fails.
61 | # assert pytest.approx(predictval, abs=1e-3) == prediction
62 |
--------------------------------------------------------------------------------
/tests/test_notebooks.py:
--------------------------------------------------------------------------------
1 | """ Code to automatically run all notebooks as a test.
2 | Adapted from the same code for the Microsoft DoWhy library.
3 | """
4 |
5 | import os
6 | import subprocess
7 | import sys
8 | import tempfile
9 |
10 | import nbformat
11 | import pytest
12 |
13 | NOTEBOOKS_PATH = "docs/source/notebooks/"
14 |
15 | # Adding the dice root folder to the python path so that jupyter notebooks
16 | if 'PYTHONPATH' not in os.environ:
17 | os.environ['PYTHONPATH'] = os.getcwd()
18 | elif os.getcwd() not in os.environ['PYTHONPATH'].split(os.pathsep):
19 | os.environ['PYTHONPATH'] = os.environ['PYTHONPATH'] + os.pathsep + os.getcwd()
20 |
21 |
22 | def get_notebook_parameter_list():
23 | notebooks_list = [f.name for f in os.scandir(NOTEBOOKS_PATH) if f.name.endswith(".ipynb")]
24 | # notebooks that should not be run
25 | advanced_notebooks = [
26 | "DiCE_with_advanced_options.ipynb", # requires tensorflow 1.x
27 | "DiCE_getting_started_feasible.ipynb", # needs changes after latest refactor
28 | "DiCE_with_private_data.ipynb", # needs compatible version of sklearn to load model
29 | "Benchmarking_different_CF_explanation_methods.ipynb"
30 | ]
31 | # notebooks that don't need to run on python 3.10
32 | torch_notebooks_not_3_10 = [
33 | "DiCE_getting_started.ipynb"
34 | ]
35 |
36 | # Creating the list of notebooks to run
37 | parameter_list = []
38 | for nb in notebooks_list:
39 | if nb in advanced_notebooks:
40 | param = pytest.param(
41 | nb,
42 | marks=[pytest.mark.skip, pytest.mark.advanced],
43 | id=nb)
44 | elif sys.version_info >= (3, 10) and nb in torch_notebooks_not_3_10:
45 | param = pytest.param(
46 | nb,
47 | marks=[pytest.mark.skip, pytest.mark.advanced],
48 | id=nb)
49 | else:
50 | param = pytest.param(nb, id=nb)
51 | parameter_list.append(param)
52 |
53 | return parameter_list
54 |
55 |
56 | def _check_notebook_cell_outputs(filepath):
57 | """Convert notebook via nbconvert, collect output and assert if any output cells are not empty.
58 |
59 | :param filepath: file path for the notebook
60 | """
61 | with tempfile.NamedTemporaryFile(suffix=".ipynb") as fout:
62 | args = ["jupyter", "nbconvert", "--to", "notebook",
63 | "-y", "--no-prompt",
64 | "--output", fout.name, filepath]
65 | subprocess.check_call(args)
66 | fout.seek(0)
67 | nb = nbformat.read(fout, nbformat.current_nbformat)
68 |
69 | for cell in nb.cells:
70 | if "outputs" in cell:
71 | if len(cell['outputs']) > 0:
72 | raise AssertionError("Output cell found in notebook. Please clean your notebook")
73 |
74 |
75 | def _notebook_run(filepath):
76 | """Execute a notebook via nbconvert and collect output.
77 |
78 | Source of this function: http://www.christianmoscardi.com/blog/2016/01/20/jupyter-testing.html
79 |
80 | :param filepath: file path for the notebook
81 | :returns List of execution errors
82 | """
83 | with tempfile.NamedTemporaryFile(suffix=".ipynb") as fout:
84 | args = ["jupyter", "nbconvert", "--to", "notebook", "--execute",
85 | # "--ExecutePreprocessor.timeout=600",
86 | "-y", "--no-prompt",
87 | "--output", fout.name, filepath]
88 | subprocess.check_call(args)
89 |
90 | fout.seek(0)
91 | nb = nbformat.read(fout, nbformat.current_nbformat)
92 |
93 | errors = [output for cell in nb.cells if "outputs" in cell
94 | for output in cell["outputs"]
95 | if output.output_type == "error"]
96 |
97 | return errors
98 |
99 |
100 | @pytest.mark.parametrize("notebook_filename", get_notebook_parameter_list())
101 | @pytest.mark.notebook_tests
102 | def test_notebook(notebook_filename):
103 | _check_notebook_cell_outputs(NOTEBOOKS_PATH + notebook_filename)
104 | errors = _notebook_run(NOTEBOOKS_PATH + notebook_filename)
105 | assert len(errors) == 0
106 |
--------------------------------------------------------------------------------