├── .github └── workflows │ ├── python-app.yml │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── artemis ├── __init__.py ├── _utilities │ ├── __init__.py │ ├── _handler.py │ ├── domain.py │ ├── exceptions.py │ ├── ops.py │ ├── pd_calculator.py │ ├── performance_metrics.py │ ├── split_score_metrics.py │ └── zenplot.py ├── additivity │ ├── __init__.py │ └── _additivity_meter.py ├── comparison │ ├── __init__.py │ └── _method_comparator.py ├── importance_methods │ ├── __init__.py │ ├── _method.py │ ├── model_agnostic │ │ ├── __init__.py │ │ ├── _pdp.py │ │ └── _permutational_importance.py │ └── model_specific │ │ ├── __init__.py │ │ ├── _minimal_depth.py │ │ └── _split_score.py ├── interactions_methods │ ├── __init__.py │ ├── _method.py │ ├── model_agnostic │ │ ├── __init__.py │ │ ├── partial_dependence_based │ │ │ ├── __init__.py │ │ │ ├── _friedman_h_statistic.py │ │ │ ├── _greenwell.py │ │ │ └── _pdp.py │ │ └── performance_based │ │ │ ├── __init__.py │ │ │ └── _sejong_oh.py │ └── model_specific │ │ ├── __init__.py │ │ ├── gb_trees │ │ ├── __init__.py │ │ └── _split_score.py │ │ └── random_forest │ │ ├── __init__.py │ │ └── _conditional_minimal_depth.py └── visualizer │ ├── __init__.py │ ├── _configuration.py │ ├── _pdp_visualizer.py │ └── _visualizer.py ├── demo.ipynb ├── docs └── artemis │ ├── additivity │ └── index.html │ ├── comparison │ └── index.html │ ├── importance_methods │ ├── index.html │ ├── model_agnostic │ │ └── index.html │ └── model_specific │ │ └── index.html │ ├── index.html │ ├── interactions_methods │ ├── index.html │ ├── model_agnostic │ │ ├── index.html │ │ ├── partial_dependence_based │ │ │ └── index.html │ │ └── performance_based │ │ │ └── index.html │ └── model_specific │ │ ├── gb_trees │ │ └── index.html │ │ ├── index.html │ │ └── random_forest │ │ └── index.html │ └── visualizer │ └── index.html ├── poetry.lock ├── pyproject.toml └── test ├── __init__.py ├── test_additivity_meter.py ├── test_conditional_minimal_depth.py ├── test_friedman_h.py ├── test_greenwell_inter.py ├── test_method_comparator.py ├── test_pdp_calculator.py ├── test_sejong_oh_inter.py ├── test_split_score_inter.py ├── test_variable_importance.py └── util.py /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: [ "master", "dev"] 6 | pull_request: 7 | branches: [ "master", "dev"] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Set up Python 3.8 20 | uses: actions/setup-python@v3 21 | with: 22 | python-version: "3.8" 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install poetry 27 | poetry install 28 | - name: Test with pytest 29 | run: | 30 | poetry run pytest 31 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 37 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.PYPI_API_TOKEN }} 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ 131 | 132 | # macOS files 133 | .DS_Store 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Paweł Fijałkowski, Mateusz Krzyziński, Artur Żółkowski 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ARTEMIS: A Robust Toolkit of Explanation Methods for Interaction Spotting 2 | A Python package with explanation methods for extraction of feature interactions from predictive models 3 | 4 | [![build](https://github.com/pyartemis/artemis/actions/workflows/python-app.yml/badge.svg)](https://github.com/pyartemis/artemis/actions/workflows/python-app.yml) 5 | [![PyPI version](https://badge.fury.io/py/pyartemis.svg)](https://pypi.org/project/pyartemis/) 6 | [![Downloads](https://static.pepy.tech/badge/pyartemis)](https://pepy.tech/project/pyartemis) 7 | 8 | ## Overview 9 | `artemis` is a **Python** package for data scientists and machine learning practitioners which exposes standardized API for extracting feature interactions from predictive models using a number of different methods described in scientific literature. 10 | 11 | The package provides both model-agnostic (no assumption about model structure), and model-specific (e.g., tree-based models) feature interaction methods, as well as other methods that can facilitate and support the analysis and exploration of the predictive model in the context of feature interactions. 12 | 13 | The available methods are suited to tabular data and classification and regression problems. The main functionality is that users are able to scrutinize a wide range of models by examining feature interactions in them by finding the strongest ones (in terms of numerical values of implemented methods) and creating tailored visualizations. 14 | 15 | ## Documentation 16 | Full documentation is available at [https://pyartemis.github.io/](https://pyartemis.github.io/). 17 | 18 | ## Installation 19 | Latest released version of the `artemis` package is available on [Python Package Index (PyPI)](https://pypi.org/project/pyartemis/): 20 | 21 | ``` 22 | pip install -U pyartemis 23 | ``` 24 | 25 | The source code and development version is currently hosted on [GitHub](https://github.com/pyartemis/artemis). 26 | 27 | *** 28 | 29 | ## Authors 30 | 31 | The package was created as a software project associated with the BSc thesis ***Methods for extraction of interactions from predictive models*** in the field of Data Science (pl. *Inżynieria i analiza danych*) at Faculty of Mathematics and Information Science (MiNI), Warsaw University of Technology in cooperation with NASK National Research Institute. 32 | 33 | The authors of the `artemis` package are: 34 | - [Paweł Fijałkowski](https://github.com/pablo2811) 35 | - [Mateusz Krzyziński](https://github.com/krzyzinskim) 36 | - [Artur Żółkowski](https://github.com/arturzolkowski) 37 | 38 | BSc thesis and work on the `artemis` package was supervised by [Przemysław Biecek, PhD, DSc](https://github.com/pbiecek). 39 | 40 | -------------------------------------------------------------------------------- /artemis/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | # `artemis`: A Robust Toolkit of Explanation Methods for Interaction Spotting 3 | 4 | ## What is `artemis`? 5 | **`artemis` is a Python library for explanations of feature interactions in machine learning models.** 6 | - It provides methods for analyzing predictive models in terms of interactions between features and feature importance. 7 | - There are both model-agnostic methods that can work with any properly prepared model, and model-specific methods adapted to tree-based ones 8 | (due to their structure, which naturally influences the possibility of interaction). 9 | - It enables to scrutinize a wide range of models by examining strength of feature interactions and visualizing them. 10 | 11 | """ -------------------------------------------------------------------------------- /artemis/_utilities/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyartemis/artemis/fd6ed2018b6c9b8aa1de3e3f91a345f60d801271/artemis/_utilities/__init__.py -------------------------------------------------------------------------------- /artemis/_utilities/_handler.py: -------------------------------------------------------------------------------- 1 | from re import search 2 | import pandas as pd 3 | 4 | from artemis._utilities.exceptions import ModelNotSupportedException 5 | 6 | 7 | class GBTreesHandler: 8 | """Class to unify the structure of the gradient boosing decision trees models. 9 | 10 | Attributes 11 | ---------- 12 | package : str 13 | Name of the package used to train the model. 14 | trees_df : pd.DataFrame 15 | Unified structure of the trained trees. 16 | """ 17 | def __init__(self, model=None) -> None: 18 | """Constructor for GBTreesHandler. 19 | 20 | Parameters 21 | ---------- 22 | model : object, optional 23 | Trained model which structure will be unified.""" 24 | if model is not None: 25 | self.unify_structure(model) 26 | 27 | def unify_structure(self, model) -> None: 28 | model_class = search("(?<= np.ndarray: 39 | if feature_values is None: 40 | return self.pd_single[feature]["pd_values"] 41 | selected_values = np.zeros(len(feature_values)) 42 | for i, feature_value in enumerate(feature_values): 43 | f_index = get_index(self.pd_single[feature]["f_values"], feature_value) 44 | selected_values[i] = self.pd_single[feature]["pd_values"][f_index] 45 | return selected_values 46 | 47 | def get_pd_pairs(self, feature1: str, feature2: str, feature_values: Optional[List[Tuple[Any, Any]]] = None) -> np.ndarray: 48 | pair_key = self._get_pair_key((feature1, feature2)) 49 | all_matrix = self.pd_pairs[pair_key]["pd_values"] 50 | if feature_values is None: 51 | return all_matrix 52 | if pair_key != (feature1, feature2): 53 | feature_values = reorder_pair_values(feature_values) 54 | selected_values = np.zeros(len(feature_values)) 55 | for i, pair in enumerate(feature_values): 56 | f1_index = get_index(self.pd_pairs[pair_key]["f1_values"], pair[0]) 57 | f2_index = get_index(self.pd_pairs[pair_key]["f2_values"], pair[1]) 58 | selected_values[i] = all_matrix[f1_index, f2_index] 59 | return selected_values 60 | 61 | def get_pd_minus_single(self, feature: str) -> np.ndarray: 62 | return self.pd_minus_single[feature]["pd_values"] 63 | 64 | def calculate_pd_single(self, features: Optional[List[str]] = None, show_progress: bool = False, desc: str = ProgressInfoLog.CALC_VAR_IMP): 65 | if features is None: 66 | features = self.X.columns 67 | range_dict = {} 68 | current_len = 0 69 | X_full = pd.DataFrame() 70 | for feature in tqdm(features, desc=desc, disable=not show_progress): 71 | if np.isnan(self.pd_single[feature]["pd_values"]).any(): 72 | for value in self.pd_single[feature]["f_values"]: 73 | change_dict = {feature: value} 74 | X_changed = self.X.copy().assign(**change_dict) 75 | range_dict[(feature, value)] = (current_len, current_len+self.X_len) 76 | current_len += self.X_len 77 | X_full = pd.concat((X_full, X_changed)) 78 | if current_len > self.batchsize: 79 | self.fill_pd_single(range_dict, X_full) 80 | current_len = 0 81 | range_dict = {} 82 | X_full = pd.DataFrame() 83 | if current_len > 0: 84 | self.fill_pd_single(range_dict, X_full) 85 | 86 | def calculate_pd_pairs(self, feature_pairs = None, all_combinations=True, show_progress: bool = False, desc: str = ProgressInfoLog.CALC_OVO): 87 | if feature_pairs is None: 88 | feature_pairs = self.pd_pairs.keys() 89 | range_dict = {} 90 | current_len = 0 91 | X_full = pd.DataFrame() 92 | for feature1, feature2 in tqdm(feature_pairs, desc=desc, disable=not show_progress): 93 | feature1, feature2 = self._get_pair_key((feature1, feature2)) 94 | if all_combinations: 95 | feature_values = [(f1, f2) for f1 in self.pd_pairs[(feature1, feature2)]["f1_values"] for f2 in self.pd_pairs[(feature1, feature2)]["f2_values"]] 96 | else: 97 | feature_values = zip(self.X[feature1].values, self.X[feature2].values) 98 | for value1, value2 in feature_values: 99 | f1_ind = get_index(self.pd_pairs[(feature1, feature2)]["f1_values"], value1) 100 | f2_ind = get_index(self.pd_pairs[(feature1, feature2)]["f2_values"], value2) 101 | if np.isnan(self.pd_pairs[(feature1, feature2)]["pd_values"][f1_ind, f2_ind]): 102 | change_dict = {feature1: value1, feature2: value2} 103 | X_changed = self.X.copy().assign(**change_dict) 104 | range_dict[(feature1, feature2, value1, value2)] = (current_len, current_len+self.X_len) 105 | current_len += self.X_len 106 | X_full = pd.concat((X_full, X_changed)) 107 | if current_len > self.batchsize: 108 | self.fill_pd_pairs(range_dict, X_full) 109 | current_len = 0 110 | range_dict = {} 111 | X_full = pd.DataFrame() 112 | if current_len > 0: 113 | self.fill_pd_pairs(range_dict, X_full) 114 | 115 | def calculate_pd_minus_single(self, features: Optional[List[str]] = None, show_progress: bool = False, desc: str = ProgressInfoLog.CALC_OVA): 116 | if features is None: 117 | features = self.X.columns 118 | range_dict = {} 119 | current_len = 0 120 | X_full = pd.DataFrame() 121 | for feature in tqdm(features, desc=desc, disable=not show_progress): 122 | if np.isnan(self.pd_minus_single[feature]["pd_values"]).any(): 123 | for i, row in self.X.copy().reset_index(drop=True).iterrows(): 124 | change_dict = {other_feature: row[other_feature] for other_feature in self.X.columns if other_feature != feature} 125 | X_changed = self.X.copy().assign(**change_dict) 126 | range_dict[(feature, i)] = (current_len, current_len+self.X_len) 127 | current_len += self.X_len 128 | X_full = pd.concat((X_full, X_changed)) 129 | if current_len > self.batchsize: 130 | self.fill_pd_minus_single(range_dict, X_full) 131 | current_len = 0 132 | range_dict = {} 133 | X_full = pd.DataFrame() 134 | if current_len > 0: 135 | self.fill_pd_minus_single(range_dict, X_full) 136 | 137 | def fill_pd_single(self, range_dict, X_full): 138 | y = self.predict_function(self.model, X_full) 139 | for var_name, var_val in range_dict.keys(): 140 | start, end = range_dict[(var_name, var_val)] 141 | value_index = get_index(self.pd_single[var_name]["f_values"], var_val) 142 | self.pd_single[var_name]["pd_values"][value_index] = np.mean(y[start:end]) 143 | 144 | def fill_pd_pairs(self, range_dict, X_full): 145 | y = self.predict_function(self.model, X_full) 146 | for var_name1, var_name2, var_val1, var_val2 in range_dict.keys(): 147 | start, end = range_dict[(var_name1, var_name2, var_val1, var_val2)] 148 | value_index1 = get_index(self.pd_pairs[(var_name1, var_name2)]["f1_values"], var_val1) 149 | value_index2 = get_index(self.pd_pairs[(var_name1, var_name2)]["f2_values"], var_val2) 150 | self.pd_pairs[(var_name1, var_name2)]["pd_values"][value_index1, value_index2] = np.mean(y[start:end]) 151 | 152 | def fill_pd_minus_single(self, range_dict, X_full): 153 | y = self.predict_function(self.model, X_full) 154 | for var_name, row_id in range_dict.keys(): 155 | start, end = range_dict[(var_name, row_id)] 156 | self.pd_minus_single[var_name]["pd_values"][row_id] = np.mean(y[start:end]) 157 | 158 | def _get_pair_key(self, pair: Tuple[str, str]) -> Tuple[str, str]: 159 | if pair in self.pd_pairs.keys(): 160 | return pair 161 | else: 162 | return (pair[1], pair[0]) 163 | 164 | 165 | def get_index(array, value) -> int: 166 | return np.where(array == value)[0][0] 167 | 168 | def reorder_pair_values(pair_values: List[Tuple[Any, Any]]) -> List[Tuple[Any, Any]]: 169 | return [(pair[1], pair[0]) for pair in pair_values] 170 | 171 | 172 | -------------------------------------------------------------------------------- /artemis/_utilities/performance_metrics.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | 3 | import numpy as np 4 | 5 | from .domain import ProblemType 6 | 7 | 8 | class Metric: 9 | 10 | def __init__(self, problem_type: str): 11 | self.problem_type = problem_type 12 | 13 | @abstractmethod 14 | def calculate(self, y_true: np.array, y_hat: np.array) -> float: 15 | ... 16 | 17 | def applicable_to(self, problem_type: str): 18 | return self.problem_type == problem_type 19 | 20 | 21 | class RMSE(Metric): 22 | 23 | def __init__(self): 24 | super().__init__(ProblemType.REGRESSION) 25 | 26 | def calculate(self, y_true: np.array, y_hat: np.array) -> float: 27 | return np.sqrt(MSE().calculate(y_true, y_hat)) 28 | 29 | 30 | class MSE(Metric): 31 | 32 | def __init__(self): 33 | super().__init__(ProblemType.REGRESSION) 34 | 35 | def calculate(self, y_true: np.array, y_hat: np.array) -> float: 36 | return np.square(y_hat - y_true).mean() 37 | 38 | 39 | class Accuracy(Metric): 40 | 41 | def __init__(self): 42 | super().__init__(ProblemType.CLASSIFICATION) 43 | 44 | def calculate(self, y_true: np.array, y_hat: np.array) -> float: 45 | return np.equal(y_true, y_hat).mean() 46 | -------------------------------------------------------------------------------- /artemis/_utilities/split_score_metrics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class SplitScoreInteractionMetric: 6 | SUM_GAIN: str = "sum_gain" 7 | SUM_COVER: str = "sum_cover" 8 | MEAN_GAIN: str = "mean_gain" 9 | MEAN_COVER: str = "mean_cover" 10 | MEAN_DEPTH: str = "mean_depth" 11 | 12 | 13 | @dataclass 14 | class SplitScoreImportanceMetric(SplitScoreInteractionMetric): 15 | MEAN_WEIGHTED_DEPTH: str = "mean_weighted_depth" 16 | ROOT_FREQUENCY: str = "root_frequency" 17 | WEIGHTED_ROOT_FREQUENCY: str = "weighted_root_frequency" 18 | 19 | _LGBM_UNSUPPORTED_METRICS = [SplitScoreInteractionMetric.MEAN_COVER, SplitScoreInteractionMetric.SUM_COVER] 20 | _ASCENDING_ORDER_METRICS = [SplitScoreInteractionMetric.MEAN_DEPTH] -------------------------------------------------------------------------------- /artemis/_utilities/zenplot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def find_next_pair_index(left_to_vis, f_to_join_by): 5 | filtered_subset = left_to_vis[(left_to_vis.iloc[:, 0] == f_to_join_by) | 6 | (left_to_vis.iloc[:, 1] == f_to_join_by)].iloc[:, 2] 7 | if len(filtered_subset) > 0: 8 | idxmax = filtered_subset.idxmax() 9 | return idxmax, True 10 | else: 11 | return left_to_vis.index[0], False 12 | 13 | 14 | def get_second_feature(prev_feature, row): 15 | if row["Feature 1"] == prev_feature: 16 | return row["Feature 2"] 17 | return row["Feature 1"] 18 | 19 | 20 | def get_pd_pairs_values(method, pair): 21 | pair_key = method.pd_calculator._get_pair_key(pair) 22 | pair_values = method.pd_calculator.pd_pairs[pair_key].copy() 23 | if pair_key[0] != pair[0]: 24 | pair_values["f1_values"], pair_values["f2_values"] = pair_values["f2_values"], pair_values["f1_values"] 25 | pair_values["pd_values"] = pair_values["pd_values"].T 26 | return pair_values 27 | 28 | 29 | def get_pd_dict(pd_calculator, to_vis): 30 | max_pd = 0 31 | min_pd = 1 32 | for i in range(len(to_vis)): 33 | pair = (to_vis.iloc[i, 0], to_vis.iloc[i, 1]) 34 | pair_key = pd_calculator._get_pair_key(pair) 35 | pair_values = pd_calculator.pd_pairs[pair_key].copy() 36 | max_pd = max(np.max(pair_values["pd_values"]), max_pd) 37 | min_pd = min(np.min(pair_values["pd_values"]), min_pd) 38 | return min_pd, max_pd 39 | -------------------------------------------------------------------------------- /artemis/additivity/__init__.py: -------------------------------------------------------------------------------- 1 | from ._additivity_meter import AdditivityMeter 2 | 3 | __all__ = ["AdditivityMeter"] 4 | -------------------------------------------------------------------------------- /artemis/additivity/_additivity_meter.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional 2 | import numpy as np 3 | import pandas as pd 4 | from artemis._utilities.domain import ProgressInfoLog 5 | from artemis._utilities.ops import get_predict_function, sample_if_not_none 6 | 7 | from artemis._utilities.pd_calculator import PartialDependenceCalculator 8 | 9 | 10 | class AdditivityMeter: 11 | """ 12 | AdditivityMeter is a class that calculates the additivity index of a model. 13 | 14 | Attributes 15 | ----------- 16 | additivity_index : float 17 | Additivity index of the model. 18 | full_result : pd.DataFrame 19 | Dataframe with the results of the additivity index calculation. 20 | It contains centered partial dependence values and prediction for every observation and feature. 21 | preds: np.ndarray 22 | Predictions for the sampled data. 23 | model : object 24 | Model for which additivity index is calculated. 25 | X_sampled: pd.DataFrame 26 | Sampled data used for calculation. 27 | pd_calculator : PartialDependenceCalculator 28 | Object used to calculate and store partial dependence values. 29 | """ 30 | def __init__(self, random_state: Optional[int] = None): 31 | self._random_generator = np.random.default_rng(random_state) 32 | self.additivity_index = None 33 | self.full_result = None 34 | self.pred = None 35 | self.model = None 36 | self.X_smapled = None 37 | self.pd_calculator = None 38 | self 39 | 40 | def fit( 41 | self, 42 | model, 43 | X: pd.DataFrame, 44 | n: int = None, 45 | predict_function: Optional[Callable] = None, 46 | show_progress: bool = False, 47 | batchsize: int = 2000, 48 | pd_calculator: Optional[PartialDependenceCalculator] = None, 49 | ): 50 | """ 51 | Calculates the additivity index of the given model. 52 | 53 | Parameters 54 | ----------- 55 | model : object 56 | Model to calculate additivity index for, should have predict_proba or predict method, or predict_function should be provided. 57 | X : pd.DataFrame 58 | Data used to calculate the additivity index. If n is not None, n rows from X will be sampled. 59 | n : int, optional 60 | Number of samples to be used for calculation of the additivity index. If None, all rows from X will be used. Default is None. 61 | predict_function : Callable, optional 62 | Function used to predict model output. It should take model and dataset and outputs predictions. 63 | If None, `predict_proba` method will be used if it exists, otherwise `predict` method. Default is None. 64 | show_progress : bool 65 | If True, progress bar will be shown. Default is False. 66 | batchsize : int 67 | Batch size for calculating partial dependence. Prediction requests are collected until the batchsize is exceeded, 68 | then the model is queried for predictions jointly for many observations. It speeds up the operation of the method. 69 | Default is 2000. 70 | pd_calculator : PartialDependenceCalculator, optional 71 | PartialDependenceCalculator object containing partial dependence values for a given model and dataset. 72 | Providing this object speeds up the calculation as partial dependence values do not need to be recalculated. 73 | If None, it will be created from scratch. Default is None. 74 | 75 | Returns 76 | -------- 77 | additivity_index : float 78 | Additivity index of the model. Value from [0, 1] interval where 1 means that the model is additive, 79 | and 0 means that the model is not additive. 80 | """ 81 | self.predict_function = get_predict_function(model, predict_function) 82 | self.model = model 83 | self.X_sampled = sample_if_not_none(self._random_generator, X, n) 84 | 85 | if pd_calculator is None: 86 | self.pd_calculator = PartialDependenceCalculator( 87 | self.model, self.X_sampled, self.predict_function, batchsize 88 | ) 89 | else: 90 | if pd_calculator.model != self.model: 91 | raise ValueError( 92 | "Model in PDP calculator is different than the model in the method." 93 | ) 94 | if not pd_calculator.X.equals(self.X_sampled): 95 | raise ValueError( 96 | "Data in PDP calculator is different than the data in the method." 97 | ) 98 | self.pd_calculator = pd_calculator 99 | 100 | self.full_result = self.X_sampled.copy() 101 | self.additivity_index = self._calculate_additivity(show_progress=show_progress) 102 | return self.additivity_index 103 | 104 | def _calculate_additivity(self, show_progress: bool): 105 | self.pd_calculator.calculate_pd_single( 106 | show_progress=show_progress, desc=ProgressInfoLog.CALC_ADD 107 | ) 108 | 109 | self.preds = self.predict_function(self.model, self.X_sampled) 110 | for var in self.X_sampled.columns: 111 | self.full_result[var] = self.pd_calculator.get_pd_single(var, self.X_sampled[var].values) - np.mean(self.preds) 112 | 113 | self.full_result = self.full_result 114 | self.full_result["centered_prediction"] = self.preds - np.mean(self.preds) 115 | 116 | sum_first_order_effects = self.full_result.values[:, :-1].sum(axis=1) + np.mean(self.preds) 117 | return 1-np.sum((self.preds - sum_first_order_effects)**2) / np.sum((self.full_result["centered_prediction"])**2) 118 | -------------------------------------------------------------------------------- /artemis/comparison/__init__.py: -------------------------------------------------------------------------------- 1 | from ._method_comparator import FeatureInteractionMethodComparator 2 | 3 | __all__ = ["FeatureInteractionMethodComparator"] 4 | -------------------------------------------------------------------------------- /artemis/comparison/_method_comparator.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Tuple 3 | 4 | import pandas as pd 5 | from matplotlib import pyplot as plt 6 | 7 | from artemis._utilities.domain import CorrelationMethod 8 | from artemis._utilities.exceptions import MethodNotFittedException 9 | from artemis._utilities.ops import point_left_side_circle 10 | from artemis.interactions_methods._method import FeatureInteractionMethod 11 | from artemis.visualizer._configuration import InteractionGraphConfiguration 12 | 13 | 14 | class FeatureInteractionMethodComparator: 15 | """ 16 | Feature Interaction Method Comparator. 17 | It is used for statistical comparison of two different feature interaction methods. 18 | Calculates Pearson, Kendall and Spearman rank-correlation and plots one vs one profiles of two methods against 19 | each other. Monotonicity of the plot suggest cohesion in results. Both provided methods must be in fitted state. 20 | 21 | 22 | Attributes 23 | ---------- 24 | ovo_profiles_comparison_plot : Figure 25 | Matplotlib figure of comparison plots. 26 | correlations_df : pd.DataFrame 27 | Pearson, Kendall and Spearman rank correlation values 28 | 29 | References 30 | ---------- 31 | - https://en.wikipedia.org/wiki/Rank_correlation 32 | """ 33 | 34 | def __init__(self): 35 | """Constructor for FeatureInteractionMethodComparator""" 36 | self.ovo_profiles_comparison_plot = None 37 | self.correlations_df = None 38 | 39 | def summary(self, 40 | method1: FeatureInteractionMethod, 41 | method2: FeatureInteractionMethod): 42 | _assert_fitted_ovo(method1, method2) 43 | """ 44 | Calculates Feature Interaction Method comparison. 45 | Used for asserting stability and cohesion of results for a pair of explanation methods. 46 | 47 | Parameters 48 | ---------- 49 | method1 : FeatureInteractionMethod 50 | First method for comparison 51 | method2 : FeatureInteractionMethod 52 | Second method for comparison 53 | 54 | Returns 55 | ------- 56 | None 57 | """ 58 | self.correlations_df = self.correlations(method1, method2) 59 | self.ovo_profiles_comparison_plot = self.comparison_plot(method1, method2, add_correlation_box=True) 60 | 61 | def correlations(self, method1: FeatureInteractionMethod, method2: FeatureInteractionMethod): 62 | """ 63 | Calculates Pearson, Kendall and Spearman rank correlation DataFrame. 64 | 65 | Parameters 66 | ---------- 67 | method1 : FeatureInteractionMethod 68 | First method for comparison 69 | method2 : FeatureInteractionMethod 70 | Second method for comparison 71 | 72 | Returns 73 | ------- 74 | None 75 | """ 76 | correlations = [] 77 | for correlation_method in dataclasses.fields(CorrelationMethod): 78 | correlation_method_name = correlation_method.default 79 | correlations.append( 80 | { 81 | "method": correlation_method_name, 82 | "value": self.correlation(method1, method2, correlation_method_name) 83 | }) 84 | 85 | return pd.DataFrame.from_records(correlations) 86 | 87 | def comparison_plot(self, 88 | method1: FeatureInteractionMethod, 89 | method2: FeatureInteractionMethod, 90 | n_labels: int = 3, 91 | add_correlation_box: bool = False, 92 | figsize: Tuple[float, float] = (8, 6)): 93 | """ 94 | Creates comparison plot for comparing results of two feature interaction methods. Depending on the parameters 95 | rank correlation might be included on the plot. 96 | 97 | Parameters 98 | ---------- 99 | method1 : FeatureInteractionMethod 100 | First method for comparison 101 | method2 : FeatureInteractionMethod 102 | Second method for comparison 103 | n_labels: int 104 | Number of pairs of features with the greatest interaction values to show labels of, default = 3 105 | add_correlation_box: bool 106 | Flag indicating whether to show rank correlation values on the plot, default = False 107 | figsize: Tuple[float, float] 108 | Matplotlib size of the figure, default = (8, 6) 109 | 110 | 111 | Returns 112 | ------- 113 | Figure 114 | """ 115 | m1_name, m2_name = method1.method, method2.method 116 | fig, ax = plt.subplots(figsize=figsize) 117 | ax.set_axisbelow(True) 118 | plt.grid(True) 119 | circle_r = 0.2 * min(max(method1._compare_ovo[m1_name]), max(method2._compare_ovo[m2_name])) 120 | 121 | x, y = [], [] 122 | for index, row in method1._compare_ovo.iterrows(): 123 | 124 | f1, f2 = row["Feature 1"], row["Feature 2"] 125 | x_curr, y_curr = row[method1.method], method2.interaction_value(f1, f2) 126 | x.append(x_curr) 127 | y.append(y_curr) 128 | 129 | if index < n_labels: 130 | _add_arrow(ax, circle_r, f1, f2, x_curr, y_curr) 131 | 132 | ax.scatter(x, y, color=InteractionGraphConfiguration.NODE_COLOR) 133 | 134 | if method1._interactions_ascending_order: 135 | plt.gca().invert_xaxis() 136 | if method2._interactions_ascending_order: 137 | plt.gca().invert_yaxis() 138 | 139 | if add_correlation_box: 140 | 141 | corr = self.correlations_df 142 | if self.correlations_df is None: 143 | corr = self.correlations(method1, method2) 144 | 145 | _add_correlation_box(ax, corr) 146 | 147 | _title_x_y(ax, m1_name, m2_name) 148 | 149 | return fig, ax 150 | 151 | @staticmethod 152 | def correlation( 153 | method1: FeatureInteractionMethod, 154 | method2: FeatureInteractionMethod, 155 | correlation_method: str = CorrelationMethod.KENDALL): 156 | """ 157 | Calculates rank correlation of one vs one profiles using a given correlation method. 158 | 159 | Parameters 160 | ---------- 161 | method1 : FeatureInteractionMethod 162 | First method for comparison 163 | method2 : FeatureInteractionMethod 164 | Second method for comparison 165 | correlation_method: str 166 | Correlation method to use, accepted values are ['pearson', 'kendall', 'spearman'], default = 'kendall' 167 | 168 | Returns 169 | ------- 170 | value of the correlation 171 | """ 172 | 173 | rank = _rank_interaction_values_encoded(method1, method2) 174 | 175 | return rank.corr(method=correlation_method).iloc[0, 1] 176 | 177 | 178 | def _rank_interaction_values_encoded(method1, method2): 179 | rank_features_m1 = method1._compare_ovo.apply(lambda row: _alphabetical_order_pair(row), axis=1) 180 | rank_features_m2 = method2._compare_ovo.apply(lambda row: _alphabetical_order_pair(row), axis=1) 181 | rank_features_encoded = pd.concat( 182 | [rank_features_m1.astype('category').cat.codes, rank_features_m2.astype('category').cat.codes], axis=1) 183 | 184 | return rank_features_encoded 185 | 186 | 187 | def _title_x_y(ax, m1_name, m2_name): 188 | ax.set_xlabel(m1_name) 189 | ax.set_ylabel(m2_name) 190 | ax.set_title(f"{m1_name}\nand\n{m2_name}\nComparison") 191 | 192 | 193 | def _add_correlation_box(ax, correlations): 194 | lines = [ 195 | f"{m.default.capitalize()}={round(correlations[correlations['method'] == m.default]['value'].values[0], 3)}" 196 | for m in dataclasses.fields(CorrelationMethod) 197 | ] 198 | lines.insert(0, "Feature pairs rank correlation") 199 | correlation_box_text = '\n'.join(lines) 200 | 201 | props = dict(boxstyle='round', alpha=0.5, color=InteractionGraphConfiguration.EDGE_COLOR) 202 | ax.text(0.95, 0.05, 203 | correlation_box_text, transform=ax.transAxes, fontsize=10, 204 | verticalalignment='bottom', horizontalalignment="right", bbox=props) 205 | 206 | 207 | def _add_arrow(ax, circle_r, f1, f2, v1, v2): 208 | ax.annotate("-".join([f1, f2]), 209 | xy=(v1, v2), 210 | xycoords='data', 211 | xytext=point_left_side_circle(v1, v2, circle_r), 212 | textcoords='data', 213 | size=8, 214 | bbox=dict(boxstyle="round", alpha=0.1, color=InteractionGraphConfiguration.EDGE_COLOR), 215 | arrowprops=dict( 216 | arrowstyle="simple", 217 | fc="0.6", 218 | connectionstyle="arc3", 219 | color=InteractionGraphConfiguration.EDGE_COLOR)) 220 | 221 | 222 | def _assert_fitted_ovo(method1: FeatureInteractionMethod, method2: FeatureInteractionMethod): 223 | if not _suitable_for_ovo(method1): 224 | raise MethodNotFittedException(method1.method) 225 | 226 | if not _suitable_for_ovo(method2): 227 | raise MethodNotFittedException(method2.method) 228 | 229 | 230 | def _suitable_for_ovo(method: FeatureInteractionMethod): 231 | return method._compare_ovo is not None 232 | 233 | 234 | def _alphabetical_order_pair(row): 235 | features_alphabetical = sorted([row["Feature 1"], row["Feature 2"]]) 236 | 237 | return "-".join(features_alphabetical) 238 | -------------------------------------------------------------------------------- /artemis/importance_methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyartemis/artemis/fd6ed2018b6c9b8aa1de3e3f91a345f60d801271/artemis/importance_methods/__init__.py -------------------------------------------------------------------------------- /artemis/importance_methods/_method.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Optional 3 | 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | class FeatureImportanceMethod: 9 | """ 10 | Abstract base class for Feature Importance methods. 11 | This class should not be used directly. Use derived classes instead. 12 | 13 | Attributes 14 | ---------- 15 | method : str 16 | Method name. 17 | feature_importance : pd.DataFrame 18 | Feature importance values. 19 | """ 20 | 21 | def __init__(self, method: str, random_state: Optional[int] = None): 22 | self.method = method 23 | self.feature_importance = None 24 | self._random_generator = np.random.default_rng(random_state) 25 | 26 | @property 27 | @abstractmethod 28 | def importance_ascending_order(self) -> bool: 29 | ... 30 | 31 | @abstractmethod 32 | def importance(self, model, X: pd.DataFrame, **kwargs) -> pd.DataFrame: 33 | ... 34 | -------------------------------------------------------------------------------- /artemis/importance_methods/model_agnostic/__init__.py: -------------------------------------------------------------------------------- 1 | from artemis.importance_methods.model_agnostic._pdp import PartialDependenceBasedImportance 2 | from artemis.importance_methods.model_agnostic._permutational_importance import PermutationImportance 3 | 4 | __all__ = ["PermutationImportance", "PartialDependenceBasedImportance"] 5 | -------------------------------------------------------------------------------- /artemis/importance_methods/model_agnostic/_pdp.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Callable, List, Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | from artemis.importance_methods._method import FeatureImportanceMethod 9 | from artemis._utilities.domain import ImportanceMethod, ProgressInfoLog 10 | from artemis._utilities.ops import ( 11 | all_if_none, 12 | get_predict_function, 13 | sample_if_not_none, 14 | split_features_num_cat, 15 | ) 16 | from artemis._utilities.pd_calculator import PartialDependenceCalculator 17 | 18 | 19 | class PartialDependenceBasedImportance(FeatureImportanceMethod): 20 | """ 21 | Partial Dependence Based Feature Importance. 22 | It is used for calculating feature importance for partial dependence based feature interaction methods: 23 | Friedman's H-statistic and Greenwell methods. 24 | 25 | 26 | Attributes 27 | ---------- 28 | method : str 29 | Method name. 30 | feature_importance : pd.DataFrame 31 | Feature importance values. 32 | features_included : List[str] 33 | List of features for which importance is calculated. 34 | X_sampled: pd.DataFrame 35 | Sampled data used for calculation. 36 | pd_calculator : PartialDependenceCalculator 37 | Object used to calculate and store partial dependence values. 38 | 39 | References 40 | ---------- 41 | - https://arxiv.org/abs/1805.04755 42 | """ 43 | 44 | def __init__(self): 45 | """Constructor for PartialDependenceBasedImportance""" 46 | super().__init__(ImportanceMethod.PDP_BASED_IMPORTANCE) 47 | 48 | def importance( 49 | self, 50 | model, 51 | X: pd.DataFrame, 52 | n: int = None, 53 | predict_function: Optional[Callable] = None, 54 | features: Optional[List[str]] = None, 55 | show_progress: bool = False, 56 | batchsize: int = 2000, 57 | pd_calculator: Optional[PartialDependenceCalculator] = None, 58 | ): 59 | """Calculates Partial Dependence Based Feature Importance. 60 | 61 | Parameters 62 | ---------- 63 | model : object 64 | Model for which importance will be calculated, should have predict_proba or predict method, or predict_function should be provided. 65 | X : pd.DataFrame 66 | Data used to calculate importance. If n is not None, n rows from X will be sampled. 67 | n : int, optional 68 | Number of samples to be used for calculation of importance. If None, all rows from X will be used. Default is None. 69 | predict_function : Callable, optional 70 | Function used to predict model output. It should take model and dataset and outputs predictions. 71 | If None, `predict_proba` method will be used if it exists, otherwise `predict` method. Default is None. 72 | features : List[str], optional 73 | List of features for which importance will be calculated. If None, all features from X will be used. Default is None. 74 | show_progress : bool 75 | If True, progress bar will be shown. Default is False. 76 | batchsize : int 77 | Batch size for calculating partial dependence. Data for prediction are collected until the number of rows exceeds batchsize. 78 | Then, the `predict_function` is called, jointly for the entire batch of observations. It speeds up the operation of the method 79 | by reducing the number of `predict_function` calls. 80 | Default is 2000. 81 | pd_calculator : PartialDependenceCalculator, optional 82 | PartialDependenceCalculator object containing partial dependence values for a given model and dataset. 83 | Providing this object speeds up the calculation as partial dependence values do not need to be recalculated. 84 | If None, it will be created from scratch. Default is None. 85 | 86 | Returns 87 | ------- 88 | pd.DataFrame 89 | Result dataframe containing feature importance with columns: "Feature", "Importance" 90 | """ 91 | self.predict_function = get_predict_function(model, predict_function) 92 | self.X_sampled = sample_if_not_none(self._random_generator, X, n) 93 | self.features_included = all_if_none(X.columns, features) 94 | 95 | 96 | if pd_calculator is None: 97 | self.pd_calculator = PartialDependenceCalculator(model, self.X_sampled, self.predict_function, batchsize) 98 | else: 99 | if pd_calculator.model != model: 100 | raise ValueError("Model in PDP calculator is different than the model in the method.") 101 | if not pd_calculator.X.equals(self.X_sampled): 102 | raise ValueError("Data in PDP calculator is different than the data in the method.") 103 | self.pd_calculator = pd_calculator 104 | 105 | self.feature_importance = self._pdp_importance(show_progress) 106 | return self.feature_importance 107 | 108 | @property 109 | def importance_ascending_order(self): 110 | return False 111 | 112 | def _pdp_importance(self, show_progress: bool) -> pd.DataFrame: 113 | self.pd_calculator.calculate_pd_single(show_progress=show_progress) 114 | 115 | importance = [] 116 | num_features, _ = split_features_num_cat(self.X_sampled, self.features_included) 117 | 118 | for feature in self.features_included: 119 | pdp = self.pd_calculator.get_pd_single(feature) 120 | importance.append(_calc_importance(feature, pdp, feature in num_features)) 121 | 122 | return pd.DataFrame(importance, columns=["Feature", "Importance"]).sort_values( 123 | by="Importance", ascending=self.importance_ascending_order, ignore_index=True 124 | ).fillna(0) 125 | 126 | 127 | def _calc_importance(feature: str, pdp: np.ndarray, is_numerical: bool): 128 | return [feature, np.std(pdp) if is_numerical else (np.max(pdp) - np.min(pdp)) / 4] 129 | 130 | -------------------------------------------------------------------------------- /artemis/importance_methods/model_agnostic/_permutational_importance.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | from artemis.importance_methods._method import FeatureImportanceMethod 8 | from artemis._utilities.domain import ImportanceMethod, ProgressInfoLog, ProblemType 9 | from artemis._utilities.performance_metrics import Metric, RMSE 10 | 11 | 12 | class PermutationImportance(FeatureImportanceMethod): 13 | """ 14 | Permutation-Based Feature Importance. 15 | It is used for calculating feature importance for performance based feature interaction - Sejong Oh method. 16 | 17 | Importance of a feature is defined by the metric selected by user (default is sum of gains). 18 | 19 | Attributes 20 | ---------- 21 | method : str 22 | Method name. 23 | metric: Metric 24 | Metric used for calculating performance. 25 | feature_importance : pd.DataFrame 26 | Feature importance values. 27 | 28 | References 29 | ---------- 30 | - https://jmlr.org/papers/v20/18-760.html 31 | """ 32 | 33 | def __init__(self, metric: Metric = RMSE(), random_state: Optional[int] = None): 34 | """Constructor for PermutationImportance. 35 | 36 | Parameters 37 | ---------- 38 | metric : Metric 39 | Metric used to calculate model performance. Defaults to RMSE(). 40 | random_state : int, optional 41 | Random state for reproducibility. Defaults to None. 42 | """ 43 | super().__init__(ImportanceMethod.PERMUTATION_IMPORTANCE, random_state=random_state) 44 | self.metric = metric 45 | 46 | def importance( 47 | self, 48 | model, 49 | X: pd.DataFrame, 50 | y_true: np.array, 51 | n_repeat: int = 15, 52 | features: Optional[List[str]] = None, 53 | show_progress: bool = False, 54 | ): 55 | """Calculates Permutation Based Feature Importance. 56 | 57 | Parameters 58 | ---------- 59 | model : object 60 | Model for which importance will be calculated, should have predict method. 61 | X : pd.DataFrame 62 | Data used to calculate importance. 63 | y_true : np.array or pd.Series 64 | Target values for X data. 65 | n_repeat : int, optional 66 | Number of permutations. Default is 10. 67 | features : List[str], optional 68 | List of features for which importance will be calculated. If None, all features from X will be used. Default is None. 69 | show_progress : bool 70 | If True, progress bar will be shown. Default is False. 71 | 72 | Returns 73 | ------- 74 | pd.DataFrame 75 | Result dataframe containing feature importance with columns: "Feature", "Importance" 76 | """ 77 | self.feature_importance = _permutation_importance( 78 | model, X, y_true, self.metric, n_repeat, features, show_progress, self._random_generator 79 | ) 80 | return self.feature_importance 81 | @property 82 | def importance_ascending_order(self): 83 | return False 84 | 85 | 86 | def _permutation_importance( 87 | model, 88 | X: pd.DataFrame, 89 | y: np.array, 90 | metric: Metric, 91 | n_repeat: int, 92 | features: List[str], 93 | show_progress: bool, 94 | random_generator: np.random._generator.Generator 95 | ): 96 | base_score = metric.calculate(y, model.predict(X)) 97 | corrupted_scores = _corrupted_scores( 98 | model, X, y, features, metric, n_repeat, show_progress, random_generator 99 | ) 100 | 101 | feature_importance = [ 102 | { 103 | "Feature": f, 104 | "Importance": _neg_if_class(metric, np.mean(corrupted_scores[f]) - base_score), 105 | } 106 | for f in corrupted_scores.keys() 107 | ] 108 | 109 | return pd.DataFrame.from_records(feature_importance).sort_values( 110 | by="Importance", ascending=False, ignore_index=True 111 | ) 112 | 113 | 114 | def _corrupted_scores( 115 | model, 116 | X: pd.DataFrame, 117 | y: np.array, 118 | features: List[str], 119 | metric: Metric, 120 | n_repeat: int, 121 | show_progress: bool, 122 | random_generator: np.random._generator.Generator 123 | ): 124 | X_copy_permuted = X.copy() 125 | corrupted_scores = {f: [] for f in features} 126 | for _ in tqdm( 127 | range(n_repeat), disable=not show_progress, desc=ProgressInfoLog.CALC_VAR_IMP 128 | ): 129 | for feature in features: 130 | X_copy_permuted[feature] = random_generator.permutation(X_copy_permuted[feature]) 131 | corrupted_scores[feature].append( 132 | metric.calculate(y, model.predict(X_copy_permuted)) 133 | ) 134 | X_copy_permuted[feature] = X[feature] 135 | 136 | return corrupted_scores 137 | 138 | 139 | def _neg_if_class(metric: Metric, value: float): 140 | if metric.applicable_to(ProblemType.CLASSIFICATION): 141 | return -value 142 | 143 | return value 144 | -------------------------------------------------------------------------------- /artemis/importance_methods/model_specific/__init__.py: -------------------------------------------------------------------------------- 1 | from artemis.importance_methods.model_specific._minimal_depth import MinimalDepthImportance 2 | from artemis.importance_methods.model_specific._split_score import SplitScoreImportance 3 | 4 | __all__ = ["SplitScoreImportance", "MinimalDepthImportance"] 5 | -------------------------------------------------------------------------------- /artemis/importance_methods/model_specific/_minimal_depth.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from artemis.importance_methods._method import FeatureImportanceMethod 8 | from artemis._utilities.domain import ImportanceMethod, InteractionMethod 9 | from artemis._utilities.exceptions import FeatureImportanceWithoutInteractionException 10 | 11 | 12 | class MinimalDepthImportance(FeatureImportanceMethod): 13 | """ 14 | Minimal Depth Feature Importance. 15 | It applies to tree-based models like Random Forests. 16 | It uses data calculated in ConditionalMinimalDepth method from `interactions_methods` module and so needs to be calculated together. 17 | 18 | Importance of a feature is defined as the lowest depth of node using this feature as a split feature in a tree, averaged over all trees. 19 | 20 | Attributes 21 | ---------- 22 | method : str 23 | Method name. 24 | feature_importance : pd.DataFrame 25 | Feature importance values. 26 | 27 | References 28 | ---------- 29 | - https://modeloriented.github.io/randomForestExplainer/ 30 | - https://doi.org/10.1198/jasa.2009.tm08622 31 | """ 32 | 33 | def __init__(self): 34 | """Constructor for MinimalDepthImportance""" 35 | super().__init__(ImportanceMethod.MINIMAL_DEPTH_IMPORTANCE) 36 | 37 | def importance( 38 | self, 39 | model, 40 | tree_id_to_depth_split: dict, 41 | ) -> pd.DataFrame: 42 | """Calculates Minimal Depth Feature Importance. 43 | 44 | Parameters 45 | ---------- 46 | model : object 47 | Model for which importance will be calculated, should have predict method. 48 | tree_id_to_depth_split : dict 49 | Dictionary containing minimal depth of each node in each tree. 50 | 51 | Returns 52 | ------- 53 | pd.DataFrame 54 | Result dataframe containing feature importance with columns: "Feature", "Importance" 55 | """ 56 | _check_preconditions(self.method, tree_id_to_depth_split) 57 | 58 | columns = _make_column_dict(model.feature_names_in_) 59 | feature_to_depth = defaultdict(list) 60 | for tree_id in tree_id_to_depth_split.keys(): 61 | depth_tree, split_tree = tree_id_to_depth_split[tree_id] 62 | for f in split_tree.keys(): 63 | feature_to_depth[f].append(depth_tree[split_tree[f][0]]) 64 | 65 | 66 | records_result = [] 67 | for f in feature_to_depth.keys(): 68 | records_result.append( 69 | {"Feature": columns[f], "Importance": np.mean(feature_to_depth[f])} 70 | ) 71 | 72 | self.feature_importance = pd.DataFrame.from_records( 73 | records_result 74 | ).sort_values(by="Importance", ignore_index=True) 75 | 76 | return self.feature_importance 77 | 78 | @property 79 | def importance_ascending_order(self): 80 | return True 81 | 82 | 83 | 84 | def _check_preconditions(method: str, tree_id_to_depth_split: dict): 85 | if tree_id_to_depth_split is None: 86 | raise FeatureImportanceWithoutInteractionException( 87 | method, InteractionMethod.CONDITIONAL_MINIMAL_DEPTH 88 | ) 89 | 90 | 91 | def _make_column_dict(columns: np.ndarray) -> dict: 92 | return dict(zip(range(len(columns)), list(columns))) 93 | -------------------------------------------------------------------------------- /artemis/importance_methods/model_specific/_split_score.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | from artemis.importance_methods._method import FeatureImportanceMethod 8 | from artemis._utilities.domain import ImportanceMethod 9 | from artemis._utilities.split_score_metrics import SplitScoreImportanceMetric 10 | 11 | 12 | class SplitScoreImportance(FeatureImportanceMethod): 13 | """ 14 | Split Score Feature Importance. 15 | It applies to gradient boosting tree-based models. 16 | It can use data calculated in SplitScore method from `interactions_methods` module and so needs to be calculated together. 17 | 18 | Importance of a feature is defined by the metric selected by user (default is sum of gains). 19 | 20 | Attributes 21 | ---------- 22 | method : str 23 | Method name. 24 | feature_importance : pd.DataFrame 25 | Feature importance values. 26 | selected_metric : str 27 | Metric used for calculating importance. 28 | 29 | References 30 | ---------- 31 | - https://modeloriented.github.io/EIX/ 32 | """ 33 | 34 | def __init__(self): 35 | """Constructor for SplitScoreImportance""" 36 | super().__init__(ImportanceMethod.SPLIT_SCORE_IMPORTANCE) 37 | self.selected_metric = None 38 | 39 | def importance( 40 | self, 41 | model, 42 | features: Optional[List[str]] = None, 43 | selected_metric: str = SplitScoreImportanceMetric.SUM_GAIN, 44 | show_progress: bool = False, 45 | trees_df: Optional[pd.DataFrame] = None, 46 | ): 47 | """Calculates Split Score Feature Importance. 48 | 49 | Parameters 50 | ---------- 51 | model : object 52 | Model for which importance will be calculated, should have predict_proba or predict method, or predict_function should be provided. 53 | features : List[str], optional 54 | List of features for which importance will be calculated. If None, all features from X will be used. Default is None. 55 | selected_metric : str 56 | Metric used to calculate feature importance, 57 | one of ['sum_gain', 'sum_cover', 'mean_gain', 'mean_cover', 'mean_depth', 58 | 'mean_weighted_depth', 'root_frequency', 'weighted_root_frequency']. 59 | Default is 'mean_gain'. 60 | show_progress : bool 61 | If True, progress bar will be shown. Default is False. 62 | trees_df : pd.DataFrame, optional 63 | DataFrame containing unified structure of the trained trees, can be precalculated by SplitScore method. Default is None. 64 | 65 | Returns 66 | ------- 67 | pd.DataFrame 68 | Result dataframe containing feature importance with columns: "Feature", "Importance" 69 | """ 70 | if trees_df is None: 71 | trees_df = model.trees_df 72 | 73 | if trees_df["depth"].isnull().values.any(): 74 | trees_df = _calculate_depth(trees_df, show_progress) 75 | self.full_result = _calculate_all_feature_importance( 76 | trees_df, features, selected_metric 77 | ) 78 | self.feature_importance = _select_metric(self.full_result, selected_metric) 79 | self.selected_metric = selected_metric 80 | 81 | return self.feature_importance 82 | 83 | @property 84 | def importance_ascending_order(self): 85 | return self.selected_metric in [SplitScoreImportanceMetric.MEAN_DEPTH, 86 | SplitScoreImportanceMetric.MEAN_WEIGHTED_DEPTH] 87 | 88 | 89 | def _calculate_all_feature_importance( 90 | trees_df: pd.DataFrame, 91 | features: Optional[List[str]] = None, 92 | selected_metric: str = SplitScoreImportanceMetric.SUM_GAIN, 93 | ): 94 | if features is not None: 95 | trees_df = trees_df.loc[trees_df["split_feature"].isin(features)] 96 | else: 97 | trees_df = trees_df.loc[trees_df["split_feature"] != "Leaf"] 98 | 99 | basic_metrics = _calculate_basic_metrics(trees_df) 100 | root_metrics = _calculate_root_metrics(trees_df) 101 | mean_weighted_depth = _calculate_mean_weighted_depth_metric(trees_df) 102 | 103 | importance_full_result = basic_metrics.join(root_metrics) 104 | importance_full_result = pd.concat( 105 | [importance_full_result, mean_weighted_depth], axis=1 106 | ).reset_index() 107 | 108 | return importance_full_result.sort_values(selected_metric, ascending=False) 109 | 110 | 111 | def _calculate_basic_metrics(trees_df: pd.DataFrame): 112 | importance_full_result = trees_df.groupby("split_feature").agg( 113 | mean_gain=("gain", "mean"), 114 | sum_gain=("gain", "sum"), 115 | mean_cover=("cover", "mean"), 116 | sum_cover=("cover", "sum"), 117 | mean_depth=("depth", "mean"), 118 | ) 119 | return importance_full_result 120 | 121 | 122 | def _calculate_root_metrics(trees_df: pd.DataFrame): 123 | root_freq_df = ( 124 | trees_df.loc[trees_df["depth"] == 0] 125 | .groupby("split_feature") 126 | .agg(root_frequency=("tree", "count"), sum_gain_root=("gain", "sum")) 127 | ) 128 | sum_gain = np.sum(root_freq_df.sum_gain_root) 129 | root_freq_df["weighted_root_frequency"] = ( 130 | root_freq_df["root_frequency"] * root_freq_df["sum_gain_root"] / sum_gain 131 | ) 132 | return root_freq_df[["root_frequency", "weighted_root_frequency"]] 133 | 134 | 135 | def _calculate_mean_weighted_depth_metric(trees_df: pd.DataFrame): 136 | return pd.Series( 137 | trees_df.groupby("split_feature").apply( 138 | lambda x: np.average(x.depth, weights=x.gain) 139 | ), 140 | name="mean_weighted_depth", 141 | ) 142 | 143 | 144 | def _select_metric(importance_full_result: pd.DataFrame, selected_metric: str): 145 | feature_importance = importance_full_result[ 146 | ["split_feature", selected_metric] 147 | ].rename(columns={"split_feature": "Feature", selected_metric: "Importance"}) 148 | return feature_importance.sort_values( 149 | by="Importance", ascending=False, ignore_index=True 150 | ) 151 | 152 | 153 | def _calculate_depth(trees_df: pd.DataFrame, show_progress: bool = False): 154 | tqdm.pandas(disable=not show_progress) 155 | trees_df = ( 156 | trees_df.groupby("tree", group_keys=True) 157 | .progress_apply(_calculate_depth_for_one_tree) 158 | .reset_index(drop=True) 159 | ) 160 | return trees_df 161 | 162 | 163 | def _calculate_depth_for_one_tree(tree): 164 | non_leaf_nodes = tree.loc[tree["leaf"] == False].index 165 | for i in non_leaf_nodes: 166 | if tree.loc[i, "node"] == 0: 167 | tree.loc[i, "depth"] = 0 168 | left = tree.loc[i, "left_child"] 169 | right = tree.loc[i, "right_child"] 170 | if tree.loc[tree["ID"] == left, "leaf"].values[0] == False: 171 | tree.loc[tree["ID"] == left, "depth"] = tree.loc[i, "depth"] + 1 172 | if tree.loc[tree["ID"] == right, "leaf"].values[0] == False: 173 | tree.loc[tree["ID"] == right, "depth"] = tree.loc[i, "depth"] + 1 174 | return tree 175 | -------------------------------------------------------------------------------- /artemis/interactions_methods/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyartemis/artemis/fd6ed2018b6c9b8aa1de3e3f91a345f60d801271/artemis/interactions_methods/__init__.py -------------------------------------------------------------------------------- /artemis/interactions_methods/_method.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod, ABC 2 | from typing import Callable, Optional, Tuple 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from artemis._utilities.domain import VisualizationType 8 | from artemis._utilities.exceptions import MethodNotFittedException 9 | from artemis.visualizer._configuration import VisualizationConfigurationProvider 10 | from artemis.visualizer._visualizer import Visualizer 11 | 12 | 13 | class FeatureInteractionMethod(ABC): 14 | """ 15 | Abstract base class for Feature Interaction Extraction methods. 16 | This class should not be used directly. Use derived classes instead. 17 | 18 | Attributes 19 | ---------- 20 | method : str 21 | Method name, used also for naming column with results in `results` pd.DataFrame. 22 | visualizer : Visualizer 23 | Object providing visualization. Automatically created on the basis of a method and used to create visualizations. 24 | ovo : pd.DataFrame 25 | One versus one (pair) feature interaction values. 26 | feature_importance : pd.DataFrame 27 | Feature importance values. 28 | model : object 29 | Explained model. 30 | X_sampled: pd.DataFrame 31 | Sampled data used for calculation. 32 | features_included: List[str] 33 | List of features for which interactions are calculated. 34 | pairs : List[List[str]] 35 | List of pairs of features for which interactions are calculated. 36 | """ 37 | def __init__(self, method: str, random_state: Optional[int] = None): 38 | self.method = method 39 | self.visualizer = Visualizer(method, VisualizationConfigurationProvider.get(method)) 40 | self.feature_importance = None 41 | self._feature_importance_obj = None 42 | self.ovo = None 43 | self.model = None 44 | self.X_sampled = None 45 | self.features_included = None 46 | self.pairs = None 47 | self._random_generator = np.random.default_rng(random_state) 48 | 49 | @property 50 | @abstractmethod 51 | def _interactions_ascending_order(self): 52 | ... 53 | 54 | @property 55 | def _compare_ovo(self): 56 | if self.ovo is None: 57 | raise MethodNotFittedException(self.method) 58 | return self.ovo.sort_values(self.method, ascending=self._interactions_ascending_order, ignore_index=True) 59 | 60 | @abstractmethod 61 | def fit(self, model, **kwargs): 62 | """ 63 | Base abstract method for calculating feature interaction values. 64 | 65 | Parameters 66 | ---------- 67 | model : object 68 | Model for which interactions will be extracted. 69 | **kwargs : dict 70 | Parameters specific to a given feature interaction method. 71 | """ 72 | ... 73 | 74 | def plot(self, 75 | vis_type: str = VisualizationType.HEATMAP, 76 | title: str = "default", 77 | figsize: Tuple[float, float] = (8, 6), 78 | **kwargs): 79 | """ 80 | Plot results of explanations. 81 | 82 | There are four types of plots available: 83 | - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default) 84 | - bar_chart - bar chart of top feature interactions values 85 | - graph - graph of feature interactions values 86 | - summary - combination of other plots 87 | 88 | Parameters 89 | ---------- 90 | vis_type : str 91 | Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'summary']. Default is 'heatmap'. 92 | title : str 93 | Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type. 94 | figsize : (float, float) 95 | Size of plot. Default is (8, 6). 96 | **kwargs : Other Parameters 97 | Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. 98 | For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. 99 | See key parameters below. 100 | 101 | Other Parameters 102 | ------------------------ 103 | interaction_color_map : matplotlib colormap name or object, or list of colors 104 | Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r', 105 | depending on whether a greater value means a greater interaction strength or vice versa. 106 | importance_color_map : matplotlib colormap name or object, or list of colors 107 | Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r', 108 | depending on whether a greater value means a greater interaction strength or vice versa. 109 | annot_fmt : str 110 | Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'. 111 | linewidths : float 112 | Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5. 113 | linecolor : str 114 | Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'. 115 | cbar_shrink : float 116 | Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1. 117 | 118 | top_k : int 119 | Used for 'bar_chart' visualization. Maximum number of pairs that will be presented in plot. Default is 10. 120 | color : str 121 | Used for 'bar_chart' visualization. Color of bars. Default is 'mediumpurple'. 122 | 123 | n_highest_with_labels : int 124 | Used for 'graph' visualization. Top most important interactions to show as labels on edges. Default is 5. 125 | edge_color: str 126 | Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple. 127 | node_color: str 128 | Used for 'graph' visualization. Color of nodes. Default is 'green'. 129 | node_size: int 130 | Used for 'graph' visualization. Size of the nodes (networkX scale). Default is '1800'. 131 | font_color: str 132 | Used for 'graph' visualization. Font color. Default is '#3B1F2B'. 133 | font_weight: str 134 | Used for 'graph' visualization. Font weight. Default is 'bold'. 135 | font_size: int 136 | Used for 'graph' visualization. Font size (networkX scale). Default is 10. 137 | threshold_relevant_interaction : float 138 | Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display 139 | corresponding edge on visualization. Default depends on the interaction method. 140 | """ 141 | if self.ovo is None: 142 | raise MethodNotFittedException(self.method) 143 | 144 | self.visualizer.plot(self.ovo, 145 | vis_type, 146 | feature_importance=self.feature_importance, 147 | title=title, 148 | figsize=figsize, 149 | interactions_ascending_order=self._interactions_ascending_order, 150 | importance_ascending_order=self._feature_importance_obj.importance_ascending_order, 151 | **kwargs) 152 | 153 | def interaction_value(self, f1: str, f2: str): 154 | 155 | if self._compare_ovo is None: 156 | raise MethodNotFittedException(self.method) 157 | 158 | return self._compare_ovo[((self._compare_ovo["Feature 1"] == f1) & (self._compare_ovo["Feature 2"] == f2)) | 159 | ((self._compare_ovo["Feature 1"] == f2) & 160 | (self._compare_ovo["Feature 2"] == f1))][self.method].values[0] 161 | -------------------------------------------------------------------------------- /artemis/interactions_methods/model_agnostic/__init__.py: -------------------------------------------------------------------------------- 1 | from .partial_dependence_based import FriedmanHStatisticMethod, GreenwellMethod 2 | from .performance_based import SejongOhMethod 3 | 4 | __all__ = ["FriedmanHStatisticMethod", "GreenwellMethod", "SejongOhMethod"] -------------------------------------------------------------------------------- /artemis/interactions_methods/model_agnostic/partial_dependence_based/__init__.py: -------------------------------------------------------------------------------- 1 | from ._greenwell import GreenwellMethod 2 | from ._friedman_h_statistic import FriedmanHStatisticMethod 3 | 4 | __all__ = ["FriedmanHStatisticMethod", "GreenwellMethod"] 5 | -------------------------------------------------------------------------------- /artemis/interactions_methods/model_agnostic/partial_dependence_based/_friedman_h_statistic.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | from artemis._utilities.domain import VisualizationType, InteractionMethod, ProgressInfoLog 8 | from artemis._utilities.exceptions import MethodNotFittedException 9 | from artemis._utilities.ops import remove_element, center, partial_dependence_value 10 | from artemis._utilities.pd_calculator import PartialDependenceCalculator 11 | from ._pdp import PartialDependenceBasedMethod 12 | 13 | 14 | class FriedmanHStatisticMethod(PartialDependenceBasedMethod): 15 | """ 16 | Friedman's H-statistic Method for Feature Interaction Extraction. 17 | 18 | Uses partial dependence values to calculate feature interaction strengths and feature importance. 19 | 20 | Attributes 21 | ---------- 22 | method : str 23 | Method name, used also for naming column with results in `ovo` pd.DataFrame. 24 | visualizer : Visualizer 25 | Object providing visualization. Automatically created on the basis of a method and used to create visualizations. 26 | ovo : pd.DataFrame 27 | One versus one (pair) feature interaction values. 28 | feature_importance : pd.DataFrame 29 | Feature importance values. 30 | ova : pd.DataFrame 31 | One vs all feature interaction values. 32 | normalized : bool 33 | Flag determining whether interaction values are normalized. 34 | Unnrormalized version is proposed in https://www.tandfonline.com/doi/full/10.1080/10618600.2021.2007935 35 | model : object 36 | Explained model. 37 | X_sampled: pd.DataFrame 38 | Sampled data used for calculation. 39 | features_included: List[str] 40 | List of features for which interactions are calculated. 41 | pairs : List[List[str]] 42 | List of pairs of features for which interactions are calculated. 43 | pd_calculator : PartialDependenceCalculator 44 | Object used to calculate and store partial dependence values. 45 | batchsize : int 46 | Batch size used for calculation. 47 | 48 | References 49 | ---------- 50 | - https://www.jstor.org/stable/pdf/30245114.pdf 51 | - https://www.tandfonline.com/doi/full/10.1080/10618600.2021.2007935 52 | """ 53 | def __init__(self, random_state: Optional[int] = None, normalized: bool = True): 54 | """Constructor for FriedmanHStatisticMethod 55 | 56 | Parameters 57 | ---------- 58 | random_state : int, optional 59 | Random state for reproducibility. Defaults to None. 60 | normalized : bool, optional 61 | Flag determining whether to normalize the interaction values. Normalized version is original H-statistic, 62 | unnrormalized version is square root of nominator of H statistic. Defaults to True which translates to original H-statistic. 63 | """ 64 | super().__init__(InteractionMethod.H_STATISTIC, random_state=random_state) 65 | self.ova = None 66 | self.normalized = normalized 67 | 68 | def fit(self, 69 | model, 70 | X: pd.DataFrame, 71 | n: Optional[int] = None, 72 | predict_function: Optional[Callable] = None, 73 | features: Optional[List[str]] = None, 74 | show_progress: bool = False, 75 | batchsize: int = 2000, 76 | pd_calculator: Optional[PartialDependenceCalculator] = None, 77 | calculate_ova: bool = True): 78 | """Calculates H-statistic Feature Interactions Strength and Feature Importance for the given model. 79 | Despite pair interactions, this method can also calculate one vs all interactions. 80 | 81 | Parameters 82 | ---------- 83 | model : object 84 | Model to be explained, should have predict_proba or predict method, or predict_function should be provided. 85 | X : pd.DataFrame 86 | Data used to calculate interactions. If n is not None, n rows from X will be sampled. 87 | n : int, optional 88 | Number of samples to be used for calculation of interactions. If None, all rows from X will be used. Default is None. 89 | predict_function : Callable, optional 90 | Function used to predict model output. It should take model and dataset and outputs predictions. 91 | If None, `predict_proba` method will be used if it exists, otherwise `predict` method. Default is None. 92 | features : List[str], optional 93 | List of features for which interactions will be calculated. If None, all features from X will be used. Default is None. 94 | show_progress : bool 95 | If True, progress bar will be shown. Default is False. 96 | batchsize : int 97 | Batch size for calculating partial dependence. Prediction requests are collected until the batchsize is exceeded, 98 | then the model is queried for predictions jointly for many observations. It speeds up the operation of the method. 99 | Default is 2000. 100 | pd_calculator : PartialDependenceCalculator, optional 101 | PartialDependenceCalculator object containing partial dependence values for a given model and dataset. 102 | Providing this object speeds up the calculation as partial dependence values do not need to be recalculated. 103 | If None, it will be created from scratch. Default is None. 104 | calculate_ova : bool 105 | If True, one vs all interactions will be calculated. Default is True. 106 | """ 107 | super().fit(model, X, n, predict_function, features, show_progress, batchsize, pd_calculator) 108 | if calculate_ova: 109 | self.ova = self._calculate_ova_interactions_from_pd(show_progress) 110 | 111 | def plot(self, 112 | vis_type: str = VisualizationType.HEATMAP, 113 | title: str = "default", 114 | figsize: Tuple[float, float] = (8, 6), 115 | **kwargs): 116 | """ 117 | Plot results of explanations. 118 | 119 | There are five types of plots available: 120 | - heatmap - heatmap of feature interactions values with feature importance values on the diagonal (default) 121 | - bar_chart - bar chart of top feature interactions values 122 | - graph - graph of feature interactions values 123 | - bar_chart_ova - bar chart of top one vs all interactions values 124 | - summary - combination of other plots 125 | 126 | Parameters 127 | ---------- 128 | vis_type : str 129 | Type of visualization, one of ['heatmap', 'bar_chart', 'graph', 'bar_chart_ova', 'summary']. Default is 'heatmap'. 130 | title : str 131 | Title of plot, default is 'default' which means that title will be automatically generated for selected visualization type. 132 | figsize : (float, float) 133 | Size of plot. Default is (8, 6). 134 | **kwargs : Other Parameters 135 | Additional parameters for plot. Passed to suitable matplotlib or seaborn functions. 136 | For 'summary' visualization parameters for respective plots should be in dict with keys corresponding to visualization name. 137 | See key parameters below. 138 | 139 | Other Parameters 140 | ------------------------ 141 | interaction_color_map : matplotlib colormap name or object, or list of colors 142 | Used for 'heatmap' visualization. The mapping from interaction values to color space. Default is 'Purples' or 'Purpler_r', 143 | depending on whether a greater value means a greater interaction strength or vice versa. 144 | importance_color_map : matplotlib colormap name or object, or list of colors 145 | Used for 'heatmap' visualization. The mapping from importance values to color space. Default is 'Greens' or 'Greens_r', 146 | depending on whether a greater value means a greater interaction strength or vice versa. 147 | annot_fmt : str 148 | Used for 'heatmap' visualization. String formatting code to use when adding annotations with values. Default is '.3f'. 149 | linewidths : float 150 | Used for 'heatmap' visualization. Width of the lines that will divide each cell in matrix. Default is 0.5. 151 | linecolor : str 152 | Used for 'heatmap' visualization. Color of the lines that will divide each cell in matrix. Default is 'white'. 153 | cbar_shrink : float 154 | Used for 'heatmap' visualization. Fraction by which to multiply the size of the colorbar. Default is 1. 155 | 156 | top_k : int 157 | Used for 'bar_chart' and 'bar_chart_ova' visualizations. Maximum number of pairs that will be presented in plot. Default is 10. 158 | color : str 159 | Used for 'bar_chart' and 'bar_chart_ova' visualizations. Color of bars. Default is 'mediumpurple'. 160 | 161 | n_highest_with_labels : int 162 | Used for 'graph' visualization. Top most important interactions to show as labels on edges. Default is 5. 163 | edge_color: str 164 | Used for 'graph' visualization. Color of the edges. Default is 'rebeccapurple. 165 | node_color: str 166 | Used for 'graph' visualization. Color of nodes. Default is 'green'. 167 | node_size: int 168 | Used for 'graph' visualization. Size of the nodes (networkX scale). Default is '1800'. 169 | font_color: str 170 | Used for 'graph' visualization. Font color. Default is '#3B1F2B'. 171 | font_weight: str 172 | Used for 'graph' visualization. Font weight. Default is 'bold'. 173 | font_size: int 174 | Used for 'graph' visualization. Font size (networkX scale). Default is 10. 175 | threshold_relevant_interaction : float 176 | Used for 'graph' visualization. Minimum (or maximum, depends on method) value of interaction to display 177 | corresponding edge on visualization. Default depends on the interaction method. 178 | """ 179 | if self.ova is None: 180 | raise MethodNotFittedException(self.method) 181 | 182 | self.visualizer.plot(self.ovo, 183 | vis_type, 184 | self.ova, 185 | feature_importance=self.feature_importance, 186 | title=title, 187 | figsize=figsize, 188 | interactions_ascending_order=self._interactions_ascending_order, 189 | importance_ascending_order=self._feature_importance_obj.importance_ascending_order, 190 | **kwargs) 191 | 192 | def _calculate_ova_interactions_from_pd(self, show_progress: bool) -> pd.DataFrame: 193 | self.pd_calculator.calculate_pd_minus_single(self.features_included, show_progress=show_progress) 194 | preds = self.predict_function(self.model, self.X_sampled) 195 | value_minus_single = [] 196 | for feature in self.features_included: 197 | pd_f = self.pd_calculator.get_pd_single(feature, feature_values=self.X_sampled[feature].values) 198 | pd_f_minus = self.pd_calculator.get_pd_minus_single(feature) 199 | value_minus_single.append([feature, _calculate_hstat_value(pd_f, pd_f_minus, preds, self.normalized)]) 200 | return pd.DataFrame(value_minus_single, columns=["Feature", InteractionMethod.H_STATISTIC 201 | ]).sort_values(by=InteractionMethod.H_STATISTIC, 202 | ascending=self._interactions_ascending_order, 203 | ignore_index=True).fillna(0) 204 | 205 | def _calculate_ovo_interactions_from_pd(self, show_progress: bool): 206 | self.pd_calculator.calculate_pd_pairs(self.pairs, show_progress=show_progress, all_combinations=False) 207 | self.pd_calculator.calculate_pd_single(self.features_included, show_progress=False) 208 | value_pairs = [] 209 | for pair in self.pairs: 210 | pd_f1 = self.pd_calculator.get_pd_single(pair[0], feature_values=self.X_sampled[pair[0]].values) 211 | pd_f2 = self.pd_calculator.get_pd_single(pair[1], feature_values=self.X_sampled[pair[1]].values) 212 | pair_feature_values = list(zip(self.X_sampled[pair[0]].values, self.X_sampled[pair[1]].values)) 213 | pd_pair = self.pd_calculator.get_pd_pairs(pair[0], pair[1], feature_values=pair_feature_values) 214 | value_pairs.append([pair[0], pair[1], _calculate_hstat_value(pd_f1, pd_f2, pd_pair, self.normalized)]) 215 | return pd.DataFrame(value_pairs, 216 | columns=["Feature 1", "Feature 2", 217 | self.method]).sort_values(by=self.method, 218 | ascending=self._interactions_ascending_order, 219 | ignore_index=True).fillna(0) 220 | 221 | 222 | def _calculate_hstat_value(pd_i: np.ndarray, pd_versus: np.ndarray, pd_i_versus: np.ndarray, normalized: bool = True): 223 | nominator = (center(pd_i_versus) - center(pd_i) - center(pd_versus))**2 224 | if normalized: 225 | denominator = center(pd_i_versus)**2 226 | return np.sum(nominator) / np.sum(denominator) if normalized else np.sqrt(np.mean(nominator)) 227 | -------------------------------------------------------------------------------- /artemis/interactions_methods/model_agnostic/partial_dependence_based/_greenwell.py: -------------------------------------------------------------------------------- 1 | from statistics import stdev 2 | from typing import List, Optional, Tuple, Callable 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | from tqdm import tqdm 8 | from artemis._utilities.domain import InteractionMethod, ProgressInfoLog, VisualizationType 9 | from artemis._utilities.ops import partial_dependence_value, split_features_num_cat 10 | from artemis._utilities.pd_calculator import PartialDependenceCalculator 11 | from ._pdp import PartialDependenceBasedMethod 12 | 13 | 14 | class GreenwellMethod(PartialDependenceBasedMethod): 15 | """ 16 | Greenwell Method for Feature Interaction Extraction. 17 | 18 | Uses partial dependence values to calculate feature interaction strengths and feature importance. 19 | 20 | Attributes 21 | ---------- 22 | method : str 23 | Method name, used also for naming column with results in `ovo` pd.DataFrame. 24 | visualizer : Visualizer 25 | Object providing visualization. Automatically created on the basis of a method and used to create visualizations. 26 | ovo : pd.DataFrame 27 | One versus one (pair) feature interaction values. 28 | feature_importance : pd.DataFrame 29 | Feature importance values. 30 | model : object 31 | Explained model. 32 | X_sampled: pd.DataFrame 33 | Sampled data used for calculation. 34 | features_included: List[str] 35 | List of features for which interactions are calculated. 36 | pairs : List[List[str]] 37 | List of pairs of features for which interactions are calculated. 38 | pd_calculator : PartialDependenceCalculator 39 | Object used to calculate and store partial dependence values. 40 | batchsize : int 41 | Batch size used for calculation. 42 | 43 | References 44 | ---------- 45 | - https://arxiv.org/pdf/1805.04755.pdf 46 | """ 47 | def __init__(self, random_state: Optional[int] = None): 48 | """Constructor for GreenwellMethod 49 | 50 | Parameters 51 | ---------- 52 | random_state : int, optional 53 | Random state for reproducibility. Defaults to None. 54 | """ 55 | super().__init__(InteractionMethod.VARIABLE_INTERACTION, random_state=random_state) 56 | 57 | def fit(self, 58 | model, 59 | X: pd.DataFrame, 60 | n: Optional[int] = None, 61 | predict_function: Optional[Callable] = None, 62 | features: Optional[List[str]] = None, 63 | show_progress: bool = False, 64 | batchsize: int = 2000, 65 | pd_calculator: Optional[PartialDependenceCalculator] = None): 66 | super().fit(model, X, n, predict_function, features, show_progress, batchsize, pd_calculator) 67 | 68 | def plot(self, 69 | vis_type: str = VisualizationType.HEATMAP, 70 | title: str = "default", 71 | figsize: Tuple[float, float] = (8, 6), 72 | **kwargs): 73 | super().plot(vis_type, title, figsize, **kwargs) 74 | 75 | 76 | def _calculate_ovo_interactions_from_pd(self, show_progress: bool = False): 77 | self.pd_calculator.calculate_pd_pairs(self.pairs, show_progress=show_progress) 78 | value_pairs = [] 79 | num_features, _ = split_features_num_cat(self.X_sampled, self.features_included) 80 | for pair in self.pairs: 81 | pair = self.pd_calculator._get_pair_key((pair[0], pair[1])) 82 | pd_values = self.pd_calculator.get_pd_pairs(pair[0], pair[1]) 83 | res_j = np.apply_along_axis(stdev, 0, np.apply_along_axis(_calc_conditional_imp, 1, pd_values, 84 | is_numerical=pair[1] in num_features)) 85 | res_i = np.apply_along_axis(stdev, 0, np.apply_along_axis(_calc_conditional_imp, 0, pd_values, 86 | is_numerical=pair[0] in num_features)) 87 | value_pairs.append([pair[0], pair[1], (res_j + res_i) / 2]) 88 | return pd.DataFrame(value_pairs, columns=["Feature 1", "Feature 2", self.method]).sort_values( 89 | by=self.method, ascending=self._interactions_ascending_order, ignore_index=True 90 | ).fillna(0) 91 | 92 | 93 | def _calc_conditional_imp(pd_values: np.ndarray, is_numerical: bool): 94 | return stdev(pd_values) if is_numerical else (np.max(pd_values) - np.min(pd_values)) / 4 95 | -------------------------------------------------------------------------------- /artemis/interactions_methods/model_agnostic/partial_dependence_based/_pdp.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from itertools import combinations 3 | from typing import Callable, List, Optional, Tuple 4 | 5 | import pandas as pd 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import seaborn as sns 9 | 10 | from artemis.importance_methods.model_agnostic import PartialDependenceBasedImportance 11 | from artemis.interactions_methods._method import FeatureInteractionMethod 12 | from artemis._utilities.domain import ProgressInfoLog, VisualizationType 13 | from artemis._utilities.ops import get_predict_function, sample_if_not_none, all_if_none 14 | from artemis._utilities.pd_calculator import PartialDependenceCalculator 15 | from artemis._utilities.zenplot import ( 16 | get_pd_dict, 17 | get_pd_pairs_values, 18 | get_second_feature, 19 | find_next_pair_index, 20 | ) 21 | 22 | 23 | class PartialDependenceBasedMethod(FeatureInteractionMethod): 24 | def __init__(self, method: str, random_state: Optional[int] = None): 25 | super().__init__(method, random_state=random_state) 26 | self.pd_calculator = None 27 | 28 | @property 29 | def _interactions_ascending_order(self): 30 | return False 31 | 32 | def plot( 33 | self, 34 | vis_type: str = VisualizationType.HEATMAP, 35 | title: str = "default", 36 | figsize: Tuple[float, float] = (8, 6), 37 | **kwargs 38 | ): 39 | super().plot(vis_type, title, figsize, **kwargs) 40 | 41 | def fit( 42 | self, 43 | model, 44 | X: pd.DataFrame, 45 | n: Optional[int] = None, 46 | predict_function: Optional[Callable] = None, 47 | features: Optional[List[str]] = None, 48 | show_progress: bool = False, 49 | batchsize: int = 2000, 50 | pd_calculator: Optional[PartialDependenceCalculator] = None, 51 | ): 52 | """Calculates Partial Dependence Based Feature Interactions Strength and Feature Importance for the given model. 53 | 54 | Parameters 55 | ---------- 56 | model : object 57 | Model to be explained, should have predict_proba or predict method, or predict_function should be provided. 58 | X : pd.DataFrame 59 | Data used to calculate interactions. If n is not None, n rows from X will be sampled. 60 | n : int, optional 61 | Number of samples to be used for calculation of interactions. If None, all rows from X will be used. Default is None. 62 | predict_function : Callable, optional 63 | Function used to predict model output. It should take model and dataset and outputs predictions. 64 | If None, `predict_proba` method will be used if it exists, otherwise `predict` method. Default is None. 65 | features : List[str], optional 66 | List of features for which interactions will be calculated. If None, all features from X will be used. Default is None. 67 | show_progress : bool 68 | If True, progress bar will be shown. Default is False. 69 | batchsize : int 70 | Batch size for calculating partial dependence. Prediction requests are collected until the batchsize is exceeded, 71 | then the model is queried for predictions jointly for many observations. It speeds up the operation of the method. 72 | Default is 2000. 73 | pd_calculator : PartialDependenceCalculator, optional 74 | PartialDependenceCalculator object containing partial dependence values for a given model and dataset. 75 | Providing this object speeds up the calculation as partial dependence values do not need to be recalculated. 76 | If None, it will be created from scratch. Default is None. 77 | """ 78 | self.predict_function = get_predict_function(model, predict_function) 79 | self.model = model 80 | 81 | self.X_sampled = sample_if_not_none(self._random_generator, X, n) 82 | self.features_included = all_if_none(X.columns, features) 83 | self.pairs = list(combinations(self.features_included, 2)) 84 | 85 | if pd_calculator is None: 86 | self.pd_calculator = PartialDependenceCalculator( 87 | self.model, self.X_sampled, self.predict_function, batchsize 88 | ) 89 | else: 90 | if pd_calculator.model != self.model: 91 | raise ValueError( 92 | "Model in PDP calculator is different than the model in the method." 93 | ) 94 | if not pd_calculator.X.equals(self.X_sampled): 95 | raise ValueError( 96 | "Data in PDP calculator is different than the data in the method." 97 | ) 98 | self.pd_calculator = pd_calculator 99 | 100 | self.ovo = self._calculate_ovo_interactions_from_pd(show_progress=show_progress) 101 | 102 | self._feature_importance_obj = PartialDependenceBasedImportance() 103 | self.feature_importance = self._feature_importance_obj.importance( 104 | self.model, 105 | self.X_sampled, 106 | features=self.features_included, 107 | show_progress=show_progress, 108 | pd_calculator=self.pd_calculator, 109 | ) 110 | 111 | def plot_profile( 112 | self, 113 | feature1: str, 114 | feature2: Optional[str] = None, 115 | kind: str = "colormesh", 116 | cmap: str = "RdYlBu_r", 117 | figsize: Tuple[float, float] = (6, 4), 118 | ): 119 | """ 120 | Plots partial dependence profile for a given feature/pair of features. 121 | 122 | Parameters 123 | ---------- 124 | feature1 : str 125 | First feature. 126 | feature2 : str, optional 127 | Second feature. If None, profile for a single feature will be plotted. Default is None. 128 | kind : str 129 | Kind of plot, used only for pair of features. Can be 'colormesh' or 'contour'. Default is 'colormesh'. 130 | cmap: str 131 | Colormap. Default is 'RdYlBu_r'. 132 | figsize : (float, float) 133 | Size of plot. Default is (8, 6). 134 | """ 135 | plt.figure(figsize=figsize) 136 | if feature2 is not None: 137 | pair_key = self.pd_calculator._get_pair_key((feature1, feature2)) 138 | pair = self.pd_calculator.pd_pairs[pair_key] 139 | 140 | if kind == "contour": 141 | cs = plt.contour( 142 | pair["f2_values"], 143 | pair["f1_values"], 144 | pair["pd_values"], 145 | colors="black", 146 | linewidths=0.5, 147 | ) 148 | cs2 = plt.contourf( 149 | pair["f2_values"], pair["f1_values"], pair["pd_values"], cmap=cmap 150 | ) 151 | plt.clabel(cs, colors="black") 152 | clb = plt.colorbar(cs2) 153 | elif kind == "colormesh": 154 | cs = plt.pcolormesh( 155 | pair["f2_values"], 156 | pair["f1_values"], 157 | pair["pd_values"], 158 | linewidths=0.5, 159 | cmap=cmap, 160 | ) 161 | clb = plt.colorbar() 162 | clb.ax.set_title("PD value") 163 | plt.xlabel(pair_key[1]) 164 | plt.ylabel(pair_key[0]) 165 | sns.rugplot( 166 | self.pd_calculator.X, y=pair_key[0], x=pair_key[1], color="black" 167 | ) 168 | else: 169 | single = self.pd_calculator.pd_single[feature1] 170 | plt.plot(single["f_values"], single["pd_values"]) 171 | plt.xlabel(feature1) 172 | plt.ylabel("PD value") 173 | sns.rugplot(self.pd_calculator.X, x=feature1, color="black") 174 | 175 | def plot_zenplot( 176 | self, 177 | zenpath_length: int = 7, 178 | kind: str = "colormesh", 179 | cmap: str = "RdYlBu_r", 180 | figsize: Tuple[float, float] = (14, 12), 181 | ): 182 | """ 183 | Plots zenplot, a grid of charts where each panel contains a PD function visualization for a different pair of features 184 | 185 | Parameters 186 | ---------- 187 | zenpath_length : int 188 | Length of zenpath. Default is 7. 189 | kind : str 190 | Kind of plot. Can be 'colormesh' or 'contour'. Default is 'colormesh'. 191 | cmap: str 192 | Colormap. Default is 'RdYlBu_r'. 193 | figsize : (float, float) 194 | Size of plot. Default is (8, 6). 195 | 196 | References 197 | ---------- 198 | - https://www.jstatsoft.org/article/view/v095i04 199 | """ 200 | fig = plt.figure(figsize=figsize) 201 | to_vis = self.ovo.copy().iloc[:(zenpath_length + 1)] 202 | min_pd, max_pd = get_pd_dict(self.pd_calculator, to_vis) 203 | pair = to_vis.iloc[0]["Feature 1"], to_vis.iloc[0]["Feature 2"] 204 | to_vis = to_vis.drop(0) 205 | 206 | id_row, id_col = 0, 0 207 | nrows, ncols = int(np.floor((zenpath_length + 1) / 2)), int( 208 | np.floor(zenpath_length / 2) + 1 209 | ) 210 | continued = False 211 | 212 | for i in range(zenpath_length): 213 | pair_values = get_pd_pairs_values(self, pair) 214 | ax = plt.subplot2grid((nrows, ncols), (id_row, id_col), rowspan=1) 215 | 216 | if id_col > id_row: 217 | if kind == "colormesh": 218 | cs = ax.pcolormesh( 219 | pair_values["f2_values"], 220 | pair_values["f1_values"], 221 | pair_values["pd_values"], 222 | vmin=min_pd, 223 | vmax=max_pd, 224 | cmap=cmap, 225 | ) 226 | elif kind == "contour": 227 | plt.contour( 228 | pair_values["f2_values"], 229 | pair_values["f1_values"], 230 | pair_values["pd_values"], 231 | vmin=min_pd, 232 | vmax=max_pd, 233 | colors="black", 234 | linewidths=0.5, 235 | ) 236 | cs = plt.contourf( 237 | pair_values["f2_values"], 238 | pair_values["f1_values"], 239 | pair_values["pd_values"], 240 | vmin=min_pd, 241 | vmax=max_pd, 242 | cmap=cmap, 243 | ) 244 | ax.set_ylabel(pair[0]) 245 | id_row += 1 246 | else: 247 | if kind == "colormesh": 248 | cs = ax.pcolormesh( 249 | pair_values["f1_values"], 250 | pair_values["f2_values"], 251 | pair_values["pd_values"].T, 252 | vmin=min_pd, 253 | vmax=max_pd, 254 | cmap=cmap, 255 | ) 256 | elif kind == "contour": 257 | plt.contour( 258 | pair_values["f1_values"], 259 | pair_values["f2_values"], 260 | pair_values["pd_values"].T, 261 | vmin=min_pd, 262 | vmax=max_pd, 263 | colors="black", 264 | linewidths=0.5, 265 | ) 266 | cs = plt.contourf( 267 | pair_values["f1_values"], 268 | pair_values["f2_values"], 269 | pair_values["pd_values"].T, 270 | vmin=min_pd, 271 | vmax=max_pd, 272 | cmap=cmap, 273 | ) 274 | ax.set_title(pair[0], size=10) 275 | id_col += 1 276 | 277 | idx, continued = find_next_pair_index(to_vis, pair[1]) 278 | if continued: 279 | pair = pair[1], get_second_feature(pair[1], to_vis.loc[idx]) 280 | if zenpath_length - 1 == i: 281 | if id_col > id_row: 282 | ax.set_ylabel(pair[1]) 283 | else: 284 | ax.set_title(pair[1], size=10) 285 | else: 286 | if id_col > id_row: 287 | ax.set_ylabel(pair[1]) 288 | else: 289 | ax.set_title(pair[1], size=10) 290 | pair = to_vis.loc[idx]["Feature 1"], to_vis.loc[idx]["Feature 2"] 291 | 292 | to_vis = to_vis.drop(idx) 293 | 294 | plt.tight_layout() 295 | cbar_ax = fig.add_axes([1, 0.25, 0.05, 0.5]) 296 | clb = plt.colorbar(cs, cax=cbar_ax) 297 | clb.ax.set_title("PD value") 298 | 299 | @abstractmethod 300 | def _calculate_ovo_interactions_from_pd(self, show_progress: bool): 301 | ... 302 | -------------------------------------------------------------------------------- /artemis/interactions_methods/model_agnostic/performance_based/__init__.py: -------------------------------------------------------------------------------- 1 | from ._sejong_oh import SejongOhMethod -------------------------------------------------------------------------------- /artemis/interactions_methods/model_agnostic/performance_based/_sejong_oh.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | from typing import Callable, List, Optional, Union, Tuple 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | from artemis.importance_methods.model_agnostic import PermutationImportance 9 | from artemis.interactions_methods._method import FeatureInteractionMethod 10 | from artemis._utilities.domain import InteractionMethod, ProblemType, ProgressInfoLog, VisualizationType 11 | from artemis._utilities.performance_metrics import Metric, RMSE 12 | from artemis._utilities.ops import all_if_none, sample_both_if_not_none 13 | 14 | 15 | class SejongOhMethod(FeatureInteractionMethod): 16 | """ 17 | Sejong Oh's Performance Based Method for Feature Interaction Extraction. 18 | 19 | Attributes 20 | ---------- 21 | method : str 22 | Method name, used also for naming column with results in `ovo` pd.DataFrame. 23 | visualizer : Visualizer 24 | Object providing visualization. Automatically created on the basis of a method and used to create visualizations. 25 | ovo : pd.DataFrame 26 | One versus one (pair) feature interaction values. 27 | feature_importance : pd.DataFrame 28 | Feature importance values. 29 | metric : Metric 30 | Metric used for calculating performance. 31 | model : object 32 | Explained model. 33 | X_sampled: pd.DataFrame 34 | Sampled data used for calculation. 35 | y_sampled: np.array or pd.Series 36 | Sampled target values used for calculation. 37 | features_included: List[str] 38 | List of features for which interactions are calculated. 39 | pairs : List[List[str]] 40 | List of pairs of features for which interactions are calculated. 41 | random_state : int 42 | Random state used for reproducibility. 43 | 44 | References 45 | ---------- 46 | - https://www.mdpi.com/2076-3417/9/23/5191 47 | """ 48 | 49 | def __init__(self, metric: Metric = RMSE(), random_state: Optional[int] = None): 50 | """Constructor for SejongOhMethod 51 | 52 | Parameters 53 | ---------- 54 | metric : Metric 55 | Metric used to calculate model performance. Defaults to RMSE(). 56 | random_state : int, optional 57 | Random state for reproducibility. Defaults to None. 58 | """ 59 | super().__init__(InteractionMethod.PERFORMANCE_BASED, random_state) 60 | self.metric = metric 61 | self.y_sampled = None 62 | 63 | @property 64 | def _interactions_ascending_order(self): 65 | return False 66 | 67 | def plot(self, vis_type: str = VisualizationType.HEATMAP, title: str = "default", 68 | figsize: Tuple[float, float] = (8, 6), **kwargs): 69 | super().plot(vis_type, title, figsize, **kwargs) 70 | 71 | def fit( 72 | self, 73 | model, 74 | X: pd.DataFrame, 75 | y_true: Union[np.array, pd.Series], 76 | n: int = None, 77 | n_repeat: int = 10, 78 | features: List[str] = None, 79 | show_progress: bool = False, 80 | ): 81 | """Calculates Performance Based Feature Interactions Strength and Permutation Based Feature Importance for the given model. 82 | 83 | Parameters 84 | ---------- 85 | model : object 86 | Model to be explained, should have predict method. 87 | X : pd.DataFrame 88 | Data used to calculate interactions. If n is not None, n rows from X will be sampled. 89 | y_true : np.array or pd.Series 90 | Target values for X data. 91 | n : int, optional 92 | Number of samples to be used for calculation of interactions. If None, all rows from X will be used. Default is None. 93 | n_repeat : int, optional 94 | Number of permutations. Default is 10. 95 | features : List[str], optional 96 | List of features for which interactions will be calculated. If None, all features from X will be used. Default is None. 97 | show_progress : bool 98 | If True, progress bar will be shown. Default is False. 99 | """ 100 | self.X_sampled, self.y_sampled = sample_both_if_not_none(self._random_generator, X, y_true, n) 101 | self.features_included = all_if_none(X.columns, features) 102 | self.pairs = list(combinations(self.features_included, 2)) 103 | self.ovo = _perf_based_ovo(self, model, self.X_sampled, self.y_sampled, n_repeat, show_progress) 104 | 105 | # calculate feature importance 106 | self._feature_importance_obj = PermutationImportance(self.metric) 107 | self.feature_importance = self._feature_importance_obj.importance(model, X=self.X_sampled, 108 | y_true=self.y_sampled, 109 | n_repeat=n_repeat, 110 | features=self.features_included, 111 | show_progress=show_progress) 112 | 113 | 114 | def _perf_based_ovo( 115 | method_class: SejongOhMethod, model, X: pd.DataFrame, y_true: np.array, n_repeat: int, show_progress: bool 116 | ): 117 | """For each pair of `features_included`, calculate Sejong Oh performance based interaction value.""" 118 | original_performance = method_class.metric.calculate(y_true, model.predict(X)) 119 | interactions = [] 120 | 121 | for f1, f2 in tqdm(method_class.pairs, disable=not show_progress, desc=ProgressInfoLog.CALC_OVO): 122 | inter = [ 123 | np.abs(_inter(method_class, model, X, y_true, f1, f2, original_performance)) for _ in range(n_repeat) 124 | ] 125 | interactions.append([f1, f2, np.mean(inter)]) 126 | 127 | return pd.DataFrame(interactions, columns=["Feature 1", "Feature 2", method_class.method]).sort_values( 128 | by=method_class.method, key=abs, ascending=method_class._interactions_ascending_order, ignore_index=True 129 | ) 130 | 131 | 132 | def _inter( 133 | method_class: SejongOhMethod, 134 | model, 135 | X: pd.DataFrame, 136 | y_true: np.array, 137 | f1: str, 138 | f2: str, 139 | reference_performance: float, 140 | ): 141 | """ 142 | Calculates performance-based interaction between features `f1` and `f2`. 143 | Intuitively, it calculates the impact on the performance of the model, when one of [f1, f2] are permuted 144 | with respect to when both are permuted together. 145 | 146 | Specifics can be found in: https://www.mdpi.com/2076-3417/9/23/5191. 147 | """ 148 | score_f1_permuted = _permute_score(method_class, model, X, y_true, [f1], reference_performance) 149 | score_f2_permuted = _permute_score(method_class, model, X, y_true, [f2], reference_performance) 150 | score_f1_f2_permuted = _permute_score(method_class, model, X, y_true, [f1, f2], reference_performance) 151 | 152 | return _neg_if_class(method_class, score_f1_f2_permuted - score_f1_permuted - score_f2_permuted) 153 | 154 | 155 | def _permute_score( 156 | method_class: SejongOhMethod, 157 | model, 158 | X: pd.DataFrame, 159 | y_true: np.array, 160 | features: List[str], 161 | reference_performance: float, 162 | ): 163 | """Permute `features` list and assess performance of the model.""" 164 | X_copy_permuted = X.copy() 165 | p = method_class._random_generator.permutation(len(X)) 166 | 167 | for feature in features: 168 | X_copy_permuted[feature] = X_copy_permuted[feature].values[p] 169 | 170 | return _neg_if_class( 171 | method_class, 172 | method_class.metric.calculate(y_true, model.predict(X_copy_permuted)) - reference_performance, 173 | ) 174 | 175 | 176 | def _neg_if_class(method_class: SejongOhMethod, value: float): 177 | """Classification metrics should be maximized.""" 178 | if method_class.metric.applicable_to(ProblemType.CLASSIFICATION): 179 | return -value 180 | 181 | return value 182 | -------------------------------------------------------------------------------- /artemis/interactions_methods/model_specific/__init__.py: -------------------------------------------------------------------------------- 1 | from .gb_trees import SplitScoreMethod 2 | from .random_forest import ConditionalMinimalDepthMethod 3 | 4 | __all__ = ["SplitScoreMethod", "ConditionalMinimalDepthMethod"] 5 | -------------------------------------------------------------------------------- /artemis/interactions_methods/model_specific/gb_trees/__init__.py: -------------------------------------------------------------------------------- 1 | from ._split_score import SplitScoreMethod 2 | -------------------------------------------------------------------------------- /artemis/interactions_methods/model_specific/random_forest/__init__.py: -------------------------------------------------------------------------------- 1 | from ._conditional_minimal_depth import ConditionalMinimalDepthMethod -------------------------------------------------------------------------------- /artemis/visualizer/__init__.py: -------------------------------------------------------------------------------- 1 | from ._pdp_visualizer import PartialDependenceVisualizer 2 | 3 | __all__ = ["PartialDependenceVisualizer"] -------------------------------------------------------------------------------- /artemis/visualizer/_configuration.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import List 3 | 4 | from artemis._utilities.domain import InteractionMethod, VisualizationType 5 | from artemis._utilities.exceptions import MethodNotSupportedException 6 | 7 | 8 | @dataclass 9 | class InteractionGraphConfiguration: 10 | MAX_EDGE_WIDTH: int = 20 11 | N_HIGHEST_WITH_LABELS: int = 5 12 | FONT_COLOR: str = "#3B1F2B" 13 | FONT_WEIGHT: str = "bold" 14 | FONT_SIZE: int = 10 15 | EDGE_COLOR: str = "rebeccapurple" 16 | NODE_COLOR: str = "green" 17 | NODE_SIZE: int = 1800 18 | TITLE: str = "Interaction graph" 19 | THRESHOLD_RELEVANT_INTERACTION: float = 0.05 20 | 21 | 22 | @dataclass 23 | class InteractionMatrixConfiguration: 24 | TITLE: str = "Interaction matrix" 25 | INTERACTION_COLOR_MAP: str = "Purples" 26 | INTERACTION_COLOR_MAP_REVERSE: str = "Purples_r" 27 | IMPORTANCE_COLOR_MAP: str = "Greens" 28 | IMPORTANCE_COLOR_MAP_REVERSE: str = "Greens_r" 29 | ANNOT_FMT: str = ".3f" 30 | LINEWIDTHS: float = 0.5 31 | LINECOLOR: str = "white" 32 | CBAR_SHRINK: float = 0.8 33 | 34 | 35 | @dataclass 36 | class InteractionVersusAllConfiguration: 37 | TITLE: str = "Interaction with all other features" 38 | TOP_K: int = 10 39 | COLOR: str = "mediumpurple" 40 | 41 | 42 | @dataclass 43 | class InteractionVersusOneConfiguration: 44 | TITLE: str = "Pair interactions" 45 | TOP_K: int = 10 46 | COLOR: str = "mediumpurple" 47 | 48 | 49 | @dataclass 50 | class LollipopSplitScoreConfiguration: 51 | TITLE: str = "Lollipop boosting model summary" 52 | SCALE: str = "linear" 53 | MAX_TREES: float = 0.2 54 | LABEL_THRESHOLD: float = 0.1 55 | COLORS = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#ffff33"] 56 | SHAPES = ["o", ",", "v", "^", "<", ">"] 57 | MAX_DEPTH: int = 1 58 | LABELS: bool = True 59 | 60 | 61 | @dataclass 62 | class BarChartConditionalDepthConfiguration: 63 | TITLE: str = "Random forest model summary" 64 | COLOR_MAP: str = "Purples" 65 | TOP_K: int = 15 66 | COLOR: str = "black" 67 | 68 | 69 | class VisualizationConfigurationProvider: 70 | accepted_visualizations = { 71 | InteractionMethod.H_STATISTIC: [ 72 | VisualizationType.SUMMARY, 73 | VisualizationType.INTERACTION_GRAPH, 74 | VisualizationType.BAR_CHART_OVA, 75 | VisualizationType.BAR_CHART_OVO, 76 | VisualizationType.HEATMAP, 77 | ], 78 | InteractionMethod.PERFORMANCE_BASED: [ 79 | VisualizationType.SUMMARY, 80 | VisualizationType.INTERACTION_GRAPH, 81 | VisualizationType.BAR_CHART_OVO, 82 | VisualizationType.HEATMAP, 83 | ], 84 | InteractionMethod.VARIABLE_INTERACTION: [ 85 | VisualizationType.SUMMARY, 86 | VisualizationType.INTERACTION_GRAPH, 87 | VisualizationType.BAR_CHART_OVO, 88 | VisualizationType.HEATMAP, 89 | ], 90 | InteractionMethod.CONDITIONAL_MINIMAL_DEPTH: [ 91 | VisualizationType.SUMMARY, 92 | VisualizationType.INTERACTION_GRAPH, 93 | VisualizationType.BAR_CHART_OVO, 94 | VisualizationType.HEATMAP, 95 | VisualizationType.BAR_CHART_CONDITIONAL, 96 | ], 97 | InteractionMethod.SPLIT_SCORE: [ 98 | VisualizationType.SUMMARY, 99 | VisualizationType.INTERACTION_GRAPH, 100 | VisualizationType.BAR_CHART_OVO, 101 | VisualizationType.HEATMAP, 102 | VisualizationType.LOLLIPOP, 103 | ], 104 | } 105 | 106 | @classmethod 107 | def get(cls, method: str): 108 | if method == InteractionMethod.H_STATISTIC: 109 | return cls._h_stat_config() 110 | elif method == InteractionMethod.VARIABLE_INTERACTION: 111 | return cls._var_inter_config() 112 | elif method == InteractionMethod.PERFORMANCE_BASED: 113 | return cls._perf_based_config() 114 | elif method == InteractionMethod.SPLIT_SCORE: 115 | return cls._split_score_config() 116 | elif method == InteractionMethod.CONDITIONAL_MINIMAL_DEPTH: 117 | return cls._cond_depth_config() 118 | else: 119 | raise MethodNotSupportedException(method) 120 | 121 | @classmethod 122 | def _h_stat_config(cls): 123 | return VisualizationConfiguration( 124 | accepted_visualizations=cls.accepted_visualizations[ 125 | InteractionMethod.H_STATISTIC 126 | ] 127 | ) 128 | 129 | @classmethod 130 | def _var_inter_config(cls): 131 | return VisualizationConfiguration( 132 | accepted_visualizations=cls.accepted_visualizations[ 133 | InteractionMethod.VARIABLE_INTERACTION 134 | ] 135 | ) 136 | 137 | @classmethod 138 | def _perf_based_config(cls): 139 | graph_config = InteractionGraphConfiguration() 140 | graph_config.THRESHOLD_RELEVANT_INTERACTION = 0.1 141 | 142 | return VisualizationConfiguration( 143 | accepted_visualizations=cls.accepted_visualizations[ 144 | InteractionMethod.PERFORMANCE_BASED 145 | ], 146 | interaction_graph=graph_config, 147 | ) 148 | 149 | @classmethod 150 | def _split_score_config(cls): 151 | graph_config = InteractionGraphConfiguration() 152 | graph_config.THRESHOLD_RELEVANT_INTERACTION = 0.1 153 | 154 | return VisualizationConfiguration( 155 | accepted_visualizations=cls.accepted_visualizations[ 156 | InteractionMethod.SPLIT_SCORE 157 | ], 158 | interaction_graph=graph_config, 159 | ) 160 | 161 | @classmethod 162 | def _cond_depth_config(cls): 163 | graph_config = InteractionGraphConfiguration() 164 | graph_config.THRESHOLD_RELEVANT_INTERACTION = 0.6 165 | graph_config.MAX_EDGE_WIDTH = 3 166 | 167 | return VisualizationConfiguration( 168 | accepted_visualizations=cls.accepted_visualizations[ 169 | InteractionMethod.CONDITIONAL_MINIMAL_DEPTH 170 | ], 171 | interaction_graph=graph_config, 172 | ) 173 | 174 | 175 | @dataclass 176 | class VisualizationConfiguration: 177 | accepted_visualizations: List[str] 178 | interaction_graph: InteractionGraphConfiguration = field(default_factory=InteractionGraphConfiguration) 179 | interaction_matrix: InteractionMatrixConfiguration = field(default_factory=InteractionMatrixConfiguration) 180 | interaction_bar_chart_ova: InteractionVersusAllConfiguration = field( 181 | default_factory=InteractionVersusAllConfiguration 182 | ) 183 | interaction_bar_chart_ovo: InteractionVersusOneConfiguration = field( 184 | default_factory=InteractionVersusOneConfiguration 185 | ) 186 | lollipop: LollipopSplitScoreConfiguration = field(default_factory=LollipopSplitScoreConfiguration) 187 | interaction_bar_chart_conditional: BarChartConditionalDepthConfiguration = field( 188 | default_factory=BarChartConditionalDepthConfiguration 189 | ) 190 | -------------------------------------------------------------------------------- /artemis/visualizer/_pdp_visualizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | from matplotlib import pyplot as plt 3 | from sklearn.inspection import PartialDependenceDisplay 4 | from sklearn.base import BaseEstimator 5 | import pandas as pd 6 | 7 | 8 | class PartialDependenceVisualizer: 9 | """ 10 | Visualizer of 1-dimensianal and 2-dimensional partial dependence plots. 11 | It wraps scikit-learn PartialDependenceDisplay.from_estimator() method, 12 | so only models implementing predict functions in scikit-learn API are supported. 13 | 14 | Attributes 15 | ---------- 16 | model : sklearn.BaseEstimator 17 | Model for which partial dependence plot will be generated. 18 | X : pd.DataFrame 19 | Data used to calculate partial dependence functions. 20 | """ 21 | def __init__(self, model: BaseEstimator, X: pd.DataFrame): 22 | """Constructor for PartialDependenceVisualizer 23 | 24 | Parameters 25 | ---------- 26 | model : sklearn.BaseEstimator 27 | Model for which partial dependence plot will be generated. 28 | X : pd.DataFrame 29 | Data used to calculate partial dependence functions. 30 | """ 31 | self.model = model 32 | self.X = X 33 | 34 | def plot(self, 35 | features: List[Union[int, str, Tuple[int, int], Tuple[str, str]]], 36 | grid_resolution: int = 100, 37 | title: str = "Partial Dependence", 38 | figsize: Tuple[float, float] = (12, 6), 39 | **kwargs): 40 | """Plot partial dependence plot. 41 | 42 | Parameters 43 | ---------- 44 | features : int, str, (int, int), or (str, str) 45 | Features for which partial dependence plot will be generated. 46 | If one feature is provided, 1-dimensional PDP will be returned, two features -- 2-dimensional PDP. 47 | grid_resolution : int 48 | The number of equally spaced points on the axes of the plots, for each target feature. Default is 100. 49 | title : str 50 | Title of plot. Default is 'Partial Dependence'. 51 | figsize : (float, float) 52 | Size of plot. Default is (12, 6). 53 | **kwargs : Other Parameters 54 | Additional parameters for plot. Passed to PartialDependenceDisplay.from_estimator() method. 55 | """ 56 | fig, ax = plt.subplots(figsize=figsize) 57 | PartialDependenceDisplay.from_estimator(self.model, 58 | self.X, 59 | features, 60 | grid_resolution=grid_resolution, 61 | _ax=ax, 62 | **kwargs) 63 | ax.set_title(title) 64 | -------------------------------------------------------------------------------- /docs/artemis/importance_methods/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | artemis.importance_methods API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 46 | 65 |
66 | 69 | 70 | -------------------------------------------------------------------------------- /docs/artemis/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | artemis API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Package artemis

