├── .flake8 ├── .github └── workflows │ ├── code-scan.yml │ ├── notebook-linting.yml │ ├── notebook-tests.yml │ ├── python-linting.yml │ ├── python-package-conda.yml │ ├── python-package.yml │ └── python-publish.yml ├── .gitignore ├── CODEOWNERS ├── LICENSE ├── MANIFEST.in ├── README.rst ├── dice_ml ├── __init__.py ├── constants.py ├── counterfactual_explanations.py ├── data.py ├── data_interfaces │ ├── __init__.py │ ├── base_data_interface.py │ ├── private_data_interface.py │ └── public_data_interface.py ├── dice.py ├── diverse_counterfactuals.py ├── explainer_interfaces │ ├── __init__.py │ ├── dice_KD.py │ ├── dice_genetic.py │ ├── dice_pytorch.py │ ├── dice_random.py │ ├── dice_tensorflow1.py │ ├── dice_tensorflow2.py │ ├── dice_xgboost.py │ ├── explainer_base.py │ ├── feasible_base_vae.py │ └── feasible_model_approx.py ├── model.py ├── model_interfaces │ ├── __init__.py │ ├── base_model.py │ ├── keras_tensorflow_model.py │ ├── pytorch_model.py │ └── xgboost_model.py ├── schema │ ├── __init__.py │ ├── counterfactual_explanations_v1.0.json │ └── counterfactual_explanations_v2.0.json └── utils │ ├── __init__.py │ ├── exception.py │ ├── helpers.py │ ├── neuralnetworks.py │ ├── sample_architecture │ ├── __init__.py │ └── vae_model.py │ ├── sample_trained_models │ ├── adult-margin-0.165-validity_reg-42.0-epoch-25-base-gen.pth │ ├── adult-margin-0.344-validity_reg-76.0-epoch-25-ae-gen.pth │ ├── adult.h5 │ ├── adult.pkl │ ├── adult.pth │ ├── adult_2nodes.pth │ ├── custom.sav │ ├── custom_binary.sav │ ├── custom_multiclass.sav │ ├── custom_regression.sav │ └── custom_vars.sav │ └── serialize.py ├── docs ├── .buildinfo ├── .nojekyll ├── Makefile ├── _images │ └── dice_getting_started_api.png ├── _modules │ ├── dice_ml │ │ ├── constants.html │ │ ├── counterfactual_explanations.html │ │ ├── data.html │ │ ├── data_interfaces │ │ │ ├── private_data_interface.html │ │ │ └── public_data_interface.html │ │ ├── dice.html │ │ ├── dice_interfaces │ │ │ ├── dice_base.html │ │ │ ├── dice_tensorflow1.html │ │ │ └── dice_tensorflow2.html │ │ ├── diverse_counterfactuals.html │ │ ├── explainer_interfaces │ │ │ ├── dice_KD.html │ │ │ ├── dice_genetic.html │ │ │ ├── dice_pytorch.html │ │ │ ├── dice_random.html │ │ │ ├── dice_tensorflow1.html │ │ │ ├── dice_tensorflow2.html │ │ │ ├── explainer_base.html │ │ │ ├── feasible_base_vae.html │ │ │ └── feasible_model_approx.html │ │ ├── model.html │ │ ├── model_interfaces │ │ │ ├── base_model.html │ │ │ ├── keras_tensorflow_model.html │ │ │ └── pytorch_model.html │ │ └── utils │ │ │ ├── exception.html │ │ │ ├── helpers.html │ │ │ ├── neuralnetworks.html │ │ │ ├── sample_architecture │ │ │ └── vae_model.html │ │ │ └── serialize.html │ └── index.html ├── _static │ ├── _sphinx_javascript_frameworks_compat.js │ ├── basic.css │ ├── css │ │ ├── badge_only.css │ │ ├── fonts │ │ │ ├── Roboto-Slab-Bold.woff │ │ │ ├── Roboto-Slab-Bold.woff2 │ │ │ ├── Roboto-Slab-Regular.woff │ │ │ ├── Roboto-Slab-Regular.woff2 │ │ │ ├── fontawesome-webfont.eot │ │ │ ├── fontawesome-webfont.svg │ │ │ ├── fontawesome-webfont.ttf │ │ │ ├── fontawesome-webfont.woff │ │ │ ├── fontawesome-webfont.woff2 │ │ │ ├── lato-bold-italic.woff │ │ │ ├── lato-bold-italic.woff2 │ │ │ ├── lato-bold.woff │ │ │ ├── lato-bold.woff2 │ │ │ ├── lato-normal-italic.woff │ │ │ ├── lato-normal-italic.woff2 │ │ │ ├── lato-normal.woff │ │ │ └── lato-normal.woff2 │ │ └── theme.css │ ├── doctools.js │ ├── documentation_options.js │ ├── file.png │ ├── fonts │ │ ├── FontAwesome.otf │ │ ├── Inconsolata-Bold.ttf │ │ ├── Inconsolata-Regular.ttf │ │ ├── Inconsolata.ttf │ │ ├── Lato-Bold.ttf │ │ ├── Lato-Regular.ttf │ │ ├── Lato │ │ │ ├── lato-bold.eot │ │ │ ├── lato-bold.ttf │ │ │ ├── lato-bold.woff │ │ │ ├── lato-bold.woff2 │ │ │ ├── lato-bolditalic.eot │ │ │ ├── lato-bolditalic.ttf │ │ │ ├── lato-bolditalic.woff │ │ │ ├── lato-bolditalic.woff2 │ │ │ ├── lato-italic.eot │ │ │ ├── lato-italic.ttf │ │ │ ├── lato-italic.woff │ │ │ ├── lato-italic.woff2 │ │ │ ├── lato-regular.eot │ │ │ ├── lato-regular.ttf │ │ │ ├── lato-regular.woff │ │ │ └── lato-regular.woff2 │ │ ├── Roboto-Slab-Bold.woff │ │ ├── Roboto-Slab-Bold.woff2 │ │ ├── Roboto-Slab-Light.woff │ │ ├── Roboto-Slab-Light.woff2 │ │ ├── Roboto-Slab-Regular.woff │ │ ├── Roboto-Slab-Regular.woff2 │ │ ├── Roboto-Slab-Thin.woff │ │ ├── Roboto-Slab-Thin.woff2 │ │ ├── RobotoSlab-Bold.ttf │ │ ├── RobotoSlab-Regular.ttf │ │ ├── RobotoSlab │ │ │ ├── roboto-slab-v7-bold.eot │ │ │ ├── roboto-slab-v7-bold.ttf │ │ │ ├── roboto-slab-v7-bold.woff │ │ │ ├── roboto-slab-v7-bold.woff2 │ │ │ ├── roboto-slab-v7-regular.eot │ │ │ ├── roboto-slab-v7-regular.ttf │ │ │ ├── roboto-slab-v7-regular.woff │ │ │ └── roboto-slab-v7-regular.woff2 │ │ ├── fontawesome-webfont.eot │ │ ├── fontawesome-webfont.svg │ │ ├── fontawesome-webfont.ttf │ │ ├── fontawesome-webfont.woff │ │ ├── fontawesome-webfont.woff2 │ │ ├── lato-bold-italic.woff │ │ ├── lato-bold-italic.woff2 │ │ ├── lato-bold.woff │ │ ├── lato-bold.woff2 │ │ ├── lato-normal-italic.woff │ │ ├── lato-normal-italic.woff2 │ │ ├── lato-normal.woff │ │ └── lato-normal.woff2 │ ├── getting_started_output.png │ ├── getting_started_updated.png │ ├── jquery-3.2.1.js │ ├── jquery-3.4.1.js │ ├── jquery-3.5.1.js │ ├── jquery-3.6.0.js │ ├── jquery.js │ ├── js │ │ ├── badge_only.js │ │ ├── html5shiv-printshiv.min.js │ │ ├── html5shiv.min.js │ │ ├── modernizr.min.js │ │ └── theme.js │ ├── language_data.js │ ├── minus.png │ ├── plus.png │ ├── pygments.css │ ├── searchtools.js │ ├── sphinx_highlight.js │ ├── underscore-1.13.1.js │ ├── underscore-1.3.1.js │ └── underscore.js ├── dice_ml.data_interfaces.html ├── dice_ml.dice_interfaces.html ├── dice_ml.explainer_interfaces.html ├── dice_ml.html ├── dice_ml.model_interfaces.html ├── dice_ml.schema.html ├── dice_ml.utils.html ├── dice_ml.utils.sample_architecture.html ├── genindex.html ├── index.html ├── make.bat ├── modules.html ├── notebooks │ ├── Benchmarking_different_CF_explanation_methods.html │ ├── DiCE_feature_importances.html │ ├── DiCE_getting_started.html │ ├── DiCE_getting_started_feasible.html │ ├── DiCE_model_agnostic_CFs.html │ ├── DiCE_multiclass_classification_and_regression.html │ ├── DiCE_with_advanced_options.html │ ├── DiCE_with_private_data.html │ └── nb_index.html ├── objects.inv ├── py-modindex.html ├── readme.html ├── search.html ├── searchindex.js ├── source │ ├── conf.py │ ├── dice_ml.data_interfaces.rst │ ├── dice_ml.explainer_interfaces.rst │ ├── dice_ml.model_interfaces.rst │ ├── dice_ml.rst │ ├── dice_ml.schema.rst │ ├── dice_ml.utils.rst │ ├── dice_ml.utils.sample_architecture.rst │ ├── index.rst │ ├── modules.rst │ ├── notebooks │ │ ├── Benchmarking_different_CF_explanation_methods.ipynb │ │ ├── DiCE_feature_importances.ipynb │ │ ├── DiCE_getting_started.ipynb │ │ ├── DiCE_getting_started_feasible.ipynb │ │ ├── DiCE_model_agnostic_CFs.ipynb │ │ ├── DiCE_multiclass_classification_and_regression.ipynb │ │ ├── DiCE_with_advanced_options.ipynb │ │ ├── DiCE_with_private_data.ipynb │ │ ├── images │ │ │ └── dice_getting_started_api.png │ │ └── nb_index.rst │ └── readme.rst └── update_docs.sh ├── environment-deeplearning.yml ├── environment.yml ├── requirements-deeplearning.txt ├── requirements-linting.txt ├── requirements-test.txt ├── requirements.txt ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── test_counterfactual_explanations.py ├── test_data.py ├── test_data_interface ├── __init__.py ├── test_base_data_interface.py ├── test_private_data_interface.py └── test_public_data_interface.py ├── test_dice.py ├── test_dice_interface ├── __init__.py ├── test_dice_KD.py ├── test_dice_genetic.py ├── test_dice_pytorch.py ├── test_dice_random.py ├── test_dice_tensorflow.py └── test_explainer_base.py ├── test_helpers.py ├── test_model.py ├── test_model_interface ├── __init__.py ├── test_base_model.py ├── test_keras_tensorflow_model.py └── test_pytorch_model.py └── test_notebooks.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 127 3 | -------------------------------------------------------------------------------- /.github/workflows/code-scan.yml: -------------------------------------------------------------------------------- 1 | name: code scan 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | schedule: 9 | - cron: '30 5 * * *' 10 | 11 | jobs: 12 | analyze: 13 | name: Analyze 14 | runs-on: ubuntu-latest 15 | permissions: 16 | actions: read 17 | contents: read 18 | security-events: write 19 | 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | language: ["python"] 24 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] 25 | # Learn more: 26 | # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed 27 | 28 | steps: 29 | - name: Checkout repository 30 | uses: actions/checkout@v4 31 | 32 | # Initializes the CodeQL tools for scanning. 33 | - name: Initialize CodeQL 34 | uses: github/codeql-action/init@v1 35 | with: 36 | languages: ${{ matrix.language }} 37 | # If you wish to specify custom queries, you can do so here or in a config file. 38 | # By default, queries listed here will override any specified in a config file. 39 | # Prefix the list here with "+" to use these queries and those in the config file. 40 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 41 | 42 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 43 | # If this step fails, then you should remove it and run the build manually (see below) 44 | - name: Autobuild 45 | uses: github/codeql-action/autobuild@v1 46 | 47 | # ℹ️ Command-line programs to run using the OS shell. 48 | # 📚 https://git.io/JvXDl 49 | 50 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 51 | # and modify them (or add more) to build your code if your project 52 | # uses a compiled language 53 | 54 | #- run: | 55 | # make bootstrap 56 | # make release 57 | 58 | - name: Perform CodeQL Analysis 59 | uses: github/codeql-action/analyze@v1 60 | -------------------------------------------------------------------------------- /.github/workflows/notebook-linting.yml: -------------------------------------------------------------------------------- 1 | # This workflow will lint jupyter notebooks with flake8-nb. 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Notebook linting 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | schedule: 12 | - cron: '30 5 * * *' 13 | 14 | jobs: 15 | build: 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Set up Python 3.11 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.11' 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install flake8-nb==0.4.0 28 | - name: Lint notebooks with flake8_nb 29 | run: | 30 | # stop the build if there are flake8 errors in notebooks 31 | flake8_nb docs/source/notebooks/ --statistics --max-line-length=127 32 | -------------------------------------------------------------------------------- /.github/workflows/notebook-tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Notebook tests 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | schedule: 12 | - cron: '30 5 * * *' 13 | 14 | jobs: 15 | build: 16 | 17 | runs-on: ${{ matrix.os }} 18 | strategy: 19 | matrix: 20 | python-version: ["3.9", "3.10", "3.11", "3.12"] 21 | os: [ubuntu-latest, macos-latest] 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python ${{ matrix.python-version }} ${{ matrix.os }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - name: Upgrade pip 30 | run: | 31 | python -m pip install --upgrade pip 32 | - name: Install core dependencies 33 | run: | 34 | pip install -r requirements.txt 35 | - name: Install deep learning dependencies 36 | run: | 37 | pip install -r requirements-deeplearning.txt 38 | - name: Install test dependencies 39 | run: | 40 | pip install -r requirements-test.txt 41 | - name: Test with pytest 42 | run: | 43 | # pytest 44 | pytest tests/ -m "notebook_tests" --durations=10 --doctest-modules --junitxml=junit/test-results.xml --cov=dice_ml --cov-report=xml --cov-report=html 45 | -------------------------------------------------------------------------------- /.github/workflows/python-linting.yml: -------------------------------------------------------------------------------- 1 | # This workflow will lint python code with flake8. 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python linting 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | schedule: 12 | - cron: '30 5 * * *' 13 | 14 | jobs: 15 | build: 16 | runs-on: ubuntu-latest 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Set up Python 3.11 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.11' 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install -r requirements-linting.txt 28 | - name: Check sorted python imports using isort 29 | run: | 30 | isort . -c 31 | - name: Lint code with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # The GitHub editor is 127 chars wide. 36 | flake8 . --count --max-complexity=30 --max-line-length=127 --statistics 37 | # Check for cyclometric complexity for specific files where this metric has been 38 | # reduced to ten and below 39 | flake8 dice_ml/data_interfaces/ --count --max-complexity=10 --max-line-length=127 40 | 41 | -------------------------------------------------------------------------------- /.github/workflows/python-package-conda.yml: -------------------------------------------------------------------------------- 1 | name: Python Package using Conda 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | schedule: 9 | - cron: '30 5 * * *' 10 | 11 | jobs: 12 | build-linux: 13 | runs-on: ubuntu-latest 14 | strategy: 15 | max-parallel: 5 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python 3.8 20 | uses: actions/setup-python@v5 21 | with: 22 | python-version: 3.8 23 | - name: Add conda to system path 24 | run: | 25 | # $CONDA is an environment variable pointing to the root of the miniconda directory 26 | echo $CONDA/bin >> $GITHUB_PATH 27 | - name: Install core dependencies 28 | run: | 29 | conda env update --file environment.yml --name base 30 | - name: Install deep learning dependencies 31 | run: | 32 | conda env update --file environment-deeplearning.yml --name base 33 | - name: Test with pytest 34 | run: | 35 | conda install pytest ipython jupyter nbformat pytest-mock 36 | 37 | pytest 38 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package test 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | schedule: 12 | - cron: '30 5 * * *' 13 | 14 | jobs: 15 | build: 16 | 17 | runs-on: ${{ matrix.os }} 18 | strategy: 19 | matrix: 20 | python-version: ["3.9", "3.10", "3.11", "3.12"] 21 | os: [ubuntu-latest, macos-latest, windows-latest] 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python ${{ matrix.python-version }} ${{ matrix.os }} 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: ${{ matrix.python-version }} 29 | - name: Upgrade pip 30 | run: | 31 | python -m pip install --upgrade pip 32 | - name: Install core dependencies 33 | run: | 34 | pip install -r requirements.txt 35 | - name: Install deep learning dependencies 36 | run: | 37 | pip install -r requirements-deeplearning.txt 38 | - name: Install test dependencies 39 | run: | 40 | pip install -r requirements-test.txt 41 | - name: Test with pytest 42 | run: | 43 | # pytest 44 | pytest tests/ -m "not notebook_tests" --durations=10 --doctest-modules --junitxml=junit/test-results.xml --cov=dice_ml --cov-report=xml --cov-report=html 45 | - name: Publish Unit Test Results 46 | uses: EnricoMi/publish-unit-test-result-action/composite@v1 47 | if: ${{ (matrix.python-version == '3.9') && (matrix.os == 'ubuntu-latest') }} 48 | with: 49 | files: junit/test-results.xml 50 | # - name: Upload coverage to Codecov 51 | # uses: codecov/codecov-action@v3 52 | # if: ${{ (matrix.python-version == '3.9') && (matrix.os == 'ubuntu-latest') }} 53 | # with: 54 | # token: ${{ secrets.CODECOV_TOKEN }} 55 | # directory: . 56 | # env_vars: OS,PYTHON 57 | # fail_ci_if_error: true 58 | # files: ./coverage.xml 59 | # flags: unittests 60 | # name: codecov-umbrella 61 | # path_to_write_report: ./coverage/codecov_report.txt 62 | # verbose: true 63 | - name: Check package consistency with twine 64 | run: | 65 | python setup.py check sdist bdist_wheel 66 | twine check dist/* 67 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | on: 2 | release: 3 | types: [created] 4 | workflow_dispatch: 5 | 6 | jobs: 7 | pypi-publish: 8 | name: upload release to PyPI 9 | runs-on: ubuntu-latest 10 | # Specifying a GitHub environment is optional, but strongly encouraged 11 | environment: release 12 | permissions: 13 | # IMPORTANT: this permission is mandatory for trusted publishing 14 | id-token: write 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python 18 | uses: actions/setup-python@v5 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | run: | 27 | python setup.py sdist bdist_wheel 28 | - name: Publish package distributions to PyPI 29 | uses: pypa/gh-action-pypi-publish@release/v1 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | docs/build/ 69 | docs/_sources/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | *.data 108 | 109 | # notebook to solve issues 110 | notebooks/DiCE_issues.ipynb 111 | # to avoid auto copied notebooks to be stored in repo 112 | docs/notebooks/DiCE_getting_started.ipynb 113 | docs/notebooks/DiCE_getting_started_feasible.ipynb 114 | docs/notebooks/DiCE_with_advanced_options.ipynb 115 | docs/notebooks/DiCE_with_private_data.ipynb 116 | docs/notebooks/*.ipynb 117 | 118 | 119 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # dice-ml package 2 | /dice_ml @gaugup @amit-sharma -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | include requirements-deeplearning.txt 3 | include requirements-test.txt 4 | include requirements-linting.txt 5 | include environment.yml 6 | include environment-deeplearning.yml 7 | include LICENSE 8 | include CODEOWNERS 9 | recursive-include docs * 10 | recursive-include tests *.py 11 | include dice_ml/utils/sample_trained_models/* 12 | -------------------------------------------------------------------------------- /dice_ml/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import Data 2 | from .dice import Dice 3 | from .model import Model 4 | 5 | __all__ = ["Data", 6 | "Model", 7 | "Dice"] 8 | -------------------------------------------------------------------------------- /dice_ml/constants.py: -------------------------------------------------------------------------------- 1 | """Constants for dice-ml package.""" 2 | 3 | 4 | class BackEndTypes: 5 | Sklearn = 'sklearn' 6 | Tensorflow1 = 'TF1' 7 | Tensorflow2 = 'TF2' 8 | Pytorch = 'PYT' 9 | 10 | ALL = [Sklearn, Tensorflow1, Tensorflow2, Pytorch] 11 | 12 | 13 | class SamplingStrategy: 14 | Random = 'random' 15 | Genetic = 'genetic' 16 | KdTree = 'kdtree' 17 | Gradient = 'gradient' 18 | 19 | 20 | class ModelTypes: 21 | Classifier = 'classifier' 22 | Regressor = 'regressor' 23 | 24 | ALL = [Classifier, Regressor] 25 | 26 | 27 | class _SchemaVersions: 28 | V1 = '1.0' 29 | V2 = '2.0' 30 | CURRENT_VERSION = V2 31 | 32 | ALL_VERSIONS = [V1, V2] 33 | 34 | 35 | class _PostHocSparsityTypes: 36 | LINEAR = 'linear' 37 | BINARY = 'binary' 38 | 39 | ALL = [LINEAR, BINARY] 40 | -------------------------------------------------------------------------------- /dice_ml/data.py: -------------------------------------------------------------------------------- 1 | """Module pointing to different implementations of Data class 2 | 3 | DiCE requires only few parameters about the data such as the range of continuous 4 | features and the levels of categorical features. Hence, DiCE can be used for a 5 | private data whose meta data are only available (such as the feature names and 6 | range/levels of different features) by specifying appropriate parameters. 7 | """ 8 | 9 | from dice_ml.data_interfaces.base_data_interface import _BaseData 10 | 11 | 12 | class Data(_BaseData): 13 | """Class containing all required information about the data for DiCE.""" 14 | 15 | def __init__(self, **params): 16 | """Init method 17 | 18 | :param **params: a dictionary of required parameters. 19 | """ 20 | self.decide_implementation_type(params) 21 | 22 | def decide_implementation_type(self, params): 23 | """Decides if the Data class is for public or private data.""" 24 | self.__class__ = decide(params) 25 | self.__init__(params) 26 | 27 | 28 | def decide(params): 29 | """Decides if the Data class is for public or private data. 30 | 31 | To add new implementations of Data, add the class in data_interfaces 32 | subpackage and import-and-return the class in an elif loop as shown 33 | in the below method. 34 | """ 35 | if 'dataframe' in params: 36 | # if params contain a Pandas dataframe, then use PublicData class 37 | from dice_ml.data_interfaces.public_data_interface import PublicData 38 | return PublicData 39 | else: 40 | # use PrivateData if only meta data is provided 41 | from dice_ml.data_interfaces.private_data_interface import PrivateData 42 | return PrivateData 43 | -------------------------------------------------------------------------------- /dice_ml/data_interfaces/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/data_interfaces/__init__.py -------------------------------------------------------------------------------- /dice_ml/data_interfaces/base_data_interface.py: -------------------------------------------------------------------------------- 1 | """Module containing base class for data interfaces for dice-ml.""" 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | import pandas as pd 6 | from raiutils.exceptions import UserConfigValidationException 7 | 8 | from dice_ml.utils.exception import SystemException 9 | 10 | 11 | class _BaseData(ABC): 12 | 13 | def _validate_and_set_data_name(self, params): 14 | """Validate and set the data name.""" 15 | if 'data_name' in params: 16 | self.data_name = params['data_name'] 17 | else: 18 | self.data_name = 'mydata' 19 | 20 | def _validate_and_set_outcome_name(self, params): 21 | """Validate and set the outcome name.""" 22 | if 'outcome_name' not in params: 23 | raise ValueError("should provide the name of outcome feature") 24 | 25 | if isinstance(params['outcome_name'], str): 26 | self.outcome_name = params['outcome_name'] 27 | else: 28 | raise ValueError("should provide the name of outcome feature as a string") 29 | 30 | def set_continuous_feature_indexes(self, query_instance): 31 | """Remaps continuous feature indices based on the query instance""" 32 | self.continuous_feature_indexes = [query_instance.columns.get_loc(name) for name in 33 | self.continuous_feature_names] 34 | 35 | def check_features_to_vary(self, features_to_vary): 36 | if features_to_vary is not None and features_to_vary != 'all': 37 | not_training_features = set(features_to_vary) - set(self.feature_names) 38 | if len(not_training_features) > 0: 39 | raise UserConfigValidationException("Got features {0} which are not present in training data".format( 40 | not_training_features)) 41 | 42 | def check_permitted_range(self, permitted_range): 43 | if permitted_range is not None: 44 | permitted_range_features = list(permitted_range) 45 | not_training_features = set(permitted_range_features) - set(self.feature_names) 46 | if len(not_training_features) > 0: 47 | raise UserConfigValidationException("Got features {0} which are not present in training data".format( 48 | not_training_features)) 49 | 50 | for feature in permitted_range_features: 51 | if feature in self.categorical_feature_names: 52 | train_categories = self.permitted_range[feature] 53 | for test_category in permitted_range[feature]: 54 | if test_category not in train_categories: 55 | raise UserConfigValidationException( 56 | 'The category {0} does not occur in the training data for feature {1}.' 57 | ' Allowed categories are {2}'.format(test_category, feature, train_categories)) 58 | 59 | def _validate_and_set_permitted_range(self, params, features_dict=None): 60 | """Validate and set the dictionary of permitted ranges for continuous features.""" 61 | input_permitted_range = None 62 | if 'permitted_range' in params: 63 | input_permitted_range = params['permitted_range'] 64 | 65 | if not hasattr(self, 'feature_names'): 66 | raise SystemException('Feature names not correctly set in public data interface') 67 | 68 | for input_permitted_range_feature_name in input_permitted_range: 69 | if input_permitted_range_feature_name not in self.feature_names: 70 | raise UserConfigValidationException( 71 | "permitted_range contains some feature names which are not part of columns in dataframe" 72 | ) 73 | self.permitted_range, _ = self.get_features_range(input_permitted_range, features_dict) 74 | 75 | def ensure_consistent_type(self, output_df, query_instance): 76 | qdf = self.query_instance_to_df(query_instance) 77 | output_df = output_df.astype(qdf.dtypes.to_dict()) 78 | return output_df 79 | 80 | def query_instance_to_df(self, query_instance): 81 | if isinstance(query_instance, list): 82 | if isinstance(query_instance[0], dict): # prepare a list of query instances 83 | test = pd.DataFrame(query_instance, columns=self.feature_names) 84 | 85 | else: # prepare a single query instance in list 86 | query_instance = {'row1': query_instance} 87 | test = pd.DataFrame.from_dict( 88 | query_instance, orient='index', columns=self.feature_names) 89 | 90 | elif isinstance(query_instance, dict): 91 | test = pd.DataFrame({k: [v] for k, v in query_instance.items()}, columns=self.feature_names) 92 | 93 | elif isinstance(query_instance, pd.DataFrame): 94 | test = query_instance.copy() 95 | 96 | else: 97 | raise ValueError("Query instance should be a dict, a pandas dataframe, a list, or a list of dicts") 98 | return test 99 | 100 | @abstractmethod 101 | def __init__(self, params): 102 | """The init method needs to be implemented by the inherting classes.""" 103 | pass 104 | -------------------------------------------------------------------------------- /dice_ml/dice.py: -------------------------------------------------------------------------------- 1 | """Module pointing to different implementations of DiCE based on different 2 | frameworks such as Tensorflow or PyTorch or sklearn, and different methods 3 | such as RandomSampling, DiCEKD or DiCEGenetic""" 4 | 5 | from raiutils.exceptions import UserConfigValidationException 6 | 7 | from dice_ml.constants import BackEndTypes, SamplingStrategy 8 | from dice_ml.data_interfaces.private_data_interface import PrivateData 9 | from dice_ml.explainer_interfaces.explainer_base import ExplainerBase 10 | 11 | 12 | class Dice(ExplainerBase): 13 | """An interface class to different DiCE implementations.""" 14 | 15 | def __init__(self, data_interface, model_interface, method="random", **kwargs): 16 | """Init method 17 | 18 | :param data_interface: an interface to access data related params. 19 | :param model_interface: an interface to access the output or gradients of a trained ML model. 20 | :param method: Name of the method to use for generating counterfactuals 21 | """ 22 | self.decide_implementation_type(data_interface, model_interface, method, **kwargs) 23 | 24 | def decide_implementation_type(self, data_interface, model_interface, method, **kwargs): 25 | """Decides DiCE implementation type.""" 26 | if model_interface.backend == BackEndTypes.Sklearn: 27 | if method == SamplingStrategy.KdTree and isinstance(data_interface, PrivateData): 28 | raise UserConfigValidationException( 29 | 'Private data interface is not supported with kdtree explainer' 30 | ' since kdtree explainer needs access to entire training data') 31 | self.__class__ = decide(model_interface, method) 32 | self.__init__(data_interface, model_interface, **kwargs) 33 | 34 | def _generate_counterfactuals(self, query_instance, total_CFs, 35 | desired_class="opposite", desired_range=None, 36 | permitted_range=None, features_to_vary="all", 37 | stopping_threshold=0.5, posthoc_sparsity_param=0.1, 38 | posthoc_sparsity_algorithm="linear", verbose=False, **kwargs): 39 | raise NotImplementedError("This method should be implemented by the concrete classes " 40 | "that inherit from ExplainerBase") 41 | 42 | 43 | def decide(model_interface, method): 44 | """Decides DiCE implementation type. 45 | 46 | To add new implementations of DiCE, add the class in explainer_interfaces 47 | subpackage and import-and-return the class in an elif loop as shown in 48 | the below method. 49 | """ 50 | if method == SamplingStrategy.Random: 51 | # random sampling of CFs 52 | from dice_ml.explainer_interfaces.dice_random import DiceRandom 53 | return DiceRandom 54 | elif method == SamplingStrategy.Genetic: 55 | from dice_ml.explainer_interfaces.dice_genetic import DiceGenetic 56 | return DiceGenetic 57 | elif method == SamplingStrategy.KdTree: 58 | from dice_ml.explainer_interfaces.dice_KD import DiceKD 59 | return DiceKD 60 | elif method == SamplingStrategy.Gradient: 61 | if model_interface.backend == BackEndTypes.Tensorflow1: 62 | # pretrained Keras Sequential model with Tensorflow 1.x backend 63 | from dice_ml.explainer_interfaces.dice_tensorflow1 import \ 64 | DiceTensorFlow1 65 | return DiceTensorFlow1 66 | 67 | elif model_interface.backend == BackEndTypes.Tensorflow2: 68 | # pretrained Keras Sequential model with Tensorflow 2.x backend 69 | from dice_ml.explainer_interfaces.dice_tensorflow2 import \ 70 | DiceTensorFlow2 71 | return DiceTensorFlow2 72 | 73 | elif model_interface.backend == BackEndTypes.Pytorch: 74 | # PyTorch backend 75 | from dice_ml.explainer_interfaces.dice_pytorch import DicePyTorch 76 | return DicePyTorch 77 | else: 78 | raise UserConfigValidationException( 79 | "{0} is only supported for differentiable neural network models. " 80 | "Please choose one of {1}, {2} or {3}".format( 81 | method, SamplingStrategy.Random, 82 | SamplingStrategy.Genetic, 83 | SamplingStrategy.KdTree 84 | )) 85 | elif method is None: 86 | # all other backends 87 | backend_dice = model_interface.backend['explainer'] 88 | module_name, class_name = backend_dice.split('.') 89 | module = __import__("dice_ml.explainer_interfaces." + module_name, fromlist=[class_name]) 90 | return getattr(module, class_name) 91 | else: 92 | raise UserConfigValidationException("Unsupported sample strategy {0} provided. " 93 | "Please choose one of {1}, {2} or {3}".format( 94 | method, SamplingStrategy.Random, 95 | SamplingStrategy.Genetic, 96 | SamplingStrategy.KdTree 97 | )) 98 | -------------------------------------------------------------------------------- /dice_ml/explainer_interfaces/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/explainer_interfaces/__init__.py -------------------------------------------------------------------------------- /dice_ml/explainer_interfaces/dice_xgboost.py: -------------------------------------------------------------------------------- 1 | from dice_ml.explainer_interfaces.explainer_base import ExplainerBase 2 | 3 | 4 | class DiceXGBoost(ExplainerBase): 5 | def __init__(self, data_interface, model_interface): 6 | """Initialize with data and model interfaces""" 7 | super().__init__(data_interface, model_interface) 8 | 9 | def generate_counterfactuals(self, query_instance, total_CFs=5): 10 | """Generate counterfactuals""" 11 | # Implement your logic to generate counterfactuals 12 | raise NotImplementedError("Counterfactual generation for XGBoost is not implemented yet.") 13 | -------------------------------------------------------------------------------- /dice_ml/explainer_interfaces/feasible_model_approx.py: -------------------------------------------------------------------------------- 1 | # Dice Imports 2 | # Pytorch 3 | import torch 4 | import torch.utils.data 5 | from torch.nn import functional as F 6 | 7 | from dice_ml.explainer_interfaces.explainer_base import ExplainerBase 8 | from dice_ml.explainer_interfaces.feasible_base_vae import FeasibleBaseVAE 9 | from dice_ml.utils.helpers import get_base_gen_cf_initialization 10 | 11 | 12 | class FeasibleModelApprox(FeasibleBaseVAE, ExplainerBase): 13 | 14 | def __init__(self, data_interface, model_interface, **kwargs): 15 | """ 16 | :param data_interface: an interface class to data related params 17 | :param model_interface: an interface class to access trained ML model 18 | """ 19 | 20 | # initiating data related parameters 21 | ExplainerBase.__init__(self, data_interface) 22 | 23 | # Black Box ML Model to be explained 24 | self.pred_model = model_interface.model 25 | 26 | self.minx, self.maxx, self.encoded_categorical_feature_indexes, \ 27 | self.encoded_continuous_feature_indexes, self.cont_minx, self.cont_maxx, self.cont_precisions = \ 28 | self.data_interface.get_data_params_for_gradient_dice() 29 | self.data_interface.one_hot_encoded_data = self.data_interface.one_hot_encode_data(self.data_interface.data_df) 30 | # Hyperparam 31 | self.encoded_size = kwargs['encoded_size'] 32 | self.learning_rate = kwargs['lr'] 33 | self.batch_size = kwargs['batch_size'] 34 | self.validity_reg = kwargs['validity_reg'] 35 | self.margin = kwargs['margin'] 36 | self.epochs = kwargs['epochs'] 37 | self.wm1 = kwargs['wm1'] 38 | self.wm2 = kwargs['wm2'] 39 | self.wm3 = kwargs['wm3'] 40 | 41 | # Initializing parameters for the DiceModelApproxGenCF 42 | self.vae_train_dataset, self.vae_val_dataset, self.vae_test_dataset, self.normalise_weights, \ 43 | self.cf_vae, self.cf_vae_optimizer = \ 44 | get_base_gen_cf_initialization( 45 | self.data_interface, self.encoded_size, self.cont_minx, 46 | self.cont_maxx, self.margin, self.validity_reg, self.epochs, 47 | self.wm1, self.wm2, self.wm3, self.learning_rate) 48 | 49 | # Data paths 50 | self.base_model_dir = '../../../dice_ml/utils/sample_trained_models/' 51 | self.save_path = self.base_model_dir + self.data_interface.data_name + \ 52 | '-margin-' + str(self.margin) + '-validity_reg-' + str(self.validity_reg) + \ 53 | '-epoch-' + str(self.epochs) + '-' + 'ae-gen' + '.pth' 54 | 55 | def train(self, constraint_type, constraint_variables, constraint_direction, constraint_reg, pre_trained=False): 56 | ''' 57 | :param pre_trained: Bool Variable to check whether pre trained model exists to avoid training again 58 | :param constraint_type: Binary Variable currently: (1) unary / (0) monotonic 59 | :param constraint_variables: List of List: [[Effect, Cause1, Cause2, .... ]] 60 | :param constraint_direction: -1: Negative, 1: Positive ( By default has to be one for monotonic constraints ) 61 | :param constraint_reg: Tunable Hyperparamter 62 | 63 | :return None 64 | ''' 65 | if pre_trained: 66 | self.cf_vae.load_state_dict(torch.load(self.save_path)) 67 | self.cf_vae.eval() 68 | return 69 | 70 | # TODO: Handling such dataset specific constraints in a more general way 71 | # CF Generation for only low to high income data points 72 | self.vae_train_dataset = self.vae_train_dataset[self.vae_train_dataset[:, -1] == 0, :] 73 | self.vae_val_dataset = self.vae_val_dataset[self.vae_val_dataset[:, -1] == 0, :] 74 | 75 | # Removing the outcome variable from the datasets 76 | self.vae_train_feat = self.vae_train_dataset[:, :-1] 77 | self.vae_val_feat = self.vae_val_dataset[:, :-1] 78 | 79 | for epoch in range(self.epochs): 80 | batch_num = 0 81 | train_loss = 0.0 82 | train_size = 0 83 | 84 | train_dataset = torch.tensor(self.vae_train_feat).float() 85 | train_dataset = torch.utils.data.DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) 86 | for train in enumerate(train_dataset): 87 | self.cf_vae_optimizer.zero_grad() 88 | 89 | train_x = train[1] 90 | train_y = 1.0-torch.argmax(self.pred_model(train_x), dim=1) 91 | train_size += train_x.shape[0] 92 | 93 | out = self.cf_vae(train_x, train_y) 94 | loss = self.compute_loss(out, train_x, train_y) 95 | 96 | # Unary Case 97 | if constraint_type: 98 | for const in constraint_variables: 99 | # Get the index from the feature name 100 | # Handle the categorical variable case here too 101 | const_idx = const[0] 102 | dm = out['x_pred'] 103 | mc_samples = out['mc_samples'] 104 | x_pred = dm[0] 105 | 106 | constraint_loss = F.hinge_embedding_loss( 107 | constraint_direction*(x_pred[:, const_idx] - train_x[:, const_idx]), torch.tensor(-1), 0) 108 | 109 | for j in range(1, mc_samples): 110 | x_pred = dm[j] 111 | constraint_loss += F.hinge_embedding_loss( 112 | constraint_direction*(x_pred[:, const_idx] - train_x[:, const_idx]), torch.tensor(-1), 0) 113 | 114 | constraint_loss = constraint_loss/mc_samples 115 | constraint_loss = constraint_reg*constraint_loss 116 | loss += constraint_loss 117 | print('Constraint: ', constraint_loss, torch.mean(constraint_loss)) 118 | else: 119 | # Train the regression model 120 | raise NotImplementedError( 121 | "This has not been implemented yet. If you'd like this to be implemented in the next version, " 122 | "please raise an issue at https://github.com/interpretml/DiCE/issues") 123 | 124 | loss.backward() 125 | train_loss += loss.item() 126 | self.cf_vae_optimizer.step() 127 | 128 | batch_num += 1 129 | 130 | ret = loss/batch_num 131 | print('Train Avg Loss: ', ret, train_size) 132 | 133 | # Save the model after training every 10 epochs and at the last epoch 134 | if (epoch != 0 and (epoch % 10) == 0) or epoch == self.epochs-1: 135 | torch.save(self.cf_vae.state_dict(), self.save_path) 136 | -------------------------------------------------------------------------------- /dice_ml/model.py: -------------------------------------------------------------------------------- 1 | """Module pointing to different implementations of Model class 2 | 3 | The implementations contain methods to access the output or gradients of ML models trained based on different 4 | frameworks such as Tensorflow or PyTorch. 5 | """ 6 | import warnings 7 | 8 | from raiutils.exceptions import UserConfigValidationException 9 | 10 | from dice_ml.constants import BackEndTypes, ModelTypes 11 | 12 | 13 | class Model: 14 | """An interface class to different ML Model implementations.""" 15 | def __init__(self, model=None, model_path='', backend=BackEndTypes.Tensorflow1, model_type=ModelTypes.Classifier, 16 | func=None, kw_args=None): 17 | """Init method 18 | 19 | :param model: trained ML model. 20 | :param model_path: path to trained ML model. 21 | :param backend: "TF1" ("TF2") for TensorFLow 1.0 (2.0), "PYT" for PyTorch implementations, 22 | "sklearn" for Scikit-Learn implementations of standard 23 | DiCE (https://arxiv.org/pdf/1905.07697.pdf). For all other frameworks and 24 | implementations, provide a dictionary with "model" and "explainer" as keys, 25 | and include module and class names as values in the form module_name.class_name. 26 | For instance, if there is a model interface class "XGBoostModel" in module "xgboost_model.py" 27 | inside the subpackage dice_ml.model_interfaces, and dice interface class "DiceXGBoost" 28 | in module "dice_xgboost" inside dice_ml.explainer_interfaces, then backend parameter 29 | should be {"model": "xgboost_model.XGBoostModel", "explainer": dice_xgboost.DiceXGBoost}. 30 | :param func: function transformation required for ML model. If func is None, then func will be the identity function. 31 | :param kw_args: Dictionary of additional keyword arguments to pass to func. DiCE's data_interface is appended 32 | to the dictionary of kw_args, by default. 33 | """ 34 | if backend not in BackEndTypes.ALL: 35 | warnings.warn('{0} backend not in supported backends {1}'.format( 36 | backend, ','.join(BackEndTypes.ALL)), stacklevel=2) 37 | 38 | if model_type not in ModelTypes.ALL: 39 | raise UserConfigValidationException('{0} model type not in supported model types {1}'.format( 40 | model_type, ','.join(ModelTypes.ALL)) 41 | ) 42 | 43 | self.model_type = model_type 44 | if model is None and model_path == '': 45 | raise ValueError("should provide either a trained model or the path to a model") 46 | else: 47 | self.decide_implementation_type(model, model_path, backend, func, kw_args) 48 | 49 | def decide_implementation_type(self, model, model_path, backend, func, kw_args): 50 | """Decides the Model implementation type.""" 51 | 52 | self.__class__ = decide(backend) 53 | self.__init__(model, model_path, backend, func, kw_args) 54 | 55 | 56 | def decide(backend): 57 | """Decides the Model implementation type. 58 | 59 | To add new implementations of Model, add the class in model_interfaces subpackage and 60 | import-and-return the class in an elif loop as shown in the below method. 61 | """ 62 | if backend == BackEndTypes.Sklearn: 63 | # random sampling of CFs 64 | from dice_ml.model_interfaces.base_model import BaseModel 65 | return BaseModel 66 | 67 | elif backend == BackEndTypes.Tensorflow1 or backend == BackEndTypes.Tensorflow2: 68 | # Tensorflow 1 or 2 backend 69 | try: 70 | import tensorflow # noqa: F401 71 | except ImportError: 72 | raise UserConfigValidationException("Unable to import tensorflow. Please install tensorflow") 73 | from dice_ml.model_interfaces.keras_tensorflow_model import \ 74 | KerasTensorFlowModel 75 | return KerasTensorFlowModel 76 | 77 | elif backend == BackEndTypes.Pytorch: 78 | # PyTorch backend 79 | try: 80 | import torch # noqa: F401 81 | except ImportError: 82 | raise UserConfigValidationException("Unable to import torch. Please install torch from https://pytorch.org/") 83 | from dice_ml.model_interfaces.pytorch_model import PyTorchModel 84 | return PyTorchModel 85 | 86 | else: 87 | # all other implementations and frameworks 88 | backend_model = backend['model'] 89 | module_name, class_name = backend_model.split('.') 90 | module = __import__("dice_ml.model_interfaces." + module_name, fromlist=[class_name]) 91 | return getattr(module, class_name) 92 | -------------------------------------------------------------------------------- /dice_ml/model_interfaces/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/model_interfaces/__init__.py -------------------------------------------------------------------------------- /dice_ml/model_interfaces/base_model.py: -------------------------------------------------------------------------------- 1 | """Module containing a template class as an interface to ML model. 2 | Subclasses implement model interfaces for different ML frameworks such as TensorFlow, PyTorch OR Sklearn. 3 | All model interface methods are in dice_ml.model_interfaces""" 4 | 5 | import pickle 6 | 7 | import numpy as np 8 | 9 | from dice_ml.constants import ModelTypes 10 | from dice_ml.utils.exception import SystemException 11 | from dice_ml.utils.helpers import DataTransfomer 12 | 13 | 14 | class BaseModel: 15 | 16 | def __init__(self, model=None, model_path='', backend='', func=None, kw_args=None): 17 | """Init method 18 | 19 | :param model: trained ML Model. 20 | :param model_path: path to trained model. 21 | :param backend: ML framework. For frameworks other than TensorFlow or PyTorch, 22 | or for implementations other than standard DiCE 23 | (https://arxiv.org/pdf/1905.07697.pdf), 24 | provide both the module and class names as module_name.class_name. 25 | For instance, if there is a model interface class "SklearnModel" 26 | in module "sklearn_model.py" inside the subpackage dice_ml.model_interfaces, 27 | then backend parameter should be "sklearn_model.SklearnModel". 28 | :param func: function transformation required for ML model. If func is None, then func will be the identity function. 29 | :param kw_args: Dictionary of additional keyword arguments to pass to func. DiCE's data_interface is appended to the 30 | dictionary of kw_args, by default. 31 | 32 | """ 33 | self.model = model 34 | self.model_path = model_path 35 | self.backend = backend 36 | # calls FunctionTransformer of scikit-learn internally 37 | # (https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.FunctionTransformer.html) 38 | self.transformer = DataTransfomer(func, kw_args) 39 | 40 | def load_model(self): 41 | if self.model_path != '': 42 | with open(self.model_path, 'rb') as filehandle: 43 | self.model = pickle.load(filehandle) 44 | 45 | def get_output(self, input_instance, model_score=True): 46 | """returns prediction probabilities for a classifier and the predicted output for a regressor. 47 | 48 | :returns: an array of output scores for a classifier, and a singleton 49 | array of predicted value for a regressor. 50 | """ 51 | input_instance = self.transformer.transform(input_instance) 52 | if model_score: 53 | if self.model_type == ModelTypes.Classifier: 54 | return self.model.predict_proba(input_instance) 55 | else: 56 | return self.model.predict(input_instance) 57 | else: 58 | return self.model.predict(input_instance) 59 | 60 | def get_gradient(self): 61 | raise NotImplementedError 62 | 63 | def get_num_output_nodes(self, inp_size): 64 | temp_input = np.transpose(np.array([np.random.uniform(0, 1) for i in range(inp_size)]).reshape(-1, 1)) 65 | return self.get_output(temp_input).shape[1] 66 | 67 | def get_num_output_nodes2(self, input_instance): 68 | if self.model_type == ModelTypes.Regressor: 69 | raise SystemException('Number of output nodes not supported for regression') 70 | return self.get_output(input_instance).shape[1] 71 | -------------------------------------------------------------------------------- /dice_ml/model_interfaces/keras_tensorflow_model.py: -------------------------------------------------------------------------------- 1 | """Module containing an interface to trained Keras Tensorflow model.""" 2 | 3 | import tensorflow as tf 4 | from tensorflow import keras 5 | 6 | from dice_ml.model_interfaces.base_model import BaseModel 7 | 8 | 9 | class KerasTensorFlowModel(BaseModel): 10 | 11 | def __init__(self, model=None, model_path='', backend='TF1', func=None, kw_args=None): 12 | """Init method 13 | 14 | :param model: trained Keras Sequential Model. 15 | :param model_path: path to trained model. 16 | :param backend: "TF1" for TensorFlow 1 and "TF2" for TensorFlow 2. 17 | :param func: function transformation required for ML model. If func is None, then func will be the identity function. 18 | :param kw_args: Dictionary of additional keyword arguments to pass to func. DiCE's data_interface is appended to the 19 | dictionary of kw_args, by default. 20 | """ 21 | 22 | super().__init__(model, model_path, backend, func, kw_args) 23 | 24 | def load_model(self): 25 | if self.model_path != '': 26 | self.model = keras.models.load_model(self.model_path) 27 | 28 | def get_output(self, input_tensor, training=False, transform_data=False): 29 | """returns prediction probabilities 30 | 31 | :param input_tensor: test input. 32 | :param training: to determine training mode in TF2. 33 | :param transform_data: boolean to indicate if data transformation is required. 34 | """ 35 | if transform_data or not tf.is_tensor(input_tensor): 36 | input_tensor = tf.constant(self.transformer.transform(input_tensor).to_numpy(), dtype=tf.float32) 37 | if self.backend == 'TF2': 38 | return self.model(input_tensor, training=training) 39 | else: 40 | return self.model(input_tensor) 41 | 42 | def get_gradient(self, input_instance): 43 | # Future Support 44 | raise NotImplementedError("Future Support") 45 | 46 | def get_num_output_nodes(self, inp_size): 47 | temp_input = tf.convert_to_tensor([tf.random.uniform([inp_size])], dtype=tf.float32) 48 | return self.get_output(temp_input) 49 | -------------------------------------------------------------------------------- /dice_ml/model_interfaces/pytorch_model.py: -------------------------------------------------------------------------------- 1 | """Module containing an interface to trained PyTorch model.""" 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from dice_ml.constants import ModelTypes 7 | from dice_ml.model_interfaces.base_model import BaseModel 8 | 9 | 10 | class PyTorchModel(BaseModel): 11 | 12 | def __init__(self, model=None, model_path='', backend='PYT', func=None, kw_args=None): 13 | """Init method 14 | 15 | :param model: trained PyTorch Model. 16 | :param model_path: path to trained model. 17 | :param backend: "PYT" for PyTorch framework. 18 | :param func: function transformation required for ML model. If func is None, then func will be the identity function. 19 | :param kw_args: Dictionary of additional keyword arguments to pass to func. DiCE's data_interface is appended to the 20 | dictionary of kw_args, by default. 21 | """ 22 | 23 | super().__init__(model, model_path, backend, func, kw_args) 24 | 25 | def load_model(self, weights_only=False): 26 | if self.model_path != '': 27 | self.model = torch.load(self.model_path, weights_only=weights_only) 28 | 29 | def get_output(self, input_instance, model_score=True, 30 | transform_data=False, out_tensor=False): 31 | """returns prediction probabilities 32 | 33 | :param input_tensor: test input. 34 | :param transform_data: boolean to indicate if data transformation is required. 35 | """ 36 | input_tensor = input_instance 37 | if transform_data: 38 | input_tensor = torch.tensor(self.transformer.transform(input_instance).to_numpy(dtype=np.float64)).float() 39 | if not torch.is_tensor(input_instance): 40 | input_tensor = torch.tensor(self.transformer.transform(input_instance).to_numpy(dtype=np.float64)).float() 41 | out = self.model(input_tensor).float() 42 | if not out_tensor: 43 | out = out.data.numpy() 44 | if model_score is False and self.model_type == ModelTypes.Classifier: 45 | out = np.round(out) # TODO need to generalize for n-class classifier 46 | return out 47 | 48 | def set_eval_mode(self): 49 | self.model.eval() 50 | 51 | def get_gradient(self, input_instance): 52 | # Future Support 53 | raise NotImplementedError("Future Support") 54 | 55 | def get_num_output_nodes(self, inp_size): 56 | temp_input = torch.rand(1, inp_size).float() 57 | return self.get_output(temp_input).data 58 | -------------------------------------------------------------------------------- /dice_ml/model_interfaces/xgboost_model.py: -------------------------------------------------------------------------------- 1 | import xgboost as xgb 2 | 3 | from dice_ml.constants import ModelTypes 4 | from dice_ml.model_interfaces.base_model import BaseModel 5 | 6 | 7 | class XGBoostModel(BaseModel): 8 | 9 | def __init__(self, model=None, model_path='', backend='', func=None, kw_args=None): 10 | super().__init__(model=model, model_path=model_path, backend='xgboost', func=func, kw_args=kw_args) 11 | if model is None and model_path: 12 | self.load_model() 13 | 14 | def load_model(self): 15 | if self.model_path != '': 16 | self.model = xgb.Booster() 17 | self.model.load_model(self.model_path) 18 | 19 | def get_output(self, input_instance, model_score=True): 20 | input_instance = self.transformer.transform(input_instance) 21 | for col in input_instance.columns: 22 | input_instance[col] = input_instance[col].astype('int64') 23 | if model_score: 24 | if self.model_type == ModelTypes.Classifier: 25 | return self.model.predict_proba(input_instance) 26 | else: 27 | return self.model.predict(input_instance) 28 | else: 29 | return self.model.predict(input_instance) 30 | 31 | def get_gradient(self): 32 | raise NotImplementedError("XGBoost does not support gradient calculation in this context") 33 | -------------------------------------------------------------------------------- /dice_ml/schema/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/schema/__init__.py -------------------------------------------------------------------------------- /dice_ml/schema/counterfactual_explanations_v1.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/draft-07/schema#", 3 | "title": "Dashboard Dictionary for Counterfactual outputs", 4 | "description": "The original JSON format for counterfactual examples", 5 | "type": "object", 6 | "properties": { 7 | "cf_examples_list": { 8 | "description": "The list of the computed counterfactual examples.", 9 | "type": "array", 10 | "items": { 11 | "type": "string" 12 | }, 13 | "uniqueItems": true 14 | }, 15 | "local_importance": { 16 | "description": "The list of counterfactual local importance for the features in input data.", 17 | "type": ["array", "null"], 18 | "items": { 19 | "type": "object" 20 | } 21 | }, 22 | "summary_importance": { 23 | "description": "The list of counterfactual summary importance for the features in input data.", 24 | "type": ["object", "null"] 25 | }, 26 | "metadata": { 27 | "description": "The metadata about the generated counterfactuals.", 28 | "type": "object" 29 | } 30 | }, 31 | "required": [ 32 | "cf_examples_list", 33 | "local_importance", 34 | "summary_importance", 35 | "metadata" 36 | ] 37 | } -------------------------------------------------------------------------------- /dice_ml/schema/counterfactual_explanations_v2.0.json: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "http://json-schema.org/draft-07/schema#", 3 | "title": "Output contract for Counterfactual outputs", 4 | "description": "The original JSON format for counterfactual examples", 5 | "type": "object", 6 | "properties": { 7 | "cfs_list": { 8 | "description": "The list of the computed counterfactual examples.", 9 | "type": "array", 10 | "items": { 11 | "type": ["array", "null"] 12 | } 13 | }, 14 | "test_data": { 15 | "description": "The list of the input test samples for which counterfactual examples need to be computed.", 16 | "type": "array", 17 | "items": { 18 | "type": "array" 19 | } 20 | }, 21 | "local_importance": { 22 | "description": "The list of counterfactual local importance for the features in input data.", 23 | "type": ["array", "null"], 24 | "items": { 25 | "type": "array" 26 | } 27 | }, 28 | "summary_importance": { 29 | "description": "The list of counterfactual summary importance for the features in input data.", 30 | "type": ["array", "null"], 31 | "items": { 32 | "type": "number" 33 | } 34 | }, 35 | "feature_names": { 36 | "description": "The list of features in the input data.", 37 | "type": ["array", "null"], 38 | "items": { 39 | "type": "string" 40 | } 41 | }, 42 | "feature_names_including_target": { 43 | "description": "The list of features including the target in input data.", 44 | "type": ["array", "null"], 45 | "items": { 46 | "type": "string" 47 | } 48 | }, 49 | "model_type": { 50 | "description": "The type of model is either a classifier/regressor", 51 | "type": ["string", "null"] 52 | }, 53 | "desired_class": { 54 | "description": "The target class for the generated counterfactual examples", 55 | "type": ["string", "integer", "null"] 56 | }, 57 | "desired_range": { 58 | "description": "The target range for the generated counterfactual examples", 59 | "type": ["array", "null"], 60 | "items": { 61 | "type": "number" 62 | } 63 | }, 64 | "data_interface": { 65 | "description": "The data interface details including outcome name.", 66 | "type": ["object", "null"] 67 | }, 68 | "metadata": { 69 | "description": "The metadata about the generated counterfactuals.", 70 | "type": "object" 71 | } 72 | }, 73 | "required": [ 74 | "cfs_list", 75 | "test_data", 76 | "local_importance", 77 | "summary_importance", 78 | "feature_names", 79 | "feature_names_including_target", 80 | "model_type", 81 | "desired_class", 82 | "desired_range", 83 | "data_interface", 84 | "metadata" 85 | ] 86 | } -------------------------------------------------------------------------------- /dice_ml/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/__init__.py -------------------------------------------------------------------------------- /dice_ml/utils/exception.py: -------------------------------------------------------------------------------- 1 | """Exceptions for the package.""" 2 | 3 | 4 | class SystemException(Exception): 5 | """An exception indicating that some system exception happened during execution. 6 | 7 | :param exception_message: A message describing the error. 8 | :type exception_message: str 9 | """ 10 | _error_code = "System Error" 11 | -------------------------------------------------------------------------------- /dice_ml/utils/neuralnetworks.py: -------------------------------------------------------------------------------- 1 | from torch import nn, sigmoid 2 | 3 | 4 | class FFNetwork(nn.Module): 5 | def __init__(self, input_size, is_classifier=True): 6 | super(FFNetwork, self).__init__() 7 | self.is_classifier = is_classifier 8 | self.flatten = nn.Flatten() 9 | self.linear_relu_stack = nn.Sequential( 10 | nn.Linear(input_size, 16), 11 | nn.ReLU(), 12 | nn.Linear(16, 1), 13 | ) 14 | 15 | def forward(self, x): 16 | x = self.flatten(x) 17 | out = self.linear_relu_stack(x) 18 | out = sigmoid(out) 19 | if not self.is_classifier: 20 | out = 3 * out # output between 0 and 3 21 | return out 22 | 23 | 24 | class MulticlassNetwork(nn.Module): 25 | def __init__(self, input_size: int, num_class: int): 26 | super(MulticlassNetwork, self).__init__() 27 | 28 | self.linear_relu_stack = nn.Sequential( 29 | nn.Linear(input_size, 16), 30 | nn.ReLU(), 31 | nn.Linear(16, num_class) 32 | ) 33 | self.softmax = nn.Softmax(dim=1) 34 | 35 | def forward(self, x): 36 | x = self.linear_relu_stack(x) 37 | out = self.softmax(x) 38 | 39 | return out 40 | -------------------------------------------------------------------------------- /dice_ml/utils/sample_architecture/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_architecture/__init__.py -------------------------------------------------------------------------------- /dice_ml/utils/sample_trained_models/adult-margin-0.165-validity_reg-42.0-epoch-25-base-gen.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult-margin-0.165-validity_reg-42.0-epoch-25-base-gen.pth -------------------------------------------------------------------------------- /dice_ml/utils/sample_trained_models/adult-margin-0.344-validity_reg-76.0-epoch-25-ae-gen.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult-margin-0.344-validity_reg-76.0-epoch-25-ae-gen.pth -------------------------------------------------------------------------------- /dice_ml/utils/sample_trained_models/adult.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult.h5 -------------------------------------------------------------------------------- /dice_ml/utils/sample_trained_models/adult.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult.pkl -------------------------------------------------------------------------------- /dice_ml/utils/sample_trained_models/adult.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult.pth -------------------------------------------------------------------------------- /dice_ml/utils/sample_trained_models/adult_2nodes.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/adult_2nodes.pth -------------------------------------------------------------------------------- /dice_ml/utils/sample_trained_models/custom.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/custom.sav -------------------------------------------------------------------------------- /dice_ml/utils/sample_trained_models/custom_binary.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/custom_binary.sav -------------------------------------------------------------------------------- /dice_ml/utils/sample_trained_models/custom_multiclass.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/custom_multiclass.sav -------------------------------------------------------------------------------- /dice_ml/utils/sample_trained_models/custom_regression.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/custom_regression.sav -------------------------------------------------------------------------------- /dice_ml/utils/sample_trained_models/custom_vars.sav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/dice_ml/utils/sample_trained_models/custom_vars.sav -------------------------------------------------------------------------------- /dice_ml/utils/serialize.py: -------------------------------------------------------------------------------- 1 | class DummyDataInterface: 2 | def __init__(self, outcome_name, data_df=None): 3 | self.outcome_name = outcome_name 4 | self.data_df = None 5 | if data_df is not None: 6 | self.data_df = data_df 7 | 8 | def to_json(self): 9 | return { 10 | 'outcome_name': self.outcome_name, 11 | 'data_df': self.data_df 12 | } 13 | -------------------------------------------------------------------------------- /docs/.buildinfo: -------------------------------------------------------------------------------- 1 | # Sphinx build info version 1 2 | # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. 3 | config: 7648e4eb332244e06501f513e1affc13 4 | tags: 645f666f9bcd5a90fca523b33c5a78b7 5 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/.nojekyll -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | 18 | movehtmlfiles: html 19 | cp -r $(BUILDDIR)/html/* . 20 | rm -r _sources 21 | 22 | # Catch-all target: route all unknown targets to Sphinx using the new 23 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 24 | %: Makefile 25 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 26 | -------------------------------------------------------------------------------- /docs/_images/dice_getting_started_api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_images/dice_getting_started_api.png -------------------------------------------------------------------------------- /docs/_modules/dice_ml/utils/exception.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | dice_ml.utils.exception — DiCE 0.11 documentation 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 62 | 63 |
67 | 68 |
69 |
70 |
71 |
    72 |
  • »
  • 73 |
  • Module code »
  • 74 |
  • dice_ml.utils.exception
  • 75 |
  • 76 |
  • 77 |
