├── slime ├── __init__.py ├── tests │ ├── __init__.py │ ├── test_generic_utils.py │ ├── test_scikit_image.py │ ├── test_lime_text.py │ ├── test_discretize.py │ └── test_lime_tabular.py ├── utils │ ├── __init__.py │ └── generic_utils.py ├── wrappers │ ├── __init__.py │ └── scikit_image.py ├── .DS_Store ├── exceptions.py ├── js │ ├── main.js │ ├── predict_proba.js │ ├── bar_chart.js │ ├── predicted_value.js │ └── explanation.js ├── test_table.html ├── style.css ├── webpack.config.js ├── package.json ├── submodular_pick.py ├── discretize.py ├── explanation.py ├── lime_base.py ├── lime_text.py └── lime_image.py ├── slime_lm ├── __init__.py └── _least_angle.py ├── MANIFEST.in ├── .DS_Store ├── doc └── images │ ├── demo1.png │ └── demo2.png ├── setup.py ├── LICENSE ├── .gitignore └── README.md /slime/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime_lm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /slime/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include slime/*.js 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengzeZhou/slime/HEAD/.DS_Store -------------------------------------------------------------------------------- /slime/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengzeZhou/slime/HEAD/slime/.DS_Store -------------------------------------------------------------------------------- /slime/exceptions.py: -------------------------------------------------------------------------------- 1 | class LimeError(Exception): 2 | """Raise for errors""" 3 | -------------------------------------------------------------------------------- /doc/images/demo1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengzeZhou/slime/HEAD/doc/images/demo1.png -------------------------------------------------------------------------------- /doc/images/demo2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengzeZhou/slime/HEAD/doc/images/demo2.png -------------------------------------------------------------------------------- /slime/js/main.js: -------------------------------------------------------------------------------- 1 | if (!global._babelPolyfill) { 2 | require('babel-polyfill') 3 | } 4 | 5 | 6 | import Explanation from './explanation.js'; 7 | import Barchart from './bar_chart.js'; 8 | import PredictProba from './predict_proba.js'; 9 | import PredictedValue from './predicted_value.js'; 10 | require('../style.css'); 11 | 12 | export {Explanation, Barchart, PredictProba, PredictedValue}; 13 | //require('style-loader'); 14 | 15 | 16 | -------------------------------------------------------------------------------- /slime/test_table.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /slime/style.css: -------------------------------------------------------------------------------- 1 | .lime { 2 | all: initial; 3 | } 4 | .lime.top_div { 5 | display: flex; 6 | flex-wrap: wrap; 7 | } 8 | .lime.predict_proba { 9 | width: 245px; 10 | } 11 | .lime.predicted_value { 12 | width: 245px; 13 | } 14 | .lime.explanation { 15 | width: 350px; 16 | } 17 | 18 | .lime.text_div { 19 | max-height:300px; 20 | flex: 1 0 300px; 21 | overflow:scroll; 22 | } 23 | .lime.table_div { 24 | max-height:300px; 25 | flex: 1 0 300px; 26 | overflow:scroll; 27 | } 28 | .lime.table_div table { 29 | border-collapse: collapse; 30 | color: white; 31 | border-style: hidden; 32 | margin: 0 auto; 33 | } 34 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='slime', 4 | version='0.1', 5 | description='Stabilized-LIME for Model Explanation', 6 | url='https://github.com/ZhengzeZhou/slime', 7 | author='Zhengze Zhou', 8 | author_email='zz433@cornell.edu', 9 | license='BSD', 10 | packages=find_packages(exclude=['js', 'node_modules', 'tests']), 11 | python_requires='>=3.5', 12 | install_requires=[ 13 | 'lime', 14 | 'matplotlib', 15 | 'numpy', 16 | 'scipy', 17 | 'tqdm >= 4.29.1', 18 | 'scikit-learn>=0.18', 19 | 'scikit-image>=0.12', 20 | 'pyDOE2==1.3.0' 21 | ], 22 | extras_require={ 23 | 'dev': ['pytest', 'flake8'], 24 | }, 25 | include_package_data=True, 26 | zip_safe=False) 27 | -------------------------------------------------------------------------------- /slime/webpack.config.js: -------------------------------------------------------------------------------- 1 | var path = require('path'); 2 | var webpack = require('webpack'); 3 | 4 | module.exports = { 5 | entry: './js/main.js', 6 | output: { 7 | path: __dirname, 8 | filename: 'bundle.js', 9 | library: 'lime' 10 | }, 11 | module: { 12 | loaders: [ 13 | { 14 | loader: 'babel-loader', 15 | test: path.join(__dirname, 'js'), 16 | query: { 17 | presets: 'es2015-ie', 18 | }, 19 | 20 | }, 21 | { 22 | test: /\.css$/, 23 | loaders: ['style-loader', 'css-loader'], 24 | 25 | } 26 | 27 | ] 28 | }, 29 | plugins: [ 30 | // Avoid publishing files when compilation fails 31 | new webpack.NoErrorsPlugin() 32 | ], 33 | stats: { 34 | // Nice colored output 35 | colors: true 36 | }, 37 | // Create Sourcemaps for the bundle 38 | devtool: 'source-map', 39 | }; 40 | 41 | -------------------------------------------------------------------------------- /slime/utils/generic_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import inspect 3 | import types 4 | 5 | 6 | def has_arg(fn, arg_name): 7 | """Checks if a callable accepts a given keyword argument. 8 | 9 | Args: 10 | fn: callable to inspect 11 | arg_name: string, keyword argument name to check 12 | 13 | Returns: 14 | bool, whether `fn` accepts a `arg_name` keyword argument. 15 | """ 16 | if sys.version_info < (3,): 17 | if isinstance(fn, types.FunctionType) or isinstance(fn, types.MethodType): 18 | arg_spec = inspect.getargspec(fn) 19 | else: 20 | try: 21 | arg_spec = inspect.getargspec(fn.__call__) 22 | except AttributeError: 23 | return False 24 | return (arg_name in arg_spec.args) 25 | elif sys.version_info < (3, 6): 26 | arg_spec = inspect.getfullargspec(fn) 27 | return (arg_name in arg_spec.args or 28 | arg_name in arg_spec.kwonlyargs) 29 | else: 30 | try: 31 | signature = inspect.signature(fn) 32 | except ValueError: 33 | # handling Cython 34 | signature = inspect.signature(fn.__call__) 35 | parameter = signature.parameters.get(arg_name) 36 | if parameter is None: 37 | return False 38 | return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, 39 | inspect.Parameter.KEYWORD_ONLY)) 40 | -------------------------------------------------------------------------------- /slime/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "lime", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "main.js", 6 | "scripts": { 7 | "build": "webpack", 8 | "watch": "webpack --watch", 9 | "start": "webpack-dev-server --hot --inline", 10 | "lint": "eslint js" 11 | }, 12 | "repository": { 13 | "type": "git", 14 | "url": "git+https://github.com/marcotcr/lime.git" 15 | }, 16 | "author": "Marco Tulio Ribeiro ", 17 | "license": "TODO", 18 | "bugs": { 19 | "url": "https://github.com/marcotcr/lime/issues" 20 | }, 21 | "homepage": "https://github.com/marcotcr/lime#readme", 22 | "devDependencies": { 23 | "babel-cli": "^6.8.0", 24 | "babel-core": "^6.17.0", 25 | "babel-eslint": "^6.1.0", 26 | "babel-loader": "^6.2.4", 27 | "babel-polyfill": "^6.16.0", 28 | "babel-preset-es2015": "^6.0.15", 29 | "babel-preset-es2015-ie": "^6.6.2", 30 | "css-loader": "^0.23.1", 31 | "eslint": "^6.6.0", 32 | "node-libs-browser": "^0.5.3", 33 | "style-loader": "^0.13.1", 34 | "webpack": "^1.13.0", 35 | "webpack-dev-server": "^1.14.1" 36 | }, 37 | "dependencies": { 38 | "d3": "^3.5.17", 39 | "lodash": "^4.11.2" 40 | }, 41 | "eslintConfig": { 42 | "parser": "babel-eslint", 43 | "parserOptions": { 44 | "ecmaVersion": 6, 45 | "sourceType": "module", 46 | "ecmaFeatures": { 47 | "jsx": true 48 | } 49 | }, 50 | "extends": "eslint:recommended" 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, Zhengze Zhou 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Setuptools distribution folder. 5 | /dist/ 6 | 7 | /lime/node_modules 8 | 9 | # Python egg metadata, regenerated from source files by setuptools. 10 | /*.egg-info 11 | 12 | # Unit test / coverage reports 13 | .cache 14 | 15 | # Created by https://www.gitignore.io/api/pycharm 16 | 17 | ### PyCharm ### 18 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 19 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 20 | 21 | # User-specific stuff: 22 | .idea/workspace.xml 23 | .idea/tasks.xml 24 | .idea/dictionaries 25 | .idea/vcs.xml 26 | .idea/jsLibraryMappings.xml 27 | 28 | # Sensitive or high-churn files: 29 | .idea/dataSources.ids 30 | .idea/dataSources.xml 31 | .idea/dataSources.local.xml 32 | .idea/sqlDataSources.xml 33 | .idea/dynamic.xml 34 | .idea/uiDesigner.xml 35 | 36 | # Gradle: 37 | .idea/gradle.xml 38 | .idea/libraries 39 | 40 | # Mongo Explorer plugin: 41 | .idea/mongoSettings.xml 42 | 43 | ## File-based project format: 44 | *.iws 45 | 46 | ## Plugin-specific files: 47 | 48 | # IntelliJ 49 | /out/ 50 | 51 | # mpeltonen/sbt-idea plugin 52 | .idea_modules/ 53 | 54 | # JIRA plugin 55 | atlassian-ide-plugin.xml 56 | 57 | # Crashlytics plugin (for Android Studio and IntelliJ) 58 | com_crashlytics_export_strings.xml 59 | crashlytics.properties 60 | crashlytics-build.properties 61 | fabric.properties 62 | 63 | ### PyCharm Patch ### 64 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 65 | 66 | # *.iml 67 | # modules.xml 68 | # .idea/misc.xml 69 | # *.ipr 70 | 71 | # Pycharm 72 | .idea 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # slime 2 | 3 | This repository holds code for replicating experiments in papar [S-LIME: Stabilized-LIME for Model Explanation]() to appear in [KDD2021](https://www.kdd.org/kdd2021/). 4 | 5 | It is built on the implementation of [LIME](https://github.com/marcotcr/lime) with added functionalities. 6 | 7 | ## Introduction 8 | 9 | It has been shown that post hoc explanations based on perturbations (such as LIME) exhibit large instability, posing serious challenges to the effectiveness of the method itself and harming user trust. S-LIME stands for Stabilized-LIME, which utilizes a hypothesis testing framework based on central limit theorem for determining the number of perturbation points needed to guarantee stability of the resulting explanation. 10 | 11 | ## Installation 12 | 13 | clone the repository and install using pip: 14 | 15 | ```sh 16 | git clone https://github.com/ZhengzeZhou/slime.git 17 | cd slime 18 | pip install . 19 | ``` 20 | 21 | ## Usage 22 | 23 | Currently, S-LIME only support tabular data and when feature selection method is set to "lasso_path". We are woring on extending the use cases to other data types and feature selection methods. 24 | 25 | The following screenshot shows a typical usage of LIME on breasd cancer data. We can easily observe that two runs of the explanation algorithms result in different features being selected. 26 | 27 | ![demo1](doc/images/demo1.png) 28 | 29 | S-LIME is invoked by calling **explainer.slime** instead of **explainer.explain_instance**. *n_max* indicates the maximum number of sythetic samples to generate and *alpha* denotes the significance level of hypothesis testing. S-LIME explanations are guranteed to be stable under high probability. 30 | 31 | ![demo2](doc/images/demo2.png) 32 | 33 | ## Notebooks 34 | 35 | - [Breast Cancer Data](https://github.com/ZhengzeZhou/slime/blob/main/doc/notebooks/Breast%20Cancer%20Data.ipynb) 36 | - [MARS](https://github.com/ZhengzeZhou/slime/blob/main/doc/notebooks/MARS.ipynb) 37 | - [Dog Images](https://github.com/joangog/slime/blob/main/doc/notebooks/Dogs.ipynb) 38 | 39 | 40 | -------------------------------------------------------------------------------- /slime/tests/test_generic_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import sys 3 | from lime.utils.generic_utils import has_arg 4 | 5 | 6 | class TestGenericUtils(unittest.TestCase): 7 | 8 | def test_has_arg(self): 9 | # fn is callable / is not callable 10 | 11 | class FooNotCallable: 12 | 13 | def __init__(self, word): 14 | self.message = word 15 | 16 | class FooCallable: 17 | 18 | def __init__(self, word): 19 | self.message = word 20 | 21 | def __call__(self, message): 22 | return message 23 | 24 | def positional_argument_call(self, arg1): 25 | return self.message 26 | 27 | def multiple_positional_arguments_call(self, *args): 28 | res = [] 29 | for a in args: 30 | res.append(a) 31 | return res 32 | 33 | def keyword_argument_call(self, filter_=True): 34 | res = self.message 35 | if filter_: 36 | res = 'KO' 37 | return res 38 | 39 | def multiple_keyword_arguments_call(self, arg1='1', arg2='2'): 40 | return self.message + arg1 + arg2 41 | 42 | def undefined_keyword_arguments_call(self, **kwargs): 43 | res = self.message 44 | for a in kwargs: 45 | res = res + a 46 | return a 47 | 48 | foo_callable = FooCallable('OK') 49 | self.assertTrue(has_arg(foo_callable, 'message')) 50 | 51 | if sys.version_info < (3,): 52 | foo_not_callable = FooNotCallable('KO') 53 | self.assertFalse(has_arg(foo_not_callable, 'message')) 54 | elif sys.version_info < (3, 6): 55 | with self.assertRaises(TypeError): 56 | foo_not_callable = FooNotCallable('KO') 57 | has_arg(foo_not_callable, 'message') 58 | 59 | # Python 2, argument in / not in valid arguments / keyword arguments 60 | if sys.version_info < (3,): 61 | self.assertFalse(has_arg(foo_callable, 'invalid_arg')) 62 | self.assertTrue(has_arg(foo_callable.positional_argument_call, 'arg1')) 63 | self.assertFalse(has_arg(foo_callable.multiple_positional_arguments_call, 'argX')) 64 | self.assertFalse(has_arg(foo_callable.keyword_argument_call, 'argX')) 65 | self.assertTrue(has_arg(foo_callable.keyword_argument_call, 'filter_')) 66 | self.assertTrue(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg2')) 67 | self.assertFalse(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg3')) 68 | self.assertFalse(has_arg(foo_callable.undefined_keyword_arguments_call, 'argX')) 69 | # Python 3, argument in / not in valid arguments / keyword arguments 70 | elif sys.version_info < (3, 6): 71 | self.assertFalse(has_arg(foo_callable, 'invalid_arg')) 72 | self.assertTrue(has_arg(foo_callable.positional_argument_call, 'arg1')) 73 | self.assertFalse(has_arg(foo_callable.multiple_positional_arguments_call, 'argX')) 74 | self.assertFalse(has_arg(foo_callable.keyword_argument_call, 'argX')) 75 | self.assertTrue(has_arg(foo_callable.keyword_argument_call, 'filter_')) 76 | self.assertTrue(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg2')) 77 | self.assertFalse(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg3')) 78 | self.assertFalse(has_arg(foo_callable.undefined_keyword_arguments_call, 'argX')) 79 | else: 80 | self.assertFalse(has_arg(foo_callable, 'invalid_arg')) 81 | self.assertTrue(has_arg(foo_callable.positional_argument_call, 'arg1')) 82 | self.assertFalse(has_arg(foo_callable.multiple_positional_arguments_call, 'argX')) 83 | self.assertFalse(has_arg(foo_callable.keyword_argument_call, 'argX')) 84 | self.assertTrue(has_arg(foo_callable.keyword_argument_call, 'filter_')) 85 | self.assertTrue(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg2')) 86 | self.assertFalse(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg3')) 87 | self.assertFalse(has_arg(foo_callable.undefined_keyword_arguments_call, 'argX')) 88 | # argname is None 89 | self.assertFalse(has_arg(foo_callable, None)) 90 | 91 | 92 | if __name__ == '__main__': 93 | unittest.main() 94 | -------------------------------------------------------------------------------- /slime/js/predict_proba.js: -------------------------------------------------------------------------------- 1 | import d3 from 'd3'; 2 | import {range, sortBy} from 'lodash'; 3 | 4 | class PredictProba { 5 | // svg: d3 object with the svg in question 6 | // class_names: array of class names 7 | // predict_probas: array of prediction probabilities 8 | constructor(svg, class_names, predict_probas, title='Prediction probabilities') { 9 | let width = parseInt(svg.style('width')); 10 | this.names = class_names; 11 | this.names.push('Other'); 12 | if (class_names.length < 10) { 13 | this.colors = d3.scale.category10().domain(this.names); 14 | this.colors_i = d3.scale.category10().domain(range(this.names.length)); 15 | } 16 | else { 17 | this.colors = d3.scale.category20().domain(this.names); 18 | this.colors_i = d3.scale.category20().domain(range(this.names.length)); 19 | } 20 | let [names, data] = this.map_classes(this.names, predict_probas); 21 | let bar_x = width - 125; 22 | let class_names_width = bar_x; 23 | let bar_width = width - bar_x - 32; 24 | let x_scale = d3.scale.linear().range([0, bar_width]); 25 | let bar_height = 17; 26 | let space_between_bars = 5; 27 | let bar_yshift= title === '' ? 0 : 35; 28 | let n_bars = Math.min(5, data.length); 29 | this.svg_height = n_bars * (bar_height + space_between_bars) + bar_yshift; 30 | svg.style('height', this.svg_height + 'px'); 31 | let this_object = this; 32 | if (title !== '') { 33 | svg.append('text') 34 | .text(title) 35 | .attr('x', 20) 36 | .attr('y', 20); 37 | } 38 | let bar_y = i => (bar_height + space_between_bars) * i + bar_yshift; 39 | let bar = svg.append("g"); 40 | 41 | for (let i of range(data.length)) { 42 | var color = this.colors(names[i]); 43 | if (names[i] == 'Other' && this.names.length > 20) { 44 | color = '#5F9EA0'; 45 | } 46 | let rect = bar.append("rect"); 47 | rect.attr("x", bar_x) 48 | .attr("y", bar_y(i)) 49 | .attr("height", bar_height) 50 | .attr("width", x_scale(data[i])) 51 | .style("fill", color); 52 | bar.append("rect").attr("x", bar_x) 53 | .attr("y", bar_y(i)) 54 | .attr("height", bar_height) 55 | .attr("width", bar_width - 1) 56 | .attr("fill-opacity", 0) 57 | .attr("stroke", "black"); 58 | let text = bar.append("text"); 59 | text.classed("prob_text", true); 60 | text.attr("y", bar_y(i) + bar_height - 3).attr("fill", "black").style("font", "14px tahoma, sans-serif"); 61 | text = bar.append("text"); 62 | text.attr("x", bar_x + x_scale(data[i]) + 5) 63 | .attr("y", bar_y(i) + bar_height - 3) 64 | .attr("fill", "black") 65 | .style("font", "14px tahoma, sans-serif") 66 | .text(data[i].toFixed(2)); 67 | text = bar.append("text"); 68 | text.attr("x", bar_x - 10) 69 | .attr("y", bar_y(i) + bar_height - 3) 70 | .attr("fill", "black") 71 | .attr("text-anchor", "end") 72 | .style("font", "14px tahoma, sans-serif") 73 | .text(names[i]); 74 | while (text.node().getBBox()['width'] + 1 > (class_names_width - 10)) { 75 | // TODO: ta mostrando só dois, e talvez quando hover mostrar o texto 76 | // todo 77 | let cur_text = text.text().slice(0, text.text().length - 5); 78 | text.text(cur_text + '...'); 79 | if (cur_text === '') { 80 | break 81 | } 82 | } 83 | } 84 | } 85 | map_classes(class_names, predict_proba) { 86 | if (class_names.length <= 6) { 87 | return [class_names, predict_proba]; 88 | } 89 | let class_dict = range(predict_proba.length).map(i => ({'name': class_names[i], 'prob': predict_proba[i], 'i' : i})); 90 | let sorted = sortBy(class_dict, d => -d.prob); 91 | let other = new Set(); 92 | range(4, sorted.length).map(d => other.add(sorted[d].name)); 93 | let other_prob = 0; 94 | let ret_probs = []; 95 | let ret_names = []; 96 | for (let d of range(sorted.length)) { 97 | if (other.has(sorted[d].name)) { 98 | other_prob += sorted[d].prob; 99 | } 100 | else { 101 | ret_probs.push(sorted[d].prob); 102 | ret_names.push(sorted[d].name); 103 | } 104 | }; 105 | ret_names.push("Other"); 106 | ret_probs.push(other_prob); 107 | return [ret_names, ret_probs]; 108 | } 109 | 110 | } 111 | export default PredictProba; 112 | 113 | 114 | -------------------------------------------------------------------------------- /slime/js/bar_chart.js: -------------------------------------------------------------------------------- 1 | import d3 from 'd3'; 2 | class Barchart { 3 | // svg: d3 object with the svg in question 4 | // exp_array: list of (feature_name, weight) 5 | constructor(svg, exp_array, two_sided=true, titles=undefined, colors=['red', 'green'], show_numbers=false, bar_height=5) { 6 | let svg_width = Math.min(600, parseInt(svg.style('width'))); 7 | let bar_width = two_sided ? svg_width / 2 : svg_width; 8 | if (titles === undefined) { 9 | titles = two_sided ? ['Cons', 'Pros'] : 'Pros'; 10 | } 11 | if (show_numbers) { 12 | bar_width = bar_width - 30; 13 | } 14 | let x_offset = two_sided ? svg_width / 2 : 10; 15 | // 13.1 is +- the width of W, the widest letter. 16 | if (two_sided && titles.length == 2) { 17 | svg.append('text') 18 | .attr('x', svg_width / 4) 19 | .attr('y', 15) 20 | .attr('font-size', '20') 21 | .attr('text-anchor', 'middle') 22 | .style('fill', colors[0]) 23 | .text(titles[0]); 24 | 25 | svg.append('text') 26 | .attr('x', svg_width / 4 * 3) 27 | .attr('y', 15) 28 | .attr('font-size', '20') 29 | .attr('text-anchor', 'middle') 30 | .style('fill', colors[1]) 31 | .text(titles[1]); 32 | } 33 | else { 34 | let pos = two_sided ? svg_width / 2 : x_offset; 35 | let anchor = two_sided ? 'middle' : 'begin'; 36 | svg.append('text') 37 | .attr('x', pos) 38 | .attr('y', 15) 39 | .attr('font-size', '20') 40 | .attr('text-anchor', anchor) 41 | .text(titles); 42 | } 43 | let yshift = 20; 44 | let space_between_bars = 0; 45 | let text_height = 16; 46 | let space_between_bar_and_text = 3; 47 | let total_bar_height = text_height + space_between_bar_and_text + bar_height + space_between_bars; 48 | let total_height = (total_bar_height) * exp_array.length; 49 | this.svg_height = total_height + yshift; 50 | let yscale = d3.scale.linear() 51 | .domain([0, exp_array.length]) 52 | .range([yshift, yshift + total_height]) 53 | let names = exp_array.map(v => v[0]); 54 | let weights = exp_array.map(v => v[1]); 55 | let max_weight = Math.max(...(weights.map(v=>Math.abs(v)))); 56 | let xscale = d3.scale.linear() 57 | .domain([0,Math.max(1, max_weight)]) 58 | .range([0, bar_width]); 59 | 60 | for (var i = 0; i < exp_array.length; ++i) { 61 | let name = names[i]; 62 | let weight = weights[i]; 63 | var size = xscale(Math.abs(weight)); 64 | let to_the_right = (weight > 0 || !two_sided) 65 | let text = svg.append('text') 66 | .attr('x', to_the_right ? x_offset + 2 : x_offset - 2) 67 | .attr('y', yscale(i) + text_height) 68 | .attr('text-anchor', to_the_right ? 'begin' : 'end') 69 | .attr('font-size', '14') 70 | .text(name); 71 | while (text.node().getBBox()['width'] + 1 > bar_width) { 72 | let cur_text = text.text().slice(0, text.text().length - 5); 73 | text.text(cur_text + '...'); 74 | if (text === '...') { 75 | break; 76 | } 77 | } 78 | let bar = svg.append('rect') 79 | .attr('height', bar_height) 80 | .attr('x', to_the_right ? x_offset : x_offset - size) 81 | .attr('y', text_height + yscale(i) + space_between_bar_and_text)// + bar_height) 82 | .attr('width', size) 83 | .style('fill', weight > 0 ? colors[1] : colors[0]); 84 | if (show_numbers) { 85 | let bartext = svg.append('text') 86 | .attr('x', to_the_right ? x_offset + size + 1 : x_offset - size - 1) 87 | .attr('text-anchor', (weight > 0 || !two_sided) ? 'begin' : 'end') 88 | .attr('y', bar_height + yscale(i) + text_height + space_between_bar_and_text) 89 | .attr('font-size', '10') 90 | .text(Math.abs(weight).toFixed(2)); 91 | } 92 | } 93 | let line = svg.append("line") 94 | .attr("x1", x_offset) 95 | .attr("x2", x_offset) 96 | .attr("y1", bar_height + yshift) 97 | .attr("y2", Math.max(bar_height, yscale(exp_array.length))) 98 | .style("stroke-width",2) 99 | .style("stroke", "black"); 100 | } 101 | 102 | } 103 | export default Barchart; 104 | -------------------------------------------------------------------------------- /slime/wrappers/scikit_image.py: -------------------------------------------------------------------------------- 1 | import types 2 | from lime.utils.generic_utils import has_arg 3 | from skimage.segmentation import felzenszwalb, slic, quickshift 4 | 5 | 6 | class BaseWrapper(object): 7 | """Base class for LIME Scikit-Image wrapper 8 | 9 | 10 | Args: 11 | target_fn: callable function or class instance 12 | target_params: dict, parameters to pass to the target_fn 13 | 14 | 15 | 'target_params' takes parameters required to instanciate the 16 | desired Scikit-Image class/model 17 | """ 18 | 19 | def __init__(self, target_fn=None, **target_params): 20 | self.target_fn = target_fn 21 | self.target_params = target_params 22 | 23 | def _check_params(self, parameters): 24 | """Checks for mistakes in 'parameters' 25 | 26 | Args : 27 | parameters: dict, parameters to be checked 28 | 29 | Raises : 30 | ValueError: if any parameter is not a valid argument for the target function 31 | or the target function is not defined 32 | TypeError: if argument parameters is not iterable 33 | """ 34 | a_valid_fn = [] 35 | if self.target_fn is None: 36 | if callable(self): 37 | a_valid_fn.append(self.__call__) 38 | else: 39 | raise TypeError('invalid argument: tested object is not callable,\ 40 | please provide a valid target_fn') 41 | elif isinstance(self.target_fn, types.FunctionType) \ 42 | or isinstance(self.target_fn, types.MethodType): 43 | a_valid_fn.append(self.target_fn) 44 | else: 45 | a_valid_fn.append(self.target_fn.__call__) 46 | 47 | if not isinstance(parameters, str): 48 | for p in parameters: 49 | for fn in a_valid_fn: 50 | if has_arg(fn, p): 51 | pass 52 | else: 53 | raise ValueError('{} is not a valid parameter'.format(p)) 54 | else: 55 | raise TypeError('invalid argument: list or dictionnary expected') 56 | 57 | def set_params(self, **params): 58 | """Sets the parameters of this estimator. 59 | Args: 60 | **params: Dictionary of parameter names mapped to their values. 61 | 62 | Raises : 63 | ValueError: if any parameter is not a valid argument 64 | for the target function 65 | """ 66 | self._check_params(params) 67 | self.target_params = params 68 | 69 | def filter_params(self, fn, override=None): 70 | """Filters `target_params` and return those in `fn`'s arguments. 71 | Args: 72 | fn : arbitrary function 73 | override: dict, values to override target_params 74 | Returns: 75 | result : dict, dictionary containing variables 76 | in both target_params and fn's arguments. 77 | """ 78 | override = override or {} 79 | result = {} 80 | for name, value in self.target_params.items(): 81 | if has_arg(fn, name): 82 | result.update({name: value}) 83 | result.update(override) 84 | return result 85 | 86 | 87 | class SegmentationAlgorithm(BaseWrapper): 88 | """ Define the image segmentation function based on Scikit-Image 89 | implementation and a set of provided parameters 90 | 91 | Args: 92 | algo_type: string, segmentation algorithm among the following: 93 | 'quickshift', 'slic', 'felzenszwalb' 94 | target_params: dict, algorithm parameters (valid model paramters 95 | as define in Scikit-Image documentation) 96 | """ 97 | 98 | def __init__(self, algo_type, **target_params): 99 | self.algo_type = algo_type 100 | if (self.algo_type == 'quickshift'): 101 | BaseWrapper.__init__(self, quickshift, **target_params) 102 | kwargs = self.filter_params(quickshift) 103 | self.set_params(**kwargs) 104 | elif (self.algo_type == 'felzenszwalb'): 105 | BaseWrapper.__init__(self, felzenszwalb, **target_params) 106 | kwargs = self.filter_params(felzenszwalb) 107 | self.set_params(**kwargs) 108 | elif (self.algo_type == 'slic'): 109 | BaseWrapper.__init__(self, slic, **target_params) 110 | kwargs = self.filter_params(slic) 111 | self.set_params(**kwargs) 112 | 113 | def __call__(self, *args): 114 | return self.target_fn(args[0], **self.target_params) 115 | -------------------------------------------------------------------------------- /slime/tests/test_scikit_image.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from lime.wrappers.scikit_image import BaseWrapper 3 | from lime.wrappers.scikit_image import SegmentationAlgorithm 4 | from skimage.segmentation import quickshift 5 | from skimage.data import chelsea 6 | from skimage.util import img_as_float 7 | import numpy as np 8 | 9 | 10 | class TestBaseWrapper(unittest.TestCase): 11 | 12 | def test_base_wrapper(self): 13 | 14 | obj_with_params = BaseWrapper(a=10, b='message') 15 | obj_without_params = BaseWrapper() 16 | 17 | def foo_fn(): 18 | return 'bar' 19 | 20 | obj_with_fn = BaseWrapper(foo_fn) 21 | self.assertEqual(obj_with_params.target_params, {'a': 10, 'b': 'message'}) 22 | self.assertEqual(obj_without_params.target_params, {}) 23 | self.assertEqual(obj_with_fn.target_fn(), 'bar') 24 | 25 | def test__check_params(self): 26 | 27 | def bar_fn(a): 28 | return str(a) 29 | 30 | class Pipo(): 31 | 32 | def __init__(self): 33 | self.name = 'pipo' 34 | 35 | def __call__(self, message): 36 | return message 37 | 38 | pipo = Pipo() 39 | obj_with_valid_fn = BaseWrapper(bar_fn, a=10, b='message') 40 | obj_with_valid_callable_fn = BaseWrapper(pipo, c=10, d='message') 41 | obj_with_invalid_fn = BaseWrapper([1, 2, 3], fn_name='invalid') 42 | 43 | # target_fn is not a callable or function/method 44 | with self.assertRaises(AttributeError): 45 | obj_with_invalid_fn._check_params('fn_name') 46 | 47 | # parameters is not in target_fn args 48 | with self.assertRaises(ValueError): 49 | obj_with_valid_fn._check_params(['c']) 50 | obj_with_valid_callable_fn._check_params(['e']) 51 | 52 | # params is in target_fn args 53 | try: 54 | obj_with_valid_fn._check_params(['a']) 55 | obj_with_valid_callable_fn._check_params(['message']) 56 | except Exception: 57 | self.fail("_check_params() raised an unexpected exception") 58 | 59 | # params is not a dict or list 60 | with self.assertRaises(TypeError): 61 | obj_with_valid_fn._check_params(None) 62 | with self.assertRaises(TypeError): 63 | obj_with_valid_fn._check_params('param_name') 64 | 65 | def test_set_params(self): 66 | 67 | class Pipo(): 68 | 69 | def __init__(self): 70 | self.name = 'pipo' 71 | 72 | def __call__(self, message): 73 | return message 74 | pipo = Pipo() 75 | obj = BaseWrapper(pipo) 76 | 77 | # argument is set accordingly 78 | obj.set_params(message='OK') 79 | self.assertEqual(obj.target_params, {'message': 'OK'}) 80 | self.assertEqual(obj.target_fn(**obj.target_params), 'OK') 81 | 82 | # invalid argument is passed 83 | try: 84 | obj = BaseWrapper(Pipo()) 85 | obj.set_params(invalid='KO') 86 | except Exception: 87 | self.assertEqual(obj.target_params, {}) 88 | 89 | def test_filter_params(self): 90 | 91 | # right arguments are kept and wrong dismmissed 92 | def baz_fn(a, b, c=True): 93 | if c: 94 | return a + b 95 | else: 96 | return a 97 | obj_ = BaseWrapper(baz_fn, a=10, b=100, d=1000) 98 | self.assertEqual(obj_.filter_params(baz_fn), {'a': 10, 'b': 100}) 99 | 100 | # target_params is overriden using 'override' argument 101 | self.assertEqual(obj_.filter_params(baz_fn, override={'c': False}), 102 | {'a': 10, 'b': 100, 'c': False}) 103 | 104 | 105 | class TestSegmentationAlgorithm(unittest.TestCase): 106 | 107 | def test_instanciate_segmentation_algorithm(self): 108 | img = img_as_float(chelsea()[::2, ::2]) 109 | 110 | # wrapped functions provide the same result 111 | fn = SegmentationAlgorithm('quickshift', kernel_size=3, max_dist=6, 112 | ratio=0.5, random_seed=133) 113 | fn_result = fn(img) 114 | original_result = quickshift(img, kernel_size=3, max_dist=6, ratio=0.5, 115 | random_seed=133) 116 | 117 | # same segments 118 | self.assertTrue(np.array_equal(fn_result, original_result)) 119 | 120 | def test_instanciate_slic(self): 121 | pass 122 | 123 | def test_instanciate_felzenszwalb(self): 124 | pass 125 | 126 | 127 | if __name__ == '__main__': 128 | unittest.main() 129 | -------------------------------------------------------------------------------- /slime/js/predicted_value.js: -------------------------------------------------------------------------------- 1 | import d3 from 'd3'; 2 | import {range, sortBy} from 'lodash'; 3 | 4 | class PredictedValue { 5 | // svg: d3 object with the svg in question 6 | // class_names: array of class names 7 | // predict_probas: array of prediction probabilities 8 | constructor(svg, predicted_value, min_value, max_value, title='Predicted value', log_coords = false) { 9 | 10 | if (min_value == max_value){ 11 | var width_proportion = 1.0; 12 | } else { 13 | var width_proportion = (predicted_value - min_value) / (max_value - min_value); 14 | } 15 | 16 | 17 | let width = parseInt(svg.style('width')) 18 | 19 | this.color = d3.scale.category10() 20 | this.color('predicted_value') 21 | // + 2 is due to it being a float 22 | let num_digits = Math.floor(Math.max(Math.log10(Math.abs(min_value)), Math.log10(Math.abs(max_value)))) + 2 23 | num_digits = Math.max(num_digits, 3) 24 | 25 | let corner_width = 12 * num_digits; 26 | let corner_padding = 5.5 * num_digits; 27 | let bar_x = corner_width + corner_padding; 28 | let bar_width = width - corner_width * 2 - corner_padding * 2; 29 | let x_scale = d3.scale.linear().range([0, bar_width]); 30 | let bar_height = 17; 31 | let bar_yshift= title === '' ? 0 : 35; 32 | let n_bars = 1; 33 | let this_object = this; 34 | if (title !== '') { 35 | svg.append('text') 36 | .text(title) 37 | .attr('x', 20) 38 | .attr('y', 20); 39 | } 40 | let bar_y = bar_yshift; 41 | let bar = svg.append("g"); 42 | 43 | //filled in bar representing predicted value in range 44 | let rect = bar.append("rect"); 45 | rect.attr("x", bar_x) 46 | .attr("y", bar_y) 47 | .attr("height", bar_height) 48 | .attr("width", x_scale(width_proportion)) 49 | .style("fill", this.color); 50 | 51 | //empty box representing range 52 | bar.append("rect").attr("x", bar_x) 53 | .attr("y", bar_y) 54 | .attr("height", bar_height) 55 | .attr("width",x_scale(1)) 56 | .attr("fill-opacity", 0) 57 | .attr("stroke", "black"); 58 | let text = bar.append("text"); 59 | text.classed("prob_text", true); 60 | text.attr("y", bar_y + bar_height - 3).attr("fill", "black").style("font", "14px tahoma, sans-serif"); 61 | 62 | 63 | //text for min value 64 | text = bar.append("text"); 65 | text.attr("x", bar_x - corner_padding) 66 | .attr("y", bar_y + bar_height - 3) 67 | .attr("fill", "black") 68 | .attr("text-anchor", "end") 69 | .style("font", "14px tahoma, sans-serif") 70 | .text(min_value.toFixed(2)); 71 | 72 | //text for range min annotation 73 | let v_adjust_min_value_annotation = text.node().getBBox().height; 74 | text = bar.append("text"); 75 | text.attr("x", bar_x - corner_padding) 76 | .attr("y", bar_y + bar_height - 3 + v_adjust_min_value_annotation) 77 | .attr("fill", "black") 78 | .attr("text-anchor", "end") 79 | .style("font", "14px tahoma, sans-serif") 80 | .text("(min)"); 81 | 82 | 83 | //text for predicted value 84 | // console.log('bar height: ' + bar_height) 85 | text = bar.append("text"); 86 | text.text(predicted_value.toFixed(2)); 87 | // let h_adjust_predicted_value_text = text.node().getBBox().width / 2; 88 | let v_adjust_predicted_value_text = text.node().getBBox().height; 89 | text.attr("x", bar_x + x_scale(width_proportion)) 90 | .attr("y", bar_y + bar_height + v_adjust_predicted_value_text) 91 | .attr("fill", "black") 92 | .attr("text-anchor", "middle") 93 | .style("font", "14px tahoma, sans-serif") 94 | 95 | 96 | 97 | 98 | 99 | //text for max value 100 | text = bar.append("text"); 101 | text.text(max_value.toFixed(2)); 102 | // let h_adjust = text.node().getBBox().width; 103 | text.attr("x", bar_x + bar_width + corner_padding) 104 | .attr("y", bar_y + bar_height - 3) 105 | .attr("fill", "black") 106 | .attr("text-anchor", "begin") 107 | .style("font", "14px tahoma, sans-serif"); 108 | 109 | 110 | //text for range max annotation 111 | let v_adjust_max_value_annotation = text.node().getBBox().height; 112 | text = bar.append("text"); 113 | text.attr("x", bar_x + bar_width + corner_padding) 114 | .attr("y", bar_y + bar_height - 3 + v_adjust_min_value_annotation) 115 | .attr("fill", "black") 116 | .attr("text-anchor", "begin") 117 | .style("font", "14px tahoma, sans-serif") 118 | .text("(max)"); 119 | 120 | 121 | //readjust svg size 122 | // let svg_width = width + 1 * h_adjust; 123 | // svg.style('width', svg_width + 'px'); 124 | 125 | this.svg_height = n_bars * (bar_height) + bar_yshift + (2 * text.node().getBBox().height) + 10; 126 | svg.style('height', this.svg_height + 'px'); 127 | if (log_coords) { 128 | console.log("svg width: " + svg_width); 129 | console.log("svg height: " + this.svg_height); 130 | console.log("bar_y: " + bar_y); 131 | console.log("bar_x: " + bar_x); 132 | console.log("Min value: " + min_value); 133 | console.log("Max value: " + max_value); 134 | console.log("Pred value: " + predicted_value); 135 | } 136 | } 137 | } 138 | 139 | 140 | export default PredictedValue; 141 | -------------------------------------------------------------------------------- /slime/js/explanation.js: -------------------------------------------------------------------------------- 1 | import d3 from 'd3'; 2 | import Barchart from './bar_chart.js'; 3 | import {range, sortBy} from 'lodash'; 4 | class Explanation { 5 | constructor(class_names) { 6 | this.names = class_names; 7 | if (class_names.length < 10) { 8 | this.colors = d3.scale.category10().domain(this.names); 9 | this.colors_i = d3.scale.category10().domain(range(this.names.length)); 10 | } 11 | else { 12 | this.colors = d3.scale.category20().domain(this.names); 13 | this.colors_i = d3.scale.category20().domain(range(this.names.length)); 14 | } 15 | } 16 | // exp: [(feature-name, weight), ...] 17 | // label: int 18 | // div: d3 selection 19 | show(exp, label, div) { 20 | let svg = div.append('svg').style('width', '100%'); 21 | let colors=['#5F9EA0', this.colors_i(label)]; 22 | let names = [`NOT ${this.names[label]}`, this.names[label]]; 23 | if (this.names.length == 2) { 24 | colors=[this.colors_i(0), this.colors_i(1)]; 25 | names = this.names; 26 | } 27 | let plot = new Barchart(svg, exp, true, names, colors, true, 10); 28 | svg.style('height', plot.svg_height + 'px'); 29 | } 30 | // exp has all ocurrences of words, with start index and weight: 31 | // exp = [('word', 132, -0.13), ('word3', 111, 1.3) 32 | show_raw_text(exp, label, raw, div, opacity=true) { 33 | //let colors=['#5F9EA0', this.colors(this.exp['class'])]; 34 | let colors=['#5F9EA0', this.colors_i(label)]; 35 | if (this.names.length == 2) { 36 | colors=[this.colors_i(0), this.colors_i(1)]; 37 | } 38 | let word_lists = [[], []]; 39 | let max_weight = -1; 40 | for (let [word, start, weight] of exp) { 41 | if (weight > 0) { 42 | word_lists[1].push([start, start + word.length, weight]); 43 | } 44 | else { 45 | word_lists[0].push([start, start + word.length, -weight]); 46 | } 47 | max_weight = Math.max(max_weight, Math.abs(weight)); 48 | } 49 | if (!opacity) { 50 | max_weight = 0; 51 | } 52 | this.display_raw_text(div, raw, word_lists, colors, max_weight, true); 53 | } 54 | // exp is list of (feature_name, value, weight) 55 | show_raw_tabular(exp, label, div) { 56 | div.classed('lime', true).classed('table_div', true); 57 | let colors=['#5F9EA0', this.colors_i(label)]; 58 | if (this.names.length == 2) { 59 | colors=[this.colors_i(0), this.colors_i(1)]; 60 | } 61 | const table = div.append('table'); 62 | const thead = table.append('tr'); 63 | thead.append('td').text('Feature'); 64 | thead.append('td').text('Value'); 65 | thead.style('color', 'black') 66 | .style('font-size', '20px'); 67 | for (let [fname, value, weight] of exp) { 68 | const tr = table.append('tr'); 69 | tr.style('border-style', 'hidden'); 70 | tr.append('td').text(fname); 71 | tr.append('td').text(value); 72 | if (weight > 0) { 73 | tr.style('background-color', colors[1]); 74 | } 75 | else if (weight < 0) { 76 | tr.style('background-color', colors[0]); 77 | } 78 | else { 79 | tr.style('color', 'black'); 80 | } 81 | } 82 | } 83 | hexToRgb(hex) { 84 | let result = /^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(hex); 85 | return result ? { 86 | r: parseInt(result[1], 16), 87 | g: parseInt(result[2], 16), 88 | b: parseInt(result[3], 16) 89 | } : null; 90 | } 91 | applyAlpha(hex, alpha) { 92 | let components = this.hexToRgb(hex); 93 | return 'rgba(' + components.r + "," + components.g + "," + components.b + "," + alpha.toFixed(3) + ")" 94 | } 95 | // sord_lists is an array of arrays, of length (colors). if with_positions is true, 96 | // word_lists is an array of [start,end] positions instead 97 | display_raw_text(div, raw_text, word_lists=[], colors=[], max_weight=1, positions=false) { 98 | div.classed('lime', true).classed('text_div', true); 99 | div.append('h3').text('Text with highlighted words'); 100 | let highlight_tag = 'span'; 101 | let text_span = div.append('span').style('white-space', 'pre-wrap').text(raw_text); 102 | let position_lists = word_lists; 103 | if (!positions) { 104 | position_lists = this.wordlists_to_positions(word_lists, raw_text); 105 | } 106 | let objects = [] 107 | for (let i of range(position_lists.length)) { 108 | position_lists[i].map(x => objects.push({'label' : i, 'start': x[0], 'end': x[1], 'alpha': max_weight === 0 ? 1: x[2] / max_weight})); 109 | } 110 | objects = sortBy(objects, x=>x['start']); 111 | let node = text_span.node().childNodes[0]; 112 | let subtract = 0; 113 | for (let obj of objects) { 114 | let word = raw_text.slice(obj.start, obj.end); 115 | let start = obj.start - subtract; 116 | let end = obj.end - subtract; 117 | let match = document.createElement(highlight_tag); 118 | match.appendChild(document.createTextNode(word)); 119 | match.style.backgroundColor = this.applyAlpha(colors[obj.label], obj.alpha); 120 | let after = node.splitText(start); 121 | after.nodeValue = after.nodeValue.substring(word.length); 122 | node.parentNode.insertBefore(match, after); 123 | subtract += end; 124 | node = after; 125 | } 126 | } 127 | wordlists_to_positions(word_lists, raw_text) { 128 | let ret = [] 129 | for(let words of word_lists) { 130 | if (words.length === 0) { 131 | ret.push([]); 132 | continue; 133 | } 134 | let re = new RegExp("\\b(" + words.join('|') + ")\\b",'gm') 135 | let temp; 136 | let list = []; 137 | while ((temp = re.exec(raw_text)) !== null) { 138 | list.push([temp.index, temp.index + temp[0].length]); 139 | } 140 | ret.push(list); 141 | } 142 | return ret; 143 | } 144 | 145 | } 146 | export default Explanation; 147 | -------------------------------------------------------------------------------- /slime/submodular_pick.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | 4 | 5 | class SubmodularPick(object): 6 | """Class for submodular pick 7 | 8 | Saves a representative sample of explanation objects using SP-LIME, 9 | as well as saving all generated explanations 10 | 11 | First, a collection of candidate explanations are generated 12 | (see explain_instance). From these candidates, num_exps_desired are 13 | chosen using submodular pick. (see marcotcr et al paper).""" 14 | 15 | def __init__(self, 16 | explainer, 17 | data, 18 | predict_fn, 19 | method='sample', 20 | sample_size=1000, 21 | num_exps_desired=5, 22 | num_features=10, 23 | **kwargs): 24 | 25 | """ 26 | Args: 27 | data: a numpy array where each row is a single input into predict_fn 28 | predict_fn: prediction function. For classifiers, this should be a 29 | function that takes a numpy array and outputs prediction 30 | probabilities. For regressors, this takes a numpy array and 31 | returns the predictions. For ScikitClassifiers, this is 32 | `classifier.predict_proba()`. For ScikitRegressors, this 33 | is `regressor.predict()`. The prediction function needs to work 34 | on multiple feature vectors (the vectors randomly perturbed 35 | from the data_row). 36 | method: The method to use to generate candidate explanations 37 | method == 'sample' will sample the data uniformly at 38 | random. The sample size is given by sample_size. Otherwise 39 | if method == 'full' then explanations will be generated for the 40 | entire data. l 41 | sample_size: The number of instances to explain if method == 'sample' 42 | num_exps_desired: The number of explanation objects returned 43 | num_features: maximum number of features present in explanation 44 | 45 | 46 | Sets value: 47 | sp_explanations: A list of explanation objects that has a high coverage 48 | explanations: All the candidate explanations saved for potential future use. 49 | """ 50 | 51 | top_labels = kwargs.get('top_labels', 1) 52 | if 'top_labels' in kwargs: 53 | del kwargs['top_labels'] 54 | # Parse args 55 | if method == 'sample': 56 | if sample_size > len(data): 57 | warnings.warn("""Requested sample size larger than 58 | size of input data. Using all data""") 59 | sample_size = len(data) 60 | all_indices = np.arange(len(data)) 61 | np.random.shuffle(all_indices) 62 | sample_indices = all_indices[:sample_size] 63 | elif method == 'full': 64 | sample_indices = np.arange(len(data)) 65 | else: 66 | raise ValueError('Method must be \'sample\' or \'full\'') 67 | 68 | # Generate Explanations 69 | self.explanations = [] 70 | for i in sample_indices: 71 | self.explanations.append( 72 | explainer.explain_instance( 73 | data[i], predict_fn, num_features=num_features, 74 | top_labels=top_labels, 75 | **kwargs)) 76 | # Error handling 77 | try: 78 | num_exps_desired = int(num_exps_desired) 79 | except TypeError: 80 | return("Requested number of explanations should be an integer") 81 | if num_exps_desired > len(self.explanations): 82 | warnings.warn("""Requested number of explanations larger than 83 | total number of explanations, returning all 84 | explanations instead.""") 85 | num_exps_desired = min(num_exps_desired, len(self.explanations)) 86 | 87 | # Find all the explanation model features used. Defines the dimension d' 88 | features_dict = {} 89 | feature_iter = 0 90 | for exp in self.explanations: 91 | labels = exp.available_labels() if exp.mode == 'classification' else [1] 92 | for label in labels: 93 | for feature, _ in exp.as_list(label=label): 94 | if feature not in features_dict.keys(): 95 | features_dict[feature] = (feature_iter) 96 | feature_iter += 1 97 | d_prime = len(features_dict.keys()) 98 | 99 | # Create the n x d' dimensional 'explanation matrix', W 100 | W = np.zeros((len(self.explanations), d_prime)) 101 | for i, exp in enumerate(self.explanations): 102 | labels = exp.available_labels() if exp.mode == 'classification' else [1] 103 | for label in labels: 104 | for feature, value in exp.as_list(label): 105 | W[i, features_dict[feature]] += value 106 | 107 | # Create the global importance vector, I_j described in the paper 108 | importance = np.sum(abs(W), axis=0)**.5 109 | 110 | # Now run the SP-LIME greedy algorithm 111 | remaining_indices = set(range(len(self.explanations))) 112 | V = [] 113 | for _ in range(num_exps_desired): 114 | best = 0 115 | best_ind = None 116 | current = 0 117 | for i in remaining_indices: 118 | current = np.dot( 119 | (np.sum(abs(W)[V + [i]], axis=0) > 0), importance 120 | ) # coverage function 121 | if current >= best: 122 | best = current 123 | best_ind = i 124 | V.append(best_ind) 125 | remaining_indices -= {best_ind} 126 | 127 | self.sp_explanations = [self.explanations[i] for i in V] 128 | self.V = V 129 | -------------------------------------------------------------------------------- /slime/tests/test_lime_text.py: -------------------------------------------------------------------------------- 1 | import re 2 | import unittest 3 | 4 | import sklearn # noqa 5 | from sklearn.datasets import fetch_20newsgroups 6 | from sklearn.feature_extraction.text import TfidfVectorizer 7 | from sklearn.metrics import f1_score 8 | from sklearn.naive_bayes import MultinomialNB 9 | from sklearn.pipeline import make_pipeline 10 | 11 | import numpy as np 12 | 13 | from lime.lime_text import LimeTextExplainer 14 | from lime.lime_text import IndexedCharacters, IndexedString 15 | 16 | 17 | class TestLimeText(unittest.TestCase): 18 | 19 | def test_lime_text_explainer_good_regressor(self): 20 | categories = ['alt.atheism', 'soc.religion.christian'] 21 | newsgroups_train = fetch_20newsgroups(subset='train', 22 | categories=categories) 23 | newsgroups_test = fetch_20newsgroups(subset='test', 24 | categories=categories) 25 | class_names = ['atheism', 'christian'] 26 | vectorizer = TfidfVectorizer(lowercase=False) 27 | train_vectors = vectorizer.fit_transform(newsgroups_train.data) 28 | test_vectors = vectorizer.transform(newsgroups_test.data) 29 | nb = MultinomialNB(alpha=.01) 30 | nb.fit(train_vectors, newsgroups_train.target) 31 | pred = nb.predict(test_vectors) 32 | f1_score(newsgroups_test.target, pred, average='weighted') 33 | c = make_pipeline(vectorizer, nb) 34 | explainer = LimeTextExplainer(class_names=class_names) 35 | idx = 83 36 | exp = explainer.explain_instance(newsgroups_test.data[idx], 37 | c.predict_proba, num_features=6) 38 | self.assertIsNotNone(exp) 39 | self.assertEqual(6, len(exp.as_list())) 40 | 41 | def test_lime_text_tabular_equal_random_state(self): 42 | categories = ['alt.atheism', 'soc.religion.christian'] 43 | newsgroups_train = fetch_20newsgroups(subset='train', 44 | categories=categories) 45 | newsgroups_test = fetch_20newsgroups(subset='test', 46 | categories=categories) 47 | class_names = ['atheism', 'christian'] 48 | vectorizer = TfidfVectorizer(lowercase=False) 49 | train_vectors = vectorizer.fit_transform(newsgroups_train.data) 50 | test_vectors = vectorizer.transform(newsgroups_test.data) 51 | nb = MultinomialNB(alpha=.01) 52 | nb.fit(train_vectors, newsgroups_train.target) 53 | pred = nb.predict(test_vectors) 54 | f1_score(newsgroups_test.target, pred, average='weighted') 55 | c = make_pipeline(vectorizer, nb) 56 | 57 | explainer = LimeTextExplainer(class_names=class_names, random_state=10) 58 | exp_1 = explainer.explain_instance(newsgroups_test.data[83], 59 | c.predict_proba, num_features=6) 60 | 61 | explainer = LimeTextExplainer(class_names=class_names, random_state=10) 62 | exp_2 = explainer.explain_instance(newsgroups_test.data[83], 63 | c.predict_proba, num_features=6) 64 | 65 | self.assertTrue(exp_1.as_map() == exp_2.as_map()) 66 | 67 | def test_lime_text_tabular_not_equal_random_state(self): 68 | categories = ['alt.atheism', 'soc.religion.christian'] 69 | newsgroups_train = fetch_20newsgroups(subset='train', 70 | categories=categories) 71 | newsgroups_test = fetch_20newsgroups(subset='test', 72 | categories=categories) 73 | class_names = ['atheism', 'christian'] 74 | vectorizer = TfidfVectorizer(lowercase=False) 75 | train_vectors = vectorizer.fit_transform(newsgroups_train.data) 76 | test_vectors = vectorizer.transform(newsgroups_test.data) 77 | nb = MultinomialNB(alpha=.01) 78 | nb.fit(train_vectors, newsgroups_train.target) 79 | pred = nb.predict(test_vectors) 80 | f1_score(newsgroups_test.target, pred, average='weighted') 81 | c = make_pipeline(vectorizer, nb) 82 | 83 | explainer = LimeTextExplainer( 84 | class_names=class_names, random_state=10) 85 | exp_1 = explainer.explain_instance(newsgroups_test.data[83], 86 | c.predict_proba, num_features=6) 87 | 88 | explainer = LimeTextExplainer( 89 | class_names=class_names, random_state=20) 90 | exp_2 = explainer.explain_instance(newsgroups_test.data[83], 91 | c.predict_proba, num_features=6) 92 | 93 | self.assertFalse(exp_1.as_map() == exp_2.as_map()) 94 | 95 | def test_indexed_characters_bow(self): 96 | s = 'Please, take your time' 97 | inverse_vocab = ['P', 'l', 'e', 'a', 's', ',', ' ', 't', 'k', 'y', 'o', 'u', 'r', 'i', 'm'] 98 | positions = [[0], [1], [2, 5, 11, 21], [3, 9], 99 | [4], [6], [7, 12, 17], [8, 18], [10], 100 | [13], [14], [15], [16], [19], [20]] 101 | ic = IndexedCharacters(s) 102 | 103 | self.assertTrue(np.array_equal(ic.as_np, np.array(list(s)))) 104 | self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s)))) 105 | self.assertTrue(ic.inverse_vocab == inverse_vocab) 106 | self.assertTrue(ic.positions == positions) 107 | 108 | def test_indexed_characters_not_bow(self): 109 | s = 'Please, take your time' 110 | 111 | ic = IndexedCharacters(s, bow=False) 112 | 113 | self.assertTrue(np.array_equal(ic.as_np, np.array(list(s)))) 114 | self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s)))) 115 | self.assertTrue(ic.inverse_vocab == list(s)) 116 | self.assertTrue(np.array_equal(ic.positions, np.arange(len(s)))) 117 | 118 | def test_indexed_string_regex(self): 119 | s = 'Please, take your time. Please' 120 | tokenized_string = np.array( 121 | ['Please', ', ', 'take', ' ', 'your', ' ', 'time', '. ', 'Please']) 122 | inverse_vocab = ['Please', 'take', 'your', 'time'] 123 | start_positions = [0, 6, 8, 12, 13, 17, 18, 22, 24] 124 | positions = [[0, 8], [2], [4], [6]] 125 | indexed_string = IndexedString(s) 126 | 127 | self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string)) 128 | self.assertTrue(np.array_equal(indexed_string.string_start, start_positions)) 129 | self.assertTrue(indexed_string.inverse_vocab == inverse_vocab) 130 | self.assertTrue(np.array_equal(indexed_string.positions, positions)) 131 | 132 | def test_indexed_string_callable(self): 133 | s = 'aabbccddaa' 134 | 135 | def tokenizer(string): 136 | return [string[i] + string[i + 1] for i in range(0, len(string) - 1, 2)] 137 | 138 | tokenized_string = np.array(['aa', 'bb', 'cc', 'dd', 'aa']) 139 | inverse_vocab = ['aa', 'bb', 'cc', 'dd'] 140 | start_positions = [0, 2, 4, 6, 8] 141 | positions = [[0, 4], [1], [2], [3]] 142 | indexed_string = IndexedString(s, tokenizer) 143 | 144 | self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string)) 145 | self.assertTrue(np.array_equal(indexed_string.string_start, start_positions)) 146 | self.assertTrue(indexed_string.inverse_vocab == inverse_vocab) 147 | self.assertTrue(np.array_equal(indexed_string.positions, positions)) 148 | 149 | def test_indexed_string_inverse_removing_tokenizer(self): 150 | s = 'This is a good movie. This, it is a great movie.' 151 | 152 | def tokenizer(string): 153 | return re.split(r'(?:\W+)|$', string) 154 | 155 | indexed_string = IndexedString(s, tokenizer) 156 | 157 | self.assertEqual(s, indexed_string.inverse_removing([])) 158 | 159 | def test_indexed_string_inverse_removing_regex(self): 160 | s = 'This is a good movie. This is a great movie' 161 | indexed_string = IndexedString(s) 162 | 163 | self.assertEqual(s, indexed_string.inverse_removing([])) 164 | 165 | 166 | if __name__ == '__main__': 167 | unittest.main() 168 | -------------------------------------------------------------------------------- /slime/tests/test_discretize.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import TestCase 3 | 4 | import numpy as np 5 | 6 | from sklearn.datasets import load_iris 7 | 8 | from lime.discretize import QuartileDiscretizer, DecileDiscretizer, EntropyDiscretizer 9 | 10 | 11 | class TestDiscretize(TestCase): 12 | 13 | def setUp(self): 14 | iris = load_iris() 15 | 16 | self.feature_names = iris.feature_names 17 | self.x = iris.data 18 | self.y = iris.target 19 | 20 | def check_random_state_for_discretizer_class(self, DiscretizerClass): 21 | # ---------------------------------------------------------------------- 22 | # -----------Check if the same random_state produces the same----------- 23 | # -------------results for different discretizer instances.------------- 24 | # ---------------------------------------------------------------------- 25 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, 26 | random_state=10) 27 | x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) 28 | 29 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, 30 | random_state=10) 31 | x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) 32 | 33 | self.assertEqual((x_1 == x_2).sum(), x_1.shape[0] * x_1.shape[1]) 34 | 35 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, 36 | random_state=np.random.RandomState(10)) 37 | x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) 38 | 39 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, 40 | random_state=np.random.RandomState(10)) 41 | x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) 42 | 43 | self.assertEqual((x_1 == x_2).sum(), x_1.shape[0] * x_1.shape[1]) 44 | 45 | # ---------------------------------------------------------------------- 46 | # ---------Check if two different random_state values produces---------- 47 | # -------different results for different discretizers instances.-------- 48 | # ---------------------------------------------------------------------- 49 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, 50 | random_state=10) 51 | x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) 52 | 53 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, 54 | random_state=20) 55 | x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) 56 | 57 | self.assertFalse((x_1 == x_2).sum() == x_1.shape[0] * x_1.shape[1]) 58 | 59 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, 60 | random_state=np.random.RandomState(10)) 61 | x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) 62 | 63 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, 64 | random_state=np.random.RandomState(20)) 65 | x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) 66 | 67 | self.assertFalse((x_1 == x_2).sum() == x_1.shape[0] * x_1.shape[1]) 68 | 69 | def test_random_state(self): 70 | self.check_random_state_for_discretizer_class(QuartileDiscretizer) 71 | 72 | self.check_random_state_for_discretizer_class(DecileDiscretizer) 73 | 74 | self.check_random_state_for_discretizer_class(EntropyDiscretizer) 75 | 76 | def test_feature_names_1(self): 77 | self.maxDiff = None 78 | discretizer = QuartileDiscretizer(self.x, [], self.feature_names, 79 | self.y, random_state=10) 80 | self.assertDictEqual( 81 | {0: ['sepal length (cm) <= 5.10', 82 | '5.10 < sepal length (cm) <= 5.80', 83 | '5.80 < sepal length (cm) <= 6.40', 84 | 'sepal length (cm) > 6.40'], 85 | 1: ['sepal width (cm) <= 2.80', 86 | '2.80 < sepal width (cm) <= 3.00', 87 | '3.00 < sepal width (cm) <= 3.30', 88 | 'sepal width (cm) > 3.30'], 89 | 2: ['petal length (cm) <= 1.60', 90 | '1.60 < petal length (cm) <= 4.35', 91 | '4.35 < petal length (cm) <= 5.10', 92 | 'petal length (cm) > 5.10'], 93 | 3: ['petal width (cm) <= 0.30', 94 | '0.30 < petal width (cm) <= 1.30', 95 | '1.30 < petal width (cm) <= 1.80', 96 | 'petal width (cm) > 1.80']}, 97 | discretizer.names) 98 | 99 | def test_feature_names_2(self): 100 | self.maxDiff = None 101 | discretizer = DecileDiscretizer(self.x, [], self.feature_names, self.y, 102 | random_state=10) 103 | self.assertDictEqual( 104 | {0: ['sepal length (cm) <= 4.80', 105 | '4.80 < sepal length (cm) <= 5.00', 106 | '5.00 < sepal length (cm) <= 5.27', 107 | '5.27 < sepal length (cm) <= 5.60', 108 | '5.60 < sepal length (cm) <= 5.80', 109 | '5.80 < sepal length (cm) <= 6.10', 110 | '6.10 < sepal length (cm) <= 6.30', 111 | '6.30 < sepal length (cm) <= 6.52', 112 | '6.52 < sepal length (cm) <= 6.90', 113 | 'sepal length (cm) > 6.90'], 114 | 1: ['sepal width (cm) <= 2.50', 115 | '2.50 < sepal width (cm) <= 2.70', 116 | '2.70 < sepal width (cm) <= 2.80', 117 | '2.80 < sepal width (cm) <= 3.00', 118 | '3.00 < sepal width (cm) <= 3.10', 119 | '3.10 < sepal width (cm) <= 3.20', 120 | '3.20 < sepal width (cm) <= 3.40', 121 | '3.40 < sepal width (cm) <= 3.61', 122 | 'sepal width (cm) > 3.61'], 123 | 2: ['petal length (cm) <= 1.40', 124 | '1.40 < petal length (cm) <= 1.50', 125 | '1.50 < petal length (cm) <= 1.70', 126 | '1.70 < petal length (cm) <= 3.90', 127 | '3.90 < petal length (cm) <= 4.35', 128 | '4.35 < petal length (cm) <= 4.64', 129 | '4.64 < petal length (cm) <= 5.00', 130 | '5.00 < petal length (cm) <= 5.32', 131 | '5.32 < petal length (cm) <= 5.80', 132 | 'petal length (cm) > 5.80'], 133 | 3: ['petal width (cm) <= 0.20', 134 | '0.20 < petal width (cm) <= 0.40', 135 | '0.40 < petal width (cm) <= 1.16', 136 | '1.16 < petal width (cm) <= 1.30', 137 | '1.30 < petal width (cm) <= 1.50', 138 | '1.50 < petal width (cm) <= 1.80', 139 | '1.80 < petal width (cm) <= 1.90', 140 | '1.90 < petal width (cm) <= 2.20', 141 | 'petal width (cm) > 2.20']}, 142 | discretizer.names) 143 | 144 | def test_feature_names_3(self): 145 | self.maxDiff = None 146 | discretizer = EntropyDiscretizer(self.x, [], self.feature_names, 147 | self.y, random_state=10) 148 | self.assertDictEqual( 149 | {0: ['sepal length (cm) <= 4.85', 150 | '4.85 < sepal length (cm) <= 5.45', 151 | '5.45 < sepal length (cm) <= 5.55', 152 | '5.55 < sepal length (cm) <= 5.85', 153 | '5.85 < sepal length (cm) <= 6.15', 154 | '6.15 < sepal length (cm) <= 7.05', 155 | 'sepal length (cm) > 7.05'], 156 | 1: ['sepal width (cm) <= 2.45', 157 | '2.45 < sepal width (cm) <= 2.95', 158 | '2.95 < sepal width (cm) <= 3.05', 159 | '3.05 < sepal width (cm) <= 3.35', 160 | '3.35 < sepal width (cm) <= 3.45', 161 | '3.45 < sepal width (cm) <= 3.55', 162 | 'sepal width (cm) > 3.55'], 163 | 2: ['petal length (cm) <= 2.45', 164 | '2.45 < petal length (cm) <= 4.45', 165 | '4.45 < petal length (cm) <= 4.75', 166 | '4.75 < petal length (cm) <= 5.15', 167 | 'petal length (cm) > 5.15'], 168 | 3: ['petal width (cm) <= 0.80', 169 | '0.80 < petal width (cm) <= 1.35', 170 | '1.35 < petal width (cm) <= 1.75', 171 | '1.75 < petal width (cm) <= 1.85', 172 | 'petal width (cm) > 1.85']}, 173 | discretizer.names) 174 | 175 | 176 | if __name__ == '__main__': 177 | unittest.main() 178 | -------------------------------------------------------------------------------- /slime/discretize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Discretizers classes, to be used in lime_tabular 3 | """ 4 | import numpy as np 5 | import sklearn 6 | import sklearn.tree 7 | import scipy 8 | from sklearn.utils import check_random_state 9 | from abc import ABCMeta, abstractmethod 10 | 11 | 12 | class BaseDiscretizer(): 13 | """ 14 | Abstract class - Build a class that inherits from this class to implement 15 | a custom discretizer. 16 | Method bins() is to be redefined in the child class, as it is the actual 17 | custom part of the discretizer. 18 | """ 19 | 20 | __metaclass__ = ABCMeta # abstract class 21 | 22 | def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None, 23 | data_stats=None): 24 | """Initializer 25 | Args: 26 | data: numpy 2d array 27 | categorical_features: list of indices (ints) corresponding to the 28 | categorical columns. These features will not be discretized. 29 | Everything else will be considered continuous, and will be 30 | discretized. 31 | categorical_names: map from int to list of names, where 32 | categorical_names[x][y] represents the name of the yth value of 33 | column x. 34 | feature_names: list of names (strings) corresponding to the columns 35 | in the training data. 36 | data_stats: must have 'means', 'stds', 'mins' and 'maxs', use this 37 | if you don't want these values to be computed from data 38 | """ 39 | self.to_discretize = ([x for x in range(data.shape[1]) 40 | if x not in categorical_features]) 41 | self.data_stats = data_stats 42 | self.names = {} 43 | self.lambdas = {} 44 | self.means = {} 45 | self.stds = {} 46 | self.mins = {} 47 | self.maxs = {} 48 | self.random_state = check_random_state(random_state) 49 | 50 | # To override when implementing a custom binning 51 | bins = self.bins(data, labels) 52 | bins = [np.unique(x) for x in bins] 53 | 54 | # Read the stats from data_stats if exists 55 | if data_stats: 56 | self.means = self.data_stats.get("means") 57 | self.stds = self.data_stats.get("stds") 58 | self.mins = self.data_stats.get("mins") 59 | self.maxs = self.data_stats.get("maxs") 60 | 61 | for feature, qts in zip(self.to_discretize, bins): 62 | n_bins = qts.shape[0] # Actually number of borders (= #bins-1) 63 | boundaries = np.min(data[:, feature]), np.max(data[:, feature]) 64 | name = feature_names[feature] 65 | 66 | self.names[feature] = ['%s <= %.2f' % (name, qts[0])] 67 | for i in range(n_bins - 1): 68 | self.names[feature].append('%.2f < %s <= %.2f' % 69 | (qts[i], name, qts[i + 1])) 70 | self.names[feature].append('%s > %.2f' % (name, qts[n_bins - 1])) 71 | 72 | self.lambdas[feature] = lambda x, qts=qts: np.searchsorted(qts, x) 73 | discretized = self.lambdas[feature](data[:, feature]) 74 | 75 | # If data stats are provided no need to compute the below set of details 76 | if data_stats: 77 | continue 78 | 79 | self.means[feature] = [] 80 | self.stds[feature] = [] 81 | for x in range(n_bins + 1): 82 | selection = data[discretized == x, feature] 83 | mean = 0 if len(selection) == 0 else np.mean(selection) 84 | self.means[feature].append(mean) 85 | std = 0 if len(selection) == 0 else np.std(selection) 86 | std += 0.00000000001 87 | self.stds[feature].append(std) 88 | self.mins[feature] = [boundaries[0]] + qts.tolist() 89 | self.maxs[feature] = qts.tolist() + [boundaries[1]] 90 | 91 | @abstractmethod 92 | def bins(self, data, labels): 93 | """ 94 | To be overridden 95 | Returns for each feature to discretize the boundaries 96 | that form each bin of the discretizer 97 | """ 98 | raise NotImplementedError("Must override bins() method") 99 | 100 | def discretize(self, data): 101 | """Discretizes the data. 102 | Args: 103 | data: numpy 2d or 1d array 104 | Returns: 105 | numpy array of same dimension, discretized. 106 | """ 107 | ret = data.copy() 108 | for feature in self.lambdas: 109 | if len(data.shape) == 1: 110 | ret[feature] = int(self.lambdas[feature](ret[feature])) 111 | else: 112 | ret[:, feature] = self.lambdas[feature]( 113 | ret[:, feature]).astype(int) 114 | return ret 115 | 116 | def get_undiscretize_values(self, feature, values): 117 | mins = np.array(self.mins[feature])[values] 118 | maxs = np.array(self.maxs[feature])[values] 119 | 120 | means = np.array(self.means[feature])[values] 121 | stds = np.array(self.stds[feature])[values] 122 | minz = (mins - means) / stds 123 | maxz = (maxs - means) / stds 124 | min_max_unequal = (minz != maxz) 125 | 126 | ret = minz 127 | ret[np.where(min_max_unequal)] = scipy.stats.truncnorm.rvs( 128 | minz[min_max_unequal], 129 | maxz[min_max_unequal], 130 | loc=means[min_max_unequal], 131 | scale=stds[min_max_unequal], 132 | random_state=self.random_state 133 | ) 134 | return ret 135 | 136 | def undiscretize(self, data): 137 | ret = data.copy() 138 | for feature in self.means: 139 | if len(data.shape) == 1: 140 | ret[feature] = self.get_undiscretize_values( 141 | feature, ret[feature].astype(int).reshape(-1, 1) 142 | ) 143 | else: 144 | ret[:, feature] = self.get_undiscretize_values( 145 | feature, ret[:, feature].astype(int) 146 | ) 147 | return ret 148 | 149 | 150 | class StatsDiscretizer(BaseDiscretizer): 151 | """ 152 | Class to be used to supply the data stats info when discretize_continuous is true 153 | """ 154 | 155 | def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None, 156 | data_stats=None): 157 | 158 | BaseDiscretizer.__init__(self, data, categorical_features, 159 | feature_names, labels=labels, 160 | random_state=random_state, 161 | data_stats=data_stats) 162 | 163 | def bins(self, data, labels): 164 | bins_from_stats = self.data_stats.get("bins") 165 | bins = [] 166 | if bins_from_stats is not None: 167 | for feature in self.to_discretize: 168 | bins_from_stats_feature = bins_from_stats.get(feature) 169 | if bins_from_stats_feature is not None: 170 | qts = np.array(bins_from_stats_feature) 171 | bins.append(qts) 172 | return bins 173 | 174 | 175 | class QuartileDiscretizer(BaseDiscretizer): 176 | def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None): 177 | 178 | BaseDiscretizer.__init__(self, data, categorical_features, 179 | feature_names, labels=labels, 180 | random_state=random_state) 181 | 182 | def bins(self, data, labels): 183 | bins = [] 184 | for feature in self.to_discretize: 185 | qts = np.array(np.percentile(data[:, feature], [25, 50, 75])) 186 | bins.append(qts) 187 | return bins 188 | 189 | 190 | class DecileDiscretizer(BaseDiscretizer): 191 | def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None): 192 | BaseDiscretizer.__init__(self, data, categorical_features, 193 | feature_names, labels=labels, 194 | random_state=random_state) 195 | 196 | def bins(self, data, labels): 197 | bins = [] 198 | for feature in self.to_discretize: 199 | qts = np.array(np.percentile(data[:, feature], 200 | [10, 20, 30, 40, 50, 60, 70, 80, 90])) 201 | bins.append(qts) 202 | return bins 203 | 204 | 205 | class EntropyDiscretizer(BaseDiscretizer): 206 | def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None): 207 | if(labels is None): 208 | raise ValueError('Labels must be not None when using \ 209 | EntropyDiscretizer') 210 | BaseDiscretizer.__init__(self, data, categorical_features, 211 | feature_names, labels=labels, 212 | random_state=random_state) 213 | 214 | def bins(self, data, labels): 215 | bins = [] 216 | for feature in self.to_discretize: 217 | # Entropy splitting / at most 8 bins so max_depth=3 218 | dt = sklearn.tree.DecisionTreeClassifier(criterion='entropy', 219 | max_depth=3, 220 | random_state=self.random_state) 221 | x = np.reshape(data[:, feature], (-1, 1)) 222 | dt.fit(x, labels) 223 | qts = dt.tree_.threshold[np.where(dt.tree_.children_left > -1)] 224 | 225 | if qts.shape[0] == 0: 226 | qts = np.array([np.median(data[:, feature])]) 227 | else: 228 | qts = np.sort(qts) 229 | 230 | bins.append(qts) 231 | 232 | return bins 233 | -------------------------------------------------------------------------------- /slime/explanation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Explanation class, with visualization functions. 3 | """ 4 | from io import open 5 | import os 6 | import os.path 7 | import json 8 | import string 9 | import numpy as np 10 | 11 | from .exceptions import LimeError 12 | 13 | from sklearn.utils import check_random_state 14 | 15 | 16 | def id_generator(size=15, random_state=None): 17 | """Helper function to generate random div ids. This is useful for embedding 18 | HTML into ipython notebooks.""" 19 | chars = list(string.ascii_uppercase + string.digits) 20 | return ''.join(random_state.choice(chars, size, replace=True)) 21 | 22 | 23 | class DomainMapper(object): 24 | """Class for mapping features to the specific domain. 25 | 26 | The idea is that there would be a subclass for each domain (text, tables, 27 | images, etc), so that we can have a general Explanation class, and separate 28 | out the specifics of visualizing features in here. 29 | """ 30 | 31 | def __init__(self): 32 | pass 33 | 34 | def map_exp_ids(self, exp, **kwargs): 35 | """Maps the feature ids to concrete names. 36 | 37 | Default behaviour is the identity function. Subclasses can implement 38 | this as they see fit. 39 | 40 | Args: 41 | exp: list of tuples [(id, weight), (id,weight)] 42 | kwargs: optional keyword arguments 43 | 44 | Returns: 45 | exp: list of tuples [(name, weight), (name, weight)...] 46 | """ 47 | return exp 48 | 49 | def visualize_instance_html(self, 50 | exp, 51 | label, 52 | div_name, 53 | exp_object_name, 54 | **kwargs): 55 | """Produces html for visualizing the instance. 56 | 57 | Default behaviour does nothing. Subclasses can implement this as they 58 | see fit. 59 | 60 | Args: 61 | exp: list of tuples [(id, weight), (id,weight)] 62 | label: label id (integer) 63 | div_name: name of div object to be used for rendering(in js) 64 | exp_object_name: name of js explanation object 65 | kwargs: optional keyword arguments 66 | 67 | Returns: 68 | js code for visualizing the instance 69 | """ 70 | return '' 71 | 72 | 73 | class Explanation(object): 74 | """Object returned by explainers.""" 75 | 76 | def __init__(self, 77 | domain_mapper, 78 | mode='classification', 79 | class_names=None, 80 | random_state=None): 81 | """ 82 | 83 | Initializer. 84 | 85 | Args: 86 | domain_mapper: must inherit from DomainMapper class 87 | type: "classification" or "regression" 88 | class_names: list of class names (only used for classification) 89 | random_state: an integer or numpy.RandomState that will be used to 90 | generate random numbers. If None, the random state will be 91 | initialized using the internal numpy seed. 92 | """ 93 | self.random_state = random_state 94 | self.mode = mode 95 | self.domain_mapper = domain_mapper 96 | self.local_exp = {} 97 | self.intercept = {} 98 | self.score = {} 99 | self.local_pred = {} 100 | if mode == 'classification': 101 | self.class_names = class_names 102 | self.top_labels = None 103 | self.predict_proba = None 104 | elif mode == 'regression': 105 | self.class_names = ['negative', 'positive'] 106 | self.predicted_value = None 107 | self.min_value = 0.0 108 | self.max_value = 1.0 109 | self.dummy_label = 1 110 | else: 111 | raise LimeError('Invalid explanation mode "{}". ' 112 | 'Should be either "classification" ' 113 | 'or "regression".'.format(mode)) 114 | 115 | def available_labels(self): 116 | """ 117 | Returns the list of classification labels for which we have any explanations. 118 | """ 119 | try: 120 | assert self.mode == "classification" 121 | except AssertionError: 122 | raise NotImplementedError('Not supported for regression explanations.') 123 | else: 124 | ans = self.top_labels if self.top_labels else self.local_exp.keys() 125 | return list(ans) 126 | 127 | def as_list(self, label=1, **kwargs): 128 | """Returns the explanation as a list. 129 | 130 | Args: 131 | label: desired label. If you ask for a label for which an 132 | explanation wasn't computed, will throw an exception. 133 | Will be ignored for regression explanations. 134 | kwargs: keyword arguments, passed to domain_mapper 135 | 136 | Returns: 137 | list of tuples (representation, weight), where representation is 138 | given by domain_mapper. Weight is a float. 139 | """ 140 | label_to_use = label if self.mode == "classification" else self.dummy_label 141 | ans = self.domain_mapper.map_exp_ids(self.local_exp[label_to_use], **kwargs) 142 | ans = [(x[0], float(x[1])) for x in ans] 143 | return ans 144 | 145 | def as_map(self): 146 | """Returns the map of explanations. 147 | 148 | Returns: 149 | Map from label to list of tuples (feature_id, weight). 150 | """ 151 | return self.local_exp 152 | 153 | def as_pyplot_figure(self, label=1, **kwargs): 154 | """Returns the explanation as a pyplot figure. 155 | 156 | Will throw an error if you don't have matplotlib installed 157 | Args: 158 | label: desired label. If you ask for a label for which an 159 | explanation wasn't computed, will throw an exception. 160 | Will be ignored for regression explanations. 161 | kwargs: keyword arguments, passed to domain_mapper 162 | 163 | Returns: 164 | pyplot figure (barchart). 165 | """ 166 | import matplotlib.pyplot as plt 167 | exp = self.as_list(label=label, **kwargs) 168 | fig = plt.figure() 169 | vals = [x[1] for x in exp] 170 | names = [x[0] for x in exp] 171 | vals.reverse() 172 | names.reverse() 173 | colors = ['green' if x > 0 else 'red' for x in vals] 174 | pos = np.arange(len(exp)) + .5 175 | plt.barh(pos, vals, align='center', color=colors) 176 | plt.yticks(pos, names) 177 | if self.mode == "classification": 178 | title = 'Local explanation for class %s' % self.class_names[label] 179 | else: 180 | title = 'Local explanation' 181 | plt.title(title) 182 | return fig 183 | 184 | def show_in_notebook(self, 185 | labels=None, 186 | predict_proba=True, 187 | show_predicted_value=True, 188 | **kwargs): 189 | """Shows html explanation in ipython notebook. 190 | 191 | See as_html() for parameters. 192 | This will throw an error if you don't have IPython installed""" 193 | 194 | from IPython.core.display import display, HTML 195 | display(HTML(self.as_html(labels=labels, 196 | predict_proba=predict_proba, 197 | show_predicted_value=show_predicted_value, 198 | **kwargs))) 199 | 200 | def save_to_file(self, 201 | file_path, 202 | labels=None, 203 | predict_proba=True, 204 | show_predicted_value=True, 205 | **kwargs): 206 | """Saves html explanation to file. . 207 | 208 | Params: 209 | file_path: file to save explanations to 210 | 211 | See as_html() for additional parameters. 212 | 213 | """ 214 | file_ = open(file_path, 'w', encoding='utf8') 215 | file_.write(self.as_html(labels=labels, 216 | predict_proba=predict_proba, 217 | show_predicted_value=show_predicted_value, 218 | **kwargs)) 219 | file_.close() 220 | 221 | def as_html(self, 222 | labels=None, 223 | predict_proba=True, 224 | show_predicted_value=True, 225 | **kwargs): 226 | """Returns the explanation as an html page. 227 | 228 | Args: 229 | labels: desired labels to show explanations for (as barcharts). 230 | If you ask for a label for which an explanation wasn't 231 | computed, will throw an exception. If None, will show 232 | explanations for all available labels. (only used for classification) 233 | predict_proba: if true, add barchart with prediction probabilities 234 | for the top classes. (only used for classification) 235 | show_predicted_value: if true, add barchart with expected value 236 | (only used for regression) 237 | kwargs: keyword arguments, passed to domain_mapper 238 | 239 | Returns: 240 | code for an html page, including javascript includes. 241 | """ 242 | 243 | def jsonize(x): 244 | return json.dumps(x, ensure_ascii=False) 245 | 246 | if labels is None and self.mode == "classification": 247 | labels = self.available_labels() 248 | 249 | this_dir, _ = os.path.split(__file__) 250 | bundle = open(os.path.join(this_dir, 'bundle.js'), 251 | encoding="utf8").read() 252 | 253 | out = u''' 254 | 255 | ''' % bundle 256 | random_id = id_generator(size=15, random_state=check_random_state(self.random_state)) 257 | out += u''' 258 |
259 | ''' % random_id 260 | 261 | predict_proba_js = '' 262 | if self.mode == "classification" and predict_proba: 263 | predict_proba_js = u''' 264 | var pp_div = top_div.append('div') 265 | .classed('lime predict_proba', true); 266 | var pp_svg = pp_div.append('svg').style('width', '100%%'); 267 | var pp = new lime.PredictProba(pp_svg, %s, %s); 268 | ''' % (jsonize([str(x) for x in self.class_names]), 269 | jsonize(list(self.predict_proba.astype(float)))) 270 | 271 | predict_value_js = '' 272 | if self.mode == "regression" and show_predicted_value: 273 | # reference self.predicted_value 274 | # (svg, predicted_value, min_value, max_value) 275 | predict_value_js = u''' 276 | var pp_div = top_div.append('div') 277 | .classed('lime predicted_value', true); 278 | var pp_svg = pp_div.append('svg').style('width', '100%%'); 279 | var pp = new lime.PredictedValue(pp_svg, %s, %s, %s); 280 | ''' % (jsonize(float(self.predicted_value)), 281 | jsonize(float(self.min_value)), 282 | jsonize(float(self.max_value))) 283 | 284 | exp_js = '''var exp_div; 285 | var exp = new lime.Explanation(%s); 286 | ''' % (jsonize([str(x) for x in self.class_names])) 287 | 288 | if self.mode == "classification": 289 | for label in labels: 290 | exp = jsonize(self.as_list(label)) 291 | exp_js += u''' 292 | exp_div = top_div.append('div').classed('lime explanation', true); 293 | exp.show(%s, %d, exp_div); 294 | ''' % (exp, label) 295 | else: 296 | exp = jsonize(self.as_list()) 297 | exp_js += u''' 298 | exp_div = top_div.append('div').classed('lime explanation', true); 299 | exp.show(%s, %s, exp_div); 300 | ''' % (exp, self.dummy_label) 301 | 302 | raw_js = '''var raw_div = top_div.append('div');''' 303 | 304 | if self.mode == "classification": 305 | html_data = self.local_exp[labels[0]] 306 | else: 307 | html_data = self.local_exp[self.dummy_label] 308 | 309 | raw_js += self.domain_mapper.visualize_instance_html( 310 | html_data, 311 | labels[0] if self.mode == "classification" else self.dummy_label, 312 | 'raw_div', 313 | 'exp', 314 | **kwargs) 315 | out += u''' 316 | 323 | ''' % (random_id, predict_proba_js, predict_value_js, exp_js, raw_js) 324 | out += u'' 325 | 326 | return out 327 | -------------------------------------------------------------------------------- /slime/lime_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains abstract functionality for learning locally linear sparse model. 3 | """ 4 | import numpy as np 5 | import scipy as sp 6 | import sklearn 7 | from sklearn.linear_model import Ridge 8 | from slime_lm._least_angle import lars_path 9 | from sklearn.utils import check_random_state 10 | 11 | 12 | class LimeBase(object): 13 | """Class for learning a locally linear sparse model from perturbed data""" 14 | def __init__(self, 15 | kernel_fn, 16 | verbose=False, 17 | random_state=None): 18 | """Init function 19 | 20 | Args: 21 | kernel_fn: function that transforms an array of distances into an 22 | array of proximity values (floats). 23 | verbose: if true, print local prediction values from linear model. 24 | random_state: an integer or numpy.RandomState that will be used to 25 | generate random numbers. If None, the random state will be 26 | initialized using the internal numpy seed. 27 | """ 28 | self.kernel_fn = kernel_fn 29 | self.verbose = verbose 30 | self.random_state = check_random_state(random_state) 31 | 32 | @staticmethod 33 | def generate_lars_path(weighted_data, weighted_labels, testing=False, alpha=0.05): 34 | """Generates the lars path for weighted data. 35 | 36 | Args: 37 | weighted_data: data that has been weighted by kernel 38 | weighted_label: labels, weighted by kernel 39 | 40 | Returns: 41 | (alphas, coefs), both are arrays corresponding to the 42 | regularization parameter and coefficients, respectively 43 | """ 44 | x_vector = weighted_data 45 | if not testing: 46 | alphas, _, coefs = lars_path(x_vector, 47 | weighted_labels, 48 | method='lasso', 49 | verbose=False, 50 | alpha=alpha) 51 | return alphas, coefs 52 | else: 53 | alphas, _, coefs, test_result = lars_path(x_vector, 54 | weighted_labels, 55 | method='lasso', 56 | verbose=False, 57 | testing=testing) 58 | return alphas, coefs, test_result 59 | 60 | def forward_selection(self, data, labels, weights, num_features): 61 | """Iteratively adds features to the model""" 62 | clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state) 63 | used_features = [] 64 | for _ in range(min(num_features, data.shape[1])): 65 | max_ = -100000000 66 | best = 0 67 | for feature in range(data.shape[1]): 68 | if feature in used_features: 69 | continue 70 | clf.fit(data[:, used_features + [feature]], labels, 71 | sample_weight=weights) 72 | score = clf.score(data[:, used_features + [feature]], 73 | labels, 74 | sample_weight=weights) 75 | if score > max_: 76 | best = feature 77 | max_ = score 78 | used_features.append(best) 79 | return np.array(used_features) 80 | 81 | def feature_selection(self, data, labels, weights, num_features, method, testing=False, alpha=0.05): 82 | """Selects features for the model. see explain_instance_with_data to 83 | understand the parameters.""" 84 | if method == 'none': 85 | return np.array(range(data.shape[1])) 86 | elif method == 'forward_selection': 87 | return self.forward_selection(data, labels, weights, num_features) 88 | elif method == 'highest_weights': 89 | clf = Ridge(alpha=0.01, fit_intercept=True, 90 | random_state=self.random_state) 91 | clf.fit(data, labels, sample_weight=weights) 92 | 93 | coef = clf.coef_ 94 | if sp.sparse.issparse(data): 95 | coef = sp.sparse.csr_matrix(clf.coef_) 96 | weighted_data = coef.multiply(data[0]) 97 | # Note: most efficient to slice the data before reversing 98 | sdata = len(weighted_data.data) 99 | argsort_data = np.abs(weighted_data.data).argsort() 100 | # Edge case where data is more sparse than requested number of feature importances 101 | # In that case, we just pad with zero-valued features 102 | if sdata < num_features: 103 | nnz_indexes = argsort_data[::-1] 104 | indices = weighted_data.indices[nnz_indexes] 105 | num_to_pad = num_features - sdata 106 | indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype))) 107 | indices_set = set(indices) 108 | pad_counter = 0 109 | for i in range(data.shape[1]): 110 | if i not in indices_set: 111 | indices[pad_counter + sdata] = i 112 | pad_counter += 1 113 | if pad_counter >= num_to_pad: 114 | break 115 | else: 116 | nnz_indexes = argsort_data[sdata - num_features:sdata][::-1] 117 | indices = weighted_data.indices[nnz_indexes] 118 | return indices 119 | else: 120 | weighted_data = coef * data[0] 121 | feature_weights = sorted( 122 | zip(range(data.shape[1]), weighted_data), 123 | key=lambda x: np.abs(x[1]), 124 | reverse=True) 125 | return np.array([x[0] for x in feature_weights[:num_features]]) 126 | elif method == 'lasso_path': 127 | if not testing: 128 | weighted_data = ((data - np.average(data, axis=0, weights=weights)) 129 | * np.sqrt(weights[:, np.newaxis])) 130 | weighted_labels = ((labels - np.average(labels, weights=weights)) 131 | * np.sqrt(weights)) 132 | 133 | nonzero = range(weighted_data.shape[1]) 134 | _, coefs = self.generate_lars_path(weighted_data, 135 | weighted_labels) 136 | for i in range(len(coefs.T) - 1, 0, -1): 137 | nonzero = coefs.T[i].nonzero()[0] 138 | if len(nonzero) <= num_features: 139 | break 140 | used_features = nonzero 141 | return used_features 142 | else: 143 | weighted_data = ((data - np.average(data, axis=0, weights=weights)) 144 | * np.sqrt(weights[:, np.newaxis])) 145 | weighted_labels = ((labels - np.average(labels, weights=weights)) 146 | * np.sqrt(weights)) 147 | 148 | # Xscaler = sklearn.preprocessing.StandardScaler() 149 | # Xscaler.fit(weighted_data) 150 | # weighted_data = Xscaler.transform(weighted_data) 151 | 152 | # Yscaler = sklearn.preprocessing.StandardScaler() 153 | # Yscaler.fit(weighted_labels.reshape(-1, 1)) 154 | # weighted_labels = Yscaler.transform(weighted_labels.reshape(-1, 1)).ravel() 155 | 156 | nonzero = range(weighted_data.shape[1]) 157 | alphas, coefs, test_result = self.generate_lars_path(weighted_data, 158 | weighted_labels, 159 | testing=True, 160 | alpha=alpha) 161 | for i in range(len(coefs.T) - 1, 0, -1): 162 | nonzero = coefs.T[i].nonzero()[0] 163 | if len(nonzero) <= num_features: 164 | break 165 | used_features = nonzero 166 | return used_features, test_result 167 | elif method == 'auto': 168 | if num_features <= 6: 169 | n_method = 'forward_selection' 170 | else: 171 | n_method = 'highest_weights' 172 | return self.feature_selection(data, labels, weights, 173 | num_features, n_method) 174 | 175 | def explain_instance_with_data(self, 176 | neighborhood_data, 177 | neighborhood_labels, 178 | distances, 179 | label, 180 | num_features, 181 | feature_selection='auto', 182 | model_regressor=None): 183 | """Takes perturbed data, labels and distances, returns explanation. 184 | 185 | Args: 186 | neighborhood_data: perturbed data, 2d array. first element is 187 | assumed to be the original data point. 188 | neighborhood_labels: corresponding perturbed labels. should have as 189 | many columns as the number of possible labels. 190 | distances: distances to original data point. 191 | label: label for which we want an explanation 192 | num_features: maximum number of features in explanation 193 | feature_selection: how to select num_features. options are: 194 | 'forward_selection': iteratively add features to the model. 195 | This is costly when num_features is high 196 | 'highest_weights': selects the features that have the highest 197 | product of absolute weight * original data point when 198 | learning with all the features 199 | 'lasso_path': chooses features based on the lasso 200 | regularization path 201 | 'none': uses all features, ignores num_features 202 | 'auto': uses forward_selection if num_features <= 6, and 203 | 'highest_weights' otherwise. 204 | model_regressor: sklearn regressor to use in explanation. 205 | Defaults to Ridge regression if None. Must have 206 | model_regressor.coef_ and 'sample_weight' as a parameter 207 | to model_regressor.fit() 208 | 209 | Returns: 210 | (intercept, exp, score, local_pred): 211 | intercept is a float. 212 | exp is a sorted list of tuples, where each tuple (x,y) corresponds 213 | to the feature id (x) and the local weight (y). The list is sorted 214 | by decreasing absolute value of y. 215 | score is the R^2 value of the returned explanation 216 | local_pred is the prediction of the explanation model on the original instance 217 | """ 218 | 219 | weights = self.kernel_fn(distances) 220 | labels_column = neighborhood_labels[:, label] 221 | used_features = self.feature_selection(neighborhood_data, 222 | labels_column, 223 | weights, 224 | num_features, 225 | feature_selection) 226 | if model_regressor is None: 227 | model_regressor = Ridge(alpha=1, fit_intercept=True, 228 | random_state=self.random_state) 229 | easy_model = model_regressor 230 | easy_model.fit(neighborhood_data[:, used_features], 231 | labels_column, sample_weight=weights) 232 | prediction_score = easy_model.score( 233 | neighborhood_data[:, used_features], 234 | labels_column, sample_weight=weights) 235 | 236 | local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1)) 237 | 238 | if self.verbose: 239 | print('Intercept', easy_model.intercept_) 240 | print('Prediction_local', local_pred,) 241 | print('Right:', neighborhood_labels[0, label]) 242 | return (easy_model.intercept_, 243 | sorted(zip(used_features, easy_model.coef_), 244 | key=lambda x: np.abs(x[1]), reverse=True), 245 | prediction_score, local_pred) 246 | 247 | def testing_explain_instance_with_data(self, 248 | neighborhood_data, 249 | neighborhood_labels, 250 | distances, 251 | label, 252 | num_features, 253 | feature_selection='lasso_path', 254 | model_regressor=None, 255 | alpha=0.05): 256 | """Takes perturbed data, labels and distances, returns explanation. 257 | This is a helper function for slime. 258 | 259 | Args: 260 | neighborhood_data: perturbed data, 2d array. first element is 261 | assumed to be the original data point. 262 | neighborhood_labels: corresponding perturbed labels. should have as 263 | many columns as the number of possible labels. 264 | distances: distances to original data point. 265 | label: label for which we want an explanation 266 | num_features: maximum number of features in explanation 267 | feature_selection: how to select num_features. options are: 268 | 'forward_selection': iteratively add features to the model. 269 | This is costly when num_features is high 270 | 'highest_weights': selects the features that have the highest 271 | product of absolute weight * original data point when 272 | learning with all the features 273 | 'lasso_path': chooses features based on the lasso 274 | regularization path 275 | 'none': uses all features, ignores num_features 276 | 'auto': uses forward_selection if num_features <= 6, and 277 | 'highest_weights' otherwise. 278 | model_regressor: sklearn regressor to use in explanation. 279 | Defaults to Ridge regression if None. Must have 280 | model_regressor.coef_ and 'sample_weight' as a parameter 281 | to model_regressor.fit() 282 | alpha: significance level of hypothesis testing. 283 | 284 | Returns: 285 | (intercept, exp, score, local_pred): 286 | intercept is a float. 287 | exp is a sorted list of tuples, where each tuple (x,y) corresponds 288 | to the feature id (x) and the local weight (y). The list is sorted 289 | by decreasing absolute value of y. 290 | score is the R^2 value of the returned explanation 291 | local_pred is the prediction of the explanation model on the original instance 292 | """ 293 | weights = self.kernel_fn(distances) 294 | labels_column = neighborhood_labels[:, label] 295 | used_features, test_result = self.feature_selection(neighborhood_data, 296 | labels_column, 297 | weights, 298 | num_features, 299 | feature_selection, 300 | testing=True, 301 | alpha=alpha) 302 | if model_regressor is None: 303 | model_regressor = Ridge(alpha=1, fit_intercept=True, 304 | random_state=self.random_state) 305 | easy_model = model_regressor 306 | easy_model.fit(neighborhood_data[:, used_features], 307 | labels_column, sample_weight=weights) 308 | prediction_score = easy_model.score( 309 | neighborhood_data[:, used_features], 310 | labels_column, sample_weight=weights) 311 | 312 | local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1)) 313 | 314 | if self.verbose: 315 | print('Intercept', easy_model.intercept_) 316 | print('Prediction_local', local_pred,) 317 | print('Right:', neighborhood_labels[0, label]) 318 | return (easy_model.intercept_, 319 | sorted(zip(used_features, easy_model.coef_), 320 | key=lambda x: np.abs(x[1]), reverse=True), 321 | prediction_score, local_pred, used_features, test_result) 322 | 323 | -------------------------------------------------------------------------------- /slime/lime_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for explaining text classifiers. 3 | """ 4 | from functools import partial 5 | import itertools 6 | import json 7 | import re 8 | 9 | import numpy as np 10 | import scipy as sp 11 | import sklearn 12 | from sklearn.utils import check_random_state 13 | 14 | from . import explanation 15 | from . import lime_base 16 | 17 | 18 | class TextDomainMapper(explanation.DomainMapper): 19 | """Maps feature ids to words or word-positions""" 20 | 21 | def __init__(self, indexed_string): 22 | """Initializer. 23 | 24 | Args: 25 | indexed_string: lime_text.IndexedString, original string 26 | """ 27 | self.indexed_string = indexed_string 28 | 29 | def map_exp_ids(self, exp, positions=False): 30 | """Maps ids to words or word-position strings. 31 | 32 | Args: 33 | exp: list of tuples [(id, weight), (id,weight)] 34 | positions: if True, also return word positions 35 | 36 | Returns: 37 | list of tuples (word, weight), or (word_positions, weight) if 38 | examples: ('bad', 1) or ('bad_3-6-12', 1) 39 | """ 40 | if positions: 41 | exp = [('%s_%s' % ( 42 | self.indexed_string.word(x[0]), 43 | '-'.join( 44 | map(str, 45 | self.indexed_string.string_position(x[0])))), x[1]) 46 | for x in exp] 47 | else: 48 | exp = [(self.indexed_string.word(x[0]), x[1]) for x in exp] 49 | return exp 50 | 51 | def visualize_instance_html(self, exp, label, div_name, exp_object_name, 52 | text=True, opacity=True): 53 | """Adds text with highlighted words to visualization. 54 | 55 | Args: 56 | exp: list of tuples [(id, weight), (id,weight)] 57 | label: label id (integer) 58 | div_name: name of div object to be used for rendering(in js) 59 | exp_object_name: name of js explanation object 60 | text: if False, return empty 61 | opacity: if True, fade colors according to weight 62 | """ 63 | if not text: 64 | return u'' 65 | text = (self.indexed_string.raw_string() 66 | .encode('utf-8', 'xmlcharrefreplace').decode('utf-8')) 67 | text = re.sub(r'[<>&]', '|', text) 68 | exp = [(self.indexed_string.word(x[0]), 69 | self.indexed_string.string_position(x[0]), 70 | x[1]) for x in exp] 71 | all_occurrences = list(itertools.chain.from_iterable( 72 | [itertools.product([x[0]], x[1], [x[2]]) for x in exp])) 73 | all_occurrences = [(x[0], int(x[1]), x[2]) for x in all_occurrences] 74 | ret = ''' 75 | %s.show_raw_text(%s, %d, %s, %s, %s); 76 | ''' % (exp_object_name, json.dumps(all_occurrences), label, 77 | json.dumps(text), div_name, json.dumps(opacity)) 78 | return ret 79 | 80 | 81 | class IndexedString(object): 82 | """String with various indexes.""" 83 | 84 | def __init__(self, raw_string, split_expression=r'\W+', bow=True, 85 | mask_string=None): 86 | """Initializer. 87 | 88 | Args: 89 | raw_string: string with raw text in it 90 | split_expression: Regex string or callable. If regex string, will be used with re.split. 91 | If callable, the function should return a list of tokens. 92 | bow: if True, a word is the same everywhere in the text - i.e. we 93 | will index multiple occurrences of the same word. If False, 94 | order matters, so that the same word will have different ids 95 | according to position. 96 | mask_string: If not None, replace words with this if bow=False 97 | if None, default value is UNKWORDZ 98 | """ 99 | self.raw = raw_string 100 | self.mask_string = 'UNKWORDZ' if mask_string is None else mask_string 101 | 102 | if callable(split_expression): 103 | tokens = split_expression(self.raw) 104 | self.as_list = self._segment_with_tokens(self.raw, tokens) 105 | tokens = set(tokens) 106 | 107 | def non_word(string): 108 | return string not in tokens 109 | 110 | else: 111 | # with the split_expression as a non-capturing group (?:), we don't need to filter out 112 | # the separator character from the split results. 113 | splitter = re.compile(r'(%s)|$' % split_expression) 114 | self.as_list = [s for s in splitter.split(self.raw) if s] 115 | non_word = splitter.match 116 | 117 | self.as_np = np.array(self.as_list) 118 | self.string_start = np.hstack( 119 | ([0], np.cumsum([len(x) for x in self.as_np[:-1]]))) 120 | vocab = {} 121 | self.inverse_vocab = [] 122 | self.positions = [] 123 | self.bow = bow 124 | non_vocab = set() 125 | for i, word in enumerate(self.as_np): 126 | if word in non_vocab: 127 | continue 128 | if non_word(word): 129 | non_vocab.add(word) 130 | continue 131 | if bow: 132 | if word not in vocab: 133 | vocab[word] = len(vocab) 134 | self.inverse_vocab.append(word) 135 | self.positions.append([]) 136 | idx_word = vocab[word] 137 | self.positions[idx_word].append(i) 138 | else: 139 | self.inverse_vocab.append(word) 140 | self.positions.append(i) 141 | if not bow: 142 | self.positions = np.array(self.positions) 143 | 144 | def raw_string(self): 145 | """Returns the original raw string""" 146 | return self.raw 147 | 148 | def num_words(self): 149 | """Returns the number of tokens in the vocabulary for this document.""" 150 | return len(self.inverse_vocab) 151 | 152 | def word(self, id_): 153 | """Returns the word that corresponds to id_ (int)""" 154 | return self.inverse_vocab[id_] 155 | 156 | def string_position(self, id_): 157 | """Returns a np array with indices to id_ (int) occurrences""" 158 | if self.bow: 159 | return self.string_start[self.positions[id_]] 160 | else: 161 | return self.string_start[[self.positions[id_]]] 162 | 163 | def inverse_removing(self, words_to_remove): 164 | """Returns a string after removing the appropriate words. 165 | 166 | If self.bow is false, replaces word with UNKWORDZ instead of removing 167 | it. 168 | 169 | Args: 170 | words_to_remove: list of ids (ints) to remove 171 | 172 | Returns: 173 | original raw string with appropriate words removed. 174 | """ 175 | mask = np.ones(self.as_np.shape[0], dtype='bool') 176 | mask[self.__get_idxs(words_to_remove)] = False 177 | if not self.bow: 178 | return ''.join( 179 | [self.as_list[i] if mask[i] else self.mask_string 180 | for i in range(mask.shape[0])]) 181 | return ''.join([self.as_list[v] for v in mask.nonzero()[0]]) 182 | 183 | @staticmethod 184 | def _segment_with_tokens(text, tokens): 185 | """Segment a string around the tokens created by a passed-in tokenizer""" 186 | list_form = [] 187 | text_ptr = 0 188 | for token in tokens: 189 | inter_token_string = [] 190 | while not text[text_ptr:].startswith(token): 191 | inter_token_string.append(text[text_ptr]) 192 | text_ptr += 1 193 | if text_ptr >= len(text): 194 | raise ValueError("Tokenization produced tokens that do not belong in string!") 195 | text_ptr += len(token) 196 | if inter_token_string: 197 | list_form.append(''.join(inter_token_string)) 198 | list_form.append(token) 199 | if text_ptr < len(text): 200 | list_form.append(text[text_ptr:]) 201 | return list_form 202 | 203 | def __get_idxs(self, words): 204 | """Returns indexes to appropriate words.""" 205 | if self.bow: 206 | return list(itertools.chain.from_iterable( 207 | [self.positions[z] for z in words])) 208 | else: 209 | return self.positions[words] 210 | 211 | 212 | class IndexedCharacters(object): 213 | """String with various indexes.""" 214 | 215 | def __init__(self, raw_string, bow=True, mask_string=None): 216 | """Initializer. 217 | 218 | Args: 219 | raw_string: string with raw text in it 220 | bow: if True, a char is the same everywhere in the text - i.e. we 221 | will index multiple occurrences of the same character. If False, 222 | order matters, so that the same word will have different ids 223 | according to position. 224 | mask_string: If not None, replace characters with this if bow=False 225 | if None, default value is chr(0) 226 | """ 227 | self.raw = raw_string 228 | self.as_list = list(self.raw) 229 | self.as_np = np.array(self.as_list) 230 | self.mask_string = chr(0) if mask_string is None else mask_string 231 | self.string_start = np.arange(len(self.raw)) 232 | vocab = {} 233 | self.inverse_vocab = [] 234 | self.positions = [] 235 | self.bow = bow 236 | non_vocab = set() 237 | for i, char in enumerate(self.as_np): 238 | if char in non_vocab: 239 | continue 240 | if bow: 241 | if char not in vocab: 242 | vocab[char] = len(vocab) 243 | self.inverse_vocab.append(char) 244 | self.positions.append([]) 245 | idx_char = vocab[char] 246 | self.positions[idx_char].append(i) 247 | else: 248 | self.inverse_vocab.append(char) 249 | self.positions.append(i) 250 | if not bow: 251 | self.positions = np.array(self.positions) 252 | 253 | def raw_string(self): 254 | """Returns the original raw string""" 255 | return self.raw 256 | 257 | def num_words(self): 258 | """Returns the number of tokens in the vocabulary for this document.""" 259 | return len(self.inverse_vocab) 260 | 261 | def word(self, id_): 262 | """Returns the word that corresponds to id_ (int)""" 263 | return self.inverse_vocab[id_] 264 | 265 | def string_position(self, id_): 266 | """Returns a np array with indices to id_ (int) occurrences""" 267 | if self.bow: 268 | return self.string_start[self.positions[id_]] 269 | else: 270 | return self.string_start[[self.positions[id_]]] 271 | 272 | def inverse_removing(self, words_to_remove): 273 | """Returns a string after removing the appropriate words. 274 | 275 | If self.bow is false, replaces word with UNKWORDZ instead of removing 276 | it. 277 | 278 | Args: 279 | words_to_remove: list of ids (ints) to remove 280 | 281 | Returns: 282 | original raw string with appropriate words removed. 283 | """ 284 | mask = np.ones(self.as_np.shape[0], dtype='bool') 285 | mask[self.__get_idxs(words_to_remove)] = False 286 | if not self.bow: 287 | return ''.join( 288 | [self.as_list[i] if mask[i] else self.mask_string 289 | for i in range(mask.shape[0])]) 290 | return ''.join([self.as_list[v] for v in mask.nonzero()[0]]) 291 | 292 | def __get_idxs(self, words): 293 | """Returns indexes to appropriate words.""" 294 | if self.bow: 295 | return list(itertools.chain.from_iterable( 296 | [self.positions[z] for z in words])) 297 | else: 298 | return self.positions[words] 299 | 300 | 301 | class LimeTextExplainer(object): 302 | """Explains text classifiers. 303 | Currently, we are using an exponential kernel on cosine distance, and 304 | restricting explanations to words that are present in documents.""" 305 | 306 | def __init__(self, 307 | kernel_width=25, 308 | kernel=None, 309 | verbose=False, 310 | class_names=None, 311 | feature_selection='auto', 312 | split_expression=r'\W+', 313 | bow=True, 314 | mask_string=None, 315 | random_state=None, 316 | char_level=False): 317 | """Init function. 318 | 319 | Args: 320 | kernel_width: kernel width for the exponential kernel. 321 | kernel: similarity kernel that takes euclidean distances and kernel 322 | width as input and outputs weights in (0,1). If None, defaults to 323 | an exponential kernel. 324 | verbose: if true, print local prediction values from linear model 325 | class_names: list of class names, ordered according to whatever the 326 | classifier is using. If not present, class names will be '0', 327 | '1', ... 328 | feature_selection: feature selection method. can be 329 | 'forward_selection', 'lasso_path', 'none' or 'auto'. 330 | See function 'explain_instance_with_data' in lime_base.py for 331 | details on what each of the options does. 332 | split_expression: Regex string or callable. If regex string, will be used with re.split. 333 | If callable, the function should return a list of tokens. 334 | bow: if True (bag of words), will perturb input data by removing 335 | all occurrences of individual words or characters. 336 | Explanations will be in terms of these words. Otherwise, will 337 | explain in terms of word-positions, so that a word may be 338 | important the first time it appears and unimportant the second. 339 | Only set to false if the classifier uses word order in some way 340 | (bigrams, etc), or if you set char_level=True. 341 | mask_string: String used to mask tokens or characters if bow=False 342 | if None, will be 'UNKWORDZ' if char_level=False, chr(0) 343 | otherwise. 344 | random_state: an integer or numpy.RandomState that will be used to 345 | generate random numbers. If None, the random state will be 346 | initialized using the internal numpy seed. 347 | char_level: an boolean identifying that we treat each character 348 | as an independent occurence in the string 349 | """ 350 | 351 | if kernel is None: 352 | def kernel(d, kernel_width): 353 | return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2)) 354 | 355 | kernel_fn = partial(kernel, kernel_width=kernel_width) 356 | 357 | self.random_state = check_random_state(random_state) 358 | self.base = lime_base.LimeBase(kernel_fn, verbose, 359 | random_state=self.random_state) 360 | self.class_names = class_names 361 | self.vocabulary = None 362 | self.feature_selection = feature_selection 363 | self.bow = bow 364 | self.mask_string = mask_string 365 | self.split_expression = split_expression 366 | self.char_level = char_level 367 | 368 | def explain_instance(self, 369 | text_instance, 370 | classifier_fn, 371 | labels=(1,), 372 | top_labels=None, 373 | num_features=10, 374 | num_samples=5000, 375 | distance_metric='cosine', 376 | model_regressor=None): 377 | """Generates explanations for a prediction. 378 | 379 | First, we generate neighborhood data by randomly hiding features from 380 | the instance (see __data_labels_distance_mapping). We then learn 381 | locally weighted linear models on this neighborhood data to explain 382 | each of the classes in an interpretable way (see lime_base.py). 383 | 384 | Args: 385 | text_instance: raw text string to be explained. 386 | classifier_fn: classifier prediction probability function, which 387 | takes a list of d strings and outputs a (d, k) numpy array with 388 | prediction probabilities, where k is the number of classes. 389 | For ScikitClassifiers , this is classifier.predict_proba. 390 | labels: iterable with labels to be explained. 391 | top_labels: if not None, ignore labels and produce explanations for 392 | the K labels with highest prediction probabilities, where K is 393 | this parameter. 394 | num_features: maximum number of features present in explanation 395 | num_samples: size of the neighborhood to learn the linear model 396 | distance_metric: the distance metric to use for sample weighting, 397 | defaults to cosine similarity 398 | model_regressor: sklearn regressor to use in explanation. Defaults 399 | to Ridge regression in LimeBase. Must have model_regressor.coef_ 400 | and 'sample_weight' as a parameter to model_regressor.fit() 401 | Returns: 402 | An Explanation object (see explanation.py) with the corresponding 403 | explanations. 404 | """ 405 | 406 | indexed_string = (IndexedCharacters( 407 | text_instance, bow=self.bow, mask_string=self.mask_string) 408 | if self.char_level else 409 | IndexedString(text_instance, bow=self.bow, 410 | split_expression=self.split_expression, 411 | mask_string=self.mask_string)) 412 | domain_mapper = TextDomainMapper(indexed_string) 413 | data, yss, distances = self.__data_labels_distances( 414 | indexed_string, classifier_fn, num_samples, 415 | distance_metric=distance_metric) 416 | if self.class_names is None: 417 | self.class_names = [str(x) for x in range(yss[0].shape[0])] 418 | ret_exp = explanation.Explanation(domain_mapper=domain_mapper, 419 | class_names=self.class_names, 420 | random_state=self.random_state) 421 | ret_exp.predict_proba = yss[0] 422 | if top_labels: 423 | labels = np.argsort(yss[0])[-top_labels:] 424 | ret_exp.top_labels = list(labels) 425 | ret_exp.top_labels.reverse() 426 | for label in labels: 427 | (ret_exp.intercept[label], 428 | ret_exp.local_exp[label], 429 | ret_exp.score[label], 430 | ret_exp.local_pred[label]) = self.base.explain_instance_with_data( 431 | data, yss, distances, label, num_features, 432 | model_regressor=model_regressor, 433 | feature_selection=self.feature_selection) 434 | return ret_exp 435 | 436 | def __data_labels_distances(self, 437 | indexed_string, 438 | classifier_fn, 439 | num_samples, 440 | distance_metric='cosine'): 441 | """Generates a neighborhood around a prediction. 442 | 443 | Generates neighborhood data by randomly removing words from 444 | the instance, and predicting with the classifier. Uses cosine distance 445 | to compute distances between original and perturbed instances. 446 | Args: 447 | indexed_string: document (IndexedString) to be explained, 448 | classifier_fn: classifier prediction probability function, which 449 | takes a string and outputs prediction probabilities. For 450 | ScikitClassifier, this is classifier.predict_proba. 451 | num_samples: size of the neighborhood to learn the linear model 452 | distance_metric: the distance metric to use for sample weighting, 453 | defaults to cosine similarity. 454 | 455 | 456 | Returns: 457 | A tuple (data, labels, distances), where: 458 | data: dense num_samples * K binary matrix, where K is the 459 | number of tokens in indexed_string. The first row is the 460 | original instance, and thus a row of ones. 461 | labels: num_samples * L matrix, where L is the number of target 462 | labels 463 | distances: cosine distance between the original instance and 464 | each perturbed instance (computed in the binary 'data' 465 | matrix), times 100. 466 | """ 467 | 468 | def distance_fn(x): 469 | return sklearn.metrics.pairwise.pairwise_distances( 470 | x, x[0], metric=distance_metric).ravel() * 100 471 | 472 | doc_size = indexed_string.num_words() 473 | sample = self.random_state.randint(1, doc_size + 1, num_samples - 1) 474 | data = np.ones((num_samples, doc_size)) 475 | data[0] = np.ones(doc_size) 476 | features_range = range(doc_size) 477 | inverse_data = [indexed_string.raw_string()] 478 | for i, size in enumerate(sample, start=1): 479 | inactive = self.random_state.choice(features_range, size, 480 | replace=False) 481 | data[i, inactive] = 0 482 | inverse_data.append(indexed_string.inverse_removing(inactive)) 483 | labels = classifier_fn(inverse_data) 484 | distances = distance_fn(sp.sparse.csr_matrix(data)) 485 | return data, labels, distances 486 | -------------------------------------------------------------------------------- /slime/lime_image.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for explaining classifiers that use Image data. 3 | """ 4 | import copy 5 | from functools import partial 6 | 7 | import numpy as np 8 | import sklearn 9 | from sklearn.utils import check_random_state 10 | from skimage.color import gray2rgb 11 | from tqdm.auto import tqdm 12 | 13 | 14 | from . import lime_base 15 | from .wrappers.scikit_image import SegmentationAlgorithm 16 | 17 | 18 | class ImageExplanation(object): 19 | def __init__(self, image, segments): 20 | """Init function. 21 | 22 | Args: 23 | image: 3d numpy array 24 | segments: 2d numpy array, with the output from skimage.segmentation 25 | """ 26 | self.image = image 27 | self.segments = segments 28 | self.intercept = {} 29 | self.local_exp = {} 30 | self.local_pred = {} 31 | self.score = {} 32 | 33 | def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False, 34 | num_features=5, min_weight=0.): 35 | """Init function. 36 | 37 | Args: 38 | label: label to explain 39 | positive_only: if True, only take superpixels that positively contribute to 40 | the prediction of the label. 41 | negative_only: if True, only take superpixels that negatively contribute to 42 | the prediction of the label. If false, and so is positive_only, then both 43 | negativey and positively contributions will be taken. 44 | Both can't be True at the same time 45 | hide_rest: if True, make the non-explanation part of the return 46 | image gray 47 | num_features: number of superpixels to include in explanation 48 | min_weight: minimum weight of the superpixels to include in explanation 49 | 50 | Returns: 51 | (image, mask), where image is a 3d numpy array and mask is a 2d 52 | numpy array that can be used with 53 | skimage.segmentation.mark_boundaries 54 | """ 55 | if label not in self.local_exp: 56 | raise KeyError('Label not in explanation') 57 | if positive_only & negative_only: 58 | raise ValueError("Positive_only and negative_only cannot be true at the same time.") 59 | segments = self.segments 60 | image = self.image 61 | exp = self.local_exp[label] 62 | mask = np.zeros(segments.shape, segments.dtype) 63 | if hide_rest: 64 | temp = np.zeros(self.image.shape) 65 | else: 66 | temp = self.image.copy() 67 | if positive_only: 68 | fs = [x[0] for x in exp 69 | if x[1] > 0 and x[1] > min_weight][:num_features] 70 | if negative_only: 71 | fs = [x[0] for x in exp 72 | if x[1] < 0 and abs(x[1]) > min_weight][:num_features] 73 | if positive_only or negative_only: 74 | for f in fs: 75 | temp[segments == f] = image[segments == f].copy() 76 | mask[segments == f] = 1 77 | return temp, mask 78 | else: 79 | for f, w in exp[:num_features]: 80 | if np.abs(w) < min_weight: 81 | continue 82 | c = 0 if w < 0 else 1 83 | mask[segments == f] = -1 if w < 0 else 1 84 | temp[segments == f] = image[segments == f].copy() 85 | temp[segments == f, c] = np.max(image) 86 | return temp, mask 87 | 88 | 89 | class LimeImageExplainer(object): 90 | """Explains predictions on Image (i.e. matrix) data. 91 | For numerical features, perturb them by sampling from a Normal(0,1) and 92 | doing the inverse operation of mean-centering and scaling, according to the 93 | means and stds in the training data. For categorical features, perturb by 94 | sampling according to the training distribution, and making a binary 95 | feature that is 1 when the value is the same as the instance being 96 | explained.""" 97 | 98 | def __init__(self, kernel_width=.25, kernel=None, verbose=False, 99 | feature_selection='auto', random_state=None): 100 | """Init function. 101 | 102 | Args: 103 | kernel_width: kernel width for the exponential kernel. 104 | If None, defaults to sqrt(number of columns) * 0.75. 105 | kernel: similarity kernel that takes euclidean distances and kernel 106 | width as input and outputs weights in (0,1). If None, defaults to 107 | an exponential kernel. 108 | verbose: if true, print local prediction values from linear model 109 | feature_selection: feature selection method. can be 110 | 'forward_selection', 'lasso_path', 'none' or 'auto'. 111 | See function 'explain_instance_with_data' in lime_base.py for 112 | details on what each of the options does. 113 | random_state: an integer or numpy.RandomState that will be used to 114 | generate random numbers. If None, the random state will be 115 | initialized using the internal numpy seed. 116 | """ 117 | kernel_width = float(kernel_width) 118 | 119 | if kernel is None: 120 | def kernel(d, kernel_width): 121 | return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2)) 122 | 123 | kernel_fn = partial(kernel, kernel_width=kernel_width) 124 | 125 | self.random_state = check_random_state(random_state) 126 | self.feature_selection = feature_selection 127 | self.base = lime_base.LimeBase(kernel_fn, verbose, random_state=self.random_state) 128 | 129 | def explain_instance(self, image, classifier_fn, labels=(1,), 130 | hide_color=None, 131 | top_labels=5, num_features=100000, num_samples=1000, 132 | batch_size=10, 133 | segmentation_fn=None, 134 | distance_metric='cosine', 135 | model_regressor=None, 136 | random_seed=None, 137 | progress_bar=True): 138 | """Generates explanations for a prediction. 139 | 140 | First, we generate neighborhood data by randomly perturbing features 141 | from the instance (see __data_inverse). We then learn locally weighted 142 | linear models on this neighborhood data to explain each of the classes 143 | in an interpretable way (see lime_base.py). 144 | 145 | Args: 146 | image: 3 dimension RGB image. If this is only two dimensional, 147 | we will assume it's a grayscale image and call gray2rgb. 148 | classifier_fn: classifier prediction probability function, which 149 | takes a numpy array and outputs prediction probabilities. For 150 | ScikitClassifiers , this is classifier.predict_proba. 151 | labels: iterable with labels to be explained. 152 | hide_color: TODO 153 | top_labels: if not None, ignore labels and produce explanations for 154 | the K labels with highest prediction probabilities, where K is 155 | this parameter. 156 | num_features: maximum number of features present in explanation 157 | num_samples: size of the neighborhood to learn the linear model 158 | batch_size: TODO 159 | distance_metric: the distance metric to use for weights. 160 | model_regressor: sklearn regressor to use in explanation. Defaults 161 | to Ridge regression in LimeBase. Must have model_regressor.coef_ 162 | and 'sample_weight' as a parameter to model_regressor.fit() 163 | segmentation_fn: SegmentationAlgorithm, wrapped skimage 164 | segmentation function 165 | random_seed: integer used as random seed for the segmentation 166 | algorithm. If None, a random integer, between 0 and 1000, 167 | will be generated using the internal random number generator. 168 | progress_bar: if True, show tqdm progress bar. 169 | 170 | Returns: 171 | An ImageExplanation object (see lime_image.py) with the corresponding 172 | explanations. 173 | """ 174 | if len(image.shape) == 2: 175 | image = gray2rgb(image) 176 | if random_seed is None: 177 | random_seed = self.random_state.randint(0, high=1000) 178 | 179 | if segmentation_fn is None: 180 | segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, 181 | max_dist=200, ratio=0.2, 182 | random_seed=random_seed) 183 | try: 184 | segments = segmentation_fn(image) 185 | except ValueError as e: 186 | raise e 187 | 188 | fudged_image = image.copy() 189 | if hide_color is None: 190 | for x in np.unique(segments): 191 | fudged_image[segments == x] = ( 192 | np.mean(image[segments == x][:, 0]), 193 | np.mean(image[segments == x][:, 1]), 194 | np.mean(image[segments == x][:, 2])) 195 | else: 196 | fudged_image[:] = hide_color 197 | 198 | top = labels 199 | 200 | data, labels = self.data_labels(image, fudged_image, segments, 201 | classifier_fn, num_samples, 202 | batch_size=batch_size, 203 | progress_bar=progress_bar) 204 | 205 | distances = sklearn.metrics.pairwise_distances( 206 | data, 207 | data[0].reshape(1, -1), 208 | metric=distance_metric 209 | ).ravel() 210 | 211 | ret_exp = ImageExplanation(image, segments) 212 | if top_labels: 213 | top = np.argsort(labels[0])[-top_labels:] 214 | ret_exp.top_labels = list(top) 215 | ret_exp.top_labels.reverse() 216 | for label in top: 217 | (ret_exp.intercept[label], 218 | ret_exp.local_exp[label], 219 | ret_exp.score[label], 220 | ret_exp.local_pred[label]) = self.base.explain_instance_with_data( 221 | data, labels, distances, label, num_features, 222 | model_regressor=model_regressor, 223 | feature_selection=self.feature_selection) 224 | return ret_exp 225 | 226 | def testing_explain_instance(self, image, classifier_fn, labels=(1,), 227 | hide_color=None, 228 | top_labels=5, num_features=100000, num_samples=1000, 229 | batch_size=10, 230 | segmentation_fn=None, 231 | distance_metric='cosine', 232 | model_regressor=None, 233 | alpha=0.05, 234 | random_seed=None, 235 | progress_bar=True): 236 | """Generates explanations for a prediction. 237 | 238 | First, we generate neighborhood data by randomly perturbing features 239 | from the instance (see __data_inverse). We then learn locally weighted 240 | linear models on this neighborhood data to explain each of the classes 241 | in an interpretable way (see lime_base.py). 242 | 243 | Args: 244 | image: 3 dimension RGB image. If this is only two dimensional, 245 | we will assume it's a grayscale image and call gray2rgb. 246 | classifier_fn: classifier prediction probability function, which 247 | takes a numpy array and outputs prediction probabilities. For 248 | ScikitClassifiers , this is classifier.predict_proba. 249 | labels: iterable with labels to be explained. 250 | hide_color: TODO 251 | top_labels: if not None, ignore labels and produce explanations for 252 | the K labels with highest prediction probabilities, where K is 253 | this parameter. 254 | num_features: maximum number of features present in explanation 255 | num_samples: size of the neighborhood to learn the linear model 256 | batch_size: TODO 257 | distance_metric: the distance metric to use for weights. 258 | model_regressor: sklearn regressor to use in explanation. Defaults 259 | to Ridge regression in LimeBase. Must have model_regressor.coef_ 260 | and 'sample_weight' as a parameter to model_regressor.fit() 261 | segmentation_fn: SegmentationAlgorithm, wrapped skimage 262 | segmentation function 263 | random_seed: integer used as random seed for the segmentation 264 | algorithm. If None, a random integer, between 0 and 1000, 265 | will be generated using the internal random number generator. 266 | progress_bar: if True, show tqdm progress bar. 267 | 268 | Returns: 269 | An ImageExplanation object (see lime_image.py) with the corresponding 270 | explanations. 271 | """ 272 | if len(image.shape) == 2: 273 | image = gray2rgb(image) 274 | if random_seed is None: 275 | random_seed = self.random_state.randint(0, high=1000) 276 | 277 | if segmentation_fn is None: 278 | segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4, 279 | max_dist=200, ratio=0.2, 280 | random_seed=random_seed) 281 | try: 282 | segments = segmentation_fn(image) 283 | except ValueError as e: 284 | raise e 285 | 286 | fudged_image = image.copy() 287 | if hide_color is None: 288 | for x in np.unique(segments): 289 | fudged_image[segments == x] = ( 290 | np.mean(image[segments == x][:, 0]), 291 | np.mean(image[segments == x][:, 1]), 292 | np.mean(image[segments == x][:, 2])) 293 | else: 294 | fudged_image[:] = hide_color 295 | 296 | top = labels 297 | 298 | data, labels = self.data_labels(image, fudged_image, segments, 299 | classifier_fn, num_samples, 300 | batch_size=batch_size, 301 | progress_bar=progress_bar) 302 | 303 | distances = sklearn.metrics.pairwise_distances( 304 | data, 305 | data[0].reshape(1, -1), 306 | metric=distance_metric 307 | ).ravel() 308 | 309 | ret_exp = ImageExplanation(image, segments) 310 | if top_labels: 311 | top = np.argsort(labels[0])[-top_labels:] 312 | ret_exp.top_labels = list(top) 313 | ret_exp.top_labels.reverse() 314 | for label in top: 315 | (ret_exp.intercept[label], 316 | ret_exp.local_exp[label], 317 | ret_exp.score[label], 318 | ret_exp.local_pred[label], 319 | used_features, 320 | test_result) = self.base.testing_explain_instance_with_data( 321 | data, labels, distances, label, num_features, 322 | model_regressor=model_regressor, 323 | feature_selection=self.feature_selection, 324 | alpha=alpha) 325 | return ret_exp, test_result 326 | 327 | def data_labels(self, 328 | image, 329 | fudged_image, 330 | segments, 331 | classifier_fn, 332 | num_samples, 333 | batch_size=10, 334 | progress_bar=True): 335 | """Generates images and predictions in the neighborhood of this image. 336 | 337 | Args: 338 | image: 3d numpy array, the image 339 | fudged_image: 3d numpy array, image to replace original image when 340 | superpixel is turned off 341 | segments: segmentation of the image 342 | classifier_fn: function that takes a list of images and returns a 343 | matrix of prediction probabilities 344 | num_samples: size of the neighborhood to learn the linear model 345 | batch_size: classifier_fn will be called on batches of this size. 346 | progress_bar: if True, show tqdm progress bar. 347 | 348 | Returns: 349 | A tuple (data, labels), where: 350 | data: dense num_samples * num_superpixels 351 | labels: prediction probabilities matrix 352 | """ 353 | n_features = np.unique(segments).shape[0] 354 | data = self.random_state.randint(0, 2, num_samples * n_features)\ 355 | .reshape((num_samples, n_features)) 356 | labels = [] 357 | data[0, :] = 1 358 | imgs = [] 359 | rows = tqdm(data) if progress_bar else data 360 | for row in rows: 361 | temp = copy.deepcopy(image) 362 | zeros = np.where(row == 0)[0] 363 | mask = np.zeros(segments.shape).astype(bool) 364 | for z in zeros: 365 | mask[segments == z] = True 366 | temp[mask] = fudged_image[mask] 367 | imgs.append(temp) 368 | if len(imgs) == batch_size: 369 | preds = classifier_fn(np.array(imgs)) 370 | labels.extend(preds) 371 | imgs = [] 372 | if len(imgs) > 0: 373 | preds = classifier_fn(np.array(imgs)) 374 | labels.extend(preds) 375 | return data, np.array(labels) 376 | 377 | def slime(self, 378 | image, classifier_fn, labels=(1,), 379 | hide_color=None, 380 | top_labels=5, num_features=100000, num_samples=1000, 381 | batch_size=10, 382 | segmentation_fn=None, 383 | distance_metric='cosine', 384 | model_regressor=None, 385 | n_max=10000, 386 | alpha=0.05, 387 | tol=1e-3, 388 | random_seed=None, 389 | progress_bar=True 390 | ): 391 | """Generates explanations for a prediction with S-LIME. 392 | 393 | First, we generate neighborhood data by randomly perturbing features 394 | from the instance (see __data_inverse). We then learn locally weighted 395 | linear models on this neighborhood data to explain each of the classes 396 | in an interpretable way (see lime_base.py). 397 | 398 | Args: 399 | data_row: 1d numpy array or scipy.sparse matrix, corresponding to a row 400 | predict_fn: prediction function. For classifiers, this should be a 401 | function that takes a numpy array and outputs prediction 402 | probabilities. For regressors, this takes a numpy array and 403 | returns the predictions. For ScikitClassifiers, this is 404 | `classifier.predict_proba()`. For ScikitRegressors, this 405 | is `regressor.predict()`. The prediction function needs to work 406 | on multiple feature vectors (the vectors randomly perturbed 407 | from the data_row). 408 | labels: iterable with labels to be explained. 409 | top_labels: if not None, ignore labels and produce explanations for 410 | the K labels with highest prediction probabilities, where K is 411 | this parameter. 412 | num_features: maximum number of features present in explanation 413 | num_samples: size of the neighborhood to learn the linear model as a start 414 | distance_metric: the distance metric to use for weights. 415 | model_regressor: sklearn regressor to use in explanation. Defaults 416 | to Ridge regression in LimeBase. Must have model_regressor.coef_ 417 | and 'sample_weight' as a parameter to model_regressor.fit() 418 | sampling_method: Method to sample synthetic data. Defaults to Gaussian 419 | sampling. Can also use Latin Hypercube Sampling. 420 | n_max: maximum number of sythetic samples to generate. 421 | alpha: significance level of hypothesis testing. 422 | tol: tolerence level of hypothesis testing. 423 | 424 | Returns: 425 | An Explanation object (see explanation.py) with the corresponding 426 | explanations. 427 | """ 428 | 429 | while True: 430 | ret_exp, test_result = self.testing_explain_instance(image=image, 431 | classifier_fn=classifier_fn, 432 | labels=labels, 433 | hide_color=hide_color, 434 | top_labels=top_labels, 435 | num_features=num_features, 436 | num_samples=num_samples, 437 | batch_size=batch_size, 438 | segmentation_fn=segmentation_fn, 439 | distance_metric=distance_metric, 440 | model_regressor=model_regressor, 441 | alpha=alpha, 442 | random_seed=random_seed, 443 | progress_bar=progress_bar) 444 | flag = False 445 | for k in range(1, num_features): # changes num_features + 1 to num_features because it fixes bug 446 | if test_result[k][0] < -tol: 447 | flag = True 448 | break 449 | if flag and num_samples != n_max: 450 | num_samples = min(int(test_result[k][1]), 2 * num_samples) 451 | if num_samples > n_max: 452 | num_samples = n_max 453 | else: 454 | break 455 | 456 | return ret_exp 457 | -------------------------------------------------------------------------------- /slime/tests/test_lime_tabular.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import collections 5 | import sklearn # noqa 6 | import sklearn.datasets 7 | import sklearn.ensemble 8 | import sklearn.linear_model # noqa 9 | from numpy.testing import assert_array_equal 10 | from sklearn.datasets import load_iris, make_classification, make_multilabel_classification 11 | from sklearn.ensemble import RandomForestClassifier 12 | from sklearn.linear_model import LinearRegression 13 | from lime.discretize import QuartileDiscretizer, DecileDiscretizer, EntropyDiscretizer 14 | 15 | 16 | try: 17 | from sklearn.model_selection import train_test_split 18 | except ImportError: 19 | # Deprecated in scikit-learn version 0.18, removed in 0.20 20 | from sklearn.cross_validation import train_test_split 21 | 22 | from lime.lime_tabular import LimeTabularExplainer 23 | 24 | 25 | class TestLimeTabular(unittest.TestCase): 26 | 27 | def setUp(self): 28 | iris = load_iris() 29 | 30 | self.feature_names = iris.feature_names 31 | self.target_names = iris.target_names 32 | 33 | (self.train, 34 | self.test, 35 | self.labels_train, 36 | self.labels_test) = train_test_split(iris.data, iris.target, train_size=0.80) 37 | 38 | def test_lime_explainer_good_regressor(self): 39 | np.random.seed(1) 40 | rf = RandomForestClassifier(n_estimators=500) 41 | rf.fit(self.train, self.labels_train) 42 | i = np.random.randint(0, self.test.shape[0]) 43 | 44 | explainer = LimeTabularExplainer(self.train, 45 | mode="classification", 46 | feature_names=self.feature_names, 47 | class_names=self.target_names, 48 | discretize_continuous=True) 49 | 50 | exp = explainer.explain_instance(self.test[i], 51 | rf.predict_proba, 52 | num_features=2, 53 | model_regressor=LinearRegression()) 54 | 55 | self.assertIsNotNone(exp) 56 | keys = [x[0] for x in exp.as_list()] 57 | self.assertEqual(1, 58 | sum([1 if 'petal width' in x else 0 for x in keys]), 59 | "Petal Width is a major feature") 60 | self.assertEqual(1, 61 | sum([1 if 'petal length' in x else 0 for x in keys]), 62 | "Petal Length is a major feature") 63 | 64 | def test_lime_explainer_good_regressor_synthetic_data(self): 65 | X, y = make_classification(n_samples=1000, 66 | n_features=20, 67 | n_informative=2, 68 | n_redundant=2, 69 | random_state=10) 70 | 71 | rf = RandomForestClassifier(n_estimators=500) 72 | rf.fit(X, y) 73 | instance = np.random.randint(0, X.shape[0]) 74 | feature_names = ["feature" + str(i) for i in range(20)] 75 | explainer = LimeTabularExplainer(X, 76 | feature_names=feature_names, 77 | discretize_continuous=True) 78 | 79 | exp = explainer.explain_instance(X[instance], rf.predict_proba) 80 | 81 | self.assertIsNotNone(exp) 82 | self.assertEqual(10, len(exp.as_list())) 83 | 84 | def test_lime_explainer_sparse_synthetic_data(self): 85 | n_features = 20 86 | X, y = make_multilabel_classification(n_samples=100, 87 | sparse=True, 88 | n_features=n_features, 89 | n_classes=1, 90 | n_labels=2) 91 | rf = RandomForestClassifier(n_estimators=500) 92 | rf.fit(X, y) 93 | instance = np.random.randint(0, X.shape[0]) 94 | feature_names = ["feature" + str(i) for i in range(n_features)] 95 | explainer = LimeTabularExplainer(X, 96 | feature_names=feature_names, 97 | discretize_continuous=True) 98 | 99 | exp = explainer.explain_instance(X[instance], rf.predict_proba) 100 | 101 | self.assertIsNotNone(exp) 102 | self.assertEqual(10, len(exp.as_list())) 103 | 104 | def test_lime_explainer_no_regressor(self): 105 | np.random.seed(1) 106 | 107 | rf = RandomForestClassifier(n_estimators=500) 108 | rf.fit(self.train, self.labels_train) 109 | i = np.random.randint(0, self.test.shape[0]) 110 | 111 | explainer = LimeTabularExplainer(self.train, 112 | feature_names=self.feature_names, 113 | class_names=self.target_names, 114 | discretize_continuous=True) 115 | 116 | exp = explainer.explain_instance(self.test[i], 117 | rf.predict_proba, 118 | num_features=2) 119 | self.assertIsNotNone(exp) 120 | keys = [x[0] for x in exp.as_list()] 121 | self.assertEqual(1, 122 | sum([1 if 'petal width' in x else 0 for x in keys]), 123 | "Petal Width is a major feature") 124 | self.assertEqual(1, 125 | sum([1 if 'petal length' in x else 0 for x in keys]), 126 | "Petal Length is a major feature") 127 | 128 | def test_lime_explainer_entropy_discretizer(self): 129 | np.random.seed(1) 130 | 131 | rf = RandomForestClassifier(n_estimators=500) 132 | rf.fit(self.train, self.labels_train) 133 | i = np.random.randint(0, self.test.shape[0]) 134 | 135 | explainer = LimeTabularExplainer(self.train, 136 | feature_names=self.feature_names, 137 | class_names=self.target_names, 138 | training_labels=self.labels_train, 139 | discretize_continuous=True, 140 | discretizer='entropy') 141 | 142 | exp = explainer.explain_instance(self.test[i], 143 | rf.predict_proba, 144 | num_features=2) 145 | self.assertIsNotNone(exp) 146 | keys = [x[0] for x in exp.as_list()] 147 | print(keys) 148 | self.assertEqual(1, 149 | sum([1 if 'petal width' in x else 0 for x in keys]), 150 | "Petal Width is a major feature") 151 | self.assertEqual(1, 152 | sum([1 if 'petal length' in x else 0 for x in keys]), 153 | "Petal Length is a major feature") 154 | 155 | def test_lime_tabular_explainer_equal_random_state(self): 156 | X, y = make_classification(n_samples=1000, 157 | n_features=20, 158 | n_informative=2, 159 | n_redundant=2, 160 | random_state=10) 161 | 162 | rf = RandomForestClassifier(n_estimators=500, random_state=10) 163 | rf.fit(X, y) 164 | instance = np.random.RandomState(10).randint(0, X.shape[0]) 165 | feature_names = ["feature" + str(i) for i in range(20)] 166 | 167 | # ---------------------------------------------------------------------- 168 | # -------------------------Quartile Discretizer------------------------- 169 | # ---------------------------------------------------------------------- 170 | discretizer = QuartileDiscretizer(X, [], feature_names, y, 171 | random_state=10) 172 | explainer_1 = LimeTabularExplainer(X, 173 | feature_names=feature_names, 174 | discretize_continuous=True, 175 | discretizer=discretizer, 176 | random_state=10) 177 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 178 | num_samples=500) 179 | 180 | discretizer = QuartileDiscretizer(X, [], feature_names, y, 181 | random_state=10) 182 | explainer_2 = LimeTabularExplainer(X, 183 | feature_names=feature_names, 184 | discretize_continuous=True, 185 | discretizer=discretizer, 186 | random_state=10) 187 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 188 | num_samples=500) 189 | 190 | self.assertDictEqual(exp_1.as_map(), exp_2.as_map()) 191 | 192 | # ---------------------------------------------------------------------- 193 | # --------------------------Decile Discretizer-------------------------- 194 | # ---------------------------------------------------------------------- 195 | discretizer = DecileDiscretizer(X, [], feature_names, y, 196 | random_state=10) 197 | explainer_1 = LimeTabularExplainer(X, 198 | feature_names=feature_names, 199 | discretize_continuous=True, 200 | discretizer=discretizer, 201 | random_state=10) 202 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 203 | num_samples=500) 204 | 205 | discretizer = DecileDiscretizer(X, [], feature_names, y, 206 | random_state=10) 207 | explainer_2 = LimeTabularExplainer(X, 208 | feature_names=feature_names, 209 | discretize_continuous=True, 210 | discretizer=discretizer, 211 | random_state=10) 212 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 213 | num_samples=500) 214 | 215 | self.assertDictEqual(exp_1.as_map(), exp_2.as_map()) 216 | 217 | # ---------------------------------------------------------------------- 218 | # -------------------------Entropy Discretizer-------------------------- 219 | # ---------------------------------------------------------------------- 220 | discretizer = EntropyDiscretizer(X, [], feature_names, y, 221 | random_state=10) 222 | explainer_1 = LimeTabularExplainer(X, 223 | feature_names=feature_names, 224 | discretize_continuous=True, 225 | discretizer=discretizer, 226 | random_state=10) 227 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 228 | num_samples=500) 229 | 230 | discretizer = EntropyDiscretizer(X, [], feature_names, y, 231 | random_state=10) 232 | explainer_2 = LimeTabularExplainer(X, 233 | feature_names=feature_names, 234 | discretize_continuous=True, 235 | discretizer=discretizer, 236 | random_state=10) 237 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 238 | num_samples=500) 239 | 240 | self.assertDictEqual(exp_1.as_map(), exp_2.as_map()) 241 | 242 | def test_lime_tabular_explainer_not_equal_random_state(self): 243 | X, y = make_classification(n_samples=1000, 244 | n_features=20, 245 | n_informative=2, 246 | n_redundant=2, 247 | random_state=10) 248 | 249 | rf = RandomForestClassifier(n_estimators=500, random_state=10) 250 | rf.fit(X, y) 251 | instance = np.random.RandomState(10).randint(0, X.shape[0]) 252 | feature_names = ["feature" + str(i) for i in range(20)] 253 | 254 | # ---------------------------------------------------------------------- 255 | # -------------------------Quartile Discretizer------------------------- 256 | # ---------------------------------------------------------------------- 257 | 258 | # ---------------------------------[1]---------------------------------- 259 | discretizer = QuartileDiscretizer(X, [], feature_names, y, 260 | random_state=20) 261 | explainer_1 = LimeTabularExplainer(X, 262 | feature_names=feature_names, 263 | discretize_continuous=True, 264 | discretizer=discretizer, 265 | random_state=10) 266 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 267 | num_samples=500) 268 | 269 | discretizer = QuartileDiscretizer(X, [], feature_names, y, 270 | random_state=10) 271 | explainer_2 = LimeTabularExplainer(X, 272 | feature_names=feature_names, 273 | discretize_continuous=True, 274 | discretizer=discretizer, 275 | random_state=10) 276 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 277 | num_samples=500) 278 | 279 | self.assertTrue(exp_1.as_map() != exp_2.as_map()) 280 | 281 | # ---------------------------------[2]---------------------------------- 282 | discretizer = QuartileDiscretizer(X, [], feature_names, y, 283 | random_state=20) 284 | explainer_1 = LimeTabularExplainer(X, 285 | feature_names=feature_names, 286 | discretize_continuous=True, 287 | discretizer=discretizer, 288 | random_state=20) 289 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 290 | num_samples=500) 291 | 292 | discretizer = QuartileDiscretizer(X, [], feature_names, y, 293 | random_state=10) 294 | explainer_2 = LimeTabularExplainer(X, 295 | feature_names=feature_names, 296 | discretize_continuous=True, 297 | discretizer=discretizer, 298 | random_state=10) 299 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 300 | num_samples=500) 301 | 302 | self.assertTrue(exp_1.as_map() != exp_2.as_map()) 303 | 304 | # ---------------------------------[3]---------------------------------- 305 | discretizer = QuartileDiscretizer(X, [], feature_names, y, 306 | random_state=20) 307 | explainer_1 = LimeTabularExplainer(X, 308 | feature_names=feature_names, 309 | discretize_continuous=True, 310 | discretizer=discretizer, 311 | random_state=20) 312 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 313 | num_samples=500) 314 | 315 | discretizer = QuartileDiscretizer(X, [], feature_names, y, 316 | random_state=20) 317 | explainer_2 = LimeTabularExplainer(X, 318 | feature_names=feature_names, 319 | discretize_continuous=True, 320 | discretizer=discretizer, 321 | random_state=10) 322 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 323 | num_samples=500) 324 | 325 | self.assertTrue(exp_1.as_map() != exp_2.as_map()) 326 | 327 | # ---------------------------------[4]---------------------------------- 328 | discretizer = QuartileDiscretizer(X, [], feature_names, y, 329 | random_state=20) 330 | explainer_1 = LimeTabularExplainer(X, 331 | feature_names=feature_names, 332 | discretize_continuous=True, 333 | discretizer=discretizer, 334 | random_state=20) 335 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 336 | num_samples=500) 337 | 338 | discretizer = QuartileDiscretizer(X, [], feature_names, y, 339 | random_state=20) 340 | explainer_2 = LimeTabularExplainer(X, 341 | feature_names=feature_names, 342 | discretize_continuous=True, 343 | discretizer=discretizer, 344 | random_state=20) 345 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 346 | num_samples=500) 347 | 348 | self.assertFalse(exp_1.as_map() != exp_2.as_map()) 349 | 350 | # ---------------------------------------------------------------------- 351 | # --------------------------Decile Discretizer-------------------------- 352 | # ---------------------------------------------------------------------- 353 | 354 | # ---------------------------------[1]---------------------------------- 355 | discretizer = DecileDiscretizer(X, [], feature_names, y, 356 | random_state=20) 357 | explainer_1 = LimeTabularExplainer(X, 358 | feature_names=feature_names, 359 | discretize_continuous=True, 360 | discretizer=discretizer, 361 | random_state=10) 362 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 363 | num_samples=500) 364 | 365 | discretizer = DecileDiscretizer(X, [], feature_names, y, 366 | random_state=10) 367 | explainer_2 = LimeTabularExplainer(X, 368 | feature_names=feature_names, 369 | discretize_continuous=True, 370 | discretizer=discretizer, 371 | random_state=10) 372 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 373 | num_samples=500) 374 | 375 | self.assertTrue(exp_1.as_map() != exp_2.as_map()) 376 | 377 | # ---------------------------------[2]---------------------------------- 378 | discretizer = DecileDiscretizer(X, [], feature_names, y, 379 | random_state=20) 380 | explainer_1 = LimeTabularExplainer(X, 381 | feature_names=feature_names, 382 | discretize_continuous=True, 383 | discretizer=discretizer, 384 | random_state=20) 385 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 386 | num_samples=500) 387 | 388 | discretizer = DecileDiscretizer(X, [], feature_names, y, 389 | random_state=10) 390 | explainer_2 = LimeTabularExplainer(X, 391 | feature_names=feature_names, 392 | discretize_continuous=True, 393 | discretizer=discretizer, 394 | random_state=10) 395 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 396 | num_samples=500) 397 | 398 | self.assertTrue(exp_1.as_map() != exp_2.as_map()) 399 | 400 | # ---------------------------------[3]---------------------------------- 401 | discretizer = DecileDiscretizer(X, [], feature_names, y, 402 | random_state=20) 403 | explainer_1 = LimeTabularExplainer(X, 404 | feature_names=feature_names, 405 | discretize_continuous=True, 406 | discretizer=discretizer, 407 | random_state=20) 408 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 409 | num_samples=500) 410 | 411 | discretizer = DecileDiscretizer(X, [], feature_names, y, 412 | random_state=20) 413 | explainer_2 = LimeTabularExplainer(X, 414 | feature_names=feature_names, 415 | discretize_continuous=True, 416 | discretizer=discretizer, 417 | random_state=10) 418 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 419 | num_samples=500) 420 | 421 | self.assertTrue(exp_1.as_map() != exp_2.as_map()) 422 | 423 | # ---------------------------------[4]---------------------------------- 424 | discretizer = DecileDiscretizer(X, [], feature_names, y, 425 | random_state=20) 426 | explainer_1 = LimeTabularExplainer(X, 427 | feature_names=feature_names, 428 | discretize_continuous=True, 429 | discretizer=discretizer, 430 | random_state=20) 431 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 432 | num_samples=500) 433 | 434 | discretizer = DecileDiscretizer(X, [], feature_names, y, 435 | random_state=20) 436 | explainer_2 = LimeTabularExplainer(X, 437 | feature_names=feature_names, 438 | discretize_continuous=True, 439 | discretizer=discretizer, 440 | random_state=20) 441 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 442 | num_samples=500) 443 | 444 | self.assertFalse(exp_1.as_map() != exp_2.as_map()) 445 | 446 | # ---------------------------------------------------------------------- 447 | # --------------------------Entropy Discretizer------------------------- 448 | # ---------------------------------------------------------------------- 449 | 450 | # ---------------------------------[1]---------------------------------- 451 | discretizer = EntropyDiscretizer(X, [], feature_names, y, 452 | random_state=20) 453 | explainer_1 = LimeTabularExplainer(X, 454 | feature_names=feature_names, 455 | discretize_continuous=True, 456 | discretizer=discretizer, 457 | random_state=10) 458 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 459 | num_samples=500) 460 | 461 | discretizer = EntropyDiscretizer(X, [], feature_names, y, 462 | random_state=10) 463 | explainer_2 = LimeTabularExplainer(X, 464 | feature_names=feature_names, 465 | discretize_continuous=True, 466 | discretizer=discretizer, 467 | random_state=10) 468 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 469 | num_samples=500) 470 | 471 | self.assertTrue(exp_1.as_map() != exp_2.as_map()) 472 | 473 | # ---------------------------------[2]---------------------------------- 474 | discretizer = EntropyDiscretizer(X, [], feature_names, y, 475 | random_state=20) 476 | explainer_1 = LimeTabularExplainer(X, 477 | feature_names=feature_names, 478 | discretize_continuous=True, 479 | discretizer=discretizer, 480 | random_state=20) 481 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 482 | num_samples=500) 483 | 484 | discretizer = EntropyDiscretizer(X, [], feature_names, y, 485 | random_state=10) 486 | explainer_2 = LimeTabularExplainer(X, 487 | feature_names=feature_names, 488 | discretize_continuous=True, 489 | discretizer=discretizer, 490 | random_state=10) 491 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 492 | num_samples=500) 493 | 494 | self.assertTrue(exp_1.as_map() != exp_2.as_map()) 495 | 496 | # ---------------------------------[3]---------------------------------- 497 | discretizer = EntropyDiscretizer(X, [], feature_names, y, 498 | random_state=20) 499 | explainer_1 = LimeTabularExplainer(X, 500 | feature_names=feature_names, 501 | discretize_continuous=True, 502 | discretizer=discretizer, 503 | random_state=20) 504 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 505 | num_samples=500) 506 | 507 | discretizer = EntropyDiscretizer(X, [], feature_names, y, 508 | random_state=20) 509 | explainer_2 = LimeTabularExplainer(X, 510 | feature_names=feature_names, 511 | discretize_continuous=True, 512 | discretizer=discretizer, 513 | random_state=10) 514 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 515 | num_samples=500) 516 | 517 | self.assertTrue(exp_1.as_map() != exp_2.as_map()) 518 | 519 | # ---------------------------------[4]---------------------------------- 520 | discretizer = EntropyDiscretizer(X, [], feature_names, y, 521 | random_state=20) 522 | explainer_1 = LimeTabularExplainer(X, 523 | feature_names=feature_names, 524 | discretize_continuous=True, 525 | discretizer=discretizer, 526 | random_state=20) 527 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba, 528 | num_samples=500) 529 | 530 | discretizer = EntropyDiscretizer(X, [], feature_names, y, 531 | random_state=20) 532 | explainer_2 = LimeTabularExplainer(X, 533 | feature_names=feature_names, 534 | discretize_continuous=True, 535 | discretizer=discretizer, 536 | random_state=20) 537 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba, 538 | num_samples=500) 539 | 540 | self.assertFalse(exp_1.as_map() != exp_2.as_map()) 541 | 542 | def testFeatureNamesAndCategoricalFeats(self): 543 | training_data = np.array([[0., 1.], [1., 0.]]) 544 | 545 | explainer = LimeTabularExplainer(training_data=training_data) 546 | self.assertEqual(explainer.feature_names, ['0', '1']) 547 | self.assertEqual(explainer.categorical_features, [0, 1]) 548 | 549 | explainer = LimeTabularExplainer( 550 | training_data=training_data, 551 | feature_names=np.array(['one', 'two']) 552 | ) 553 | self.assertEqual(explainer.feature_names, ['one', 'two']) 554 | 555 | explainer = LimeTabularExplainer( 556 | training_data=training_data, 557 | categorical_features=np.array([0]), 558 | discretize_continuous=False 559 | ) 560 | self.assertEqual(explainer.categorical_features, [0]) 561 | 562 | def testFeatureValues(self): 563 | training_data = np.array([ 564 | [0, 0, 2], 565 | [1, 1, 0], 566 | [0, 2, 2], 567 | [1, 3, 0] 568 | ]) 569 | 570 | explainer = LimeTabularExplainer( 571 | training_data=training_data, 572 | categorical_features=[0, 1, 2] 573 | ) 574 | 575 | self.assertEqual(set(explainer.feature_values[0]), {0, 1}) 576 | self.assertEqual(set(explainer.feature_values[1]), {0, 1, 2, 3}) 577 | self.assertEqual(set(explainer.feature_values[2]), {0, 2}) 578 | 579 | assert_array_equal(explainer.feature_frequencies[0], np.array([.5, .5])) 580 | assert_array_equal(explainer.feature_frequencies[1], np.array([.25, .25, .25, .25])) 581 | assert_array_equal(explainer.feature_frequencies[2], np.array([.5, .5])) 582 | 583 | def test_lime_explainer_with_data_stats(self): 584 | np.random.seed(1) 585 | 586 | rf = RandomForestClassifier(n_estimators=500) 587 | rf.fit(self.train, self.labels_train) 588 | i = np.random.randint(0, self.test.shape[0]) 589 | 590 | # Generate stats using a quartile descritizer 591 | descritizer = QuartileDiscretizer(self.train, [], self.feature_names, self.target_names, 592 | random_state=20) 593 | 594 | d_means = descritizer.means 595 | d_stds = descritizer.stds 596 | d_mins = descritizer.mins 597 | d_maxs = descritizer.maxs 598 | d_bins = descritizer.bins(self.train, self.target_names) 599 | 600 | # Compute feature values and frequencies of all columns 601 | cat_features = np.arange(self.train.shape[1]) 602 | discretized_training_data = descritizer.discretize(self.train) 603 | 604 | feature_values = {} 605 | feature_frequencies = {} 606 | for feature in cat_features: 607 | column = discretized_training_data[:, feature] 608 | feature_count = collections.Counter(column) 609 | values, frequencies = map(list, zip(*(feature_count.items()))) 610 | feature_values[feature] = values 611 | feature_frequencies[feature] = frequencies 612 | 613 | # Convert bins to list from array 614 | d_bins_revised = {} 615 | index = 0 616 | for bin in d_bins: 617 | d_bins_revised[index] = bin.tolist() 618 | index = index+1 619 | 620 | # Descritized stats 621 | data_stats = {} 622 | data_stats["means"] = d_means 623 | data_stats["stds"] = d_stds 624 | data_stats["maxs"] = d_maxs 625 | data_stats["mins"] = d_mins 626 | data_stats["bins"] = d_bins_revised 627 | data_stats["feature_values"] = feature_values 628 | data_stats["feature_frequencies"] = feature_frequencies 629 | 630 | data = np.zeros((2, len(self.feature_names))) 631 | explainer = LimeTabularExplainer( 632 | data, feature_names=self.feature_names, random_state=10, 633 | training_data_stats=data_stats, training_labels=self.target_names) 634 | 635 | exp = explainer.explain_instance(self.test[i], 636 | rf.predict_proba, 637 | num_features=2, 638 | model_regressor=LinearRegression()) 639 | 640 | self.assertIsNotNone(exp) 641 | keys = [x[0] for x in exp.as_list()] 642 | self.assertEqual(1, 643 | sum([1 if 'petal width' in x else 0 for x in keys]), 644 | "Petal Width is a major feature") 645 | self.assertEqual(1, 646 | sum([1 if 'petal length' in x else 0 for x in keys]), 647 | "Petal Length is a major feature") 648 | 649 | 650 | if __name__ == '__main__': 651 | unittest.main() 652 | -------------------------------------------------------------------------------- /slime_lm/_least_angle.py: -------------------------------------------------------------------------------- 1 | """ 2 | Least Angle Regression algorithm. See the documentation on the 3 | Generalized Linear Model for a complete discussion. 4 | """ 5 | # Author: Fabian Pedregosa 6 | # Alexandre Gramfort 7 | # Gael Varoquaux 8 | # 9 | # License: BSD 3 clause 10 | 11 | from math import log 12 | import sys 13 | import warnings 14 | 15 | import numpy as np 16 | from scipy import linalg, interpolate 17 | from scipy.linalg.lapack import get_lapack_funcs 18 | from scipy import stats 19 | from joblib import Parallel 20 | 21 | # mypy error: Module 'sklearn.utils' has no attribute 'arrayfuncs' 22 | from sklearn.utils import arrayfuncs, as_float_array # type: ignore 23 | from sklearn.exceptions import ConvergenceWarning 24 | 25 | SOLVE_TRIANGULAR_ARGS = {'check_finite': False} 26 | 27 | 28 | def lars_path( 29 | X, 30 | y, 31 | Xy=None, 32 | *, 33 | Gram=None, 34 | max_iter=500, 35 | alpha_min=0, 36 | method="lar", 37 | copy_X=True, 38 | eps=np.finfo(float).eps, 39 | copy_Gram=True, 40 | verbose=0, 41 | return_path=True, 42 | return_n_iter=False, 43 | positive=False, 44 | testing=False, 45 | alpha=0.05, 46 | testing_stop=False, 47 | testing_verbose=False, 48 | ): 49 | """Compute Least Angle Regression or Lasso path using LARS algorithm [1] 50 | The optimization objective for the case method='lasso' is:: 51 | (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1 52 | in the case of method='lars', the objective function is only known in 53 | the form of an implicit equation (see discussion in [1]) 54 | Read more in the :ref:`User Guide `. 55 | Parameters 56 | ---------- 57 | X : None or array-like of shape (n_samples, n_features) 58 | Input data. Note that if X is None then the Gram matrix must be 59 | specified, i.e., cannot be None or False. 60 | y : None or array-like of shape (n_samples,) 61 | Input targets. 62 | Xy : array-like of shape (n_samples,) or (n_samples, n_targets), \ 63 | default=None 64 | Xy = np.dot(X.T, y) that can be precomputed. It is useful 65 | only when the Gram matrix is precomputed. 66 | Gram : None, 'auto', array-like of shape (n_features, n_features), \ 67 | default=None 68 | Precomputed Gram matrix (X' * X), if ``'auto'``, the Gram 69 | matrix is precomputed from the given X, if there are more samples 70 | than features. 71 | max_iter : int, default=500 72 | Maximum number of iterations to perform, set to infinity for no limit. 73 | alpha_min : float, default=0 74 | Minimum correlation along the path. It corresponds to the 75 | regularization parameter alpha parameter in the Lasso. 76 | method : {'lar', 'lasso'}, default='lar' 77 | Specifies the returned model. Select ``'lar'`` for Least Angle 78 | Regression, ``'lasso'`` for the Lasso. 79 | copy_X : bool, default=True 80 | If ``False``, ``X`` is overwritten. 81 | eps : float, default=np.finfo(float).eps 82 | The machine-precision regularization in the computation of the 83 | Cholesky diagonal factors. Increase this for very ill-conditioned 84 | systems. Unlike the ``tol`` parameter in some iterative 85 | optimization-based algorithms, this parameter does not control 86 | the tolerance of the optimization. 87 | copy_Gram : bool, default=True 88 | If ``False``, ``Gram`` is overwritten. 89 | verbose : int, default=0 90 | Controls output verbosity. 91 | return_path : bool, default=True 92 | If ``return_path==True`` returns the entire path, else returns only the 93 | last point of the path. 94 | return_n_iter : bool, default=False 95 | Whether to return the number of iterations. 96 | positive : bool, default=False 97 | Restrict coefficients to be >= 0. 98 | This option is only allowed with method 'lasso'. Note that the model 99 | coefficients will not converge to the ordinary-least-squares solution 100 | for small values of alpha. Only coefficients up to the smallest alpha 101 | value (``alphas_[alphas_ > 0.].min()`` when fit_path=True) reached by 102 | the stepwise Lars-Lasso algorithm are typically in congruence with the 103 | solution of the coordinate descent lasso_path function. 104 | testing : bool, default=False 105 | Whether to conduct hypothesis testing each time a new variable enters 106 | alpha : float, default=0.05 107 | Significance level of hypothesis testing. Valid only if testing is True. 108 | testing_stop : bool, default=False 109 | If set to True, stops calculating future paths when the test yields 110 | insignificant results. 111 | Only takes effect when testing is set to True. 112 | testing_verbose : bool, default=True 113 | Controls output verbosity for hypothese testing procedure. 114 | Returns 115 | ------- 116 | alphas : array-like of shape (n_alphas + 1,) 117 | Maximum of covariances (in absolute value) at each iteration. 118 | ``n_alphas`` is either ``max_iter``, ``n_features`` or the 119 | number of nodes in the path with ``alpha >= alpha_min``, whichever 120 | is smaller. 121 | active : array-like of shape (n_alphas,) 122 | Indices of active variables at the end of the path. 123 | coefs : array-like of shape (n_features, n_alphas + 1) 124 | Coefficients along the path 125 | n_iter : int 126 | Number of iterations run. Returned only if return_n_iter is set 127 | to True. 128 | test_result: disctionary 129 | Contains testing results in the form of [test_stats, new_n] produced 130 | at each step. Returned only if testing is set to True. 131 | See Also 132 | -------- 133 | lars_path_gram 134 | lasso_path 135 | lasso_path_gram 136 | LassoLars 137 | Lars 138 | LassoLarsCV 139 | LarsCV 140 | sklearn.decomposition.sparse_encode 141 | References 142 | ---------- 143 | .. [1] "Least Angle Regression", Efron et al. 144 | http://statweb.stanford.edu/~tibs/ftp/lars.pdf 145 | .. [2] `Wikipedia entry on the Least-angle regression 146 | `_ 147 | .. [3] `Wikipedia entry on the Lasso 148 | `_ 149 | """ 150 | if X is None and Gram is not None: 151 | raise ValueError( 152 | 'X cannot be None if Gram is not None' 153 | 'Use lars_path_gram to avoid passing X and y.' 154 | ) 155 | return _lars_path_solver( 156 | X=X, y=y, Xy=Xy, Gram=Gram, n_samples=None, max_iter=max_iter, 157 | alpha_min=alpha_min, method=method, copy_X=copy_X, 158 | eps=eps, copy_Gram=copy_Gram, verbose=verbose, return_path=return_path, 159 | return_n_iter=return_n_iter, positive=positive, testing=testing, 160 | alpha=alpha, testing_stop=testing_stop, testing_verbose=testing_verbose) 161 | 162 | 163 | def lars_path_gram( 164 | Xy, 165 | Gram, 166 | *, 167 | n_samples, 168 | max_iter=500, 169 | alpha_min=0, 170 | method="lar", 171 | copy_X=True, 172 | eps=np.finfo(float).eps, 173 | copy_Gram=True, 174 | verbose=0, 175 | return_path=True, 176 | return_n_iter=False, 177 | positive=False 178 | ): 179 | """lars_path in the sufficient stats mode [1] 180 | The optimization objective for the case method='lasso' is:: 181 | (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1 182 | in the case of method='lars', the objective function is only known in 183 | the form of an implicit equation (see discussion in [1]) 184 | Read more in the :ref:`User Guide `. 185 | Parameters 186 | ---------- 187 | Xy : array-like of shape (n_samples,) or (n_samples, n_targets) 188 | Xy = np.dot(X.T, y). 189 | Gram : array-like of shape (n_features, n_features) 190 | Gram = np.dot(X.T * X). 191 | n_samples : int or float 192 | Equivalent size of sample. 193 | max_iter : int, default=500 194 | Maximum number of iterations to perform, set to infinity for no limit. 195 | alpha_min : float, default=0 196 | Minimum correlation along the path. It corresponds to the 197 | regularization parameter alpha parameter in the Lasso. 198 | method : {'lar', 'lasso'}, default='lar' 199 | Specifies the returned model. Select ``'lar'`` for Least Angle 200 | Regression, ``'lasso'`` for the Lasso. 201 | copy_X : bool, default=True 202 | If ``False``, ``X`` is overwritten. 203 | eps : float, default=np.finfo(float).eps 204 | The machine-precision regularization in the computation of the 205 | Cholesky diagonal factors. Increase this for very ill-conditioned 206 | systems. Unlike the ``tol`` parameter in some iterative 207 | optimization-based algorithms, this parameter does not control 208 | the tolerance of the optimization. 209 | copy_Gram : bool, default=True 210 | If ``False``, ``Gram`` is overwritten. 211 | verbose : int, default=0 212 | Controls output verbosity. 213 | return_path : bool, default=True 214 | If ``return_path==True`` returns the entire path, else returns only the 215 | last point of the path. 216 | return_n_iter : bool, default=False 217 | Whether to return the number of iterations. 218 | positive : bool, default=False 219 | Restrict coefficients to be >= 0. 220 | This option is only allowed with method 'lasso'. Note that the model 221 | coefficients will not converge to the ordinary-least-squares solution 222 | for small values of alpha. Only coefficients up to the smallest alpha 223 | value (``alphas_[alphas_ > 0.].min()`` when fit_path=True) reached by 224 | the stepwise Lars-Lasso algorithm are typically in congruence with the 225 | solution of the coordinate descent lasso_path function. 226 | Returns 227 | ------- 228 | alphas : array-like of shape (n_alphas + 1,) 229 | Maximum of covariances (in absolute value) at each iteration. 230 | ``n_alphas`` is either ``max_iter``, ``n_features`` or the 231 | number of nodes in the path with ``alpha >= alpha_min``, whichever 232 | is smaller. 233 | active : array-like of shape (n_alphas,) 234 | Indices of active variables at the end of the path. 235 | coefs : array-like of shape (n_features, n_alphas + 1) 236 | Coefficients along the path 237 | n_iter : int 238 | Number of iterations run. Returned only if return_n_iter is set 239 | to True. 240 | See Also 241 | -------- 242 | lars_path 243 | lasso_path 244 | lasso_path_gram 245 | LassoLars 246 | Lars 247 | LassoLarsCV 248 | LarsCV 249 | sklearn.decomposition.sparse_encode 250 | References 251 | ---------- 252 | .. [1] "Least Angle Regression", Efron et al. 253 | http://statweb.stanford.edu/~tibs/ftp/lars.pdf 254 | .. [2] `Wikipedia entry on the Least-angle regression 255 | `_ 256 | .. [3] `Wikipedia entry on the Lasso 257 | `_ 258 | """ 259 | return _lars_path_solver( 260 | X=None, y=None, Xy=Xy, Gram=Gram, n_samples=n_samples, 261 | max_iter=max_iter, alpha_min=alpha_min, method=method, 262 | copy_X=copy_X, eps=eps, copy_Gram=copy_Gram, 263 | verbose=verbose, return_path=return_path, 264 | return_n_iter=return_n_iter, positive=positive) 265 | 266 | 267 | def _lars_path_solver( 268 | X, 269 | y, 270 | Xy=None, 271 | Gram=None, 272 | n_samples=None, 273 | max_iter=500, 274 | alpha_min=0, 275 | method="lar", 276 | copy_X=True, 277 | eps=np.finfo(float).eps, 278 | copy_Gram=True, 279 | verbose=0, 280 | return_path=True, 281 | return_n_iter=False, 282 | positive=False, 283 | testing=False, 284 | alpha=0.05, 285 | testing_stop=False, 286 | testing_verbose=False, 287 | ): 288 | """Compute Least Angle Regression or Lasso path using LARS algorithm [1] 289 | The optimization objective for the case method='lasso' is:: 290 | (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1 291 | in the case of method='lars', the objective function is only known in 292 | the form of an implicit equation (see discussion in [1]) 293 | Read more in the :ref:`User Guide `. 294 | Parameters 295 | ---------- 296 | X : None or ndarray of shape (n_samples, n_features) 297 | Input data. Note that if X is None then Gram must be specified, 298 | i.e., cannot be None or False. 299 | y : None or ndarray of shape (n_samples,) 300 | Input targets. 301 | Xy : array-like of shape (n_samples,) or (n_samples, n_targets), \ 302 | default=None 303 | `Xy = np.dot(X.T, y)` that can be precomputed. It is useful 304 | only when the Gram matrix is precomputed. 305 | Gram : None, 'auto' or array-like of shape (n_features, n_features), \ 306 | default=None 307 | Precomputed Gram matrix `(X' * X)`, if ``'auto'``, the Gram 308 | matrix is precomputed from the given X, if there are more samples 309 | than features. 310 | n_samples : int or float, default=None 311 | Equivalent size of sample. If `None`, it will be `n_samples`. 312 | max_iter : int, default=500 313 | Maximum number of iterations to perform, set to infinity for no limit. 314 | alpha_min : float, default=0 315 | Minimum correlation along the path. It corresponds to the 316 | regularization parameter alpha parameter in the Lasso. 317 | method : {'lar', 'lasso'}, default='lar' 318 | Specifies the returned model. Select ``'lar'`` for Least Angle 319 | Regression, ``'lasso'`` for the Lasso. 320 | copy_X : bool, default=True 321 | If ``False``, ``X`` is overwritten. 322 | eps : float, default=np.finfo(float).eps 323 | The machine-precision regularization in the computation of the 324 | Cholesky diagonal factors. Increase this for very ill-conditioned 325 | systems. Unlike the ``tol`` parameter in some iterative 326 | optimization-based algorithms, this parameter does not control 327 | the tolerance of the optimization. 328 | copy_Gram : bool, default=True 329 | If ``False``, ``Gram`` is overwritten. 330 | verbose : int, default=0 331 | Controls output verbosity. 332 | return_path : bool, default=True 333 | If ``return_path==True`` returns the entire path, else returns only the 334 | last point of the path. 335 | return_n_iter : bool, default=False 336 | Whether to return the number of iterations. 337 | positive : bool, default=False 338 | Restrict coefficients to be >= 0. 339 | This option is only allowed with method 'lasso'. Note that the model 340 | coefficients will not converge to the ordinary-least-squares solution 341 | for small values of alpha. Only coefficients up to the smallest alpha 342 | value (``alphas_[alphas_ > 0.].min()`` when fit_path=True) reached by 343 | the stepwise Lars-Lasso algorithm are typically in congruence with the 344 | solution of the coordinate descent lasso_path function. 345 | testing : bool, default=False 346 | Whether to conduct hypothesis testing each time a new variable enters 347 | alpha : float, default=0.05 348 | Significance level of hypothesis testing. Valid only if testing is True. 349 | testing_stop : bool, default=False 350 | If set to True, stops calculating future paths when the test yields 351 | insignificant results. 352 | Only takes effect when testing is set to True. 353 | testing_verbose : bool, default=True 354 | Controls output verbosity for hypothese testing procedure. 355 | Returns 356 | ------- 357 | alphas : array-like of shape (n_alphas + 1,) 358 | Maximum of covariances (in absolute value) at each iteration. 359 | ``n_alphas`` is either ``max_iter``, ``n_features`` or the 360 | number of nodes in the path with ``alpha >= alpha_min``, whichever 361 | is smaller. 362 | active : array-like of shape (n_alphas,) 363 | Indices of active variables at the end of the path. 364 | coefs : array-like of shape (n_features, n_alphas + 1) 365 | Coefficients along the path 366 | n_iter : int 367 | Number of iterations run. Returned only if return_n_iter is set 368 | to True. 369 | test_result: dictionary 370 | Contains testing results in the form of [test_stats, new_n] produced 371 | at each step. Returned only if testing is set to True. 372 | See Also 373 | -------- 374 | lasso_path 375 | LassoLars 376 | Lars 377 | LassoLarsCV 378 | LarsCV 379 | sklearn.decomposition.sparse_encode 380 | References 381 | ---------- 382 | .. [1] "Least Angle Regression", Efron et al. 383 | http://statweb.stanford.edu/~tibs/ftp/lars.pdf 384 | .. [2] `Wikipedia entry on the Least-angle regression 385 | `_ 386 | .. [3] `Wikipedia entry on the Lasso 387 | `_ 388 | """ 389 | if method == "lar" and positive: 390 | raise ValueError( 391 | "Positive constraint not supported for 'lar' " "coding method." 392 | ) 393 | 394 | n_samples = n_samples if n_samples is not None else y.size 395 | 396 | if Xy is None: 397 | Cov = np.dot(X.T, y) 398 | else: 399 | Cov = Xy.copy() 400 | 401 | if Gram is None or Gram is False: 402 | Gram = None 403 | if X is None: 404 | raise ValueError('X and Gram cannot both be unspecified.') 405 | elif isinstance(Gram, str) and Gram == 'auto' or Gram is True: 406 | if Gram is True or X.shape[0] > X.shape[1]: 407 | Gram = np.dot(X.T, X) 408 | else: 409 | Gram = None 410 | elif copy_Gram: 411 | Gram = Gram.copy() 412 | 413 | if Gram is None: 414 | n_features = X.shape[1] 415 | else: 416 | n_features = Cov.shape[0] 417 | if Gram.shape != (n_features, n_features): 418 | raise ValueError('The shapes of the inputs Gram and Xy' 419 | ' do not match.') 420 | 421 | if copy_X and X is not None and Gram is None: 422 | # force copy. setting the array to be fortran-ordered 423 | # speeds up the calculation of the (partial) Gram matrix 424 | # and allows to easily swap columns 425 | X = X.copy('F') 426 | 427 | max_features = min(max_iter, n_features) 428 | 429 | dtypes = set(a.dtype for a in (X, y, Xy, Gram) if a is not None) 430 | if len(dtypes) == 1: 431 | # use the precision level of input data if it is consistent 432 | return_dtype = next(iter(dtypes)) 433 | else: 434 | # fallback to double precision otherwise 435 | return_dtype = np.float64 436 | 437 | if return_path: 438 | coefs = np.zeros((max_features + 1, n_features), dtype=return_dtype) 439 | alphas = np.zeros(max_features + 1, dtype=return_dtype) 440 | else: 441 | coef, prev_coef = (np.zeros(n_features, dtype=return_dtype), 442 | np.zeros(n_features, dtype=return_dtype)) 443 | alpha, prev_alpha = (np.array([0.], dtype=return_dtype), 444 | np.array([0.], dtype=return_dtype)) 445 | # above better ideas? 446 | 447 | n_iter, n_active = 0, 0 448 | active, indices = list(), np.arange(n_features) 449 | # holds the sign of covariance 450 | sign_active = np.empty(max_features, dtype=np.int8) 451 | drop = False 452 | 453 | # will hold the cholesky factorization. Only lower part is 454 | # referenced. 455 | if Gram is None: 456 | L = np.empty((max_features, max_features), dtype=X.dtype) 457 | swap, nrm2 = linalg.get_blas_funcs(('swap', 'nrm2'), (X,)) 458 | else: 459 | L = np.empty((max_features, max_features), dtype=Gram.dtype) 460 | swap, nrm2 = linalg.get_blas_funcs(('swap', 'nrm2'), (Cov,)) 461 | solve_cholesky, = get_lapack_funcs(('potrs',), (L,)) 462 | 463 | if verbose: 464 | if verbose > 1: 465 | print("Step\t\tAdded\t\tDropped\t\tActive set size\t\tC") 466 | else: 467 | sys.stdout.write('.') 468 | sys.stdout.flush() 469 | 470 | tiny32 = np.finfo(np.float32).tiny # to avoid division by 0 warning 471 | equality_tolerance = np.finfo(np.float32).eps 472 | 473 | residual = y - 0 474 | coef = np.zeros(n_features) 475 | test_result = {} 476 | 477 | if Gram is not None: 478 | Gram_copy = Gram.copy() 479 | Cov_copy = Cov.copy() 480 | 481 | z_score = stats.norm.ppf(1 - alpha) 482 | while True: 483 | if not testing: 484 | if Cov.size: 485 | if positive: 486 | C_idx = np.argmax(Cov) 487 | else: 488 | C_idx = np.argmax(np.abs(Cov)) 489 | 490 | C_ = Cov[C_idx] 491 | 492 | if positive: 493 | C = C_ 494 | else: 495 | C = np.fabs(C_) 496 | else: 497 | C = 0. 498 | else: 499 | # not implemented when if positive is set to True 500 | if Cov.size: 501 | if positive: 502 | C_idx = np.argmax(Cov) 503 | else: 504 | C_idx = np.argmax(np.abs(Cov)) 505 | if Cov.size > 1: 506 | C_idx_second = np.abs(Cov).argsort()[-2] 507 | 508 | x1 = X.T[n_active + C_idx] 509 | x2 = X.T[n_active + C_idx_second] 510 | 511 | residual = y - np.dot(X[:, :n_active], coef[active]) 512 | u = np.array([np.dot(x1, residual), np.dot(x2, residual)]) / len(y) 513 | cov = np.cov(x1 * residual, x2 * residual) 514 | 515 | new_n = len(y) 516 | if u[0] >= 0 and u[1] >= 0: 517 | test_stats = u[0] - u[1] - z_score * np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y)) 518 | if test_stats < 0: 519 | z_alpha = (u[0] - u[1]) / np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y)) 520 | new_n = new_n * (z_score / z_alpha) ** 2 521 | elif u[0] >= 0 and u[1] < 0: 522 | test_stats = u[0] + u[1] - z_score * np.sqrt(2 * (cov[0][0] + cov[1][1] + cov[0][1] + cov[1][0]) / len(y)) 523 | if test_stats < 0: 524 | z_alpha = (u[0] + u[1]) / np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y)) 525 | new_n = new_n * (z_score / z_alpha) ** 2 526 | elif u[0] < 0 and u[1] >= 0: 527 | test_stats = -(u[0] + u[1] + z_score * np.sqrt(2 * (cov[0][0] + cov[1][1] + cov[0][1] + cov[1][0]) / len(y))) 528 | if test_stats < 0: 529 | z_alpha = (-u[0] - u[1]) / np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y)) 530 | new_n = new_n * (z_score / z_alpha) ** 2 531 | else: 532 | test_stats = -(u[0] - u[1] + z_score * np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y))) 533 | if test_stats < 0: 534 | z_alpha = (-u[0] + u[1]) / np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y)) 535 | new_n = new_n * (z_score / z_alpha) ** 2 536 | 537 | test_result[n_active + 1] = [test_stats, new_n] 538 | 539 | if testing_verbose: 540 | print("Selecting " + str(n_active + 1) + "th varieble: ") 541 | print("Correlations: " + str(np.round(u, 4))) 542 | print("Test statistics: " + str(round(test_stats, 4))) 543 | 544 | if testing_stop: 545 | if test_stats < 0: 546 | if testing_verbose: 547 | print("Not enough samples!") 548 | return alphas, active, coefs.T, test_result 549 | else: 550 | test_result[n_active + 1] = [0, 0] 551 | 552 | C_ = Cov[C_idx] 553 | 554 | if positive: 555 | C = C_ 556 | else: 557 | C = np.fabs(C_) 558 | else: 559 | C = 0. 560 | 561 | if return_path: 562 | alpha = alphas[n_iter, np.newaxis] 563 | coef = coefs[n_iter] 564 | prev_alpha = alphas[n_iter - 1, np.newaxis] 565 | prev_coef = coefs[n_iter - 1] 566 | 567 | alpha[0] = C / n_samples 568 | if alpha[0] <= alpha_min + equality_tolerance: # early stopping 569 | if abs(alpha[0] - alpha_min) > equality_tolerance: 570 | # interpolation factor 0 <= ss < 1 571 | if n_iter > 0: 572 | # In the first iteration, all alphas are zero, the formula 573 | # below would make ss a NaN 574 | ss = ((prev_alpha[0] - alpha_min) / 575 | (prev_alpha[0] - alpha[0])) 576 | coef[:] = prev_coef + ss * (coef - prev_coef) 577 | alpha[0] = alpha_min 578 | if return_path: 579 | coefs[n_iter] = coef 580 | break 581 | 582 | if n_iter >= max_iter or n_active >= n_features: 583 | break 584 | if not drop: 585 | 586 | ########################################################## 587 | # Append x_j to the Cholesky factorization of (Xa * Xa') # 588 | # # 589 | # ( L 0 ) # 590 | # L -> ( ) , where L * w = Xa' x_j # 591 | # ( w z ) and z = ||x_j|| # 592 | # # 593 | ########################################################## 594 | 595 | if positive: 596 | sign_active[n_active] = np.ones_like(C_) 597 | else: 598 | sign_active[n_active] = np.sign(C_) 599 | m, n = n_active, C_idx + n_active 600 | 601 | Cov[C_idx], Cov[0] = swap(Cov[C_idx], Cov[0]) 602 | indices[n], indices[m] = indices[m], indices[n] 603 | Cov_not_shortened = Cov 604 | Cov = Cov[1:] # remove Cov[0] 605 | 606 | if Gram is None: 607 | X.T[n], X.T[m] = swap(X.T[n], X.T[m]) 608 | c = nrm2(X.T[n_active]) ** 2 609 | L[n_active, :n_active] = \ 610 | np.dot(X.T[n_active], X.T[:n_active].T) 611 | else: 612 | # swap does only work inplace if matrix is fortran 613 | # contiguous ... 614 | Gram[m], Gram[n] = swap(Gram[m], Gram[n]) 615 | Gram[:, m], Gram[:, n] = swap(Gram[:, m], Gram[:, n]) 616 | c = Gram[n_active, n_active] 617 | L[n_active, :n_active] = Gram[n_active, :n_active] 618 | 619 | # Update the cholesky decomposition for the Gram matrix 620 | if n_active: 621 | linalg.solve_triangular(L[:n_active, :n_active], 622 | L[n_active, :n_active], 623 | trans=0, lower=1, 624 | overwrite_b=True, 625 | **SOLVE_TRIANGULAR_ARGS) 626 | 627 | v = np.dot(L[n_active, :n_active], L[n_active, :n_active]) 628 | diag = max(np.sqrt(np.abs(c - v)), eps) 629 | L[n_active, n_active] = diag 630 | 631 | if diag < 1e-7: 632 | # The system is becoming too ill-conditioned. 633 | # We have degenerate vectors in our active set. 634 | # We'll 'drop for good' the last regressor added. 635 | 636 | # Note: this case is very rare. It is no longer triggered by 637 | # the test suite. The `equality_tolerance` margin added in 0.16 638 | # to get early stopping to work consistently on all versions of 639 | # Python including 32 bit Python under Windows seems to make it 640 | # very difficult to trigger the 'drop for good' strategy. 641 | warnings.warn('Regressors in active set degenerate. ' 642 | 'Dropping a regressor, after %i iterations, ' 643 | 'i.e. alpha=%.3e, ' 644 | 'with an active set of %i regressors, and ' 645 | 'the smallest cholesky pivot element being %.3e.' 646 | ' Reduce max_iter or increase eps parameters.' 647 | % (n_iter, alpha, n_active, diag), 648 | ConvergenceWarning) 649 | 650 | # XXX: need to figure a 'drop for good' way 651 | Cov = Cov_not_shortened 652 | Cov[0] = 0 653 | Cov[C_idx], Cov[0] = swap(Cov[C_idx], Cov[0]) 654 | continue 655 | 656 | active.append(indices[n_active]) 657 | n_active += 1 658 | 659 | if verbose > 1: 660 | print("%s\t\t%s\t\t%s\t\t%s\t\t%s" % (n_iter, active[-1], '', 661 | n_active, C)) 662 | 663 | if method == 'lasso' and n_iter > 0 and prev_alpha[0] < alpha[0]: 664 | # alpha is increasing. This is because the updates of Cov are 665 | # bringing in too much numerical error that is greater than 666 | # than the remaining correlation with the 667 | # regressors. Time to bail out 668 | warnings.warn('Early stopping the lars path, as the residues ' 669 | 'are small and the current value of alpha is no ' 670 | 'longer well controlled. %i iterations, alpha=%.3e, ' 671 | 'previous alpha=%.3e, with an active set of %i ' 672 | 'regressors.' 673 | % (n_iter, alpha, prev_alpha, n_active), 674 | ConvergenceWarning) 675 | break 676 | 677 | # least squares solution 678 | least_squares, _ = solve_cholesky(L[:n_active, :n_active], 679 | sign_active[:n_active], 680 | lower=True) 681 | 682 | if least_squares.size == 1 and least_squares == 0: 683 | # This happens because sign_active[:n_active] = 0 684 | least_squares[...] = 1 685 | AA = 1. 686 | else: 687 | # is this really needed ? 688 | AA = 1. / np.sqrt(np.sum(least_squares * sign_active[:n_active])) 689 | 690 | if not np.isfinite(AA): 691 | # L is too ill-conditioned 692 | i = 0 693 | L_ = L[:n_active, :n_active].copy() 694 | while not np.isfinite(AA): 695 | L_.flat[::n_active + 1] += (2 ** i) * eps 696 | least_squares, _ = solve_cholesky( 697 | L_, sign_active[:n_active], lower=True) 698 | tmp = max(np.sum(least_squares * sign_active[:n_active]), 699 | eps) 700 | AA = 1. / np.sqrt(tmp) 701 | i += 1 702 | least_squares *= AA 703 | 704 | if Gram is None: 705 | # equiangular direction of variables in the active set 706 | eq_dir = np.dot(X.T[:n_active].T, least_squares) 707 | # correlation between each unactive variables and 708 | # eqiangular vector 709 | corr_eq_dir = np.dot(X.T[n_active:], eq_dir) 710 | else: 711 | # if huge number of features, this takes 50% of time, I 712 | # think could be avoided if we just update it using an 713 | # orthogonal (QR) decomposition of X 714 | corr_eq_dir = np.dot(Gram[:n_active, n_active:].T, 715 | least_squares) 716 | 717 | g1 = arrayfuncs.min_pos((C - Cov) / (AA - corr_eq_dir + tiny32)) 718 | if positive: 719 | gamma_ = min(g1, C / AA) 720 | else: 721 | g2 = arrayfuncs.min_pos((C + Cov) / (AA + corr_eq_dir + tiny32)) 722 | gamma_ = min(g1, g2, C / AA) 723 | 724 | # TODO: better names for these variables: z 725 | drop = False 726 | z = -coef[active] / (least_squares + tiny32) 727 | z_pos = arrayfuncs.min_pos(z) 728 | if z_pos < gamma_: 729 | # some coefficients have changed sign 730 | idx = np.where(z == z_pos)[0][::-1] 731 | 732 | # update the sign, important for LAR 733 | sign_active[idx] = -sign_active[idx] 734 | 735 | if method == 'lasso': 736 | gamma_ = z_pos 737 | drop = True 738 | 739 | n_iter += 1 740 | 741 | if return_path: 742 | if n_iter >= coefs.shape[0]: 743 | del coef, alpha, prev_alpha, prev_coef 744 | # resize the coefs and alphas array 745 | add_features = 2 * max(1, (max_features - n_active)) 746 | coefs = np.resize(coefs, (n_iter + add_features, n_features)) 747 | coefs[-add_features:] = 0 748 | alphas = np.resize(alphas, n_iter + add_features) 749 | alphas[-add_features:] = 0 750 | coef = coefs[n_iter] 751 | prev_coef = coefs[n_iter - 1] 752 | else: 753 | # mimic the effect of incrementing n_iter on the array references 754 | prev_coef = coef 755 | prev_alpha[0] = alpha[0] 756 | coef = np.zeros_like(coef) 757 | 758 | coef[active] = prev_coef[active] + gamma_ * least_squares 759 | 760 | # update correlations 761 | Cov -= gamma_ * corr_eq_dir 762 | 763 | # See if any coefficient has changed sign 764 | if drop and method == 'lasso': 765 | 766 | # handle the case when idx is not length of 1 767 | for ii in idx: 768 | arrayfuncs.cholesky_delete(L[:n_active, :n_active], ii) 769 | 770 | n_active -= 1 771 | # handle the case when idx is not length of 1 772 | drop_idx = [active.pop(ii) for ii in idx] 773 | 774 | if Gram is None: 775 | # propagate dropped variable 776 | for ii in idx: 777 | for i in range(ii, n_active): 778 | X.T[i], X.T[i + 1] = swap(X.T[i], X.T[i + 1]) 779 | # yeah this is stupid 780 | indices[i], indices[i + 1] = indices[i + 1], indices[i] 781 | 782 | # TODO: this could be updated 783 | residual = y - np.dot(X[:, :n_active], coef[active]) 784 | temp = np.dot(X.T[n_active], residual) 785 | 786 | Cov = np.r_[temp, Cov] 787 | else: 788 | for ii in idx: 789 | for i in range(ii, n_active): 790 | indices[i], indices[i + 1] = indices[i + 1], indices[i] 791 | Gram[i], Gram[i + 1] = swap(Gram[i], Gram[i + 1]) 792 | Gram[:, i], Gram[:, i + 1] = swap(Gram[:, i], 793 | Gram[:, i + 1]) 794 | 795 | # Cov_n = Cov_j + x_j * X + increment(betas) TODO: 796 | # will this still work with multiple drops ? 797 | 798 | # recompute covariance. Probably could be done better 799 | # wrong as Xy is not swapped with the rest of variables 800 | 801 | # TODO: this could be updated 802 | temp = Cov_copy[drop_idx] - np.dot(Gram_copy[drop_idx], coef) 803 | Cov = np.r_[temp, Cov] 804 | 805 | sign_active = np.delete(sign_active, idx) 806 | sign_active = np.append(sign_active, 0.) # just to maintain size 807 | if verbose > 1: 808 | print("%s\t\t%s\t\t%s\t\t%s\t\t%s" % (n_iter, '', drop_idx, 809 | n_active, abs(temp))) 810 | 811 | if return_path: 812 | # resize coefs in case of early stop 813 | alphas = alphas[:n_iter + 1] 814 | coefs = coefs[:n_iter + 1] 815 | 816 | if return_n_iter: 817 | return alphas, active, coefs.T, n_iter 818 | else: 819 | if testing: 820 | return alphas, active, coefs.T, test_result 821 | else: 822 | return alphas, active, coefs.T 823 | else: 824 | if return_n_iter: 825 | return alpha, active, coef, n_iter 826 | else: 827 | return alpha, active, coef --------------------------------------------------------------------------------