23 |
24 |
25 |

artemis: A Robust Toolkit of Explanation Methods for Interaction Spotting

26 |

What is artemis?

27 |

artemis is a Python library for explanations of feature interactions in machine learning models. 28 | - It provides methods for analyzing predictive models in terms of interactions between features and feature importance. 29 | - There are both model-agnostic methods that can work with any properly prepared model, and model-specific methods adapted to tree-based ones 30 | (due to their structure, which naturally influences the possibility of interaction). 31 | - It enables to scrutinize a wide range of models by examining strength of feature interactions and visualizing them.

32 |
33 | 34 | Expand source code 35 | 36 |
"""
 37 | # `artemis`: A Robust Toolkit of Explanation Methods for Interaction Spotting
 38 | 
 39 | ## What is `artemis`?  
 40 | **`artemis` is a Python library for explanations of feature interactions in machine learning models.**
 41 | - It provides methods for analyzing predictive models in terms of interactions between features and feature importance. 
 42 | - There are both model-agnostic methods that can work with any properly prepared model, and model-specific methods adapted to tree-based ones 
 43 | (due to their structure, which naturally influences the possibility of interaction).
 44 | - It enables to scrutinize a wide range of models by examining strength of feature interactions and visualizing them.
 45 | 
 46 | """
47 |
48 |
49 |
50 |