78 |
79 |
80 |
81 |
82 | 83 |

Source code for dice_ml.utils.exception

 84 | """Exceptions for the package."""
 85 | 
 86 | 
 87 | 
[docs]class SystemException(Exception): 88 | """An exception indicating that some system exception happened during execution. 89 | 90 | :param exception_message: A message describing the error. 91 | :type exception_message: str 92 | """ 93 | _error_code = "System Error"
94 |
95 | 96 |
97 |
98 | 112 |
113 |
114 |
115 |
116 | 121 | 122 | 123 | -------------------------------------------------------------------------------- /docs/_modules/dice_ml/utils/serialize.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | dice_ml.utils.serialize — DiCE 0.11 documentation 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 62 | 63 |
67 | 68 |
69 |
70 |
71 |
    72 |
  • »
  • 73 |
  • Module code »
  • 74 |
  • dice_ml.utils.serialize
  • 75 |
  • 76 |
  • 77 |
78 |
79 |
80 |
81 |
82 | 83 |

Source code for dice_ml.utils.serialize

 84 | 
[docs]class DummyDataInterface: 85 | def __init__(self, outcome_name, data_df=None): 86 | self.outcome_name = outcome_name 87 | self.data_df = None 88 | if data_df is not None: 89 | self.data_df = data_df 90 | 91 |
[docs] def to_json(self): 92 | return { 93 | 'outcome_name': self.outcome_name, 94 | 'data_df': self.data_df 95 | }
96 |
97 | 98 |
99 |
100 | 114 |
115 |
116 |
117 |
118 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /docs/_modules/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Overview: module code — DiCE 0.11 documentation 7 | 8 | 9 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 62 | 63 |
67 | 68 | 128 |
129 |
130 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /docs/_static/_sphinx_javascript_frameworks_compat.js: -------------------------------------------------------------------------------- 1 | /* 2 | * _sphinx_javascript_frameworks_compat.js 3 | * ~~~~~~~~~~ 4 | * 5 | * Compatability shim for jQuery and underscores.js. 6 | * 7 | * WILL BE REMOVED IN Sphinx 6.0 8 | * xref RemovedInSphinx60Warning 9 | * 10 | */ 11 | 12 | /** 13 | * select a different prefix for underscore 14 | */ 15 | $u = _.noConflict(); 16 | 17 | 18 | /** 19 | * small helper function to urldecode strings 20 | * 21 | * See https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/decodeURIComponent#Decoding_query_parameters_from_a_URL 22 | */ 23 | jQuery.urldecode = function(x) { 24 | if (!x) { 25 | return x 26 | } 27 | return decodeURIComponent(x.replace(/\+/g, ' ')); 28 | }; 29 | 30 | /** 31 | * small helper function to urlencode strings 32 | */ 33 | jQuery.urlencode = encodeURIComponent; 34 | 35 | /** 36 | * This function returns the parsed url parameters of the 37 | * current request. Multiple values per key are supported, 38 | * it will always return arrays of strings for the value parts. 39 | */ 40 | jQuery.getQueryParameters = function(s) { 41 | if (typeof s === 'undefined') 42 | s = document.location.search; 43 | var parts = s.substr(s.indexOf('?') + 1).split('&'); 44 | var result = {}; 45 | for (var i = 0; i < parts.length; i++) { 46 | var tmp = parts[i].split('=', 2); 47 | var key = jQuery.urldecode(tmp[0]); 48 | var value = jQuery.urldecode(tmp[1]); 49 | if (key in result) 50 | result[key].push(value); 51 | else 52 | result[key] = [value]; 53 | } 54 | return result; 55 | }; 56 | 57 | /** 58 | * highlight a given string on a jquery object by wrapping it in 59 | * span elements with the given class name. 60 | */ 61 | jQuery.fn.highlightText = function(text, className) { 62 | function highlight(node, addItems) { 63 | if (node.nodeType === 3) { 64 | var val = node.nodeValue; 65 | var pos = val.toLowerCase().indexOf(text); 66 | if (pos >= 0 && 67 | !jQuery(node.parentNode).hasClass(className) && 68 | !jQuery(node.parentNode).hasClass("nohighlight")) { 69 | var span; 70 | var isInSVG = jQuery(node).closest("body, svg, foreignObject").is("svg"); 71 | if (isInSVG) { 72 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 73 | } else { 74 | span = document.createElement("span"); 75 | span.className = className; 76 | } 77 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 78 | node.parentNode.insertBefore(span, node.parentNode.insertBefore( 79 | document.createTextNode(val.substr(pos + text.length)), 80 | node.nextSibling)); 81 | node.nodeValue = val.substr(0, pos); 82 | if (isInSVG) { 83 | var rect = document.createElementNS("http://www.w3.org/2000/svg", "rect"); 84 | var bbox = node.parentElement.getBBox(); 85 | rect.x.baseVal.value = bbox.x; 86 | rect.y.baseVal.value = bbox.y; 87 | rect.width.baseVal.value = bbox.width; 88 | rect.height.baseVal.value = bbox.height; 89 | rect.setAttribute('class', className); 90 | addItems.push({ 91 | "parent": node.parentNode, 92 | "target": rect}); 93 | } 94 | } 95 | } 96 | else if (!jQuery(node).is("button, select, textarea")) { 97 | jQuery.each(node.childNodes, function() { 98 | highlight(this, addItems); 99 | }); 100 | } 101 | } 102 | var addItems = []; 103 | var result = this.each(function() { 104 | highlight(this, addItems); 105 | }); 106 | for (var i = 0; i < addItems.length; ++i) { 107 | jQuery(addItems[i].parent).before(addItems[i].target); 108 | } 109 | return result; 110 | }; 111 | 112 | /* 113 | * backward compatibility for jQuery.browser 114 | * This will be supported until firefox bug is fixed. 115 | */ 116 | if (!jQuery.browser) { 117 | jQuery.uaMatch = function(ua) { 118 | ua = ua.toLowerCase(); 119 | 120 | var match = /(chrome)[ \/]([\w.]+)/.exec(ua) || 121 | /(webkit)[ \/]([\w.]+)/.exec(ua) || 122 | /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(ua) || 123 | /(msie) ([\w.]+)/.exec(ua) || 124 | ua.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(ua) || 125 | []; 126 | 127 | return { 128 | browser: match[ 1 ] || "", 129 | version: match[ 2 ] || "0" 130 | }; 131 | }; 132 | jQuery.browser = {}; 133 | jQuery.browser[jQuery.uaMatch(navigator.userAgent).browser] = true; 134 | } 135 | -------------------------------------------------------------------------------- /docs/_static/css/badge_only.css: -------------------------------------------------------------------------------- 1 | .fa:before{-webkit-font-smoothing:antialiased}.clearfix{*zoom:1}.clearfix:after,.clearfix:before{display:table;content:""}.clearfix:after{clear:both}@font-face{font-family:FontAwesome;font-style:normal;font-weight:400;src:url(fonts/fontawesome-webfont.eot?674f50d287a8c48dc19ba404d20fe713?#iefix) format("embedded-opentype"),url(fonts/fontawesome-webfont.woff2?af7ae505a9eed503f8b8e6982036873e) format("woff2"),url(fonts/fontawesome-webfont.woff?fee66e712a8a08eef5805a46892932ad) format("woff"),url(fonts/fontawesome-webfont.ttf?b06871f281fee6b241d60582ae9369b9) format("truetype"),url(fonts/fontawesome-webfont.svg?912ec66d7572ff821749319396470bde#FontAwesome) format("svg")}.fa:before{font-family:FontAwesome;font-style:normal;font-weight:400;line-height:1}.fa:before,a .fa{text-decoration:inherit}.fa:before,a .fa,li .fa{display:inline-block}li .fa-large:before{width:1.875em}ul.fas{list-style-type:none;margin-left:2em;text-indent:-.8em}ul.fas li .fa{width:.8em}ul.fas li .fa-large:before{vertical-align:baseline}.fa-book:before,.icon-book:before{content:"\f02d"}.fa-caret-down:before,.icon-caret-down:before{content:"\f0d7"}.fa-caret-up:before,.icon-caret-up:before{content:"\f0d8"}.fa-caret-left:before,.icon-caret-left:before{content:"\f0d9"}.fa-caret-right:before,.icon-caret-right:before{content:"\f0da"}.rst-versions{position:fixed;bottom:0;left:0;width:300px;color:#fcfcfc;background:#1f1d1d;font-family:Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;z-index:400}.rst-versions a{color:#2980b9;text-decoration:none}.rst-versions .rst-badge-small{display:none}.rst-versions .rst-current-version{padding:12px;background-color:#272525;display:block;text-align:right;font-size:90%;cursor:pointer;color:#27ae60}.rst-versions .rst-current-version:after{clear:both;content:"";display:block}.rst-versions .rst-current-version .fa{color:#fcfcfc}.rst-versions .rst-current-version .fa-book,.rst-versions .rst-current-version .icon-book{float:left}.rst-versions .rst-current-version.rst-out-of-date{background-color:#e74c3c;color:#fff}.rst-versions .rst-current-version.rst-active-old-version{background-color:#f1c40f;color:#000}.rst-versions.shift-up{height:auto;max-height:100%;overflow-y:scroll}.rst-versions.shift-up .rst-other-versions{display:block}.rst-versions .rst-other-versions{font-size:90%;padding:12px;color:grey;display:none}.rst-versions .rst-other-versions hr{display:block;height:1px;border:0;margin:20px 0;padding:0;border-top:1px solid #413d3d}.rst-versions .rst-other-versions dd{display:inline-block;margin:0}.rst-versions .rst-other-versions dd a{display:inline-block;padding:6px;color:#fcfcfc}.rst-versions.rst-badge{width:auto;bottom:20px;right:20px;left:auto;border:none;max-width:300px;max-height:90%}.rst-versions.rst-badge .fa-book,.rst-versions.rst-badge .icon-book{float:none;line-height:30px}.rst-versions.rst-badge.shift-up .rst-current-version{text-align:right}.rst-versions.rst-badge.shift-up .rst-current-version .fa-book,.rst-versions.rst-badge.shift-up .rst-current-version .icon-book{float:left}.rst-versions.rst-badge>.rst-current-version{width:auto;height:30px;line-height:30px;padding:0 6px;display:block;text-align:center}@media screen and (max-width:768px){.rst-versions{width:85%;display:none}.rst-versions.shift{display:block}} -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/_static/css/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/css/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/_static/doctools.js: -------------------------------------------------------------------------------- 1 | /* 2 | * doctools.js 3 | * ~~~~~~~~~~~ 4 | * 5 | * Base JavaScript utilities for all Sphinx HTML documentation. 6 | * 7 | * :copyright: Copyright 2007-2022 by the Sphinx team, see AUTHORS. 8 | * :license: BSD, see LICENSE for details. 9 | * 10 | */ 11 | "use strict"; 12 | 13 | const BLACKLISTED_KEY_CONTROL_ELEMENTS = new Set([ 14 | "TEXTAREA", 15 | "INPUT", 16 | "SELECT", 17 | "BUTTON", 18 | ]); 19 | 20 | const _ready = (callback) => { 21 | if (document.readyState !== "loading") { 22 | callback(); 23 | } else { 24 | document.addEventListener("DOMContentLoaded", callback); 25 | } 26 | }; 27 | 28 | /** 29 | * Small JavaScript module for the documentation. 30 | */ 31 | const Documentation = { 32 | init: () => { 33 | Documentation.initDomainIndexTable(); 34 | Documentation.initOnKeyListeners(); 35 | }, 36 | 37 | /** 38 | * i18n support 39 | */ 40 | TRANSLATIONS: {}, 41 | PLURAL_EXPR: (n) => (n === 1 ? 0 : 1), 42 | LOCALE: "unknown", 43 | 44 | // gettext and ngettext don't access this so that the functions 45 | // can safely bound to a different name (_ = Documentation.gettext) 46 | gettext: (string) => { 47 | const translated = Documentation.TRANSLATIONS[string]; 48 | switch (typeof translated) { 49 | case "undefined": 50 | return string; // no translation 51 | case "string": 52 | return translated; // translation exists 53 | default: 54 | return translated[0]; // (singular, plural) translation tuple exists 55 | } 56 | }, 57 | 58 | ngettext: (singular, plural, n) => { 59 | const translated = Documentation.TRANSLATIONS[singular]; 60 | if (typeof translated !== "undefined") 61 | return translated[Documentation.PLURAL_EXPR(n)]; 62 | return n === 1 ? singular : plural; 63 | }, 64 | 65 | addTranslations: (catalog) => { 66 | Object.assign(Documentation.TRANSLATIONS, catalog.messages); 67 | Documentation.PLURAL_EXPR = new Function( 68 | "n", 69 | `return (${catalog.plural_expr})` 70 | ); 71 | Documentation.LOCALE = catalog.locale; 72 | }, 73 | 74 | /** 75 | * helper function to focus on search bar 76 | */ 77 | focusSearchBar: () => { 78 | document.querySelectorAll("input[name=q]")[0]?.focus(); 79 | }, 80 | 81 | /** 82 | * Initialise the domain index toggle buttons 83 | */ 84 | initDomainIndexTable: () => { 85 | const toggler = (el) => { 86 | const idNumber = el.id.substr(7); 87 | const toggledRows = document.querySelectorAll(`tr.cg-${idNumber}`); 88 | if (el.src.substr(-9) === "minus.png") { 89 | el.src = `${el.src.substr(0, el.src.length - 9)}plus.png`; 90 | toggledRows.forEach((el) => (el.style.display = "none")); 91 | } else { 92 | el.src = `${el.src.substr(0, el.src.length - 8)}minus.png`; 93 | toggledRows.forEach((el) => (el.style.display = "")); 94 | } 95 | }; 96 | 97 | const togglerElements = document.querySelectorAll("img.toggler"); 98 | togglerElements.forEach((el) => 99 | el.addEventListener("click", (event) => toggler(event.currentTarget)) 100 | ); 101 | togglerElements.forEach((el) => (el.style.display = "")); 102 | if (DOCUMENTATION_OPTIONS.COLLAPSE_INDEX) togglerElements.forEach(toggler); 103 | }, 104 | 105 | initOnKeyListeners: () => { 106 | // only install a listener if it is really needed 107 | if ( 108 | !DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS && 109 | !DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS 110 | ) 111 | return; 112 | 113 | document.addEventListener("keydown", (event) => { 114 | // bail for input elements 115 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; 116 | // bail with special keys 117 | if (event.altKey || event.ctrlKey || event.metaKey) return; 118 | 119 | if (!event.shiftKey) { 120 | switch (event.key) { 121 | case "ArrowLeft": 122 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; 123 | 124 | const prevLink = document.querySelector('link[rel="prev"]'); 125 | if (prevLink && prevLink.href) { 126 | window.location.href = prevLink.href; 127 | event.preventDefault(); 128 | } 129 | break; 130 | case "ArrowRight": 131 | if (!DOCUMENTATION_OPTIONS.NAVIGATION_WITH_KEYS) break; 132 | 133 | const nextLink = document.querySelector('link[rel="next"]'); 134 | if (nextLink && nextLink.href) { 135 | window.location.href = nextLink.href; 136 | event.preventDefault(); 137 | } 138 | break; 139 | } 140 | } 141 | 142 | // some keyboard layouts may need Shift to get / 143 | switch (event.key) { 144 | case "/": 145 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) break; 146 | Documentation.focusSearchBar(); 147 | event.preventDefault(); 148 | } 149 | }); 150 | }, 151 | }; 152 | 153 | // quick alias for translations 154 | const _ = Documentation.gettext; 155 | 156 | _ready(Documentation.init); 157 | -------------------------------------------------------------------------------- /docs/_static/documentation_options.js: -------------------------------------------------------------------------------- 1 | var DOCUMENTATION_OPTIONS = { 2 | URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), 3 | VERSION: '0.11', 4 | LANGUAGE: 'en', 5 | COLLAPSE_INDEX: false, 6 | BUILDER: 'html', 7 | FILE_SUFFIX: '.html', 8 | LINK_SUFFIX: '.html', 9 | HAS_SOURCE: true, 10 | SOURCELINK_SUFFIX: '.txt', 11 | NAVIGATION_WITH_KEYS: false, 12 | SHOW_SEARCH_SUMMARY: true, 13 | ENABLE_SEARCH_SHORTCUTS: true, 14 | }; -------------------------------------------------------------------------------- /docs/_static/file.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/file.png -------------------------------------------------------------------------------- /docs/_static/fonts/FontAwesome.otf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/FontAwesome.otf -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Inconsolata-Bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Inconsolata-Regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Inconsolata.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Inconsolata.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato-Bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato-Regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bold.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bolditalic.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bolditalic.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bolditalic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-bolditalic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-bolditalic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-italic.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-italic.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-italic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-regular.eot -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Lato/lato-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Lato/lato-regular.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Light.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Light.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Light.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Light.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Regular.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Thin.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Thin.woff -------------------------------------------------------------------------------- /docs/_static/fonts/Roboto-Slab-Thin.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/Roboto-Slab-Thin.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab-Bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab-Regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.eot -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.eot -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff -------------------------------------------------------------------------------- /docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/RobotoSlab/roboto-slab-v7-regular.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.eot: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/fontawesome-webfont.eot -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/fontawesome-webfont.ttf -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/fontawesome-webfont.woff -------------------------------------------------------------------------------- /docs/_static/fonts/fontawesome-webfont.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/fontawesome-webfont.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/lato-bold-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-bold-italic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/lato-bold-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-bold-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/lato-bold.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-bold.woff -------------------------------------------------------------------------------- /docs/_static/fonts/lato-bold.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-bold.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/lato-normal-italic.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-normal-italic.woff -------------------------------------------------------------------------------- /docs/_static/fonts/lato-normal-italic.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-normal-italic.woff2 -------------------------------------------------------------------------------- /docs/_static/fonts/lato-normal.woff: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-normal.woff -------------------------------------------------------------------------------- /docs/_static/fonts/lato-normal.woff2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/fonts/lato-normal.woff2 -------------------------------------------------------------------------------- /docs/_static/getting_started_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/getting_started_output.png -------------------------------------------------------------------------------- /docs/_static/getting_started_updated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/getting_started_updated.png -------------------------------------------------------------------------------- /docs/_static/js/badge_only.js: -------------------------------------------------------------------------------- 1 | !function(e){var t={};function r(n){if(t[n])return t[n].exports;var o=t[n]={i:n,l:!1,exports:{}};return e[n].call(o.exports,o,o.exports,r),o.l=!0,o.exports}r.m=e,r.c=t,r.d=function(e,t,n){r.o(e,t)||Object.defineProperty(e,t,{enumerable:!0,get:n})},r.r=function(e){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(e,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(e,"__esModule",{value:!0})},r.t=function(e,t){if(1&t&&(e=r(e)),8&t)return e;if(4&t&&"object"==typeof e&&e&&e.__esModule)return e;var n=Object.create(null);if(r.r(n),Object.defineProperty(n,"default",{enumerable:!0,value:e}),2&t&&"string"!=typeof e)for(var o in e)r.d(n,o,function(t){return e[t]}.bind(null,o));return n},r.n=function(e){var t=e&&e.__esModule?function(){return e.default}:function(){return e};return r.d(t,"a",t),t},r.o=function(e,t){return Object.prototype.hasOwnProperty.call(e,t)},r.p="",r(r.s=4)}({4:function(e,t,r){}}); -------------------------------------------------------------------------------- /docs/_static/js/html5shiv-printshiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3-pre | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=y.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=y.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),y.elements=c+" "+a,j(b)}function f(a){var b=x[a[v]];return b||(b={},w++,a[v]=w,x[w]=b),b}function g(a,c,d){if(c||(c=b),q)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():u.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||t.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),q)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return y.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(y,b.frag)}function j(a){a||(a=b);var d=f(a);return!y.shivCSS||p||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),q||i(a,d),a}function k(a){for(var b,c=a.getElementsByTagName("*"),e=c.length,f=RegExp("^(?:"+d().join("|")+")$","i"),g=[];e--;)b=c[e],f.test(b.nodeName)&&g.push(b.applyElement(l(b)));return g}function l(a){for(var b,c=a.attributes,d=c.length,e=a.ownerDocument.createElement(A+":"+a.nodeName);d--;)b=c[d],b.specified&&e.setAttribute(b.nodeName,b.nodeValue);return e.style.cssText=a.style.cssText,e}function m(a){for(var b,c=a.split("{"),e=c.length,f=RegExp("(^|[\\s,>+~])("+d().join("|")+")(?=[[\\s,>+~#.:]|$)","gi"),g="$1"+A+"\\:$2";e--;)b=c[e]=c[e].split("}"),b[b.length-1]=b[b.length-1].replace(f,g),c[e]=b.join("}");return c.join("{")}function n(a){for(var b=a.length;b--;)a[b].removeNode()}function o(a){function b(){clearTimeout(g._removeSheetTimer),d&&d.removeNode(!0),d=null}var d,e,g=f(a),h=a.namespaces,i=a.parentWindow;return!B||a.printShived?a:("undefined"==typeof h[A]&&h.add(A),i.attachEvent("onbeforeprint",function(){b();for(var f,g,h,i=a.styleSheets,j=[],l=i.length,n=Array(l);l--;)n[l]=i[l];for(;h=n.pop();)if(!h.disabled&&z.test(h.media)){try{f=h.imports,g=f.length}catch(o){g=0}for(l=0;g>l;l++)n.push(f[l]);try{j.push(h.cssText)}catch(o){}}j=m(j.reverse().join("")),e=k(a),d=c(a,j)}),i.attachEvent("onafterprint",function(){n(e),clearTimeout(g._removeSheetTimer),g._removeSheetTimer=setTimeout(b,500)}),a.printShived=!0,a)}var p,q,r="3.7.3",s=a.html5||{},t=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,u=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,v="_html5shiv",w=0,x={};!function(){try{var a=b.createElement("a");a.innerHTML="",p="hidden"in a,q=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){p=!0,q=!0}}();var y={elements:s.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:r,shivCSS:s.shivCSS!==!1,supportsUnknownElements:q,shivMethods:s.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=y,j(b);var z=/^$|\b(?:all|print)\b/,A="html5shiv",B=!q&&function(){var c=b.documentElement;return!("undefined"==typeof b.namespaces||"undefined"==typeof b.parentWindow||"undefined"==typeof c.applyElement||"undefined"==typeof c.removeNode||"undefined"==typeof a.attachEvent)}();y.type+=" print",y.shivPrint=o,o(b),"object"==typeof module&&module.exports&&(module.exports=y)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/_static/js/html5shiv.min.js: -------------------------------------------------------------------------------- 1 | /** 2 | * @preserve HTML5 Shiv 3.7.3 | @afarkas @jdalton @jon_neal @rem | MIT/GPL2 Licensed 3 | */ 4 | !function(a,b){function c(a,b){var c=a.createElement("p"),d=a.getElementsByTagName("head")[0]||a.documentElement;return c.innerHTML="x",d.insertBefore(c.lastChild,d.firstChild)}function d(){var a=t.elements;return"string"==typeof a?a.split(" "):a}function e(a,b){var c=t.elements;"string"!=typeof c&&(c=c.join(" ")),"string"!=typeof a&&(a=a.join(" ")),t.elements=c+" "+a,j(b)}function f(a){var b=s[a[q]];return b||(b={},r++,a[q]=r,s[r]=b),b}function g(a,c,d){if(c||(c=b),l)return c.createElement(a);d||(d=f(c));var e;return e=d.cache[a]?d.cache[a].cloneNode():p.test(a)?(d.cache[a]=d.createElem(a)).cloneNode():d.createElem(a),!e.canHaveChildren||o.test(a)||e.tagUrn?e:d.frag.appendChild(e)}function h(a,c){if(a||(a=b),l)return a.createDocumentFragment();c=c||f(a);for(var e=c.frag.cloneNode(),g=0,h=d(),i=h.length;i>g;g++)e.createElement(h[g]);return e}function i(a,b){b.cache||(b.cache={},b.createElem=a.createElement,b.createFrag=a.createDocumentFragment,b.frag=b.createFrag()),a.createElement=function(c){return t.shivMethods?g(c,a,b):b.createElem(c)},a.createDocumentFragment=Function("h,f","return function(){var n=f.cloneNode(),c=n.createElement;h.shivMethods&&("+d().join().replace(/[\w\-:]+/g,function(a){return b.createElem(a),b.frag.createElement(a),'c("'+a+'")'})+");return n}")(t,b.frag)}function j(a){a||(a=b);var d=f(a);return!t.shivCSS||k||d.hasCSS||(d.hasCSS=!!c(a,"article,aside,dialog,figcaption,figure,footer,header,hgroup,main,nav,section{display:block}mark{background:#FF0;color:#000}template{display:none}")),l||i(a,d),a}var k,l,m="3.7.3-pre",n=a.html5||{},o=/^<|^(?:button|map|select|textarea|object|iframe|option|optgroup)$/i,p=/^(?:a|b|code|div|fieldset|h1|h2|h3|h4|h5|h6|i|label|li|ol|p|q|span|strong|style|table|tbody|td|th|tr|ul)$/i,q="_html5shiv",r=0,s={};!function(){try{var a=b.createElement("a");a.innerHTML="",k="hidden"in a,l=1==a.childNodes.length||function(){b.createElement("a");var a=b.createDocumentFragment();return"undefined"==typeof a.cloneNode||"undefined"==typeof a.createDocumentFragment||"undefined"==typeof a.createElement}()}catch(c){k=!0,l=!0}}();var t={elements:n.elements||"abbr article aside audio bdi canvas data datalist details dialog figcaption figure footer header hgroup main mark meter nav output picture progress section summary template time video",version:m,shivCSS:n.shivCSS!==!1,supportsUnknownElements:l,shivMethods:n.shivMethods!==!1,type:"default",shivDocument:j,createElement:g,createDocumentFragment:h,addElements:e};a.html5=t,j(b),"object"==typeof module&&module.exports&&(module.exports=t)}("undefined"!=typeof window?window:this,document); -------------------------------------------------------------------------------- /docs/_static/js/theme.js: -------------------------------------------------------------------------------- 1 | !function(n){var e={};function t(i){if(e[i])return e[i].exports;var o=e[i]={i:i,l:!1,exports:{}};return n[i].call(o.exports,o,o.exports,t),o.l=!0,o.exports}t.m=n,t.c=e,t.d=function(n,e,i){t.o(n,e)||Object.defineProperty(n,e,{enumerable:!0,get:i})},t.r=function(n){"undefined"!=typeof Symbol&&Symbol.toStringTag&&Object.defineProperty(n,Symbol.toStringTag,{value:"Module"}),Object.defineProperty(n,"__esModule",{value:!0})},t.t=function(n,e){if(1&e&&(n=t(n)),8&e)return n;if(4&e&&"object"==typeof n&&n&&n.__esModule)return n;var i=Object.create(null);if(t.r(i),Object.defineProperty(i,"default",{enumerable:!0,value:n}),2&e&&"string"!=typeof n)for(var o in n)t.d(i,o,function(e){return n[e]}.bind(null,o));return i},t.n=function(n){var e=n&&n.__esModule?function(){return n.default}:function(){return n};return t.d(e,"a",e),e},t.o=function(n,e){return Object.prototype.hasOwnProperty.call(n,e)},t.p="",t(t.s=0)}([function(n,e,t){t(1),n.exports=t(3)},function(n,e,t){(function(){var e="undefined"!=typeof window?window.jQuery:t(2);n.exports.ThemeNav={navBar:null,win:null,winScroll:!1,winResize:!1,linkScroll:!1,winPosition:0,winHeight:null,docHeight:null,isRunning:!1,enable:function(n){var t=this;void 0===n&&(n=!0),t.isRunning||(t.isRunning=!0,e((function(e){t.init(e),t.reset(),t.win.on("hashchange",t.reset),n&&t.win.on("scroll",(function(){t.linkScroll||t.winScroll||(t.winScroll=!0,requestAnimationFrame((function(){t.onScroll()})))})),t.win.on("resize",(function(){t.winResize||(t.winResize=!0,requestAnimationFrame((function(){t.onResize()})))})),t.onResize()})))},enableSticky:function(){this.enable(!0)},init:function(n){n(document);var e=this;this.navBar=n("div.wy-side-scroll:first"),this.win=n(window),n(document).on("click","[data-toggle='wy-nav-top']",(function(){n("[data-toggle='wy-nav-shift']").toggleClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift")})).on("click",".wy-menu-vertical .current ul li a",(function(){var t=n(this);n("[data-toggle='wy-nav-shift']").removeClass("shift"),n("[data-toggle='rst-versions']").toggleClass("shift"),e.toggleCurrent(t),e.hashChange()})).on("click","[data-toggle='rst-current-version']",(function(){n("[data-toggle='rst-versions']").toggleClass("shift-up")})),n("table.docutils:not(.field-list,.footnote,.citation)").wrap("
"),n("table.docutils.footnote").wrap("
"),n("table.docutils.citation").wrap("
"),n(".wy-menu-vertical ul").not(".simple").siblings("a").each((function(){var t=n(this);expand=n(''),expand.on("click",(function(n){return e.toggleCurrent(t),n.stopPropagation(),!1})),t.prepend(expand)}))},reset:function(){var n=encodeURI(window.location.hash)||"#";try{var e=$(".wy-menu-vertical"),t=e.find('[href="'+n+'"]');if(0===t.length){var i=$('.document [id="'+n.substring(1)+'"]').closest("div.section");0===(t=e.find('[href="#'+i.attr("id")+'"]')).length&&(t=e.find('[href="#"]'))}if(t.length>0){$(".wy-menu-vertical .current").removeClass("current").attr("aria-expanded","false"),t.addClass("current").attr("aria-expanded","true"),t.closest("li.toctree-l1").parent().addClass("current").attr("aria-expanded","true");for(let n=1;n<=10;n++)t.closest("li.toctree-l"+n).addClass("current").attr("aria-expanded","true");t[0].scrollIntoView()}}catch(n){console.log("Error expanding nav for anchor",n)}},onScroll:function(){this.winScroll=!1;var n=this.win.scrollTop(),e=n+this.winHeight,t=this.navBar.scrollTop()+(n-this.winPosition);n<0||e>this.docHeight||(this.navBar.scrollTop(t),this.winPosition=n)},onResize:function(){this.winResize=!1,this.winHeight=this.win.height(),this.docHeight=$(document).height()},hashChange:function(){this.linkScroll=!0,this.win.one("hashchange",(function(){this.linkScroll=!1}))},toggleCurrent:function(n){var e=n.closest("li");e.siblings("li.current").removeClass("current").attr("aria-expanded","false"),e.siblings().find("li.current").removeClass("current").attr("aria-expanded","false");var t=e.find("> ul li");t.length&&(t.removeClass("current").attr("aria-expanded","false"),e.toggleClass("current").attr("aria-expanded",(function(n,e){return"true"==e?"false":"true"})))}},"undefined"!=typeof window&&(window.SphinxRtdTheme={Navigation:n.exports.ThemeNav,StickyNav:n.exports.ThemeNav}),function(){for(var n=0,e=["ms","moz","webkit","o"],t=0;t0 63 | var meq1 = "^(" + C + ")?" + V + C + "(" + V + ")?$"; // [C]VC[V] is m=1 64 | var mgr1 = "^(" + C + ")?" + V + C + V + C; // [C]VCVC... is m>1 65 | var s_v = "^(" + C + ")?" + v; // vowel in stem 66 | 67 | this.stemWord = function (w) { 68 | var stem; 69 | var suffix; 70 | var firstch; 71 | var origword = w; 72 | 73 | if (w.length < 3) 74 | return w; 75 | 76 | var re; 77 | var re2; 78 | var re3; 79 | var re4; 80 | 81 | firstch = w.substr(0,1); 82 | if (firstch == "y") 83 | w = firstch.toUpperCase() + w.substr(1); 84 | 85 | // Step 1a 86 | re = /^(.+?)(ss|i)es$/; 87 | re2 = /^(.+?)([^s])s$/; 88 | 89 | if (re.test(w)) 90 | w = w.replace(re,"$1$2"); 91 | else if (re2.test(w)) 92 | w = w.replace(re2,"$1$2"); 93 | 94 | // Step 1b 95 | re = /^(.+?)eed$/; 96 | re2 = /^(.+?)(ed|ing)$/; 97 | if (re.test(w)) { 98 | var fp = re.exec(w); 99 | re = new RegExp(mgr0); 100 | if (re.test(fp[1])) { 101 | re = /.$/; 102 | w = w.replace(re,""); 103 | } 104 | } 105 | else if (re2.test(w)) { 106 | var fp = re2.exec(w); 107 | stem = fp[1]; 108 | re2 = new RegExp(s_v); 109 | if (re2.test(stem)) { 110 | w = stem; 111 | re2 = /(at|bl|iz)$/; 112 | re3 = new RegExp("([^aeiouylsz])\\1$"); 113 | re4 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 114 | if (re2.test(w)) 115 | w = w + "e"; 116 | else if (re3.test(w)) { 117 | re = /.$/; 118 | w = w.replace(re,""); 119 | } 120 | else if (re4.test(w)) 121 | w = w + "e"; 122 | } 123 | } 124 | 125 | // Step 1c 126 | re = /^(.+?)y$/; 127 | if (re.test(w)) { 128 | var fp = re.exec(w); 129 | stem = fp[1]; 130 | re = new RegExp(s_v); 131 | if (re.test(stem)) 132 | w = stem + "i"; 133 | } 134 | 135 | // Step 2 136 | re = /^(.+?)(ational|tional|enci|anci|izer|bli|alli|entli|eli|ousli|ization|ation|ator|alism|iveness|fulness|ousness|aliti|iviti|biliti|logi)$/; 137 | if (re.test(w)) { 138 | var fp = re.exec(w); 139 | stem = fp[1]; 140 | suffix = fp[2]; 141 | re = new RegExp(mgr0); 142 | if (re.test(stem)) 143 | w = stem + step2list[suffix]; 144 | } 145 | 146 | // Step 3 147 | re = /^(.+?)(icate|ative|alize|iciti|ical|ful|ness)$/; 148 | if (re.test(w)) { 149 | var fp = re.exec(w); 150 | stem = fp[1]; 151 | suffix = fp[2]; 152 | re = new RegExp(mgr0); 153 | if (re.test(stem)) 154 | w = stem + step3list[suffix]; 155 | } 156 | 157 | // Step 4 158 | re = /^(.+?)(al|ance|ence|er|ic|able|ible|ant|ement|ment|ent|ou|ism|ate|iti|ous|ive|ize)$/; 159 | re2 = /^(.+?)(s|t)(ion)$/; 160 | if (re.test(w)) { 161 | var fp = re.exec(w); 162 | stem = fp[1]; 163 | re = new RegExp(mgr1); 164 | if (re.test(stem)) 165 | w = stem; 166 | } 167 | else if (re2.test(w)) { 168 | var fp = re2.exec(w); 169 | stem = fp[1] + fp[2]; 170 | re2 = new RegExp(mgr1); 171 | if (re2.test(stem)) 172 | w = stem; 173 | } 174 | 175 | // Step 5 176 | re = /^(.+?)e$/; 177 | if (re.test(w)) { 178 | var fp = re.exec(w); 179 | stem = fp[1]; 180 | re = new RegExp(mgr1); 181 | re2 = new RegExp(meq1); 182 | re3 = new RegExp("^" + C + v + "[^aeiouwxy]$"); 183 | if (re.test(stem) || (re2.test(stem) && !(re3.test(stem)))) 184 | w = stem; 185 | } 186 | re = /ll$/; 187 | re2 = new RegExp(mgr1); 188 | if (re.test(w) && re2.test(w)) { 189 | re = /.$/; 190 | w = w.replace(re,""); 191 | } 192 | 193 | // and turn initial Y back to y 194 | if (firstch == "y") 195 | w = firstch.toLowerCase() + w.substr(1); 196 | return w; 197 | } 198 | } 199 | 200 | -------------------------------------------------------------------------------- /docs/_static/minus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/minus.png -------------------------------------------------------------------------------- /docs/_static/plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/_static/plus.png -------------------------------------------------------------------------------- /docs/_static/pygments.css: -------------------------------------------------------------------------------- 1 | pre { line-height: 125%; } 2 | td.linenos .normal { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 3 | span.linenos { color: inherit; background-color: transparent; padding-left: 5px; padding-right: 5px; } 4 | td.linenos .special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 5 | span.linenos.special { color: #000000; background-color: #ffffc0; padding-left: 5px; padding-right: 5px; } 6 | .highlight .hll { background-color: #ffffcc } 7 | .highlight { background: #f8f8f8; } 8 | .highlight .c { color: #3D7B7B; font-style: italic } /* Comment */ 9 | .highlight .err { border: 1px solid #FF0000 } /* Error */ 10 | .highlight .k { color: #008000; font-weight: bold } /* Keyword */ 11 | .highlight .o { color: #666666 } /* Operator */ 12 | .highlight .ch { color: #3D7B7B; font-style: italic } /* Comment.Hashbang */ 13 | .highlight .cm { color: #3D7B7B; font-style: italic } /* Comment.Multiline */ 14 | .highlight .cp { color: #9C6500 } /* Comment.Preproc */ 15 | .highlight .cpf { color: #3D7B7B; font-style: italic } /* Comment.PreprocFile */ 16 | .highlight .c1 { color: #3D7B7B; font-style: italic } /* Comment.Single */ 17 | .highlight .cs { color: #3D7B7B; font-style: italic } /* Comment.Special */ 18 | .highlight .gd { color: #A00000 } /* Generic.Deleted */ 19 | .highlight .ge { font-style: italic } /* Generic.Emph */ 20 | .highlight .gr { color: #E40000 } /* Generic.Error */ 21 | .highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */ 22 | .highlight .gi { color: #008400 } /* Generic.Inserted */ 23 | .highlight .go { color: #717171 } /* Generic.Output */ 24 | .highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */ 25 | .highlight .gs { font-weight: bold } /* Generic.Strong */ 26 | .highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */ 27 | .highlight .gt { color: #0044DD } /* Generic.Traceback */ 28 | .highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */ 29 | .highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */ 30 | .highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */ 31 | .highlight .kp { color: #008000 } /* Keyword.Pseudo */ 32 | .highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */ 33 | .highlight .kt { color: #B00040 } /* Keyword.Type */ 34 | .highlight .m { color: #666666 } /* Literal.Number */ 35 | .highlight .s { color: #BA2121 } /* Literal.String */ 36 | .highlight .na { color: #687822 } /* Name.Attribute */ 37 | .highlight .nb { color: #008000 } /* Name.Builtin */ 38 | .highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */ 39 | .highlight .no { color: #880000 } /* Name.Constant */ 40 | .highlight .nd { color: #AA22FF } /* Name.Decorator */ 41 | .highlight .ni { color: #717171; font-weight: bold } /* Name.Entity */ 42 | .highlight .ne { color: #CB3F38; font-weight: bold } /* Name.Exception */ 43 | .highlight .nf { color: #0000FF } /* Name.Function */ 44 | .highlight .nl { color: #767600 } /* Name.Label */ 45 | .highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */ 46 | .highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */ 47 | .highlight .nv { color: #19177C } /* Name.Variable */ 48 | .highlight .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */ 49 | .highlight .w { color: #bbbbbb } /* Text.Whitespace */ 50 | .highlight .mb { color: #666666 } /* Literal.Number.Bin */ 51 | .highlight .mf { color: #666666 } /* Literal.Number.Float */ 52 | .highlight .mh { color: #666666 } /* Literal.Number.Hex */ 53 | .highlight .mi { color: #666666 } /* Literal.Number.Integer */ 54 | .highlight .mo { color: #666666 } /* Literal.Number.Oct */ 55 | .highlight .sa { color: #BA2121 } /* Literal.String.Affix */ 56 | .highlight .sb { color: #BA2121 } /* Literal.String.Backtick */ 57 | .highlight .sc { color: #BA2121 } /* Literal.String.Char */ 58 | .highlight .dl { color: #BA2121 } /* Literal.String.Delimiter */ 59 | .highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */ 60 | .highlight .s2 { color: #BA2121 } /* Literal.String.Double */ 61 | .highlight .se { color: #AA5D1F; font-weight: bold } /* Literal.String.Escape */ 62 | .highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */ 63 | .highlight .si { color: #A45A77; font-weight: bold } /* Literal.String.Interpol */ 64 | .highlight .sx { color: #008000 } /* Literal.String.Other */ 65 | .highlight .sr { color: #A45A77 } /* Literal.String.Regex */ 66 | .highlight .s1 { color: #BA2121 } /* Literal.String.Single */ 67 | .highlight .ss { color: #19177C } /* Literal.String.Symbol */ 68 | .highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */ 69 | .highlight .fm { color: #0000FF } /* Name.Function.Magic */ 70 | .highlight .vc { color: #19177C } /* Name.Variable.Class */ 71 | .highlight .vg { color: #19177C } /* Name.Variable.Global */ 72 | .highlight .vi { color: #19177C } /* Name.Variable.Instance */ 73 | .highlight .vm { color: #19177C } /* Name.Variable.Magic */ 74 | .highlight .il { color: #666666 } /* Literal.Number.Integer.Long */ -------------------------------------------------------------------------------- /docs/_static/sphinx_highlight.js: -------------------------------------------------------------------------------- 1 | /* Highlighting utilities for Sphinx HTML documentation. */ 2 | "use strict"; 3 | 4 | const SPHINX_HIGHLIGHT_ENABLED = true 5 | 6 | /** 7 | * highlight a given string on a node by wrapping it in 8 | * span elements with the given class name. 9 | */ 10 | const _highlight = (node, addItems, text, className) => { 11 | if (node.nodeType === Node.TEXT_NODE) { 12 | const val = node.nodeValue; 13 | const parent = node.parentNode; 14 | const pos = val.toLowerCase().indexOf(text); 15 | if ( 16 | pos >= 0 && 17 | !parent.classList.contains(className) && 18 | !parent.classList.contains("nohighlight") 19 | ) { 20 | let span; 21 | 22 | const closestNode = parent.closest("body, svg, foreignObject"); 23 | const isInSVG = closestNode && closestNode.matches("svg"); 24 | if (isInSVG) { 25 | span = document.createElementNS("http://www.w3.org/2000/svg", "tspan"); 26 | } else { 27 | span = document.createElement("span"); 28 | span.classList.add(className); 29 | } 30 | 31 | span.appendChild(document.createTextNode(val.substr(pos, text.length))); 32 | parent.insertBefore( 33 | span, 34 | parent.insertBefore( 35 | document.createTextNode(val.substr(pos + text.length)), 36 | node.nextSibling 37 | ) 38 | ); 39 | node.nodeValue = val.substr(0, pos); 40 | 41 | if (isInSVG) { 42 | const rect = document.createElementNS( 43 | "http://www.w3.org/2000/svg", 44 | "rect" 45 | ); 46 | const bbox = parent.getBBox(); 47 | rect.x.baseVal.value = bbox.x; 48 | rect.y.baseVal.value = bbox.y; 49 | rect.width.baseVal.value = bbox.width; 50 | rect.height.baseVal.value = bbox.height; 51 | rect.setAttribute("class", className); 52 | addItems.push({ parent: parent, target: rect }); 53 | } 54 | } 55 | } else if (node.matches && !node.matches("button, select, textarea")) { 56 | node.childNodes.forEach((el) => _highlight(el, addItems, text, className)); 57 | } 58 | }; 59 | const _highlightText = (thisNode, text, className) => { 60 | let addItems = []; 61 | _highlight(thisNode, addItems, text, className); 62 | addItems.forEach((obj) => 63 | obj.parent.insertAdjacentElement("beforebegin", obj.target) 64 | ); 65 | }; 66 | 67 | /** 68 | * Small JavaScript module for the documentation. 69 | */ 70 | const SphinxHighlight = { 71 | 72 | /** 73 | * highlight the search words provided in localstorage in the text 74 | */ 75 | highlightSearchWords: () => { 76 | if (!SPHINX_HIGHLIGHT_ENABLED) return; // bail if no highlight 77 | 78 | // get and clear terms from localstorage 79 | const url = new URL(window.location); 80 | const highlight = 81 | localStorage.getItem("sphinx_highlight_terms") 82 | || url.searchParams.get("highlight") 83 | || ""; 84 | localStorage.removeItem("sphinx_highlight_terms") 85 | url.searchParams.delete("highlight"); 86 | window.history.replaceState({}, "", url); 87 | 88 | // get individual terms from highlight string 89 | const terms = highlight.toLowerCase().split(/\s+/).filter(x => x); 90 | if (terms.length === 0) return; // nothing to do 91 | 92 | // There should never be more than one element matching "div.body" 93 | const divBody = document.querySelectorAll("div.body"); 94 | const body = divBody.length ? divBody[0] : document.querySelector("body"); 95 | window.setTimeout(() => { 96 | terms.forEach((term) => _highlightText(body, term, "highlighted")); 97 | }, 10); 98 | 99 | const searchBox = document.getElementById("searchbox"); 100 | if (searchBox === null) return; 101 | searchBox.appendChild( 102 | document 103 | .createRange() 104 | .createContextualFragment( 105 | '" 109 | ) 110 | ); 111 | }, 112 | 113 | /** 114 | * helper function to hide the search marks again 115 | */ 116 | hideSearchWords: () => { 117 | document 118 | .querySelectorAll("#searchbox .highlight-link") 119 | .forEach((el) => el.remove()); 120 | document 121 | .querySelectorAll("span.highlighted") 122 | .forEach((el) => el.classList.remove("highlighted")); 123 | localStorage.removeItem("sphinx_highlight_terms") 124 | }, 125 | 126 | initEscapeListener: () => { 127 | // only install a listener if it is really needed 128 | if (!DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS) return; 129 | 130 | document.addEventListener("keydown", (event) => { 131 | // bail for input elements 132 | if (BLACKLISTED_KEY_CONTROL_ELEMENTS.has(document.activeElement.tagName)) return; 133 | // bail with special keys 134 | if (event.shiftKey || event.altKey || event.ctrlKey || event.metaKey) return; 135 | if (DOCUMENTATION_OPTIONS.ENABLE_SEARCH_SHORTCUTS && (event.key === "Escape")) { 136 | SphinxHighlight.hideSearchWords(); 137 | event.preventDefault(); 138 | } 139 | }); 140 | }, 141 | }; 142 | 143 | _ready(SphinxHighlight.highlightSearchWords); 144 | _ready(SphinxHighlight.initEscapeListener); 145 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/objects.inv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/objects.inv -------------------------------------------------------------------------------- /docs/search.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Search — DiCE 0.11 documentation 7 | 8 | 9 | 10 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 |
30 | 65 | 66 |
70 | 71 |
72 |
73 |
74 |
    75 |
  • »
  • 76 |
  • Search
  • 77 |
  • 78 |
  • 79 |
80 |
81 |
82 |
83 |
84 | 85 | 92 | 93 | 94 |
95 | 96 |
97 | 98 |
99 |
100 |
101 | 102 |
103 | 104 |
105 |

© Copyright 2020, Ramaravind, Amit, Chenhao.

106 |
107 | 108 | Built with Sphinx using a 109 | theme 110 | provided by Read the Docs. 111 | 112 | 113 |
114 |
115 |
116 |
117 |
118 | 123 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("../../")) 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'DiCE' 21 | copyright = '2020, Ramaravind, Amit, Chenhao' # noqa: A001 22 | author = 'Ramaravind, Amit, Chenhao' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = '0.11' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'sphinx.ext.autodoc', 35 | 'sphinx.ext.viewcode', 36 | 'sphinx.ext.todo', 37 | 'nbsphinx', 38 | 'sphinx_rtd_theme' 39 | ] 40 | 41 | autodoc_mock_imports = ['numpy', 'pandas', 'matplotlib', 'os', 'tensorflow', 'random', 'collections', 42 | 'timeit', 'tensorflow.keras', 'sklearn', 'sklearn.model_selection.train_test_split', 43 | 'copy', 'IPython', 'IPython.display.display', 'collections', 'collections.OrderedDict', 44 | 'logging', 'torch', 'torchvision'] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ['_templates'] 48 | 49 | # source_suffix = ['.rst', '.md'] 50 | source_suffix = '.rst' 51 | 52 | # The master toctree document. 53 | master_doc = 'index' 54 | 55 | language = 'en' 56 | 57 | # List of patterns, relative to source directory, that match files and 58 | # directories to ignore when looking for source files. 59 | # This pattern also affects html_static_path and html_extra_path. 60 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '.ipynb_checkpoints'] 61 | 62 | 63 | # -- Options for HTML output ------------------------------------------------- 64 | 65 | # The theme to use for HTML and HTML Help pages. See the documentation for 66 | # a list of builtin themes. 67 | # 68 | html_theme = 'sphinx_rtd_theme' 69 | 70 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 71 | if not on_rtd: 72 | # only import and set the theme if we're building docs locally 73 | import sphinx_rtd_theme 74 | html_theme = 'sphinx_rtd_theme' 75 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 76 | 77 | # Add any paths that contain custom static files (such as style sheets) here, 78 | # relative to this directory. They are copied after the builtin static files, 79 | # so a file named "default.css" will overwrite the builtin "default.css". 80 | html_static_path = [] 81 | 82 | html_sidebars = {'**': ['globaltoc.html', 'relations.html', 'sourcelink.html', 'searchbox.html']} 83 | -------------------------------------------------------------------------------- /docs/source/dice_ml.data_interfaces.rst: -------------------------------------------------------------------------------- 1 | dice\_ml.data\_interfaces package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | dice\_ml.data\_interfaces.base\_data\_interface module 8 | ------------------------------------------------------ 9 | 10 | .. automodule:: dice_ml.data_interfaces.base_data_interface 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | dice\_ml.data\_interfaces.private\_data\_interface module 16 | --------------------------------------------------------- 17 | 18 | .. automodule:: dice_ml.data_interfaces.private_data_interface 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | dice\_ml.data\_interfaces.public\_data\_interface module 24 | -------------------------------------------------------- 25 | 26 | .. automodule:: dice_ml.data_interfaces.public_data_interface 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: dice_ml.data_interfaces 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/dice_ml.explainer_interfaces.rst: -------------------------------------------------------------------------------- 1 | dice\_ml.explainer\_interfaces package 2 | ====================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | dice\_ml.explainer\_interfaces.dice\_KD module 8 | ---------------------------------------------- 9 | 10 | .. automodule:: dice_ml.explainer_interfaces.dice_KD 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | dice\_ml.explainer\_interfaces.dice\_genetic module 16 | --------------------------------------------------- 17 | 18 | .. automodule:: dice_ml.explainer_interfaces.dice_genetic 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | dice\_ml.explainer\_interfaces.dice\_pytorch module 24 | --------------------------------------------------- 25 | 26 | .. automodule:: dice_ml.explainer_interfaces.dice_pytorch 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | dice\_ml.explainer\_interfaces.dice\_random module 32 | -------------------------------------------------- 33 | 34 | .. automodule:: dice_ml.explainer_interfaces.dice_random 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | dice\_ml.explainer\_interfaces.dice\_tensorflow1 module 40 | ------------------------------------------------------- 41 | 42 | .. automodule:: dice_ml.explainer_interfaces.dice_tensorflow1 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | dice\_ml.explainer\_interfaces.dice\_tensorflow2 module 48 | ------------------------------------------------------- 49 | 50 | .. automodule:: dice_ml.explainer_interfaces.dice_tensorflow2 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | dice\_ml.explainer\_interfaces.explainer\_base module 56 | ----------------------------------------------------- 57 | 58 | .. automodule:: dice_ml.explainer_interfaces.explainer_base 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | dice\_ml.explainer\_interfaces.feasible\_base\_vae module 64 | --------------------------------------------------------- 65 | 66 | .. automodule:: dice_ml.explainer_interfaces.feasible_base_vae 67 | :members: 68 | :undoc-members: 69 | :show-inheritance: 70 | 71 | dice\_ml.explainer\_interfaces.feasible\_model\_approx module 72 | ------------------------------------------------------------- 73 | 74 | .. automodule:: dice_ml.explainer_interfaces.feasible_model_approx 75 | :members: 76 | :undoc-members: 77 | :show-inheritance: 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: dice_ml.explainer_interfaces 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /docs/source/dice_ml.model_interfaces.rst: -------------------------------------------------------------------------------- 1 | dice\_ml.model\_interfaces package 2 | ================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | dice\_ml.model\_interfaces.base\_model module 8 | --------------------------------------------- 9 | 10 | .. automodule:: dice_ml.model_interfaces.base_model 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | dice\_ml.model\_interfaces.keras\_tensorflow\_model module 16 | ---------------------------------------------------------- 17 | 18 | .. automodule:: dice_ml.model_interfaces.keras_tensorflow_model 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | dice\_ml.model\_interfaces.pytorch\_model module 24 | ------------------------------------------------ 25 | 26 | .. automodule:: dice_ml.model_interfaces.pytorch_model 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: dice_ml.model_interfaces 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/dice_ml.rst: -------------------------------------------------------------------------------- 1 | dice\_ml package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | dice_ml.data_interfaces 11 | dice_ml.explainer_interfaces 12 | dice_ml.model_interfaces 13 | dice_ml.schema 14 | dice_ml.utils 15 | 16 | Submodules 17 | ---------- 18 | 19 | dice\_ml.constants module 20 | ------------------------- 21 | 22 | .. automodule:: dice_ml.constants 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | 27 | dice\_ml.counterfactual\_explanations module 28 | -------------------------------------------- 29 | 30 | .. automodule:: dice_ml.counterfactual_explanations 31 | :members: 32 | :undoc-members: 33 | :show-inheritance: 34 | 35 | dice\_ml.data module 36 | -------------------- 37 | 38 | .. automodule:: dice_ml.data 39 | :members: 40 | :undoc-members: 41 | :show-inheritance: 42 | 43 | dice\_ml.dice module 44 | -------------------- 45 | 46 | .. automodule:: dice_ml.dice 47 | :members: 48 | :undoc-members: 49 | :show-inheritance: 50 | 51 | dice\_ml.diverse\_counterfactuals module 52 | ---------------------------------------- 53 | 54 | .. automodule:: dice_ml.diverse_counterfactuals 55 | :members: 56 | :undoc-members: 57 | :show-inheritance: 58 | 59 | dice\_ml.model module 60 | --------------------- 61 | 62 | .. automodule:: dice_ml.model 63 | :members: 64 | :undoc-members: 65 | :show-inheritance: 66 | 67 | Module contents 68 | --------------- 69 | 70 | .. automodule:: dice_ml 71 | :members: 72 | :undoc-members: 73 | :show-inheritance: 74 | -------------------------------------------------------------------------------- /docs/source/dice_ml.schema.rst: -------------------------------------------------------------------------------- 1 | dice\_ml.schema package 2 | ======================= 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: dice_ml.schema 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/source/dice_ml.utils.rst: -------------------------------------------------------------------------------- 1 | dice\_ml.utils package 2 | ====================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | dice_ml.utils.sample_architecture 11 | 12 | Submodules 13 | ---------- 14 | 15 | dice\_ml.utils.exception module 16 | ------------------------------- 17 | 18 | .. automodule:: dice_ml.utils.exception 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | dice\_ml.utils.helpers module 24 | ----------------------------- 25 | 26 | .. automodule:: dice_ml.utils.helpers 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | dice\_ml.utils.neuralnetworks module 32 | ------------------------------------ 33 | 34 | .. automodule:: dice_ml.utils.neuralnetworks 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | dice\_ml.utils.serialize module 40 | ------------------------------- 41 | 42 | .. automodule:: dice_ml.utils.serialize 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: dice_ml.utils 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /docs/source/dice_ml.utils.sample_architecture.rst: -------------------------------------------------------------------------------- 1 | dice\_ml.utils.sample\_architecture package 2 | =========================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | dice\_ml.utils.sample\_architecture.vae\_model module 8 | ----------------------------------------------------- 9 | 10 | .. automodule:: dice_ml.utils.sample_architecture.vae_model 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: dice_ml.utils.sample_architecture 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. DiCE documentation master file, created by 2 | sphinx-quickstart on Tue May 19 13:15:18 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | .. include:: ../../README.rst 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | :caption: Getting Started: 11 | 12 | readme 13 | 14 | .. toctree:: 15 | :maxdepth: 2 16 | :caption: Notebooks: 17 | 18 | notebooks/nb_index 19 | 20 | .. toctree:: 21 | :maxdepth: 2 22 | :caption: Package: 23 | 24 | dice_ml 25 | 26 | Indices and tables 27 | ================== 28 | 29 | * :ref:`genindex` 30 | * :ref:`modindex` 31 | * :ref:`search` 32 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | dice_ml 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | dice_ml 8 | -------------------------------------------------------------------------------- /docs/source/notebooks/DiCE_feature_importances.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Estimating local and global feature importance scores using DiCE\n", 8 | "\n", 9 | "Summaries of counterfactual examples can be used to estimate importance of features. Intuitively, a feature that is changed more often to generate a proximal counterfactual is an important feature. We use this intuition to build a feature importance score. \n", 10 | "\n", 11 | "This score can be interpreted as a measure of the **necessity** of a feature to cause a particular model output. That is, if the feature's value changes, then it is likely that the model's output class will also change (or the model's output will significantly change in case of regression model). \n", 12 | "\n", 13 | "Below we show how counterfactuals can be used to provide local feature importance scores for any input, and how those scores can be combined to yield a global importance score for each feature." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "from sklearn.compose import ColumnTransformer\n", 23 | "from sklearn.model_selection import train_test_split\n", 24 | "from sklearn.pipeline import Pipeline\n", 25 | "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", 26 | "from sklearn.ensemble import RandomForestClassifier\n", 27 | "\n", 28 | "import dice_ml\n", 29 | "from dice_ml import Dice\n", 30 | "from dice_ml.utils import helpers # helper functions" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "%load_ext autoreload\n", 40 | "%autoreload 2" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": {}, 46 | "source": [ 47 | "## Preliminaries: Loading the data and ML model" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [], 55 | "source": [ 56 | "dataset = helpers.load_adult_income_dataset().sample(5000) # downsampling to reduce ML model fitting time\n", 57 | "helpers.get_adult_data_info()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": null, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "target = dataset[\"income\"]\n", 67 | "\n", 68 | "# Split data into train and test\n", 69 | "datasetX = dataset.drop(\"income\", axis=1)\n", 70 | "x_train, x_test, y_train, y_test = train_test_split(datasetX,\n", 71 | " target,\n", 72 | " test_size=0.2,\n", 73 | " random_state=0,\n", 74 | " stratify=target)\n", 75 | "\n", 76 | "numerical = [\"age\", \"hours_per_week\"]\n", 77 | "categorical = x_train.columns.difference(numerical)\n", 78 | "\n", 79 | "# We create the preprocessing pipelines for both numeric and categorical data.\n", 80 | "numeric_transformer = Pipeline(steps=[\n", 81 | " ('scaler', StandardScaler())])\n", 82 | "\n", 83 | "categorical_transformer = Pipeline(steps=[\n", 84 | " ('onehot', OneHotEncoder(handle_unknown='ignore'))])\n", 85 | "\n", 86 | "transformations = ColumnTransformer(\n", 87 | " transformers=[\n", 88 | " ('num', numeric_transformer, numerical),\n", 89 | " ('cat', categorical_transformer, categorical)])\n", 90 | "\n", 91 | "# Append classifier to preprocessing pipeline.\n", 92 | "# Now we have a full prediction pipeline.\n", 93 | "clf = Pipeline(steps=[('preprocessor', transformations),\n", 94 | " ('classifier', RandomForestClassifier())])\n", 95 | "model = clf.fit(x_train, y_train)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income')\n", 105 | "m = dice_ml.Model(model=model, backend=\"sklearn\")" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "## Local feature importance\n", 113 | "\n", 114 | "We first generate counterfactuals for a given input point. " 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "exp = Dice(d, m, method=\"random\")\n", 124 | "query_instance = x_train[1:2]\n", 125 | "e1 = exp.generate_counterfactuals(query_instance, total_CFs=10, desired_range=None,\n", 126 | " desired_class=\"opposite\",\n", 127 | " permitted_range=None, features_to_vary=\"all\")\n", 128 | "e1.visualize_as_dataframe(show_only_changes=True)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "These can now be used to calculate the feature importance scores. " 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "imp = exp.local_feature_importance(query_instance, cf_examples_list=e1.cf_examples_list)\n", 145 | "print(imp.local_importance)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "markdown", 150 | "metadata": {}, 151 | "source": [ 152 | "Feature importance can also be estimated directly, by leaving the `cf_examples_list` argument blank." 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "imp = exp.local_feature_importance(query_instance, posthoc_sparsity_param=None)\n", 162 | "print(imp.local_importance)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "metadata": {}, 168 | "source": [ 169 | "## Global importance\n", 170 | "\n", 171 | "For global importance, we need to generate counterfactuals for a representative sample of the dataset. " 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "metadata": {}, 178 | "outputs": [], 179 | "source": [ 180 | "cobj = exp.global_feature_importance(x_train[0:10], total_CFs=10, posthoc_sparsity_param=None)\n", 181 | "print(cobj.summary_importance)" 182 | ] 183 | }, 184 | { 185 | "cell_type": "markdown", 186 | "metadata": {}, 187 | "source": [ 188 | "## Convert the counterfactual output to json" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "json_str = cobj.to_json()\n", 198 | "print(json_str)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "markdown", 203 | "metadata": {}, 204 | "source": [ 205 | "## Convert the json output to a counterfactual object" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "metadata": {}, 212 | "outputs": [], 213 | "source": [ 214 | "imp_r = imp.from_json(json_str)\n", 215 | "print([o.visualize_as_dataframe(show_only_changes=True) for o in imp_r.cf_examples_list])\n", 216 | "print(imp_r.local_importance)\n", 217 | "print(imp_r.summary_importance)" 218 | ] 219 | } 220 | ], 221 | "metadata": { 222 | "kernelspec": { 223 | "display_name": "Python 3", 224 | "language": "python", 225 | "name": "python3" 226 | }, 227 | "language_info": { 228 | "codemirror_mode": { 229 | "name": "ipython", 230 | "version": 3 231 | }, 232 | "file_extension": ".py", 233 | "mimetype": "text/x-python", 234 | "name": "python", 235 | "nbconvert_exporter": "python", 236 | "pygments_lexer": "ipython3", 237 | "version": "3.6.12" 238 | } 239 | }, 240 | "nbformat": 4, 241 | "nbformat_minor": 4 242 | } 243 | -------------------------------------------------------------------------------- /docs/source/notebooks/images/dice_getting_started_api.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/docs/source/notebooks/images/dice_getting_started_api.png -------------------------------------------------------------------------------- /docs/source/notebooks/nb_index.rst: -------------------------------------------------------------------------------- 1 | .. notebooks to be included in dice docs 2 | 3 | Example notebooks 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | :caption: Notebooks: 8 | 9 | DiCE_getting_started 10 | DiCE_feature_importances.ipynb 11 | DiCE_multiclass_classification_and_regression.ipynb 12 | DiCE_model_agnostic_CFs.ipynb 13 | DiCE_with_private_data 14 | DiCE_with_advanced_options 15 | DiCE_getting_started_feasible 16 | 17 | -------------------------------------------------------------------------------- /docs/source/readme.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../../README.rst 2 | 3 | -------------------------------------------------------------------------------- /docs/update_docs.sh: -------------------------------------------------------------------------------- 1 | sphinx-apidoc -f -o source ../dice_ml 2 | -------------------------------------------------------------------------------- /environment-deeplearning.yml: -------------------------------------------------------------------------------- 1 | name: example-environment 2 | 3 | dependencies: 4 | - tensorflow>=1.13.0-rc1 5 | - pytorch 6 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: example-environment 2 | 3 | dependencies: 4 | - numpy 5 | - scikit-learn 6 | - pandas 7 | - jsonschema 8 | - tqdm 9 | - pip: 10 | - raiutils 11 | -------------------------------------------------------------------------------- /requirements-deeplearning.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=1.13.1 2 | torch 3 | -------------------------------------------------------------------------------- /requirements-linting.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | flake8-bugbear 3 | flake8-blind-except 4 | flake8-breakpoint 5 | flake8-builtins 6 | flake8-logging-format 7 | flake8-pytest-style 8 | flake8-all-not-strings 9 | isort 10 | packaging 11 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | ipython 2 | jupyter 3 | pytest 4 | pytest-cov 5 | twine 6 | pytest-mock 7 | rai_test_utils 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jsonschema 2 | numpy # if you are using tensorflow 1.x, it requires numpy<=1.16 3 | pandas>=2.0.0 4 | scikit-learn 5 | tqdm 6 | raiutils>=0.4.0 7 | xgboost # if you are using xgboost 8 | lightgbm # if you are using lightgbm -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.rst 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | VERSION_STR = "0.11" 4 | 5 | with open("README.rst", "r") as fh: 6 | long_description = fh.read() 7 | 8 | # Get the required packages 9 | with open('requirements.txt', encoding='utf-8') as f: 10 | install_requires = f.read().splitlines() 11 | 12 | # Deep learning packages are optional to install 13 | extras = ["deeplearning"] 14 | extras_require = dict() 15 | for e in extras: 16 | req_file = "requirements-{0}.txt".format(e) 17 | with open(req_file) as f: 18 | extras_require[e] = [line.strip() for line in f] 19 | 20 | setuptools.setup( 21 | name="dice_ml", 22 | version=VERSION_STR, 23 | license="MIT", 24 | author="Ramaravind Mothilal, Amit Sharma, Chenhao Tan", 25 | author_email="raam.arvind93@gmail.com", 26 | description="Generate Diverse Counterfactual Explanations for any machine learning model.", 27 | long_description=long_description, 28 | long_description_content_type="text/x-rst", 29 | url="https://github.com/interpretml/DiCE", 30 | download_url="https://github.com/interpretml/DiCE/archive/v"+VERSION_STR+".tar.gz", 31 | python_requires='>=3.9', 32 | packages=setuptools.find_packages(exclude=['tests*']), 33 | classifiers=[ 34 | "Intended Audience :: Developers", 35 | "Intended Audience :: Science/Research", 36 | "Programming Language :: Python :: 3", 37 | "Programming Language :: Python :: 3.9", 38 | "Programming Language :: Python :: 3.10", 39 | "Programming Language :: Python :: 3.11", 40 | "License :: OSI Approved :: MIT License", 41 | "Operating System :: OS Independent", 42 | ], 43 | keywords='machine-learning explanation interpretability counterfactual', 44 | install_requires=install_requires, 45 | extras_require=extras_require, 46 | include_package_data=True, 47 | package_data={ 48 | # If any package contains *.h5 files, include them: 49 | '': ['*.h5', 50 | 'counterfactual_explanations_v1.0.json', 51 | 'counterfactual_explanations_v2.0.json'] 52 | } 53 | ) 54 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import dice_ml 4 | 5 | 6 | def test_data_initiation(public_data_object, private_data_object): 7 | assert isinstance(public_data_object, dice_ml.data_interfaces.public_data_interface.PublicData), \ 8 | "the given parameters should instantiate PublicData class" 9 | 10 | assert isinstance(private_data_object, dice_ml.data_interfaces.private_data_interface.PrivateData), \ 11 | "the given parameters should instantiate PrivateData class" 12 | 13 | 14 | class TestCommonDataMethods: 15 | """ 16 | Test methods common to Public and Private data interfaces modules. 17 | """ 18 | @pytest.fixture(autouse=True) 19 | def _get_data_object(self, public_data_object, private_data_object): 20 | self.d = [public_data_object, private_data_object] 21 | 22 | def test_get_valid_mads(self): 23 | # public data 24 | for normalized in [True, False]: 25 | mads = self.d[0].get_valid_mads(normalized=normalized, display_warnings=False, return_mads=True) 26 | # mads denotes variability in features and should be positive for DiCE. 27 | assert all(mads[feature] > 0 for feature in mads) 28 | 29 | if normalized: 30 | min_value = 0 31 | max_value = 1 32 | 33 | errors = 0 34 | for feature in mads: 35 | if not normalized: 36 | min_value = self.d[0].data_df[feature].min() 37 | max_value = self.d[0].data_df[feature].max() 38 | 39 | if mads[feature] > max_value - min_value: 40 | errors += 1 41 | assert errors == 0 # mads can't be larger than the feature range 42 | 43 | # private data 44 | for normalized in [True, False]: 45 | mads = self.d[1].get_valid_mads(normalized=normalized, display_warnings=False, return_mads=True) 46 | # no mad is provided for private data by default, so a practical alternative is keeping all value at 1. 47 | # Check get_valid_mads() in data interface classes for more info. 48 | assert all(mads[feature] == 1 for feature in mads) 49 | 50 | # @pytest.mark.parametrize( 51 | # "encode_categorical, output_query", 52 | # [('label', [0.068, (2/3), (3/7), (3/4), (4/5), 1.0, 0.0, 0.449]), 53 | # ('one-hot', [0.068, 0.449, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 54 | # 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0])] 55 | # ) 56 | # def test_prepare_query_instance(self, sample_adultincome_query, encode_categorical, output_query): 57 | # """ 58 | # Tests prepare_query_instance method that covnerts continuous features into [0,1] range and 59 | # one-hot encodes categorical features. 60 | # """ 61 | # for d in self.d: 62 | # prepared_query = d.prepare_query_instance( 63 | # query_instance=sample_adultincome_query, encoding=encode_categorical).iloc[0].tolist() 64 | # assert output_query == pytest.approx(prepared_query, abs=1e-3) 65 | 66 | # TODO: add separate test method for label-encoded data 67 | def test_ohe_min_max_transformed_query_instance(self, sample_adultincome_query): 68 | """ 69 | Tests method that converts continuous features into [0,1] range and one-hot encodes categorical features. 70 | """ 71 | output_query = [0.068, 0.449, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 72 | 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0] 73 | d = self.d[0] 74 | prepared_query = d.get_ohe_min_max_normalized_data(query_instance=sample_adultincome_query).iloc[0].tolist() 75 | assert output_query == pytest.approx(prepared_query, abs=1e-3) 76 | 77 | def test_encoded_categorical_features(self): 78 | """ 79 | Tests if correct encoded categorical feature indexes are returned. Should work even when feature names 80 | are starting with common names. 81 | TODO also test for private data interface 82 | """ 83 | res = [] 84 | d = self.d[0] 85 | # for d in self.d: 86 | # d.categorical_feature_names = ['cat1', 'cat2'] 87 | # d.continuous_feature_names = ['cat2_cont1', 'cont2'] 88 | # d.encoded_feature_names = ['cat2_cont1', 'cont2', 'cat1_val1', 'cat1_val2', 'cat2_val1', 'cat2_val2'] 89 | print(d.data_df) 90 | temp_ohe_data = d.get_ohe_min_max_normalized_data(d.data_df.iloc[[0]]) 91 | d.create_ohe_params(temp_ohe_data) 92 | res.append(d.get_encoded_categorical_feature_indexes()) 93 | assert [2, 3, 4, 5] == res[0][0] # there are 4 types of workclass 94 | assert len(res[0][1]) == 8 # eight types of education 95 | assert len(res[0][-1]) == 2 # two types of gender in the data 96 | # 2,3,4,5 are correct indexes of encoded categorical features and the data object's method should not 97 | # return the first continuous feature that starts with the same name. Returned value should be a list of lists. 98 | 99 | def test_features_to_vary(self): 100 | """ 101 | Tests if correct indexes of features are returned. Should work even when feature names are starting with common names. 102 | 103 | TODO: also make it work for private_data_interface 104 | """ 105 | res = [] 106 | d = self.d[0] 107 | temp_ohe_data = d.get_ohe_min_max_normalized_data(d.data_df.iloc[[0]]) 108 | d.create_ohe_params(temp_ohe_data) 109 | # d.create_ohe_params() 110 | # d.categorical_feature_names = ['cat1', 'cat2'] 111 | # d.encoded_feature_names = ['cat2_cont1', 'cont2', 'cat1_val1', 'cat1_val2', 'cat2_val1', 'cat2_val2'] 112 | # d.continuous_feature_names = ['cat2_cont1', 'cont2'] 113 | res.append(d.get_indexes_of_features_to_vary(features_to_vary=['workclass'])) 114 | # 4,5 are correct indexes of features that can be varied and the data object's method should not return 115 | # the first continuous feature that starts with the same name. 116 | assert [2, 3, 4, 5] == res[0] 117 | -------------------------------------------------------------------------------- /tests/test_data_interface/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/tests/test_data_interface/__init__.py -------------------------------------------------------------------------------- /tests/test_data_interface/test_base_data_interface.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from dice_ml.data_interfaces.base_data_interface import _BaseData 4 | 5 | 6 | class TestBaseData: 7 | def test_base_data_initialization(self): 8 | with pytest.raises(TypeError): 9 | _BaseData({}) 10 | -------------------------------------------------------------------------------- /tests/test_data_interface/test_private_data_interface.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import pytest 4 | 5 | import dice_ml 6 | 7 | 8 | @pytest.fixture(scope='session') 9 | def data_object(): 10 | features_dict = OrderedDict( 11 | [('age', [17, 90]), 12 | ('workclass', ['Government', 'Other/Unknown', 'Private', 'Self-Employed']), 13 | ('education', ['Assoc', 'Bachelors', 'Doctorate', 'HS-grad', 'Masters', 'Prof-school', 'School', 'Some-college']), 14 | ('marital_status', ['Divorced', 'Married', 'Separated', 'Single', 'Widowed']), 15 | ('occupation', ['Blue-Collar', 'Other/Unknown', 'Professional', 'Sales', 'Service', 'White-Collar']), 16 | ('race', ['Other', 'White']), 17 | ('gender', ['Female', 'Male']), 18 | ('hours_per_week', [1, 99])] 19 | ) # providing an OrderedDict to make it work for Python<3.6 20 | return dice_ml.Data(features=features_dict, outcome_name='income', 21 | type_and_precision={'hours_per_week': ['float', 2]}, mad={'age': 10}) 22 | 23 | 24 | class TestPrivateDataMethods: 25 | @pytest.fixture(autouse=True) 26 | def _get_data_object(self, data_object): 27 | self.d = data_object 28 | 29 | def test_mads(self): 30 | # normalized=True is already tested in test_data.py 31 | mads = self.d.get_valid_mads(normalized=False) 32 | # 10 is given as the mad of feature 'age' while initiating private Data object; 1.0 is the default value. 33 | # Check get_valid_mads() in private_data_interface for more info. 34 | assert list(mads.values()) == [10.0, 1.0] 35 | 36 | def test_feature_precision(self): 37 | # feature precision decides the least change that can be made to the feature in optimization, 38 | # given as 2-decimal place for 'hours_per_week' feature while initiating private Data object. 39 | assert self.d.get_decimal_precisions()[1] == 2 40 | -------------------------------------------------------------------------------- /tests/test_dice.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from raiutils.exceptions import UserConfigValidationException 3 | 4 | import dice_ml 5 | from dice_ml.utils import helpers 6 | 7 | 8 | class TestBaseExplainerLoader: 9 | def _get_exp(self, backend, method="random", is_public_data_interface=True): 10 | if is_public_data_interface: 11 | dataset = helpers.load_adult_income_dataset() 12 | d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income') 13 | else: 14 | d = dice_ml.Data(features={ 15 | 'age': [17, 90], 16 | 'workclass': ['Government', 'Other/Unknown', 'Private', 'Self-Employed'], 17 | 'education': ['Assoc', 'Bachelors', 'Doctorate', 'HS-grad', 'Masters', 18 | 'Prof-school', 'School', 'Some-college'], 19 | 'marital_status': ['Divorced', 'Married', 'Separated', 'Single', 'Widowed'], 20 | 'occupation': ['Blue-Collar', 'Other/Unknown', 'Professional', 'Sales', 'Service', 'White-Collar'], 21 | 'race': ['Other', 'White'], 22 | 'gender': ['Female', 'Male'], 23 | 'hours_per_week': [1, 99]}, 24 | outcome_name='income') 25 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend) 26 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func="ohe-min-max") 27 | exp = dice_ml.Dice(d, m, method=method) 28 | return exp 29 | 30 | def test_tf(self): 31 | tf = pytest.importorskip("tensorflow") 32 | backend = 'TF'+tf.__version__[0] 33 | exp = self._get_exp(backend, method="gradient") 34 | assert issubclass(type(exp), dice_ml.explainer_interfaces.explainer_base.ExplainerBase) 35 | assert isinstance(exp, dice_ml.explainer_interfaces.dice_tensorflow2.DiceTensorFlow2) or \ 36 | isinstance(exp, dice_ml.explainer_interfaces.dice_tensorflow1.DiceTensorFlow1) 37 | 38 | def test_pyt(self): 39 | pytest.importorskip("torch") 40 | backend = 'PYT' 41 | exp = self._get_exp(backend, method="gradient") 42 | assert issubclass(type(exp), dice_ml.explainer_interfaces.explainer_base.ExplainerBase) 43 | assert isinstance(exp, dice_ml.explainer_interfaces.dice_pytorch.DicePyTorch) 44 | 45 | @pytest.mark.skip(reason="Need to fix this test") 46 | @pytest.mark.parametrize('method', ['random']) 47 | def test_sklearn(self, method): 48 | pytest.importorskip("sklearn") 49 | backend = 'sklearn' 50 | exp = self._get_exp(backend, method=method) 51 | assert issubclass(type(exp), dice_ml.explainer_interfaces.explainer_base.ExplainerBase) 52 | assert isinstance(exp, dice_ml.explainer_interfaces.dice_random.DiceRandom) 53 | 54 | @pytest.mark.skip(reason="Need to fix this test") 55 | def test_minimum_query_instances(self): 56 | pytest.importorskip('sklearn') 57 | backend = 'sklearn' 58 | exp = self._get_exp(backend) 59 | query_instances = helpers.load_adult_income_dataset().drop("income", axis=1)[0:1] 60 | with pytest.raises(UserConfigValidationException): 61 | exp.global_feature_importance(query_instances) 62 | 63 | def test_unsupported_sampling_strategy(self): 64 | pytest.importorskip('sklearn') 65 | backend = 'sklearn' 66 | with pytest.raises(UserConfigValidationException): 67 | self._get_exp(backend, method="unsupported") 68 | 69 | def test_private_data_interface_dice_kdtree(self): 70 | pytest.importorskip("sklearn") 71 | backend = 'sklearn' 72 | with pytest.raises(UserConfigValidationException) as ucve: 73 | self._get_exp(backend, method='kdtree', is_public_data_interface=False) 74 | 75 | assert 'Private data interface is not supported with kdtree explainer' + \ 76 | ' since kdtree explainer needs access to entire training data' in str(ucve) 77 | -------------------------------------------------------------------------------- /tests/test_dice_interface/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/tests/test_dice_interface/__init__.py -------------------------------------------------------------------------------- /tests/test_dice_interface/test_dice_pytorch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import dice_ml 5 | from dice_ml.counterfactual_explanations import CounterfactualExplanations 6 | from dice_ml.utils import helpers 7 | 8 | torch = pytest.importorskip("torch") 9 | 10 | 11 | @pytest.fixture(scope='session') 12 | def pyt_exp_object(): 13 | backend = 'PYT' 14 | dataset = helpers.load_adult_income_dataset() 15 | d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income') 16 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend) 17 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func="ohe-min-max") 18 | exp = dice_ml.Dice(d, m, method="gradient") 19 | return exp 20 | 21 | 22 | class TestDiceTorchMethods: 23 | @pytest.fixture(autouse=True) 24 | def _initiate_exp_object(self, pyt_exp_object, sample_adultincome_query): 25 | self.exp = pyt_exp_object # explainer object 26 | # initialize required params for CF computations 27 | self.exp.do_cf_initializations(total_CFs=4, algorithm="DiverseCF", features_to_vary="all") 28 | 29 | # prepare query isntance for CF optimization 30 | # query_instance = self.exp.data_interface.prepare_query_instance( 31 | # query_instance=sample_adultincome_query, encoding='one-hot') 32 | # self.query_instance = query_instance.iloc[0].values 33 | self.query_instance = self.exp.data_interface.get_ohe_min_max_normalized_data( 34 | sample_adultincome_query).iloc[0].to_numpy(dtype=np.float64) 35 | 36 | self.exp.initialize_CFs(self.query_instance, init_near_query_instance=True) # initialize CFs 37 | self.exp.target_cf_class = torch.tensor(1).float() # set desired class to 1 38 | 39 | # setting random feature weights 40 | np.random.seed(42) 41 | weights = np.random.rand(len(self.exp.data_interface.ohe_encoded_feature_names)) 42 | self.exp.feature_weights_list = torch.tensor(weights) 43 | 44 | @pytest.mark.parametrize(("yloss", "output"), [("hinge_loss", 10.8443), ("l2_loss", 0.9999), ("log_loss", 9.8443)]) 45 | def test_yloss(self, yloss, output): 46 | self.exp.yloss_type = yloss 47 | loss1 = self.exp.compute_yloss() 48 | assert pytest.approx(loss1.data.detach().numpy(), abs=1e-4) == output 49 | 50 | def test_proximity_loss(self): 51 | self.exp.x1 = torch.tensor(self.query_instance) 52 | loss2 = self.exp.compute_proximity_loss() 53 | # proximity loss computed for given query instance and feature weights. 54 | assert pytest.approx(loss2.data.detach().numpy(), abs=1e-4) == 0.0068 55 | 56 | @pytest.mark.parametrize(("diversity_loss", "output"), [("dpp_style:inverse_dist", 0.0104), ("avg_dist", 0.1743)]) 57 | def test_diversity_loss(self, diversity_loss, output): 58 | self.exp.diversity_loss_type = diversity_loss 59 | loss3 = self.exp.compute_diversity_loss() 60 | assert pytest.approx(loss3.data.detach().numpy(), abs=1e-4) == output 61 | 62 | def test_regularization_loss(self): 63 | loss4 = self.exp.compute_regularization_loss() 64 | # regularization loss computed for given query instance and feature weights. 65 | assert pytest.approx(loss4.data.detach().numpy(), abs=1e-4) == 0.2086 66 | 67 | def test_final_cfs_and_preds(self, sample_adultincome_query): 68 | """ 69 | Tets correctness of final CFs and their predictions for sample query instance. 70 | """ 71 | counterfactual_explanations = self.exp.generate_counterfactuals( 72 | sample_adultincome_query, total_CFs=4, desired_class="opposite") 73 | assert isinstance(counterfactual_explanations, CounterfactualExplanations) 74 | # test_cfs = [[72.0, 'Private', 'HS-grad', 'Married', 'White-Collar', 'White', 'Female', 45.0, 0.691], 75 | # [29.0, 'Private', 'Prof-school', 'Married', 'Service', 'White', 'Male', 45.0, 0.954], 76 | # [52.0, 'Private', 'Doctorate', 'Married', 'Service', 'White', 'Female', 45.0, 0.971], 77 | # [47.0, 'Private', 'Masters', 'Married', 'Service', 'White', 'Female', 73.0, 0.971]] 78 | # TODO The model predictions changed after update to posthoc sparsity. Need to investigate. 79 | # assert dice_exp.final_cfs_df_sparse.values.tolist() == test_cfs 80 | -------------------------------------------------------------------------------- /tests/test_dice_interface/test_dice_tensorflow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import dice_ml 5 | from dice_ml.counterfactual_explanations import CounterfactualExplanations 6 | from dice_ml.utils import helpers 7 | 8 | tf = pytest.importorskip("tensorflow") 9 | 10 | 11 | @pytest.fixture(scope='session') 12 | def tf_exp_object(): 13 | backend = 'TF'+tf.__version__[0] 14 | dataset = helpers.load_adult_income_dataset() 15 | d = dice_ml.Data(dataframe=dataset, continuous_features=['age', 'hours_per_week'], outcome_name='income') 16 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend) 17 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func="ohe-min-max") 18 | exp = dice_ml.Dice(d, m, method="gradient") 19 | return exp 20 | 21 | 22 | class TestDiceTensorFlowMethods: 23 | @pytest.fixture(autouse=True) 24 | def _initiate_exp_object(self, tf_exp_object, sample_adultincome_query): 25 | self.exp = tf_exp_object # explainer object 26 | # initialize required params for CF computations 27 | self.exp.do_cf_initializations(total_CFs=4, algorithm="DiverseCF", features_to_vary="all") 28 | 29 | # prepare query isntance for CF optimization 30 | # query_instance = self.exp.data_interface.prepare_query_instance( 31 | # query_instance=sample_adultincome_query, encoding='one-hot') 32 | # self.query_instance = np.array([query_instance.iloc[0].values], dtype=np.float32) 33 | self.query_instance = self.exp.data_interface.get_ohe_min_max_normalized_data(sample_adultincome_query).values 34 | 35 | init_arrs = self.exp.initialize_CFs(self.query_instance, init_near_query_instance=True) # initialize CFs 36 | self.desired_class = 1 # desired class is 1 37 | 38 | # setting random feature weights 39 | np.random.seed(42) 40 | weights = np.random.rand(len(self.exp.data_interface.ohe_encoded_feature_names)) 41 | weights = np.array([weights], dtype=np.float32) 42 | if tf.__version__[0] == '1': 43 | for i in range(4): 44 | self.exp.dice_sess.run(self.exp.cf_assign[i], feed_dict={self.exp.cf_init: init_arrs[i]}) 45 | self.exp.feature_weights = tf.Variable(self.exp.minx, dtype=tf.float32) 46 | self.exp.dice_sess.run(tf.assign(self.exp.feature_weights, weights)) 47 | else: 48 | self.exp.feature_weights_list = tf.constant([weights], dtype=tf.float32) 49 | 50 | @pytest.mark.parametrize(("yloss", "output"), [("hinge_loss", 4.6711), ("l2_loss", 0.9501), ("log_loss", 3.6968)]) 51 | def test_yloss(self, yloss, output): 52 | if tf.__version__[0] == '1': 53 | loss1 = self.exp.compute_yloss(method=yloss) 54 | loss1 = self.exp.dice_sess.run(loss1, feed_dict={self.exp.target_cf: np.array([[1]])}) 55 | else: 56 | self.exp.target_cf_class = np.array([[self.desired_class]], dtype=np.float32) 57 | self.exp.yloss_type = yloss 58 | loss1 = self.exp.compute_yloss().numpy() 59 | assert pytest.approx(loss1, abs=1e-4) == output 60 | 61 | def test_proximity_loss(self): 62 | if tf.__version__[0] == '1': 63 | loss2 = self.exp.compute_proximity_loss() 64 | loss2 = self.exp.dice_sess.run(loss2, feed_dict={self.exp.x1: self.query_instance}) 65 | else: 66 | self.exp.x1 = tf.constant(self.query_instance, dtype=tf.float32) 67 | loss2 = self.exp.compute_proximity_loss().numpy() 68 | # proximity loss computed for given query instance and feature weights. 69 | assert pytest.approx(loss2, abs=1e-4) == 0.0068 70 | 71 | @pytest.mark.parametrize(("diversity_loss", "output"), [("dpp_style:inverse_dist", 0.0104), ("avg_dist", 0.1743)]) 72 | def test_diversity_loss(self, diversity_loss, output): 73 | if tf.__version__[0] == '1': 74 | loss3 = self.exp.compute_diversity_loss(diversity_loss) 75 | loss3 = self.exp.dice_sess.run(loss3) 76 | else: 77 | self.exp.diversity_loss_type = diversity_loss 78 | loss3 = self.exp.compute_diversity_loss().numpy() 79 | assert pytest.approx(loss3, abs=1e-4) == output 80 | 81 | def test_regularization_loss(self): 82 | loss4 = self.exp.compute_regularization_loss() 83 | if tf.__version__[0] == '1': 84 | loss4 = self.exp.dice_sess.run(loss4) 85 | else: 86 | loss4 = loss4.numpy() 87 | # regularization loss computed for given query instance and feature weights. 88 | assert pytest.approx(loss4, abs=1e-4) == 0.2086 89 | 90 | def test_final_cfs_and_preds(self, sample_adultincome_query): 91 | """ 92 | Tets correctness of final CFs and their predictions for sample query instance. 93 | """ 94 | counterfactual_explanations = self.exp.generate_counterfactuals( 95 | sample_adultincome_query, total_CFs=4, desired_class="opposite") 96 | assert isinstance(counterfactual_explanations, CounterfactualExplanations) 97 | # test_cfs = [[70.0, 'Private', 'Masters', 'Single', 'White-Collar', 'White', 'Female', 51.0, 0.534], 98 | # [22.0, 'Self-Employed', 'Doctorate', 'Married', 'Service', 'White', 'Female', 45.0, 0.861], 99 | # [47.0, 'Private', 'HS-grad', 'Married', 'Service', 'White', 'Female', 45.0, 0.589], 100 | # [36.0, 'Private', 'Prof-school', 'Married', 'Service', 'White', 'Female', 62.0, 0.937]] 101 | # TODO The model predictions changed after update to posthoc sparsity. Need to investigate. 102 | # assert dice_exp.final_cfs_df_sparse.values.tolist() == test_cfs 103 | -------------------------------------------------------------------------------- /tests/test_helpers.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | from dice_ml.utils.helpers import load_adult_income_dataset 4 | 5 | 6 | class TestHelpers: 7 | def test_load_adult_income_dataset(self): 8 | adult_data = load_adult_income_dataset() 9 | assert adult_data is not None 10 | assert isinstance(adult_data, pd.DataFrame) 11 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from rai_test_utils.datasets.tabular import create_iris_data 3 | from raiutils.exceptions import UserConfigValidationException 4 | from sklearn.ensemble import RandomForestClassifier 5 | 6 | import dice_ml 7 | from dice_ml.utils import helpers 8 | 9 | 10 | class TestBaseModelLoader: 11 | def _get_model(self, backend): 12 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend) 13 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend) 14 | return m 15 | 16 | def test_tf(self): 17 | tf = pytest.importorskip("tensorflow") 18 | backend = 'TF'+tf.__version__[0] 19 | m = self._get_model(backend) 20 | assert issubclass(type(m), dice_ml.model_interfaces.base_model.BaseModel) 21 | assert isinstance(m, dice_ml.model_interfaces.keras_tensorflow_model.KerasTensorFlowModel) 22 | 23 | def test_pyt(self): 24 | pytest.importorskip("torch") 25 | backend = 'PYT' 26 | m = self._get_model(backend) 27 | assert issubclass(type(m), dice_ml.model_interfaces.base_model.BaseModel) 28 | assert isinstance(m, dice_ml.model_interfaces.pytorch_model.PyTorchModel) 29 | 30 | def test_sklearn(self): 31 | pytest.importorskip("sklearn") 32 | backend = 'sklearn' 33 | m = self._get_model(backend) 34 | assert isinstance(m, dice_ml.model_interfaces.base_model.BaseModel) 35 | 36 | 37 | class TestModelUserValidations: 38 | 39 | def create_sklearn_random_forest_classifier(self, X, y): 40 | rfc = RandomForestClassifier(n_estimators=10, max_depth=4, 41 | random_state=777) 42 | model = rfc.fit(X, y) 43 | return model 44 | 45 | def test_model_user_validation_model_type(self): 46 | x_train, x_test, y_train, y_test, feature_names, classes = \ 47 | create_iris_data() 48 | trained_model = self.create_sklearn_random_forest_classifier(x_train, y_train) 49 | 50 | assert dice_ml.Model(model=trained_model, backend='sklearn', model_type='classifier') is not None 51 | assert dice_ml.Model(model=trained_model, backend='sklearn', model_type='regressor') is not None 52 | 53 | with pytest.raises(UserConfigValidationException): 54 | dice_ml.Model(model=trained_model, backend='sklearn', model_type='random') 55 | 56 | def test_model_user_validation_no_valid_model(self): 57 | with pytest.raises( 58 | ValueError, 59 | match="should provide either a trained model or the path to a model"): 60 | dice_ml.Model(backend='sklearn') 61 | -------------------------------------------------------------------------------- /tests/test_model_interface/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/interpretml/DiCE/7a55aabd4ce3a27d887ab8b7c8df90972d32fcc8/tests/test_model_interface/__init__.py -------------------------------------------------------------------------------- /tests/test_model_interface/test_base_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from rai_test_utils.datasets.tabular import (create_housing_data, 4 | create_iris_data) 5 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor 6 | 7 | import dice_ml 8 | from dice_ml.utils.exception import SystemException 9 | 10 | 11 | class TestModelClassification: 12 | 13 | def create_sklearn_random_forest_classifier(self, X, y): 14 | rfc = RandomForestClassifier(n_estimators=10, max_depth=4, 15 | random_state=777) 16 | model = rfc.fit(X, y) 17 | return model 18 | 19 | def test_base_model_classification(self): 20 | x_train, x_test, y_train, y_test, feature_names, classes = \ 21 | create_iris_data() 22 | trained_model = self.create_sklearn_random_forest_classifier(x_train, y_train) 23 | 24 | diceml_model = dice_ml.Model(model=trained_model, backend='sklearn') 25 | diceml_model.transformer.initialize_transform_func() 26 | 27 | assert diceml_model is not None 28 | 29 | prediction_probabilities = diceml_model.get_output(x_test) 30 | assert prediction_probabilities.shape[0] == x_test.shape[0] 31 | assert prediction_probabilities.shape[1] == len(classes) 32 | 33 | predictions = diceml_model.get_output(x_test, model_score=False).reshape(-1, 1) 34 | assert predictions.shape[0] == x_test.shape[0] 35 | assert predictions.shape[1] == 1 36 | assert np.all(np.unique(predictions) == np.unique(y_test)) 37 | 38 | with pytest.raises(NotImplementedError): 39 | diceml_model.get_gradient() 40 | 41 | assert diceml_model.get_num_output_nodes2(x_test) == len(classes) 42 | 43 | 44 | class TestModelRegression: 45 | 46 | def create_sklearn_random_forest_regressor(self, X, y): 47 | rfc = RandomForestRegressor(n_estimators=10, max_depth=4, 48 | random_state=777) 49 | model = rfc.fit(X, y) 50 | return model 51 | 52 | def test_base_model_regression(self): 53 | x_train, x_test, y_train, y_test, feature_names = \ 54 | create_housing_data() 55 | trained_model = self.create_sklearn_random_forest_regressor(x_train, y_train) 56 | 57 | diceml_model = dice_ml.Model(model=trained_model, model_type='regressor', backend='sklearn') 58 | diceml_model.transformer.initialize_transform_func() 59 | 60 | assert diceml_model is not None 61 | 62 | prediction_probabilities = diceml_model.get_output(x_test).reshape(-1, 1) 63 | assert prediction_probabilities.shape[0] == x_test.shape[0] 64 | assert prediction_probabilities.shape[1] == 1 65 | 66 | predictions = diceml_model.get_output(x_test, model_score=False).reshape(-1, 1) 67 | assert predictions.shape[0] == x_test.shape[0] 68 | assert predictions.shape[1] == 1 69 | 70 | with pytest.raises(NotImplementedError): 71 | diceml_model.get_gradient() 72 | 73 | with pytest.raises(SystemException): 74 | diceml_model.get_num_output_nodes2(x_test) 75 | -------------------------------------------------------------------------------- /tests/test_model_interface/test_keras_tensorflow_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import dice_ml 4 | from dice_ml.utils import helpers 5 | from dice_ml.utils.helpers import DataTransfomer 6 | 7 | tf = pytest.importorskip("tensorflow") 8 | 9 | 10 | @pytest.fixture(scope='session') 11 | def tf_session(): 12 | if tf.__version__[0] == '1': 13 | sess = tf.InteractiveSession() 14 | return sess 15 | 16 | 17 | @pytest.fixture(scope='session') 18 | def tf_model_object(): 19 | backend = 'TF'+tf.__version__[0] 20 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend) 21 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func='ohe-min-max') 22 | return m 23 | 24 | 25 | def test_model_initiation(tf_model_object): 26 | assert isinstance(tf_model_object, dice_ml.model_interfaces.keras_tensorflow_model.KerasTensorFlowModel) 27 | 28 | 29 | def test_model_initiation_fullpath(): 30 | """ 31 | Tests if model is initiated when full path to a model and explainer class is given to backend parameter. 32 | """ 33 | tf_version = tf.__version__[0] 34 | backend = {'model': 'keras_tensorflow_model.KerasTensorFlowModel', 35 | 'explainer': 'dice_tensorflow'+tf_version+'.DiceTensorFlow'+tf_version} 36 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend) 37 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend) 38 | assert isinstance(m, dice_ml.model_interfaces.keras_tensorflow_model.KerasTensorFlowModel) 39 | 40 | 41 | class TestKerasModelMethods: 42 | @pytest.fixture(autouse=True) 43 | def _get_model_object(self, tf_model_object, tf_session): 44 | self.m = tf_model_object 45 | self.sess = tf_session 46 | 47 | def test_load_model(self): 48 | self.m.load_model() 49 | assert self.m.model is not None 50 | 51 | # @pytest.mark.parametrize("input_instance, prediction",[(np.array([[0.5]*29], dtype=np.float32), 0.747)]) 52 | # def test_model_output(self, input_instance, prediction): 53 | # self.m.load_model() 54 | # if tf.__version__[0] == '1': 55 | # input_instance_tf = tf.Variable(input_instance, dtype=tf.float32) 56 | # output_instance = self.m.get_output(input_instance_tf) 57 | # prediction = self.sess.run(output_instance, feed_dict={input_instance_tf:input_instance})[0][0] 58 | # else: 59 | # prediction = self.m.get_output(input_instance).numpy()[0][0] 60 | # pytest.approx(prediction, abs=1e-3) == prediction 61 | 62 | @pytest.mark.parametrize("prediction", [0.747]) 63 | def test_model_output(self, sample_adultincome_query, public_data_object, prediction): 64 | # Initializing data and model objects 65 | self.m.load_model() 66 | # initializing data transformation required for ML model 67 | self.m.transformer = DataTransfomer(func='ohe-min-max', kw_args=None) 68 | self.m.transformer.feed_data_params(public_data_object) 69 | self.m.transformer.initialize_transform_func() 70 | output_instance = self.m.get_output(sample_adultincome_query, transform_data=True) 71 | 72 | if tf.__version__[0] == '1': 73 | predictval = self.sess.run(output_instance)[0][0] 74 | else: 75 | predictval = output_instance.numpy()[0][0] 76 | assert predictval is not None 77 | # TODO: The assert below fails. 78 | # assert pytest.approx(predictval, abs=1e-3) == prediction 79 | -------------------------------------------------------------------------------- /tests/test_model_interface/test_pytorch_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import dice_ml 4 | from dice_ml.utils import helpers 5 | from dice_ml.utils.helpers import DataTransfomer 6 | 7 | pyt = pytest.importorskip("torch") 8 | 9 | 10 | @pytest.fixture(scope='session') 11 | def pyt_model_object(): 12 | backend = 'PYT' 13 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend) 14 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend, func='ohe-min-max') 15 | return m 16 | 17 | 18 | def test_model_initiation(pyt_model_object): 19 | assert isinstance(pyt_model_object, dice_ml.model_interfaces.pytorch_model.PyTorchModel) 20 | 21 | 22 | def test_model_initiation_fullpath(): 23 | """ 24 | Tests if model is initiated when full path to a model and explainer class is given to backend parameter. 25 | """ 26 | pytest.importorskip("torch") 27 | backend = {'model': 'pytorch_model.PyTorchModel', 28 | 'explainer': 'dice_pytorch.DicePyTorch'} 29 | ML_modelpath = helpers.get_adult_income_modelpath(backend=backend) 30 | m = dice_ml.Model(model_path=ML_modelpath, backend=backend) 31 | assert isinstance(m, dice_ml.model_interfaces.pytorch_model.PyTorchModel) 32 | 33 | 34 | class TestPyTorchModelMethods: 35 | @pytest.fixture(autouse=True) 36 | def _get_model_object(self, pyt_model_object): 37 | self.m = pyt_model_object 38 | 39 | def test_load_model(self): 40 | self.m.load_model() 41 | assert self.m.model is not None 42 | 43 | # @pytest.mark.parametrize("input_instance, prediction",[(np.array([[0.5]*29], dtype=np.float32), 0.0957)]) 44 | # def test_model_output(self, input_instance, prediction): 45 | # self.m.load_model() 46 | # input_instance_pyt = pyt.tensor(input_instance) 47 | # prediction = self.m.get_output(input_instance_pyt).detach().numpy()[0][0] 48 | # pytest.approx(prediction, abs=1e-3) == prediction 49 | 50 | @pytest.mark.parametrize("prediction", [0.0957]) 51 | def test_model_output(self, sample_adultincome_query, public_data_object, prediction): 52 | # initializing data transormation required for ML model 53 | self.m.load_model() 54 | self.m.transformer = DataTransfomer(func='ohe-min-max', kw_args=None) 55 | self.m.transformer.feed_data_params(public_data_object) 56 | self.m.transformer.initialize_transform_func() 57 | output_instance = self.m.get_output(sample_adultincome_query, transform_data=True) 58 | predictval = output_instance[0][0] 59 | assert predictval is not None 60 | # TODO: The assert below fails. 61 | # assert pytest.approx(predictval, abs=1e-3) == prediction 62 | -------------------------------------------------------------------------------- /tests/test_notebooks.py: -------------------------------------------------------------------------------- 1 | """ Code to automatically run all notebooks as a test. 2 | Adapted from the same code for the Microsoft DoWhy library. 3 | """ 4 | 5 | import os 6 | import subprocess 7 | import sys 8 | import tempfile 9 | 10 | import nbformat 11 | import pytest 12 | 13 | NOTEBOOKS_PATH = "docs/source/notebooks/" 14 | 15 | # Adding the dice root folder to the python path so that jupyter notebooks 16 | if 'PYTHONPATH' not in os.environ: 17 | os.environ['PYTHONPATH'] = os.getcwd() 18 | elif os.getcwd() not in os.environ['PYTHONPATH'].split(os.pathsep): 19 | os.environ['PYTHONPATH'] = os.environ['PYTHONPATH'] + os.pathsep + os.getcwd() 20 | 21 | 22 | def get_notebook_parameter_list(): 23 | notebooks_list = [f.name for f in os.scandir(NOTEBOOKS_PATH) if f.name.endswith(".ipynb")] 24 | # notebooks that should not be run 25 | advanced_notebooks = [ 26 | "DiCE_with_advanced_options.ipynb", # requires tensorflow 1.x 27 | "DiCE_getting_started_feasible.ipynb", # needs changes after latest refactor 28 | "DiCE_with_private_data.ipynb", # needs compatible version of sklearn to load model 29 | "Benchmarking_different_CF_explanation_methods.ipynb" 30 | ] 31 | # notebooks that don't need to run on python 3.10 32 | torch_notebooks_not_3_10 = [ 33 | "DiCE_getting_started.ipynb" 34 | ] 35 | 36 | # Creating the list of notebooks to run 37 | parameter_list = [] 38 | for nb in notebooks_list: 39 | if nb in advanced_notebooks: 40 | param = pytest.param( 41 | nb, 42 | marks=[pytest.mark.skip, pytest.mark.advanced], 43 | id=nb) 44 | elif sys.version_info >= (3, 10) and nb in torch_notebooks_not_3_10: 45 | param = pytest.param( 46 | nb, 47 | marks=[pytest.mark.skip, pytest.mark.advanced], 48 | id=nb) 49 | else: 50 | param = pytest.param(nb, id=nb) 51 | parameter_list.append(param) 52 | 53 | return parameter_list 54 | 55 | 56 | def _check_notebook_cell_outputs(filepath): 57 | """Convert notebook via nbconvert, collect output and assert if any output cells are not empty. 58 | 59 | :param filepath: file path for the notebook 60 | """ 61 | with tempfile.NamedTemporaryFile(suffix=".ipynb") as fout: 62 | args = ["jupyter", "nbconvert", "--to", "notebook", 63 | "-y", "--no-prompt", 64 | "--output", fout.name, filepath] 65 | subprocess.check_call(args) 66 | fout.seek(0) 67 | nb = nbformat.read(fout, nbformat.current_nbformat) 68 | 69 | for cell in nb.cells: 70 | if "outputs" in cell: 71 | if len(cell['outputs']) > 0: 72 | raise AssertionError("Output cell found in notebook. Please clean your notebook") 73 | 74 | 75 | def _notebook_run(filepath): 76 | """Execute a notebook via nbconvert and collect output. 77 | 78 | Source of this function: http://www.christianmoscardi.com/blog/2016/01/20/jupyter-testing.html 79 | 80 | :param filepath: file path for the notebook 81 | :returns List of execution errors 82 | """ 83 | with tempfile.NamedTemporaryFile(suffix=".ipynb") as fout: 84 | args = ["jupyter", "nbconvert", "--to", "notebook", "--execute", 85 | # "--ExecutePreprocessor.timeout=600", 86 | "-y", "--no-prompt", 87 | "--output", fout.name, filepath] 88 | subprocess.check_call(args) 89 | 90 | fout.seek(0) 91 | nb = nbformat.read(fout, nbformat.current_nbformat) 92 | 93 | errors = [output for cell in nb.cells if "outputs" in cell 94 | for output in cell["outputs"] 95 | if output.output_type == "error"] 96 | 97 | return errors 98 | 99 | 100 | @pytest.mark.parametrize("notebook_filename", get_notebook_parameter_list()) 101 | @pytest.mark.notebook_tests 102 | def test_notebook(notebook_filename): 103 | _check_notebook_cell_outputs(NOTEBOOKS_PATH + notebook_filename) 104 | errors = _notebook_run(NOTEBOOKS_PATH + notebook_filename) 105 | assert len(errors) == 0 106 | --------------------------------------------------------------------------------