├── .python-version ├── .heroku └── run.sh ├── Procfile ├── .DS_Store ├── assets ├── custom.png ├── favicon.ico ├── port.jpeg ├── ticket.jpeg └── titanic.jpeg ├── requirements.txt ├── pkls ├── reg_explainer.joblib ├── clas_explainer.joblib └── multi_explainer.joblib ├── Makefile ├── LICENSE ├── README.md ├── .gitignore ├── generate_explainers.py ├── dashboard.py ├── custom.py └── index_layout.py /.python-version: -------------------------------------------------------------------------------- 1 | 3.11 2 | -------------------------------------------------------------------------------- /.heroku/run.sh: -------------------------------------------------------------------------------- 1 | pip uninstall -y xgboost -------------------------------------------------------------------------------- /Procfile: -------------------------------------------------------------------------------- 1 | web: gunicorn --preload --timeout 60 -w 3 dashboard:app 2 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oegedijk/explainingtitanic/HEAD/.DS_Store -------------------------------------------------------------------------------- /assets/custom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oegedijk/explainingtitanic/HEAD/assets/custom.png -------------------------------------------------------------------------------- /assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oegedijk/explainingtitanic/HEAD/assets/favicon.ico -------------------------------------------------------------------------------- /assets/port.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oegedijk/explainingtitanic/HEAD/assets/port.jpeg -------------------------------------------------------------------------------- /assets/ticket.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oegedijk/explainingtitanic/HEAD/assets/ticket.jpeg -------------------------------------------------------------------------------- /assets/titanic.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oegedijk/explainingtitanic/HEAD/assets/titanic.jpeg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | explainerdashboard==0.4.7 2 | pandas>=1.3.0 3 | joblib 4 | gunicorn 5 | requests 6 | -------------------------------------------------------------------------------- /pkls/reg_explainer.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oegedijk/explainingtitanic/HEAD/pkls/reg_explainer.joblib -------------------------------------------------------------------------------- /pkls/clas_explainer.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oegedijk/explainingtitanic/HEAD/pkls/clas_explainer.joblib -------------------------------------------------------------------------------- /pkls/multi_explainer.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oegedijk/explainingtitanic/HEAD/pkls/multi_explainer.joblib -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .ONESHELL: 2 | SHELL := /bin/bash 3 | 4 | all: build 5 | 6 | build: 7 | python generate_explainers.py 8 | 9 | run: 10 | gunicorn --preload dashboard:app -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Oege Dijk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # explainingtitanic 2 | Demonstration of [explainerdashboard](http://www.github.com/oegedijk/explainerdashboard) package. 3 | 4 | A Dash dashboard app that that displays model quality, permutation importances, SHAP values and interactions, and individual trees for sklearn compatible models. 5 | 6 | ## Installation 7 | install with `pip install explainerdashoard` 8 | 9 | ## Github 10 | 11 | [www.github.com/oegedijk/explainerdashboard](http://www.github.com/oegedijk/explainerdashboard) 12 | 13 | ## graphviz buildpack 14 | 15 | In order to enable graphviz on heroku enable the following buildpack: 16 | 17 | [https://github.com/weibeld/heroku-buildpack-graphviz.git](https://github.com/weibeld/heroku-buildpack-graphviz.git) 18 | 19 | ## uninstallng xgboost 20 | 21 | dtreeviz comes with a xgboost dependency that takes a lot of space, making your slug size >500MB. 22 | To uninstall it, first enable the shell buildpack: https://github.com/niteoweb/heroku-buildpack-shell.git 23 | 24 | and then add `pip uninstall -y xgboost` to `.heroku/run.sh` 25 | ## Documentation 26 | 27 | [explainerdashboard.readthedocs.io](http://explainerdashboard.readthedocs.io). 28 | 29 | Example [notebook](http://www.github.com/oegedijk/explainerdashboard/dashboard_examples.ipynb). 30 | 31 | ## Heroku deployment 32 | 33 | Deployed at [titanicexplainer.herokuapp.com](http://titanicexplainer.herokuapp.com) 34 | 35 | Automatically deploys with each commit or merge to master. 36 | 37 | 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | # .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | scratch_notebook.ipynb 108 | scratch_imports.py 109 | .DS_Store 110 | #pkls/* 111 | -------------------------------------------------------------------------------- /generate_explainers.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor 3 | 4 | from explainerdashboard import ( 5 | ClassifierExplainer, 6 | RegressionExplainer, 7 | ExplainerDashboard, 8 | ) 9 | from explainerdashboard.datasets import * 10 | 11 | pkl_dir = Path.cwd() / "pkls" 12 | 13 | # classifier 14 | print("Generating titanic explainers") 15 | print("Generating classifier explainer") 16 | X_train, y_train, X_test, y_test = titanic_survive() 17 | model = RandomForestClassifier(n_estimators=50, max_depth=5).fit(X_train, y_train) 18 | clas_explainer = ClassifierExplainer( 19 | model, 20 | X_test, 21 | y_test, 22 | cats=["Sex", "Deck", "Embarked"], 23 | descriptions=feature_descriptions, 24 | labels=["Not survived", "Survived"], 25 | ) 26 | _ = ExplainerDashboard(clas_explainer) 27 | clas_explainer.dump(pkl_dir / "clas_explainer.joblib") 28 | 29 | 30 | # regression 31 | print("Generating titanic fare explainer") 32 | X_train, y_train, X_test, y_test = titanic_fare() 33 | model = RandomForestRegressor(n_estimators=50, max_depth=5).fit(X_train, y_train) 34 | reg_explainer = RegressionExplainer( 35 | model, 36 | X_test, 37 | y_test, 38 | cats=["Sex", "Deck", "Embarked"], 39 | descriptions=feature_descriptions, 40 | units="$", 41 | ) 42 | _ = ExplainerDashboard(reg_explainer) 43 | reg_explainer.dump(pkl_dir / "reg_explainer.joblib") 44 | 45 | # multiclass 46 | print("Generating titanic embarked multiclass explainer") 47 | X_train, y_train, X_test, y_test = titanic_embarked() 48 | model = RandomForestClassifier(n_estimators=50, max_depth=5).fit(X_train, y_train) 49 | multi_explainer = ClassifierExplainer( 50 | model, 51 | X_test, 52 | y_test, 53 | cats=["Sex", "Deck"], 54 | descriptions=feature_descriptions, 55 | labels=["Queenstown", "Southampton", "Cherbourg"], 56 | ) 57 | _ = ExplainerDashboard(multi_explainer) 58 | multi_explainer.dump(pkl_dir / "multi_explainer.joblib") 59 | -------------------------------------------------------------------------------- /dashboard.py: -------------------------------------------------------------------------------- 1 | 2 | # xgboost is a dependency of dtreeviz, but too large (>350M) for heroku 3 | # so we uninstall it and mock it here: 4 | from unittest.mock import MagicMock 5 | import sys 6 | sys.modules["xgboost"] = MagicMock() 7 | 8 | from pathlib import Path 9 | from flask import Flask 10 | 11 | import dash 12 | from dash_bootstrap_components.themes import FLATLY, BOOTSTRAP # bootstrap theme 13 | from explainerdashboard import * 14 | 15 | from index_layout import index_layout, register_callbacks 16 | from custom import CustomModelTab, CustomPredictionsTab 17 | 18 | pkl_dir = Path.cwd() / "pkls" 19 | 20 | app = Flask(__name__) 21 | 22 | clas_explainer = ClassifierExplainer.from_file(pkl_dir / "clas_explainer.joblib") 23 | clas_dashboard = ExplainerDashboard(clas_explainer, 24 | title="Classifier Explainer: Predicting survival on the Titanic", 25 | server=app, url_base_pathname="/classifier/", 26 | header_hide_selector=True) 27 | 28 | reg_explainer = RegressionExplainer.from_file(pkl_dir / "reg_explainer.joblib") 29 | reg_dashboard = ExplainerDashboard(reg_explainer, 30 | title="Regression Explainer: Predicting ticket fare", 31 | server=app, url_base_pathname="/regression/") 32 | 33 | multi_explainer = ClassifierExplainer.from_file(pkl_dir / "multi_explainer.joblib") 34 | multi_dashboard = ExplainerDashboard(multi_explainer, 35 | title="Multiclass Explainer: Predicting departure port", 36 | server=app, url_base_pathname="/multiclass/") 37 | 38 | custom_dashboard = ExplainerDashboard(clas_explainer, 39 | [CustomModelTab, CustomPredictionsTab], 40 | title='Titanic Explainer', header_hide_selector=True, 41 | bootstrap=FLATLY, 42 | server=app, url_base_pathname="/custom/") 43 | 44 | simple_classifier_dashboard = ExplainerDashboard(clas_explainer, 45 | title="Simplified Classifier Dashboard", simple=True, 46 | server=app, url_base_pathname="/simple_classifier/") 47 | 48 | simple_regression_dashboard = ExplainerDashboard(reg_explainer, 49 | title="Simplified Classifier Dashboard", simple=True, 50 | server=app, url_base_pathname="/simple_regression/") 51 | 52 | 53 | index_app = dash.Dash( 54 | __name__, 55 | server=app, 56 | url_base_pathname="/", 57 | external_stylesheets=[BOOTSTRAP]) 58 | 59 | index_app.title = 'explainerdashboard' 60 | index_app.layout = index_layout 61 | register_callbacks(index_app) 62 | 63 | @app.route("/") 64 | def index(): 65 | return index_app.index() 66 | 67 | @app.route('/classifier') 68 | def classifier_dashboard(): 69 | return clas_dashboard.app.index() 70 | 71 | @app.route('/regression') 72 | def regression_dashboard(): 73 | return reg_dashboard.app.index() 74 | 75 | @app.route('/multiclass') 76 | def multiclass_dashboard(): 77 | return multi_dashboard.app.index() 78 | 79 | @app.route('/custom') 80 | def custom_dashboard(): 81 | return custom_dashboard.app.index() 82 | 83 | @app.route('/simple_classifier') 84 | def simple_classifier_dashboard(): 85 | return simple_classifier_dashboard.app.index() 86 | 87 | @app.route('/simple_regression') 88 | def simple_regression_dashboard(): 89 | return simple_regression_dashboard.app.index() 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /custom.py: -------------------------------------------------------------------------------- 1 | from explainerdashboard.custom import * 2 | 3 | 4 | class CustomModelTab(ExplainerComponent): 5 | def __init__(self, explainer): 6 | super().__init__(explainer, title="Model Summary", name=None) 7 | self.precision = PrecisionComponent(explainer, 8 | title='Precision', 9 | hide_subtitle=True, hide_footer=True, 10 | hide_selector=True, 11 | cutoff=None) 12 | self.shap_summary = ShapSummaryComponent(explainer, 13 | title='Impact', 14 | hide_subtitle=True, hide_selector=True, 15 | hide_depth=True, depth=8, 16 | hide_cats=True, cats=True) 17 | self.shap_dependence = ShapDependenceComponent(explainer, 18 | title='Dependence', 19 | hide_subtitle=True, hide_selector=True, 20 | hide_cats=True, cats=True, 21 | hide_index=True, 22 | col='Fare', color_col="PassengerClass") 23 | self.connector = ShapSummaryDependenceConnector( 24 | self.shap_summary, self.shap_dependence) 25 | 26 | self.register_components() 27 | 28 | def layout(self): 29 | return dbc.Container([ 30 | dbc.Row([ 31 | dbc.Col([ 32 | html.H3("Model Performance"), 33 | html.Div("As you can see on the right, the model performs quite well."), 34 | html.Div("The higher the predicted probability of survival predicted by " 35 | "the model on the basis of learning from examples in the training set" 36 | ", the higher is the actual percentage of passengers surviving in " 37 | "the test set"), 38 | ], width=4, style=dict(margin=30)), 39 | dbc.Col([ 40 | self.precision.layout() 41 | ], style=dict(margin=30)) 42 | ]), 43 | dbc.Row([ 44 | dbc.Col([ 45 | self.shap_summary.layout() 46 | ], style=dict(margin=30)), 47 | dbc.Col([ 48 | html.H3("Feature Importances"), 49 | html.Div("On the left you can check out for yourself which parameters were the most important."), 50 | html.Div(f"Clearly {self.explainer.columns_ranked_by_shap()[0]} was the most important" 51 | f", followed by {self.explainer.columns_ranked_by_shap()[1]}" 52 | f" and {self.explainer.columns_ranked_by_shap()[2]}."), 53 | html.Div("If you select 'detailed' you can see the impact of that variable on " 54 | "each individual prediction. With 'aggregate' you see the average impact size " 55 | "of that variable on the final prediction."), 56 | html.Div("With the detailed view you can clearly see that the the large impact from Sex " 57 | "stems both from males having a much lower chance of survival and females a much " 58 | "higher chance.") 59 | ], width=4, style=dict(margin=30)), 60 | ]), 61 | dbc.Row([ 62 | dbc.Col([ 63 | html.H3("Feature dependence"), 64 | html.Div("In the plot to the right you can see that the higher the cost " 65 | "of the fare that passengers paid, the higher the chance of survival. " 66 | "Probably the people with more expensive tickets were in higher up cabins, " 67 | "and were more likely to make it to a lifeboat."), 68 | html.Div("When you color the impacts by PassengerClass, you can clearly see that " 69 | "the more expensive tickets were mostly 1st class, and the cheaper tickets " 70 | "mostly 3rd class."), 71 | html.Div("On the right you can check out for yourself how different features impacted " 72 | "the model output."), 73 | ], width=4, style=dict(margin=30)), 74 | dbc.Col([ 75 | self.shap_dependence.layout() 76 | ], style=dict(margin=30)), 77 | ]) 78 | ]) 79 | 80 | class CustomPredictionsTab(ExplainerComponent): 81 | def __init__(self, explainer): 82 | super().__init__(explainer, title="Predictions", name=None) 83 | 84 | self.index = ClassifierRandomIndexComponent(explainer, 85 | hide_title=True, hide_index=False, 86 | hide_slider=True, hide_labels=True, 87 | hide_pred_or_perc=True, 88 | hide_selector=True, hide_button=False) 89 | 90 | self.contributions = ShapContributionsGraphComponent(explainer, 91 | hide_title=True, hide_index=True, 92 | hide_depth=True, hide_sort=True, 93 | hide_orientation=True, hide_cats=True, 94 | hide_selector=True, 95 | sort='importance') 96 | 97 | self.trees = DecisionTreesComponent(explainer, 98 | hide_title=True, hide_index=True, 99 | hide_highlight=True, hide_selector=True) 100 | 101 | 102 | self.connector = IndexConnector(self.index, [self.contributions, self.trees]) 103 | 104 | self.register_components() 105 | 106 | def layout(self): 107 | return dbc.Container([ 108 | dbc.Row([ 109 | dbc.Col([ 110 | html.H3("Enter name:"), 111 | self.index.layout() 112 | ]) 113 | ]), 114 | dbc.Row([ 115 | dbc.Col([ 116 | html.H3("Contributions to prediction:"), 117 | self.contributions.layout() 118 | ]), 119 | 120 | ]), 121 | dbc.Row([ 122 | 123 | dbc.Col([ 124 | html.H3("Every tree in the Random Forest:"), 125 | self.trees.layout() 126 | ]), 127 | ]) 128 | ]) -------------------------------------------------------------------------------- /index_layout.py: -------------------------------------------------------------------------------- 1 | import dash_core_components as dcc 2 | import dash_html_components as html 3 | import dash_bootstrap_components as dbc 4 | 5 | from dash.dependencies import Input, Output, State 6 | 7 | navbar = dbc.NavbarSimple( 8 | children=[ 9 | dbc.DropdownMenu( 10 | children=[ 11 | dbc.DropdownMenuItem( 12 | "github", href="https://github.com/oegedijk/explainingtitanic" 13 | ), 14 | ], 15 | nav=True, 16 | in_navbar=True, 17 | label="Source", 18 | ), 19 | dbc.DropdownMenu( 20 | children=[ 21 | dbc.DropdownMenuItem( 22 | "github", href="https://github.com/oegedijk/explainerdashboard" 23 | ), 24 | dbc.DropdownMenuItem( 25 | "readthedocs", 26 | href="http://explainerdashboard.readthedocs.io/en/latest/", 27 | ), 28 | dbc.DropdownMenuItem( 29 | "pypi", href="https://pypi.org/project/explainerdashboard/" 30 | ), 31 | ], 32 | nav=True, 33 | in_navbar=True, 34 | label="explainerdashboard", 35 | ), 36 | ], 37 | brand="Titanic Explainer", 38 | brand_href="https://github.com/oegedijk/explainingtitanic", 39 | color="primary", 40 | dark=True, 41 | fluid=True, 42 | ) 43 | 44 | survive_card = dbc.Card( 45 | [ 46 | dbc.CardImg(src="assets/titanic.jpeg", top=True), 47 | dbc.CardBody( 48 | [ 49 | html.H4("Classifier Dashboard", className="card-title"), 50 | html.P( 51 | "Predicting the probability of surviving " 52 | "the titanic. Showing the full default dashboard.", 53 | className="card-text", 54 | ), 55 | html.A( 56 | dbc.Button("Go to dashboard", color="primary"), href="/classifier" 57 | ), 58 | dbc.Button("Show Code", id="clas-code-modal-open", className="mr-1"), 59 | dbc.Modal( 60 | [ 61 | dbc.ModalHeader("Code needed for this Classifier Dashboard"), 62 | dcc.Markdown( 63 | """ 64 | ```python 65 | 66 | from sklearn.ensemble import RandomForestClassifier 67 | 68 | from explainerdashboard import ClassifierExplainer, ExplainerDashboard 69 | from explainerdashboard.datasets import titanic_survive, feature_descriptions 70 | 71 | X_train, y_train, X_test, y_test = titanic_survive() 72 | model = RandomForestClassifier(n_estimators=50, max_depth=10).fit(X_train, y_train) 73 | 74 | explainer = ClassifierExplainer(model, X_test, y_test, 75 | cats=['Sex', 'Deck', 'Embarked'], 76 | descriptions=feature_descriptions, 77 | labels=['Not survived', 'Survived']) 78 | 79 | ExplainerDashboard(explainer).run() 80 | ``` 81 | """ 82 | ), 83 | dbc.ModalFooter( 84 | dbc.Button( 85 | "Close", id="clas-code-modal-close", className="ml-auto" 86 | ) 87 | ), 88 | ], 89 | id="clas-code-modal", 90 | size="lg", 91 | ), 92 | ] 93 | ), 94 | ], 95 | style={"width": "18rem"}, 96 | ) 97 | 98 | ticket_card = dbc.Card( 99 | [ 100 | dbc.CardImg(src="assets/ticket.jpeg", top=True), 101 | dbc.CardBody( 102 | [ 103 | html.H4("Regression Dashboard", className="card-title"), 104 | html.P( 105 | "Predicting the fare paid for a ticket on the titanic. " 106 | "Showing the full default dashboard.", 107 | className="card-text", 108 | ), 109 | html.A( 110 | dbc.Button("Go to dashboard", color="primary"), href="/regression" 111 | ), 112 | dbc.Button("Show Code", id="reg-code-modal-open", className="mr-1"), 113 | dbc.Modal( 114 | [ 115 | dbc.ModalHeader("Code needed for this Regression Dashboard"), 116 | dcc.Markdown( 117 | """ 118 | ```python 119 | from sklearn.ensemble import RandomForestRegressor 120 | 121 | from explainerdashboard import RegressionExplainer, ExplainerDashboard 122 | from explainerdashboard.datasets import titanic_fare, feature_descriptions 123 | 124 | X_train, y_train, X_test, y_test = titanic_fare() 125 | model = RandomForestRegressor(n_estimators=50, max_depth=10).fit(X_train, y_train) 126 | 127 | explainer = RegressionExplainer(model, X_test, y_test, 128 | cats=['Sex', 'Deck', 'Embarked'], 129 | descriptions=feature_descriptions, 130 | units="$") 131 | 132 | ExplainerDashboard(explainer).run() 133 | ``` 134 | """ 135 | ), 136 | dbc.ModalFooter( 137 | dbc.Button( 138 | "Close", id="reg-code-modal-close", className="ml-auto" 139 | ) 140 | ), 141 | ], 142 | id="reg-code-modal", 143 | size="lg", 144 | ), 145 | ] 146 | ), 147 | ], 148 | style={"width": "18rem"}, 149 | ) 150 | 151 | port_card = dbc.Card( 152 | [ 153 | dbc.CardImg(src="assets/port.jpeg", top=True), 154 | dbc.CardBody( 155 | [ 156 | html.H4("Multiclass Dashboard", className="card-title"), 157 | html.P( 158 | "Predicting the departure port for passengers on the titanic. " 159 | "Showing the full default dashboard.", 160 | className="card-text", 161 | ), 162 | html.A( 163 | dbc.Button("Go to dashboard", color="primary"), href="/multiclass" 164 | ), 165 | dbc.Button("Show Code", id="multi-code-modal-open", className="mr-1"), 166 | dbc.Modal( 167 | [ 168 | dbc.ModalHeader( 169 | "Code needed for this Multi Classifier Dashboard" 170 | ), 171 | dcc.Markdown( 172 | """ 173 | ```python 174 | 175 | from sklearn.ensemble import RandomForestClassifier 176 | 177 | from explainerdashboard import ClassifierExplainer, ExplainerDashboard 178 | from explainerdashboard.datasets import titanic_embarked, feature_descriptions 179 | 180 | X_train, y_train, X_test, y_test = titanic_embarked() 181 | model = RandomForestClassifier(n_estimators=50, max_depth=10).fit(X_train, y_train) 182 | 183 | explainer = ClassifierExplainer(model, X_test, y_test, 184 | cats=['Sex', 'Deck'], 185 | descriptions=feature_descriptions, 186 | labels=['Queenstown', 'Southampton', 'Cherbourg']) 187 | 188 | ExplainerDashboard(explainer).run() 189 | ``` 190 | """ 191 | ), 192 | dbc.ModalFooter( 193 | dbc.Button( 194 | "Close", 195 | id="multi-code-modal-close", 196 | className="ml-auto", 197 | ) 198 | ), 199 | ], 200 | id="multi-code-modal", 201 | size="lg", 202 | ), 203 | ] 204 | ), 205 | ], 206 | style={"width": "18rem"}, 207 | ) 208 | 209 | custom_card = dbc.Card( 210 | [ 211 | dbc.CardImg(src="assets/custom.png", top=True), 212 | dbc.CardBody( 213 | [ 214 | html.H4("Customized Classifier Dashboard", className="card-title"), 215 | html.P( 216 | "You can also completely customize the layout and elements of your " 217 | "dashboard using a low-code approach.", 218 | className="card-text", 219 | ), 220 | # dbc.CardLink("Source code", 221 | # href="https://github.com/oegedijk/explainingtitanic/blob/master/custom.py", 222 | # target="_blank"), 223 | html.P(), 224 | html.A(dbc.Button("Go to dashboard", color="primary"), href="/custom"), 225 | dbc.Button("Show Code", id="custom-code-modal-open", className="mr-1"), 226 | dbc.Modal( 227 | [ 228 | dbc.ModalHeader("Code needed for this Custom Dashboard"), 229 | dcc.Markdown( 230 | """ 231 | ```python 232 | from explainerdashboard import ExplainerDashboard 233 | from explainerdashboard.custom import * 234 | 235 | 236 | class CustomModelTab(ExplainerComponent): 237 | def __init__(self, explainer): 238 | super().__init__(explainer, title="Model Summary") 239 | self.precision = PrecisionComponent(explainer, 240 | title='Precision', 241 | hide_subtitle=True, hide_footer=True, 242 | hide_selector=True, 243 | cutoff=None) 244 | self.shap_summary = ShapSummaryComponent(explainer, 245 | title='Impact', 246 | hide_subtitle=True, hide_selector=True, 247 | hide_depth=True, depth=8, 248 | hide_cats=True, cats=True) 249 | self.shap_dependence = ShapDependenceComponent(explainer, 250 | title='Dependence', 251 | hide_subtitle=True, hide_selector=True, 252 | hide_cats=True, cats=True, 253 | hide_index=True, 254 | col='Fare', color_col="PassengerClass") 255 | self.connector = ShapSummaryDependenceConnector( 256 | self.shap_summary, self.shap_dependence) 257 | 258 | self.register_components() 259 | 260 | def layout(self): 261 | return dbc.Container([ 262 | dbc.Row([ 263 | dbc.Col([ 264 | html.H3("Model Performance"), 265 | html.Div("As you can see on the right, the model performs quite well."), 266 | html.Div("The higher the predicted probability of survival predicted by " 267 | "the model on the basis of learning from examples in the training set" 268 | ", the higher is the actual percentage of passengers surviving in " 269 | "the test set"), 270 | ], width=4, style=dict(margin=30)), 271 | dbc.Col([ 272 | self.precision.layout() 273 | ], style=dict(margin=30)) 274 | ]), 275 | dbc.Row([ 276 | dbc.Col([ 277 | self.shap_summary.layout() 278 | ], style=dict(margin=30)), 279 | dbc.Col([ 280 | html.H3("Feature Importances"), 281 | html.Div("On the left you can check out for yourself which parameters were the most important."), 282 | html.Div(f"Clearly {self.explainer.columns_ranked_by_shap()[0]} was the most important" 283 | f", followed by {self.explainer.columns_ranked_by_shap()[1]}" 284 | f" and {self.explainer.columns_ranked_by_shap()[2]}."), 285 | html.Div("If you select 'detailed' you can see the impact of that variable on " 286 | "each individual prediction. With 'aggregate' you see the average impact size " 287 | "of that variable on the final prediction."), 288 | html.Div("With the detailed view you can clearly see that the the large impact from Sex " 289 | "stems both from males having a much lower chance of survival and females a much " 290 | "higher chance.") 291 | ], width=4, style=dict(margin=30)), 292 | ]), 293 | dbc.Row([ 294 | dbc.Col([ 295 | html.H3("Feature dependence"), 296 | html.Div("In the plot to the right you can see that the higher the cost " 297 | "of the fare that passengers paid, the higher the chance of survival. " 298 | "Probably the people with more expensive tickets were in higher up cabins, " 299 | "and were more likely to make it to a lifeboat."), 300 | html.Div("When you color the impacts by PassengerClass, you can clearly see that " 301 | "the more expensive tickets were mostly 1st class, and the cheaper tickets " 302 | "mostly 3rd class."), 303 | html.Div("On the right you can check out for yourself how different features impacted " 304 | "the model output."), 305 | ], width=4, style=dict(margin=30)), 306 | dbc.Col([ 307 | self.shap_dependence.layout() 308 | ], style=dict(margin=30)), 309 | ]) 310 | ]) 311 | 312 | class CustomPredictionsTab(ExplainerComponent): 313 | def __init__(self, explainer): 314 | super().__init__(explainer, title="Predictions") 315 | 316 | self.index = ClassifierRandomIndexComponent(explainer, 317 | hide_title=True, hide_index=False, 318 | hide_slider=True, hide_labels=True, 319 | hide_pred_or_perc=True, 320 | hide_selector=True, hide_button=False) 321 | 322 | self.contributions = ShapContributionsGraphComponent(explainer, 323 | hide_title=True, hide_index=True, 324 | hide_depth=True, hide_sort=True, 325 | hide_orientation=True, hide_cats=True, 326 | hide_selector=True, 327 | sort='importance') 328 | 329 | self.trees = DecisionTreesComponent(explainer, 330 | hide_title=True, hide_index=True, 331 | hide_highlight=True, hide_selector=True) 332 | 333 | 334 | self.connector = IndexConnector(self.index, [self.contributions, self.trees]) 335 | 336 | self.register_components() 337 | 338 | def layout(self): 339 | return dbc.Container([ 340 | dbc.Row([ 341 | dbc.Col([ 342 | html.H3("Enter name:"), 343 | self.index.layout() 344 | ]) 345 | ]), 346 | dbc.Row([ 347 | dbc.Col([ 348 | html.H3("Contributions to prediction:"), 349 | self.contributions.layout() 350 | ]), 351 | 352 | ]), 353 | dbc.Row([ 354 | 355 | dbc.Col([ 356 | html.H3("Every tree in the Random Forest:"), 357 | self.trees.layout() 358 | ]), 359 | ]) 360 | ]) 361 | 362 | ExplainerDashboard(explainer, [CustomModelTab, CustomPredictionsTab], 363 | title='Titanic Explainer', header_hide_selector=True, 364 | bootstrap=FLATLY).run() 365 | ``` 366 | """ 367 | ), 368 | dbc.ModalFooter( 369 | dbc.Button( 370 | "Close", 371 | id="custom-code-modal-close", 372 | className="ml-auto", 373 | ) 374 | ), 375 | ], 376 | id="custom-code-modal", 377 | size="xl", 378 | scrollable=False, 379 | ), 380 | ] 381 | ), 382 | ], 383 | style={"width": "18rem"}, 384 | ) 385 | 386 | simple_survive_card = dbc.Card( 387 | [ 388 | dbc.CardImg(src="assets/titanic.jpeg", top=True), 389 | dbc.CardBody( 390 | [ 391 | html.H4("Simplified Classifier Dashboard", className="card-title"), 392 | html.P( 393 | "You can generate a simplified single page dashboard " 394 | "by passing simple=True to ExplainerDashboard.", 395 | className="card-text", 396 | ), 397 | html.A( 398 | dbc.Button("Go to dashboard", color="primary"), 399 | href="/simple_classifier", 400 | ), 401 | dbc.Button( 402 | "Show Code", id="simple-clas-code-modal-open", className="mr-1" 403 | ), 404 | dbc.Modal( 405 | [ 406 | dbc.ModalHeader("Code needed for this Classifier Dashboard"), 407 | dcc.Markdown( 408 | """ 409 | ```python 410 | 411 | from sklearn.ensemble import RandomForestClassifier 412 | 413 | from explainerdashboard import ClassifierExplainer, ExplainerDashboard 414 | from explainerdashboard.datasets import titanic_survive, feature_descriptions 415 | 416 | X_train, y_train, X_test, y_test = titanic_survive() 417 | model = RandomForestClassifier(n_estimators=50, max_depth=10).fit(X_train, y_train) 418 | 419 | explainer = ClassifierExplainer(model, X_test, y_test, 420 | cats=['Sex', 'Deck', 'Embarked'], 421 | descriptions=feature_descriptions, 422 | labels=['Not survived', 'Survived']) 423 | 424 | ExplainerDashboard(explainer, title="Simplified Classifier Dashboard", simple=True).run() 425 | ``` 426 | """ 427 | ), 428 | dbc.ModalFooter( 429 | dbc.Button( 430 | "Close", 431 | id="simple-clas-code-modal-close", 432 | className="ml-auto", 433 | ) 434 | ), 435 | ], 436 | id="simple-clas-code-modal", 437 | size="lg", 438 | ), 439 | ] 440 | ), 441 | ], 442 | style={"width": "18rem"}, 443 | ) 444 | 445 | simple_ticket_card = dbc.Card( 446 | [ 447 | dbc.CardImg(src="assets/ticket.jpeg", top=True), 448 | dbc.CardBody( 449 | [ 450 | html.H4("Simplified Regression Dashboard", className="card-title"), 451 | html.P( 452 | "You can generate a simplified single page dashboard " 453 | "by passing simple=True to ExplainerDashboard.", 454 | className="card-text", 455 | ), 456 | html.A( 457 | dbc.Button("Go to dashboard", color="primary"), 458 | href="/simple_regression", 459 | ), 460 | dbc.Button( 461 | "Show Code", id="simple-reg-code-modal-open", className="mr-1" 462 | ), 463 | dbc.Modal( 464 | [ 465 | dbc.ModalHeader("Code needed for this Regression Dashboard"), 466 | dcc.Markdown( 467 | """ 468 | ```python 469 | from sklearn.ensemble import RandomForestRegressor 470 | 471 | from explainerdashboard import RegressionExplainer, ExplainerDashboard 472 | from explainerdashboard.datasets import titanic_fare, feature_descriptions 473 | 474 | X_train, y_train, X_test, y_test = titanic_fare() 475 | model = RandomForestRegressor(n_estimators=50, max_depth=10).fit(X_train, y_train) 476 | 477 | explainer = RegressionExplainer(model, X_test, y_test, 478 | cats=['Sex', 'Deck', 'Embarked'], 479 | descriptions=feature_descriptions, 480 | units="$") 481 | 482 | ExplainerDashboard(explainer, title="Simplified Regression Dashboard", simple=True).run() 483 | ``` 484 | """ 485 | ), 486 | dbc.ModalFooter( 487 | dbc.Button( 488 | "Close", 489 | id="simple-reg-code-modal-close", 490 | className="ml-auto", 491 | ) 492 | ), 493 | ], 494 | id="simple-reg-code-modal", 495 | size="lg", 496 | ), 497 | ] 498 | ), 499 | ], 500 | style={"width": "18rem"}, 501 | ) 502 | 503 | default_cards = dbc.Row( 504 | [ 505 | dbc.Col(survive_card, width=12, md=4, className="mb-4"), 506 | dbc.Col(ticket_card, width=12, md=4, className="mb-4"), 507 | dbc.Col(port_card, width=12, md=4, className="mb-4"), 508 | ], 509 | className="g-4", 510 | ) 511 | 512 | custom_cards = dbc.Row( 513 | [ 514 | dbc.Col(simple_survive_card, width=12, md=4, className="mb-4"), 515 | dbc.Col(simple_ticket_card, width=12, md=4, className="mb-4"), 516 | dbc.Col(custom_card, width=12, md=4, className="mb-4"), 517 | ], 518 | className="g-4", 519 | ) 520 | 521 | index_layout = dbc.Container( 522 | [ 523 | navbar, 524 | dbc.Row( 525 | [ 526 | dbc.Col( 527 | [ 528 | html.H3("explainerdashboard"), 529 | dcc.Markdown( 530 | "`explainerdashboard` is a python package that makes it easy" 531 | " to quickly build an interactive dashboard that explains the inner " 532 | "workings of a fitted machine learning model. This allows you to " 533 | "open up the 'black box' and show customers, managers, " 534 | "stakeholders, regulators (and yourself) exactly how " 535 | "the machine learning algorithm generates its predictions." 536 | ), 537 | dcc.Markdown( 538 | "You can explore model performance, feature importances, " 539 | "feature contributions (SHAP values), what-if scenarios, " 540 | "(partial) dependences, feature interactions, individual predictions, " 541 | "permutation importances and even individual decision trees. " 542 | "All interactively. All with a minimum amount of code." 543 | ), 544 | dcc.Markdown( 545 | "Works with all scikit-learn compatible models, including XGBoost, Catboost and LightGBM." 546 | ), 547 | dcc.Markdown( 548 | "Due to the modular design, it is also really easy to design your " 549 | "own custom dashboards, such as the custom example below." 550 | ), 551 | ] 552 | ) 553 | ], 554 | justify="center", 555 | ), 556 | dbc.Row( 557 | [ 558 | dbc.Col( 559 | [ 560 | html.H3("Installation"), 561 | dcc.Markdown( 562 | """ 563 | You can install the library with: 564 | 565 | ``` 566 | pip install explainerdashboard 567 | ``` 568 | 569 | or: 570 | 571 | ``` 572 | conda install -c conda-forge explainerdashboard 573 | ``` 574 | 575 | """ 576 | ), 577 | ] 578 | ) 579 | ], 580 | justify="center", 581 | ), 582 | dbc.Row( 583 | [ 584 | dbc.Col( 585 | [ 586 | dcc.Markdown( 587 | """ 588 | More information can be found in the [github repo](http://github.com/oegedijk/explainerdashboard) 589 | and the documentation on [explainerdashboard.readthedocs.io](http://explainerdashboard.readthedocs.io). 590 | """ 591 | ) 592 | ] 593 | ) 594 | ], 595 | justify="center", 596 | ), 597 | dbc.Row( 598 | [ 599 | dbc.Col( 600 | [ 601 | html.H3("Examples"), 602 | dcc.Markdown(""" 603 | Below you can find demonstrations of the three default dashboards for classification, 604 | regression and multi class classification problems, plus one demonstration of 605 | a custom dashboard. 606 | """), 607 | ] 608 | ) 609 | ], 610 | justify="center", 611 | ), 612 | dbc.Row( 613 | [ 614 | dbc.Col( 615 | [ 616 | default_cards, 617 | ] 618 | ), 619 | ] 620 | ), 621 | dbc.Row([dbc.Col([custom_cards])], justify="start"), 622 | ] 623 | ) 624 | 625 | 626 | def register_callbacks(app): 627 | @app.callback( 628 | Output("clas-code-modal", "is_open"), 629 | Input("clas-code-modal-open", "n_clicks"), 630 | Input("clas-code-modal-close", "n_clicks"), 631 | State("clas-code-modal", "is_open"), 632 | ) 633 | def toggle_modal(click_open, click_close, is_open): 634 | if click_open or click_close: 635 | return not is_open 636 | return is_open 637 | 638 | @app.callback( 639 | Output("reg-code-modal", "is_open"), 640 | Input("reg-code-modal-open", "n_clicks"), 641 | Input("reg-code-modal-close", "n_clicks"), 642 | State("reg-code-modal", "is_open"), 643 | ) 644 | def toggle_modal(click_open, click_close, is_open): 645 | if click_open or click_close: 646 | return not is_open 647 | return is_open 648 | 649 | @app.callback( 650 | Output("multi-code-modal", "is_open"), 651 | Input("multi-code-modal-open", "n_clicks"), 652 | Input("multi-code-modal-close", "n_clicks"), 653 | State("multi-code-modal", "is_open"), 654 | ) 655 | def toggle_modal(click_open, click_close, is_open): 656 | if click_open or click_close: 657 | return not is_open 658 | return is_open 659 | 660 | @app.callback( 661 | Output("custom-code-modal", "is_open"), 662 | Input("custom-code-modal-open", "n_clicks"), 663 | Input("custom-code-modal-close", "n_clicks"), 664 | State("custom-code-modal", "is_open"), 665 | ) 666 | def toggle_modal(click_open, click_close, is_open): 667 | if click_open or click_close: 668 | return not is_open 669 | return is_open 670 | 671 | @app.callback( 672 | Output("simple-clas-code-modal", "is_open"), 673 | Input("simple-clas-code-modal-open", "n_clicks"), 674 | Input("simple-clas-code-modal-close", "n_clicks"), 675 | State("simple-clas-code-modal", "is_open"), 676 | ) 677 | def toggle_modal(click_open, click_close, is_open): 678 | if click_open or click_close: 679 | return not is_open 680 | return is_open 681 | 682 | @app.callback( 683 | Output("simple-reg-code-modal", "is_open"), 684 | Input("simple-reg-code-modal-open", "n_clicks"), 685 | Input("simple-reg-code-modal-close", "n_clicks"), 686 | State("simple-reg-code-modal", "is_open"), 687 | ) 688 | def toggle_modal(click_open, click_close, is_open): 689 | if click_open or click_close: 690 | return not is_open 691 | return is_open 692 | --------------------------------------------------------------------------------