Sub-modules

51 |
52 |
artemis.additivity
53 |
54 |
55 |
56 |
artemis.comparison
57 |
58 |
59 |
60 |
artemis.importance_methods
61 |
62 |
63 |
64 |
artemis.interactions_methods
65 |
66 |
67 |
68 |
artemis.visualizer
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 | 103 |
104 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /docs/artemis/interactions_methods/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | artemis.interactions_methods API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 46 | 65 |
66 | 69 | 70 | -------------------------------------------------------------------------------- /docs/artemis/interactions_methods/model_agnostic/performance_based/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | artemis.interactions_methods.model_agnostic.performance_based API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module artemis.interactions_methods.model_agnostic.performance_based

23 |
24 |
25 |
26 | 27 | Expand source code 28 | 29 |
from ._sejong_oh import SejongOhMethod
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | 54 |
55 | 58 | 59 | -------------------------------------------------------------------------------- /docs/artemis/interactions_methods/model_specific/gb_trees/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | artemis.interactions_methods.model_specific.gb_trees API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module artemis.interactions_methods.model_specific.gb_trees

23 |
24 |
25 |
26 | 27 | Expand source code 28 | 29 |
from ._split_score import SplitScoreMethod
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | 54 |
55 | 58 | 59 | -------------------------------------------------------------------------------- /docs/artemis/interactions_methods/model_specific/random_forest/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | artemis.interactions_methods.model_specific.random_forest API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module artemis.interactions_methods.model_specific.random_forest

23 |
24 |
25 |
26 | 27 | Expand source code 28 | 29 |
from ._conditional_minimal_depth import ConditionalMinimalDepthMethod
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 | 54 |
55 | 58 | 59 | -------------------------------------------------------------------------------- /docs/artemis/visualizer/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | artemis.visualizer API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module artemis.visualizer

23 |
24 |
25 |
26 | 27 | Expand source code 28 | 29 |
from ._pdp_visualizer import PartialDependenceVisualizer
 30 | 
 31 | __all__ = ["PartialDependenceVisualizer"]
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |

Classes

42 |
43 |
44 | class PartialDependenceVisualizer 45 | (model: sklearn.base.BaseEstimator, X: pandas.core.frame.DataFrame) 46 |
47 |
48 |

Visualizer of 1-dimensianal and 2-dimensional partial dependence plots. 49 | It wraps scikit-learn PartialDependenceDisplay.from_estimator() method, 50 | so only models implementing predict functions in scikit-learn API are supported.

51 |

Attributes

52 |
53 |
model : sklearn.BaseEstimator
54 |
Model for which partial dependence plot will be generated.
55 |
X : pd.DataFrame
56 |
Data used to calculate partial dependence functions.
57 |
58 |

Constructor for PartialDependenceVisualizer

59 |

Parameters

60 |
61 |
model : sklearn.BaseEstimator
62 |
Model for which partial dependence plot will be generated.
63 |
X : pd.DataFrame
64 |
Data used to calculate partial dependence functions.
65 |
66 |
67 | 68 | Expand source code 69 | 70 |
class PartialDependenceVisualizer:
 71 |     """
 72 |     Visualizer of 1-dimensianal and 2-dimensional partial dependence plots. 
 73 |     It wraps scikit-learn PartialDependenceDisplay.from_estimator() method, 
 74 |     so only models implementing predict functions in scikit-learn API are supported.
 75 | 
 76 |     Attributes
 77 |     ----------
 78 |     model : sklearn.BaseEstimator 
 79 |         Model for which partial dependence plot will be generated.
 80 |     X : pd.DataFrame
 81 |         Data used to calculate partial dependence functions.
 82 |     """
 83 |     def __init__(self, model: BaseEstimator, X: pd.DataFrame):
 84 |         """Constructor for PartialDependenceVisualizer
 85 | 
 86 |         Parameters
 87 |         ----------
 88 |         model : sklearn.BaseEstimator 
 89 |             Model for which partial dependence plot will be generated.
 90 |         X : pd.DataFrame
 91 |             Data used to calculate partial dependence functions.
 92 |         """
 93 |         self.model = model
 94 |         self.X = X
 95 | 
 96 |     def plot(self,
 97 |              features: List[Union[int, str, Tuple[int, int], Tuple[str, str]]],
 98 |              grid_resolution: int = 100,
 99 |              title: str = "Partial Dependence",
100 |              figsize: Tuple[float, float] = (12, 6),
101 |              **kwargs):
102 |         """Plot partial dependence plot.
103 | 
104 |         Parameters
105 |         ----------
106 |         features : int, str, (int, int), or (str, str) 
107 |             Features for which partial dependence plot will be generated. 
108 |             If one feature is provided, 1-dimensional PDP will be returned, two features -- 2-dimensional PDP.
109 |         grid_resolution : int
110 |             The number of equally spaced points on the axes of the plots, for each target feature. Default is 100. 
111 |         title : str 
112 |             Title of plot. Default is 'Partial Dependence'.
113 |         figsize : (float, float) 
114 |             Size of plot. Default is (12, 6).
115 |         **kwargs : Other Parameters
116 |             Additional parameters for plot. Passed to PartialDependenceDisplay.from_estimator() method.
117 |         """
118 |         fig, ax = plt.subplots(figsize=figsize)
119 |         PartialDependenceDisplay.from_estimator(self.model,
120 |                                                 self.X,
121 |                                                 features,
122 |                                                 grid_resolution=grid_resolution,
123 |                                                 _ax=ax,
124 |                                                 **kwargs)
125 |         ax.set_title(title)
126 |
127 |

Methods

128 |
129 |
130 | def plot(self, features: List[Union[int, str, Tuple[int, int], Tuple[str, str]]], grid_resolution: int = 100, title: str = 'Partial Dependence', figsize: Tuple[float, float] = (12, 6), **kwargs) 131 |
132 |
133 |

Plot partial dependence plot.

134 |

Parameters

135 |
136 |
features : int, str, (int, int), or (str, str)
137 |
Features for which partial dependence plot will be generated. 138 | If one feature is provided, 1-dimensional PDP will be returned, two features – 2-dimensional PDP.
139 |
grid_resolution : int
140 |
The number of equally spaced points on the axes of the plots, for each target feature. Default is 100.
141 |
title : str
142 |
Title of plot. Default is 'Partial Dependence'.
143 |
figsize : (float, float)
144 |
Size of plot. Default is (12, 6).
145 |
**kwargs : Other Parameters
146 |
Additional parameters for plot. Passed to PartialDependenceDisplay.from_estimator() method.
147 |
148 |
149 | 150 | Expand source code 151 | 152 |
def plot(self,
153 |          features: List[Union[int, str, Tuple[int, int], Tuple[str, str]]],
154 |          grid_resolution: int = 100,
155 |          title: str = "Partial Dependence",
156 |          figsize: Tuple[float, float] = (12, 6),
157 |          **kwargs):
158 |     """Plot partial dependence plot.
159 | 
160 |     Parameters
161 |     ----------
162 |     features : int, str, (int, int), or (str, str) 
163 |         Features for which partial dependence plot will be generated. 
164 |         If one feature is provided, 1-dimensional PDP will be returned, two features -- 2-dimensional PDP.
165 |     grid_resolution : int
166 |         The number of equally spaced points on the axes of the plots, for each target feature. Default is 100. 
167 |     title : str 
168 |         Title of plot. Default is 'Partial Dependence'.
169 |     figsize : (float, float) 
170 |         Size of plot. Default is (12, 6).
171 |     **kwargs : Other Parameters
172 |         Additional parameters for plot. Passed to PartialDependenceDisplay.from_estimator() method.
173 |     """
174 |     fig, ax = plt.subplots(figsize=figsize)
175 |     PartialDependenceDisplay.from_estimator(self.model,
176 |                                             self.X,
177 |                                             features,
178 |                                             grid_resolution=grid_resolution,
179 |                                             _ax=ax,
180 |                                             **kwargs)
181 |     ax.set_title(title)
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 | 212 |
213 | 216 | 217 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "pyartemis" 3 | packages = [ 4 | { include = "artemis" } 5 | ] 6 | version = "0.1.5" 7 | description = "A Python package with explanation methods for extraction of feature interactions from predictive models" 8 | readme = "README.md" 9 | documentation = "https://pyartemis.github.io/" 10 | keywords = ["interactions", "xai", "Explainable Artificial Intelligence", "explanations", "machine learning", "iml", "Interpretable Machine Learning"] 11 | authors = ["Artur Żółkowski ", "Mateusz Krzyziński ", "Paweł Fijałkowski "] 12 | 13 | [tool.poetry.dependencies] 14 | python = "^3.8" 15 | pandas = "^1.5.1" 16 | numpy = "^1.22.0" 17 | scikit-learn = "^1.1.3" 18 | seaborn = "^0.12.1" 19 | ipykernel = "^6.17.0" 20 | tqdm = "^4.64.1" 21 | networkx = "^2.8.8" 22 | 23 | [tool.poetry.dev-dependencies] 24 | black = {version = "^22.10.0", allow-prereleases = true} 25 | xgboost = "^1.7.1" 26 | lightgbm = "^3.2.0" 27 | parameterized = "^0.8.1" 28 | pytest = "^7.2.0" 29 | pdoc3 = "^0.10.0" 30 | 31 | [build-system] 32 | requires = ["poetry-core>=1.0.0"] 33 | build-backend = "poetry.core.masonry.api" 34 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pyartemis/artemis/fd6ed2018b6c9b8aa1de3e3f91a345f60d801271/test/__init__.py -------------------------------------------------------------------------------- /test/test_additivity_meter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from parameterized import parameterized_class 3 | 4 | from artemis.additivity import AdditivityMeter 5 | from test.util import california_housing_random_forest, california_housing_linear_regression 6 | 7 | 8 | MODEL_REG, X_REG, _ = california_housing_random_forest() 9 | LINEAR_MODEL_REG, _, _ = california_housing_linear_regression() 10 | 11 | @parameterized_class([ 12 | { 13 | "model": MODEL_REG, 14 | "X": X_REG, 15 | "linear_model": LINEAR_MODEL_REG 16 | }, 17 | ]) 18 | class AdditivityMeterUnitTest(unittest.TestCase): 19 | model = None 20 | X = None 21 | linear_model = None 22 | 23 | def setUp(self) -> None: 24 | X_sample = self.X.sample(n=100) 25 | self.additivity_meter_rf = AdditivityMeter() 26 | self.additivity_meter_rf.fit(self.model, X_sample) 27 | 28 | self.additivity_meter_linear = AdditivityMeter() 29 | self.additivity_meter_linear.fit(self.linear_model, X_sample) 30 | 31 | def test_additivity_index_values(self): 32 | self.assertLessEqual(self.additivity_meter_linear.additivity_index, 1) 33 | self.assertLessEqual(self.additivity_meter_rf.additivity_index, 1) 34 | self.assertGreater(self.additivity_meter_linear.additivity_index, self.additivity_meter_rf.additivity_index) 35 | self.assertEqual(self.additivity_meter_linear.additivity_index, 1) 36 | 37 | if __name__ == '__main__': 38 | unittest.main() 39 | -------------------------------------------------------------------------------- /test/test_conditional_minimal_depth.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from parameterized import parameterized_class 3 | 4 | from artemis.interactions_methods.model_specific import ConditionalMinimalDepthMethod 5 | from artemis._utilities.domain import InteractionMethod 6 | from artemis.visualizer._configuration import VisualizationConfigurationProvider 7 | from test.util import california_housing_random_forest, wine_random_forest 8 | 9 | MODEL_REG, X_REG, _ = california_housing_random_forest() 10 | MODEL_CLS, X_CLS, _ = wine_random_forest() 11 | 12 | 13 | @parameterized_class([ 14 | { 15 | "model": MODEL_REG, 16 | "X": X_REG 17 | }, 18 | { 19 | "model": MODEL_CLS, 20 | "X": X_CLS 21 | }, 22 | ]) 23 | class ConditionalMinimalDepthTestCase(unittest.TestCase): 24 | model = None 25 | X = None 26 | 27 | def test_ovo_all_features(self): 28 | # when 29 | cond_min = ConditionalMinimalDepthMethod() 30 | cond_min.fit(self.model) 31 | 32 | # then 33 | self.assertSetEqual(set(cond_min.ovo.columns), {"root_variable", "variable", "n_occurences", cond_min.method}) 34 | p = len(self.X.columns) 35 | self.assertEqual(len(cond_min.ovo), p * p - p) 36 | 37 | def test_plot(self): 38 | # when 39 | cond_min = ConditionalMinimalDepthMethod() 40 | cond_min.fit(self.model) 41 | 42 | # allowed plots are generated without exception 43 | accepted_vis = VisualizationConfigurationProvider.get( 44 | InteractionMethod.CONDITIONAL_MINIMAL_DEPTH).accepted_visualizations 45 | for vis in accepted_vis: 46 | cond_min.plot(vis, show=False) 47 | 48 | # then 49 | # nothing crashes! 50 | 51 | def test_minimal_depth_feature_importance(self): 52 | # when 53 | cond_min = ConditionalMinimalDepthMethod() 54 | cond_min.fit(self.model) 55 | 56 | # then 57 | self.assertSetEqual(set(cond_min.feature_importance["Feature"]), set(self.X.columns)) 58 | 59 | 60 | if __name__ == '__main__': 61 | unittest.main() 62 | -------------------------------------------------------------------------------- /test/test_friedman_h.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from parameterized import parameterized_class 3 | 4 | from .util import california_housing_random_forest, has_decreasing_order, CALIFORNIA_SUBSET, SAMPLE_SIZE, wine_random_forest, WINE_SUBSET 5 | from artemis._utilities.domain import InteractionMethod 6 | from artemis.interactions_methods.model_agnostic import FriedmanHStatisticMethod 7 | from artemis.visualizer._configuration import VisualizationConfigurationProvider 8 | 9 | 10 | MODEL_REG, X_REG, _ = california_housing_random_forest() 11 | MODEL_CLS, X_CLS, _ = wine_random_forest() 12 | 13 | 14 | @parameterized_class([ 15 | { 16 | "model": MODEL_REG, 17 | "X": X_REG, 18 | "SUBSET": CALIFORNIA_SUBSET 19 | }, 20 | { 21 | "model": MODEL_CLS, 22 | "X": X_CLS, 23 | "SUBSET": WINE_SUBSET 24 | }, 25 | ]) 26 | class FriedmanHStatisticMethodTestCase(unittest.TestCase): 27 | model = None 28 | X = None 29 | SUBSET = None 30 | SAMPLE_SIZE = 5 31 | 32 | def test_all_features_sampled(self): 33 | # when 34 | h_stat = FriedmanHStatisticMethod() 35 | h_stat.fit(self.model, self.X, SAMPLE_SIZE) 36 | 37 | # then 38 | 39 | # expected columns 40 | self.assertSetEqual(set(h_stat.ova.columns), {"Feature", InteractionMethod.H_STATISTIC}) 41 | self.assertSetEqual(set(h_stat.ovo.columns), {"Feature 1", "Feature 2", InteractionMethod.H_STATISTIC}) 42 | 43 | # ova calculated for all columns 44 | self.assertSetEqual(set(self.X.columns), set(h_stat.ova["Feature"])) 45 | 46 | # sample size taken into account 47 | self.assertEqual(len(h_stat.X_sampled), SAMPLE_SIZE) 48 | 49 | def test_subset_of_features_sampled(self): 50 | # when 51 | h_stat = FriedmanHStatisticMethod() 52 | h_stat.fit(self.model, self.X, SAMPLE_SIZE, features=self.SUBSET) 53 | 54 | # then 55 | 56 | # features parameter taken into account 57 | self.assertEqual(len(h_stat.ova), 4) 58 | self.assertEqual(len(h_stat.ovo), 6) 59 | self.assertEqual(h_stat.features_included, self.SUBSET) 60 | 61 | # sample size taken into account 62 | self.assertEqual(len(h_stat.X_sampled), SAMPLE_SIZE) 63 | 64 | def test_decreasing_order(self): 65 | # when 66 | h_stat = FriedmanHStatisticMethod() 67 | h_stat.fit(self.model, self.X, SAMPLE_SIZE) 68 | 69 | # then 70 | ovo_vals = list(h_stat.ovo[InteractionMethod.H_STATISTIC]) 71 | ova_vals = list(h_stat.ova[InteractionMethod.H_STATISTIC]) 72 | 73 | # both ovo and ova have values sorted in decreasing order 74 | self.assertTrue(has_decreasing_order(ovo_vals)) 75 | self.assertTrue(has_decreasing_order(ova_vals)) 76 | 77 | def test_plot(self): 78 | # when 79 | h_stat = FriedmanHStatisticMethod() 80 | h_stat.fit(self.model, self.X, SAMPLE_SIZE, features=self.SUBSET) 81 | 82 | # allowed plots are generated without exception 83 | accepted_vis = VisualizationConfigurationProvider.get(InteractionMethod.H_STATISTIC).accepted_visualizations 84 | for vis in accepted_vis: 85 | h_stat.plot(vis, show=False) 86 | 87 | # then 88 | # nothing crashes! 89 | 90 | if __name__ == '__main__': 91 | unittest.main() 92 | -------------------------------------------------------------------------------- /test/test_greenwell_inter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from parameterized import parameterized_class 3 | 4 | from .util import california_housing_random_forest, has_decreasing_order, CALIFORNIA_SUBSET, SAMPLE_SIZE, wine_random_forest, WINE_SUBSET 5 | from artemis._utilities.domain import InteractionMethod, VisualizationType 6 | from artemis.interactions_methods.model_agnostic import GreenwellMethod 7 | from artemis._utilities.exceptions import VisualizationNotSupportedException 8 | from artemis.visualizer._configuration import VisualizationConfigurationProvider 9 | 10 | MODEL_REG, X_REG, _ = california_housing_random_forest() 11 | MODEL_CLS, X_CLS, _ = wine_random_forest() 12 | 13 | 14 | @parameterized_class([ 15 | { 16 | "model": MODEL_REG, 17 | "X": X_REG, 18 | "SUBSET": CALIFORNIA_SUBSET 19 | }, 20 | { 21 | "model": MODEL_CLS, 22 | "X": X_CLS, 23 | "SUBSET": WINE_SUBSET 24 | }, 25 | ]) 26 | class GreenwellMethodUnitTest(unittest.TestCase): 27 | model = None 28 | X = None 29 | SUBSET = None 30 | 31 | def test_all_features_sampled(self): 32 | # when 33 | greenwell_inter = GreenwellMethod() 34 | greenwell_inter.fit(self.model, self.X, SAMPLE_SIZE) 35 | 36 | # then 37 | 38 | # expected columns 39 | self.assertSetEqual(set(greenwell_inter.ovo.columns), 40 | {"Feature 1", "Feature 2", InteractionMethod.VARIABLE_INTERACTION}) 41 | 42 | # sample size taken into account 43 | self.assertEqual(len(greenwell_inter.X_sampled), SAMPLE_SIZE) 44 | 45 | # feature importance calculated 46 | self.assertIsNotNone(greenwell_inter.feature_importance) 47 | 48 | def test_subset_of_features_sampled(self): 49 | # when 50 | greenwell_inter = GreenwellMethod() 51 | greenwell_inter.fit(self.model, self.X, SAMPLE_SIZE, features=self.SUBSET) 52 | 53 | # then 54 | 55 | # features parameter taken into account 56 | self.assertEqual(len(greenwell_inter.ovo), 6) 57 | self.assertEqual(greenwell_inter.features_included, self.SUBSET) 58 | 59 | # sample size taken into account 60 | self.assertEqual(len(greenwell_inter.X_sampled), SAMPLE_SIZE) 61 | 62 | def test_decreasing_order(self): 63 | # when 64 | greenwell_inter = GreenwellMethod() 65 | greenwell_inter.fit(self.model, self.X, SAMPLE_SIZE) 66 | 67 | # then 68 | ovo_vals = list(greenwell_inter.ovo[InteractionMethod.VARIABLE_INTERACTION]) 69 | 70 | # ovo have values sorted in decreasing order 71 | self.assertTrue(has_decreasing_order(ovo_vals)) 72 | 73 | def test_plot(self): 74 | # when 75 | greenwell_inter = GreenwellMethod() 76 | greenwell_inter.fit(self.model, self.X, SAMPLE_SIZE, features=self.SUBSET) 77 | 78 | # allowed plots are generated without exception 79 | accepted_vis = VisualizationConfigurationProvider.get( 80 | InteractionMethod.VARIABLE_INTERACTION).accepted_visualizations 81 | for vis in accepted_vis: 82 | greenwell_inter.plot(vis, show=False) 83 | 84 | # then 85 | # nothing crashes! 86 | 87 | def test_should_raise_VisualizationNotSupportedException(self): 88 | # when 89 | greenwell_inter = GreenwellMethod() 90 | greenwell_inter.fit(self.model, self.X, SAMPLE_SIZE, features=self.SUBSET) 91 | 92 | # barchart is not supported for greenwell (no OvA), so this should raise VisualizationNotSupportedException 93 | with self.assertRaises(VisualizationNotSupportedException): 94 | greenwell_inter.plot(VisualizationType.BAR_CHART_OVA) 95 | 96 | 97 | if __name__ == '__main__': 98 | unittest.main() 99 | -------------------------------------------------------------------------------- /test/test_method_comparator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from parameterized import parameterized_class 3 | 4 | from artemis.comparison import FeatureInteractionMethodComparator 5 | from artemis.interactions_methods.model_agnostic import FriedmanHStatisticMethod, GreenwellMethod 6 | from artemis._utilities.exceptions import MethodNotFittedException 7 | from test.util import california_housing_random_forest, wine_random_forest 8 | 9 | 10 | MODEL_REG, X_REG, _ = california_housing_random_forest() 11 | MODEL_CLS, X_CLS, _ = wine_random_forest() 12 | 13 | 14 | @parameterized_class([ 15 | { 16 | "model": MODEL_REG, 17 | "X": X_REG 18 | }, 19 | { 20 | "model": MODEL_CLS, 21 | "X": X_CLS 22 | }, 23 | ]) 24 | class MethodComparatorUnitTest(unittest.TestCase): 25 | model = None 26 | X = None 27 | 28 | def setUp(self) -> None: 29 | self.method_1, self.method_2 = FriedmanHStatisticMethod(), GreenwellMethod() 30 | 31 | self.method_1_fitted, self.method_2_fitted = FriedmanHStatisticMethod(), GreenwellMethod() 32 | 33 | self.method_1_fitted.fit(self.model, self.X, n=10) 34 | self.method_2_fitted.fit(self.model, self.X, n=10) 35 | 36 | def test_method_not_fitted_exception(self): 37 | comparator = FeatureInteractionMethodComparator() 38 | 39 | with self.assertRaises(MethodNotFittedException): 40 | comparator.summary(self.method_1, self.method_2_fitted) 41 | 42 | with self.assertRaises(MethodNotFittedException): 43 | comparator.summary(self.method_1_fitted, self.method_2) 44 | 45 | def test_should_calculate_correlations(self): 46 | comparator = FeatureInteractionMethodComparator() 47 | 48 | correlations = comparator.correlations(self.method_1_fitted, self.method_2_fitted) 49 | 50 | self.assertSetEqual(set(correlations["method"]), {"pearson", "kendall", "spearman"}) 51 | 52 | def test_should_calculate_comparison_plot(self): 53 | comparator = FeatureInteractionMethodComparator() 54 | 55 | fig, ax = comparator.comparison_plot(self.method_1_fitted, self.method_2_fitted) 56 | 57 | self.assertIsNotNone(fig) 58 | self.assertIsNotNone(ax) 59 | 60 | 61 | if __name__ == '__main__': 62 | unittest.main() 63 | -------------------------------------------------------------------------------- /test/test_pdp_calculator.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | import unittest 3 | import numpy as np 4 | from parameterized import parameterized_class 5 | from artemis._utilities.ops import get_predict_function 6 | 7 | from artemis._utilities.pd_calculator import PartialDependenceCalculator 8 | 9 | from .util import california_housing_random_forest, has_decreasing_order, CALIFORNIA_SUBSET, SAMPLE_SIZE, wine_random_forest, WINE_SUBSET 10 | from artemis._utilities.domain import InteractionMethod 11 | from artemis.interactions_methods.model_agnostic import FriedmanHStatisticMethod 12 | from artemis.visualizer._configuration import VisualizationConfigurationProvider 13 | 14 | 15 | MODEL_REG, X_REG, _ = california_housing_random_forest() 16 | MODEL_CLS, X_CLS, _ = wine_random_forest() 17 | 18 | 19 | @parameterized_class([ 20 | { 21 | "model": MODEL_REG, 22 | "X": X_REG, 23 | "SUBSET": CALIFORNIA_SUBSET 24 | }, 25 | { 26 | "model": MODEL_CLS, 27 | "X": X_CLS, 28 | "SUBSET": WINE_SUBSET 29 | }, 30 | ]) 31 | class PartialDependenceCalculatorTestCase(unittest.TestCase): 32 | model = None 33 | X = None 34 | SUBSET = None 35 | SAMPLE_SIZE = 5 36 | 37 | def test_all_features(self): 38 | X_sampled = self.X.sample(SAMPLE_SIZE) 39 | # when 40 | pdp_calc = PartialDependenceCalculator(self.model, X_sampled, get_predict_function(self.model)) 41 | pdp_calc.calculate_pd_pairs() 42 | pdp_calc.calculate_pd_single() 43 | pdp_calc.calculate_pd_minus_single() 44 | 45 | # then 46 | 47 | # expected columns 48 | self.assertSetEqual(set(pdp_calc.pd_single.keys()), set(X_sampled.columns)) 49 | self.assertSetEqual(set(pdp_calc.pd_minus_single.keys()), set(X_sampled.columns)) 50 | self.assertSetEqual(set(pdp_calc.pd_pairs.keys()), set(combinations(X_sampled.columns, 2))) 51 | 52 | # expect non nan values 53 | for var in X_sampled.columns: 54 | self.assertFalse(np.isnan(pdp_calc.get_pd_single(var)).any()) 55 | self.assertFalse(np.isnan(pdp_calc.get_pd_minus_single(var)).any()) 56 | 57 | for var1, var2 in combinations(X_sampled.columns, 2): 58 | self.assertFalse(np.isnan(pdp_calc.get_pd_pairs(var1, var2)).any()) 59 | 60 | def test_subset_of_features(self): 61 | X_sampled = self.X.sample(SAMPLE_SIZE) 62 | # when 63 | pdp_calc = PartialDependenceCalculator(self.model, X_sampled, get_predict_function(self.model)) 64 | pdp_calc.calculate_pd_pairs(feature_pairs = combinations(self.SUBSET, 2)) 65 | pdp_calc.calculate_pd_single(features=self.SUBSET) 66 | pdp_calc.calculate_pd_minus_single(features=self.SUBSET) 67 | 68 | # then 69 | 70 | # expected columns 71 | self.assertSetEqual(set(pdp_calc.pd_single.keys()), set(X_sampled.columns)) 72 | self.assertSetEqual(set(pdp_calc.pd_minus_single.keys()), set(X_sampled.columns)) 73 | self.assertSetEqual(set(pdp_calc.pd_pairs.keys()), set(combinations(X_sampled.columns, 2))) 74 | 75 | # expect non nan values in subset 76 | for var in self.SUBSET: 77 | self.assertFalse(np.isnan(pdp_calc.get_pd_single(var)).any()) 78 | self.assertFalse(np.isnan(pdp_calc.get_pd_minus_single(var)).any()) 79 | 80 | for var1, var2 in combinations(self.SUBSET, 2): 81 | self.assertFalse(np.isnan(pdp_calc.get_pd_pairs(var1, var2)).any()) 82 | 83 | # expect nan values in other features 84 | for var in X_sampled.columns: 85 | if var not in self.SUBSET: 86 | self.assertTrue(np.isnan(pdp_calc.get_pd_single(var)).all()) 87 | self.assertTrue(np.isnan(pdp_calc.get_pd_minus_single(var)).all()) 88 | 89 | for var1, var2 in combinations(X_sampled.columns, 2): 90 | if var1 not in self.SUBSET or var2 not in self.SUBSET: 91 | self.assertTrue(np.isnan(pdp_calc.get_pd_pairs(var1, var2)).all()) 92 | 93 | if __name__ == '__main__': 94 | unittest.main() 95 | -------------------------------------------------------------------------------- /test/test_sejong_oh_inter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from parameterized import parameterized_class 3 | 4 | from artemis._utilities.domain import InteractionMethod 5 | from artemis.interactions_methods.model_agnostic import SejongOhMethod 6 | from artemis.visualizer._configuration import VisualizationConfigurationProvider 7 | from .util import california_housing_random_forest, SAMPLE_SIZE, N_REPEAT, CALIFORNIA_SUBSET, has_decreasing_order, wine_random_forest, WINE_SUBSET 8 | 9 | 10 | MODEL_REG, X_REG, Y_REG = california_housing_random_forest() 11 | MODEL_CLS, X_CLS, Y_CLS = wine_random_forest() 12 | 13 | 14 | @parameterized_class([ 15 | { 16 | "model": MODEL_REG, 17 | "X": X_REG, 18 | "y": Y_REG, 19 | "SUBSET": CALIFORNIA_SUBSET 20 | }, 21 | { 22 | "model": MODEL_CLS, 23 | "X": X_CLS, 24 | "y": Y_CLS, 25 | "SUBSET": WINE_SUBSET 26 | }, 27 | ]) 28 | class SejongOhMethodTestCase(unittest.TestCase): 29 | model = None 30 | X = None 31 | y = None 32 | SUBSET = None 33 | 34 | def test_all_features_sampled(self): 35 | # when 36 | sejong_oh_inter = SejongOhMethod() 37 | sejong_oh_inter.fit(self.model, self.X, self.y, SAMPLE_SIZE, n_repeat=N_REPEAT) 38 | 39 | # then 40 | 41 | # expected columns 42 | self.assertSetEqual(set(sejong_oh_inter.ovo.columns), {"Feature 1", "Feature 2", InteractionMethod.PERFORMANCE_BASED}) 43 | 44 | # sample size taken into account 45 | self.assertEqual(len(sejong_oh_inter.X_sampled), SAMPLE_SIZE) 46 | 47 | # feature importance calculated 48 | self.assertIsNotNone(sejong_oh_inter.feature_importance) 49 | 50 | def test_subset_of_features_sampled(self): 51 | # when 52 | sejong_oh_inter = SejongOhMethod() 53 | sejong_oh_inter.fit(self.model, self.X, self.y, SAMPLE_SIZE, features=self.SUBSET) 54 | 55 | # then 56 | 57 | # features parameter taken into account 58 | self.assertEqual(len(sejong_oh_inter.ovo), 6) 59 | self.assertEqual(sejong_oh_inter.features_included, self.SUBSET) 60 | 61 | # sample size taken into account 62 | self.assertEqual(len(sejong_oh_inter.X_sampled), SAMPLE_SIZE) 63 | 64 | def test_decreasing_order(self): 65 | # when 66 | sejong_oh_inter = SejongOhMethod() 67 | sejong_oh_inter.fit(self.model, self.X, self.y, SAMPLE_SIZE) 68 | 69 | # then 70 | ovo_vals = list(sejong_oh_inter.ovo[InteractionMethod.PERFORMANCE_BASED].abs()) 71 | 72 | # ovo have values sorted in decreasing order 73 | self.assertTrue(has_decreasing_order(ovo_vals)) 74 | 75 | def test_plot(self): 76 | # when 77 | sejong_oh_inter = SejongOhMethod() 78 | sejong_oh_inter.fit(self.model, self.X, self.y, SAMPLE_SIZE, features=self.SUBSET) 79 | 80 | # allowed plots are generated without exception 81 | accepted_vis = VisualizationConfigurationProvider.get(InteractionMethod.PERFORMANCE_BASED).accepted_visualizations 82 | for vis in accepted_vis: 83 | sejong_oh_inter.plot(vis, show=False) 84 | 85 | # then 86 | # nothing crashes! 87 | 88 | 89 | if __name__ == '__main__': 90 | unittest.main() 91 | -------------------------------------------------------------------------------- /test/test_split_score_inter.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from parameterized import parameterized_class 3 | 4 | from .util import california_housing_random_forest, california_housing_boosting_models, has_decreasing_order, wine_random_forest, wine_boosting_models 5 | from artemis._utilities.domain import InteractionMethod, VisualizationType 6 | from artemis._utilities.split_score_metrics import SplitScoreInteractionMetric, SplitScoreImportanceMetric, _LGBM_UNSUPPORTED_METRICS 7 | from artemis.interactions_methods.model_specific import SplitScoreMethod 8 | from artemis._utilities.exceptions import VisualizationNotSupportedException, MetricNotSupportedException, ModelNotSupportedException 9 | from artemis.visualizer._configuration import VisualizationConfigurationProvider 10 | from dataclasses import fields 11 | 12 | MODEL_REG, X_REG, Y_REG = california_housing_random_forest() 13 | MODEL_CLS, X_CLS, Y_CLS = wine_random_forest() 14 | MODEL_XGB_REG, MODEL_LGBM_REG, MODEL_XGB_BIS_REG, MODEL_LGBM_BIS_REG, _, _ = california_housing_boosting_models() 15 | MODEL_XGB_CLS, MODEL_LGBM_CLS, MODEL_XGB_BIS_CLS, MODEL_LGBM_BIS_CLS, _, _ = wine_boosting_models() 16 | 17 | 18 | @parameterized_class([ 19 | { 20 | "model_rf": MODEL_REG, 21 | "model_xgb": MODEL_XGB_REG, 22 | "model_lgbm": MODEL_LGBM_REG, 23 | "model_xgb_bis": MODEL_XGB_BIS_REG, 24 | "model_lgbm_bis": MODEL_LGBM_BIS_REG, 25 | }, 26 | { 27 | "model_rf": MODEL_CLS, 28 | "model_xgb": MODEL_XGB_CLS, 29 | "model_lgbm": MODEL_LGBM_CLS, 30 | "model_xgb_bis": MODEL_XGB_BIS_CLS, 31 | "model_lgbm_bis": MODEL_LGBM_BIS_CLS, 32 | }, 33 | ]) 34 | class SplitScoreMethodUnitTest(unittest.TestCase): 35 | model_rf = None 36 | model_xgb = None 37 | model_lgbm = None 38 | model_xgb_bis = None 39 | model_lgbm_bis = None 40 | 41 | def test_all_metric_combinations_xgb(self): 42 | for int_method in fields(SplitScoreInteractionMetric): 43 | for imp_method in fields(SplitScoreImportanceMetric): 44 | # when 45 | inter = SplitScoreMethod() 46 | inter.fit(self.model_xgb, 47 | interaction_selected_metric=int_method.default, 48 | importance_selected_metric=imp_method.default) 49 | 50 | # then 51 | 52 | # expected columns 53 | self.assertListEqual(list(inter.ovo.columns), ["Feature 1", "Feature 2", InteractionMethod.SPLIT_SCORE]) 54 | self.assertListEqual(list(inter.feature_importance.columns), ["Feature", "Importance"]) 55 | 56 | # feature importance calculated 57 | self.assertIsNotNone(inter.full_ovo) 58 | self.assertIsNotNone(inter.full_result) 59 | 60 | def test_all_metric_combinations_lgbm(self): 61 | for int_method in fields(SplitScoreInteractionMetric): 62 | for imp_method in fields(SplitScoreImportanceMetric): 63 | # when 64 | if int_method.default in _LGBM_UNSUPPORTED_METRICS or imp_method.default in _LGBM_UNSUPPORTED_METRICS: 65 | # then expect exception 66 | with self.assertRaises(MetricNotSupportedException): 67 | inter = SplitScoreMethod() 68 | inter.fit(self.model_lgbm, 69 | interaction_selected_metric=int_method.default, 70 | importance_selected_metric=imp_method.default) 71 | else: 72 | inter = SplitScoreMethod() 73 | inter.fit(self.model_lgbm, 74 | interaction_selected_metric=int_method.default, 75 | importance_selected_metric=imp_method.default) 76 | 77 | # expected columns 78 | self.assertListEqual(list(inter.ovo.columns), 79 | ["Feature 1", "Feature 2", InteractionMethod.SPLIT_SCORE]) 80 | self.assertListEqual(list(inter.feature_importance.columns), ["Feature", "Importance"]) 81 | 82 | # feature importance calculated 83 | self.assertIsNotNone(inter.full_ovo) 84 | self.assertIsNotNone(inter.full_result) 85 | 86 | def test_decreasing_order(self): 87 | # when 88 | inter = SplitScoreMethod() 89 | inter.fit(self.model_xgb) 90 | inter2 = SplitScoreMethod() 91 | inter2.fit(self.model_lgbm) 92 | 93 | # then 94 | ovo_vals = list(inter.ovo[InteractionMethod.SPLIT_SCORE]) 95 | ovo_vals2 = list(inter2.ovo[InteractionMethod.SPLIT_SCORE]) 96 | 97 | # ovo have values sorted in decreasing order 98 | self.assertTrue(has_decreasing_order(ovo_vals)) 99 | self.assertTrue(has_decreasing_order(ovo_vals2)) 100 | 101 | def test_plot(self): 102 | # when 103 | inter = SplitScoreMethod() 104 | inter.fit(self.model_xgb_bis) 105 | inter2 = SplitScoreMethod() 106 | inter2.fit(self.model_lgbm_bis) 107 | # allowed plots are generated without exception 108 | accepted_vis = VisualizationConfigurationProvider.get(InteractionMethod.SPLIT_SCORE).accepted_visualizations 109 | for vis in accepted_vis: 110 | inter.plot(vis, show=False) 111 | inter2.plot(vis, show=False) 112 | # then 113 | # nothing crashes! 114 | 115 | def test_progress_bar(self): 116 | # when progress bar i shown 117 | inter = SplitScoreMethod() 118 | inter.fit(self.model_xgb, show_progress=True) 119 | inter.fit(self.model_lgbm, show_progress=True) 120 | # then 121 | # nothing crashes! 122 | 123 | def test_not_only_def_interactions(self): 124 | # when not only interactions by definition are calculated 125 | inter = SplitScoreMethod() 126 | inter.fit(self.model_xgb, only_def_interactions=False) 127 | inter.fit(self.model_lgbm, only_def_interactions=False) 128 | # then 129 | # nothing crashes! 130 | 131 | def test_should_raise_VisualizationNotSupportedException(self): 132 | # when 133 | inter = SplitScoreMethod() 134 | inter.fit(self.model_xgb) 135 | inter2 = SplitScoreMethod() 136 | inter2.fit(self.model_lgbm) 137 | 138 | # barchart (OvA) is not supported 139 | with self.assertRaises(VisualizationNotSupportedException): 140 | inter.plot(VisualizationType.BAR_CHART_OVA) 141 | with self.assertRaises(VisualizationNotSupportedException): 142 | inter2.plot(VisualizationType.BAR_CHART_OVA) 143 | 144 | def test_should_raise_ModelNotSupportedException(self): 145 | with self.assertRaises(ModelNotSupportedException): 146 | inter = SplitScoreMethod() 147 | inter.fit(self.model_rf) 148 | 149 | 150 | if __name__ == '__main__': 151 | unittest.main() 152 | -------------------------------------------------------------------------------- /test/test_variable_importance.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from parameterized import parameterized_class 3 | 4 | import pandas as pd 5 | from pandas.testing import assert_frame_equal 6 | 7 | from artemis.importance_methods.model_agnostic import PermutationImportance, PartialDependenceBasedImportance 8 | from artemis.interactions_methods.model_agnostic import FriedmanHStatisticMethod 9 | from test.util import toy_input_reg, toy_input_cls 10 | 11 | 12 | MODEL_REG, X_REG, Y_REG = toy_input_reg() 13 | MODEL_CLS, X_CLS, Y_CLS = toy_input_cls() 14 | 15 | 16 | @parameterized_class([ 17 | { 18 | "model": MODEL_REG, 19 | "X": X_REG, 20 | "y": Y_REG, 21 | }, 22 | { 23 | "model": MODEL_CLS, 24 | "X": X_CLS, 25 | "y": Y_CLS, 26 | }, 27 | ]) 28 | class VariableImportanceUnitTest(unittest.TestCase): 29 | model = None 30 | X = None 31 | y = None 32 | 33 | def test_calculate_permutation_feature_importance(self): 34 | calculator = PermutationImportance() 35 | importance = calculator.importance(self.model, self.X, self.y, features=list(self.X.columns)) 36 | 37 | self._assert_var_imp_calculated_correctly(importance) 38 | 39 | def test_calculate_pdp_based_feature_importance(self): 40 | calculator = PartialDependenceBasedImportance() 41 | importance = calculator.importance(self.model, self.X, features=list(self.X.columns)) 42 | self._assert_var_imp_calculated_correctly(importance) 43 | 44 | def test_use_feature_importance_in_pdp_method(self): 45 | importance_single = PartialDependenceBasedImportance().importance(self.model, self.X, 46 | features=list(self.X.columns)) 47 | 48 | h_stat = FriedmanHStatisticMethod() 49 | h_stat.fit(self.model, self.X) 50 | importance_h_stat = h_stat.feature_importance 51 | 52 | assert_frame_equal(importance_h_stat, importance_single, rtol=1e-1) # up to first decimal point 53 | 54 | def _assert_var_imp_calculated_correctly(self, importance): 55 | self.assertEqual(type(importance), pd.DataFrame) # resulting type - dataframe 56 | self.assertSetEqual(set(importance["Feature"]), 57 | set(self.X.columns)) # var imp for all features is calculated 58 | self.assertGreater(importance[importance["Feature"] == "important_feature"]["Importance"].values[0], 59 | importance[importance["Feature"] == "noise_feature"]["Importance"].values[0]) 60 | 61 | 62 | if __name__ == '__main__': 63 | unittest.main() 64 | -------------------------------------------------------------------------------- /test/util.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from lightgbm import LGBMRegressor, LGBMClassifier 3 | from sklearn.datasets import fetch_california_housing, load_wine 4 | from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier 5 | from sklearn.linear_model import LinearRegression, LogisticRegression 6 | from xgboost import XGBRegressor, XGBClassifier 7 | 8 | CALIFORNIA_SUBSET = ["Longitude", "Latitude", "MedInc", "AveRooms"] 9 | WINE_SUBSET = ["alcohol", "malic_acid", "ash", "alcalinity_of_ash"] 10 | SAMPLE_SIZE = 5 11 | N_REPEAT = 3 12 | N = 100 13 | 14 | 15 | def california_housing_random_forest(max_depth: int = 6, n_estimators: int = 80): 16 | california = fetch_california_housing() 17 | X = pd.DataFrame(california.data, columns=california.feature_names) 18 | y = california.target 19 | model = RandomForestRegressor(max_depth=max_depth, n_estimators=n_estimators).fit(X, y) 20 | return model, X, y 21 | 22 | 23 | def wine_random_forest(max_depth: int = 6, n_estimators: int = 80): 24 | wine = load_wine() 25 | X = pd.DataFrame(wine.data, columns=wine.feature_names) 26 | y = wine.target 27 | model = RandomForestClassifier(max_depth=max_depth, n_estimators=n_estimators).fit(X, y) 28 | return model, X, y 29 | 30 | 31 | def california_housing_boosting_models(): 32 | california = fetch_california_housing() 33 | X = pd.DataFrame(california.data, columns=california.feature_names) 34 | y = california.target 35 | model_xgb = XGBRegressor(n_estimators=10, max_depth=4).fit(X, y) 36 | model_lgbm = LGBMRegressor(n_estimators=10, max_depth=4).fit(X, y) 37 | model_xgb_bis = XGBRegressor(n_estimators=40, max_depth=8).fit(X.iloc[:, :3], y) 38 | model_lgbm_bis = LGBMRegressor(n_estimators=40, max_depth=8).fit(X.iloc[:, :3], y) 39 | return model_xgb, model_lgbm, model_xgb_bis, model_lgbm_bis, X, y 40 | 41 | 42 | def wine_boosting_models(): 43 | wine = load_wine() 44 | X = pd.DataFrame(wine.data, columns=wine.feature_names)[:N] 45 | y = wine.target[:N] 46 | model_xgb = XGBClassifier(n_estimators=10, max_depth=4).fit(X, y) 47 | model_lgbm = LGBMClassifier(n_estimators=10, max_depth=4).fit(X, y) 48 | model_xgb_bis = XGBClassifier(n_estimators=40, max_depth=8).fit(X.iloc[:, :3], y) 49 | model_lgbm_bis = LGBMClassifier(n_estimators=40, max_depth=8).fit(X.iloc[:, :3], y) 50 | return model_xgb, model_lgbm, model_xgb_bis, model_lgbm_bis, X, y 51 | 52 | 53 | def california_housing_linear_regression(): 54 | california = fetch_california_housing() 55 | X = pd.DataFrame(california.data, columns=california.feature_names) 56 | y = california.target 57 | model = LinearRegression().fit(X, y) 58 | return model, X, y 59 | 60 | 61 | def wine_logistic_regression(): 62 | wine = load_wine() 63 | X = pd.DataFrame(wine.data, columns=wine.feature_names) 64 | y = wine.target 65 | model = LogisticRegression().fit(X, y) 66 | return model, X, y 67 | 68 | 69 | def toy_input_reg(): 70 | target = list(range(N)) 71 | X = pd.DataFrame({"important_feature": target, "noise_feature": [1 for _ in range(N)]}) 72 | y = target 73 | model = RandomForestRegressor().fit(X, y) 74 | 75 | return model, X, y 76 | 77 | def toy_input_cls(): 78 | target = [0, 1] * (N // 2) 79 | X = pd.DataFrame({"important_feature": target, "noise_feature": [1 for _ in range(len(target))]}) 80 | y = target 81 | model = RandomForestClassifier().fit(X, y) 82 | 83 | return model, X, y 84 | 85 | 86 | def has_decreasing_order(vals): 87 | return all(earlier >= later for earlier, later in zip(vals, vals[1:])) 88 | --------------------------------------------------------------------------------