├── slime
├── __init__.py
├── tests
│ ├── __init__.py
│ ├── test_generic_utils.py
│ ├── test_scikit_image.py
│ ├── test_lime_text.py
│ ├── test_discretize.py
│ └── test_lime_tabular.py
├── utils
│ ├── __init__.py
│ └── generic_utils.py
├── wrappers
│ ├── __init__.py
│ └── scikit_image.py
├── .DS_Store
├── exceptions.py
├── js
│ ├── main.js
│ ├── predict_proba.js
│ ├── bar_chart.js
│ ├── predicted_value.js
│ └── explanation.js
├── test_table.html
├── style.css
├── webpack.config.js
├── package.json
├── submodular_pick.py
├── discretize.py
├── explanation.py
├── lime_base.py
├── lime_text.py
└── lime_image.py
├── slime_lm
├── __init__.py
└── _least_angle.py
├── MANIFEST.in
├── .DS_Store
├── doc
└── images
│ ├── demo1.png
│ └── demo2.png
├── setup.py
├── LICENSE
├── .gitignore
└── README.md
/slime/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime_lm/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/slime/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include slime/*.js
2 | include LICENSE
3 |
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhengzeZhou/slime/HEAD/.DS_Store
--------------------------------------------------------------------------------
/slime/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhengzeZhou/slime/HEAD/slime/.DS_Store
--------------------------------------------------------------------------------
/slime/exceptions.py:
--------------------------------------------------------------------------------
1 | class LimeError(Exception):
2 | """Raise for errors"""
3 |
--------------------------------------------------------------------------------
/doc/images/demo1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhengzeZhou/slime/HEAD/doc/images/demo1.png
--------------------------------------------------------------------------------
/doc/images/demo2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ZhengzeZhou/slime/HEAD/doc/images/demo2.png
--------------------------------------------------------------------------------
/slime/js/main.js:
--------------------------------------------------------------------------------
1 | if (!global._babelPolyfill) {
2 | require('babel-polyfill')
3 | }
4 |
5 |
6 | import Explanation from './explanation.js';
7 | import Barchart from './bar_chart.js';
8 | import PredictProba from './predict_proba.js';
9 | import PredictedValue from './predicted_value.js';
10 | require('../style.css');
11 |
12 | export {Explanation, Barchart, PredictProba, PredictedValue};
13 | //require('style-loader');
14 |
15 |
16 |
--------------------------------------------------------------------------------
/slime/test_table.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/slime/style.css:
--------------------------------------------------------------------------------
1 | .lime {
2 | all: initial;
3 | }
4 | .lime.top_div {
5 | display: flex;
6 | flex-wrap: wrap;
7 | }
8 | .lime.predict_proba {
9 | width: 245px;
10 | }
11 | .lime.predicted_value {
12 | width: 245px;
13 | }
14 | .lime.explanation {
15 | width: 350px;
16 | }
17 |
18 | .lime.text_div {
19 | max-height:300px;
20 | flex: 1 0 300px;
21 | overflow:scroll;
22 | }
23 | .lime.table_div {
24 | max-height:300px;
25 | flex: 1 0 300px;
26 | overflow:scroll;
27 | }
28 | .lime.table_div table {
29 | border-collapse: collapse;
30 | color: white;
31 | border-style: hidden;
32 | margin: 0 auto;
33 | }
34 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(name='slime',
4 | version='0.1',
5 | description='Stabilized-LIME for Model Explanation',
6 | url='https://github.com/ZhengzeZhou/slime',
7 | author='Zhengze Zhou',
8 | author_email='zz433@cornell.edu',
9 | license='BSD',
10 | packages=find_packages(exclude=['js', 'node_modules', 'tests']),
11 | python_requires='>=3.5',
12 | install_requires=[
13 | 'lime',
14 | 'matplotlib',
15 | 'numpy',
16 | 'scipy',
17 | 'tqdm >= 4.29.1',
18 | 'scikit-learn>=0.18',
19 | 'scikit-image>=0.12',
20 | 'pyDOE2==1.3.0'
21 | ],
22 | extras_require={
23 | 'dev': ['pytest', 'flake8'],
24 | },
25 | include_package_data=True,
26 | zip_safe=False)
27 |
--------------------------------------------------------------------------------
/slime/webpack.config.js:
--------------------------------------------------------------------------------
1 | var path = require('path');
2 | var webpack = require('webpack');
3 |
4 | module.exports = {
5 | entry: './js/main.js',
6 | output: {
7 | path: __dirname,
8 | filename: 'bundle.js',
9 | library: 'lime'
10 | },
11 | module: {
12 | loaders: [
13 | {
14 | loader: 'babel-loader',
15 | test: path.join(__dirname, 'js'),
16 | query: {
17 | presets: 'es2015-ie',
18 | },
19 |
20 | },
21 | {
22 | test: /\.css$/,
23 | loaders: ['style-loader', 'css-loader'],
24 |
25 | }
26 |
27 | ]
28 | },
29 | plugins: [
30 | // Avoid publishing files when compilation fails
31 | new webpack.NoErrorsPlugin()
32 | ],
33 | stats: {
34 | // Nice colored output
35 | colors: true
36 | },
37 | // Create Sourcemaps for the bundle
38 | devtool: 'source-map',
39 | };
40 |
41 |
--------------------------------------------------------------------------------
/slime/utils/generic_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import inspect
3 | import types
4 |
5 |
6 | def has_arg(fn, arg_name):
7 | """Checks if a callable accepts a given keyword argument.
8 |
9 | Args:
10 | fn: callable to inspect
11 | arg_name: string, keyword argument name to check
12 |
13 | Returns:
14 | bool, whether `fn` accepts a `arg_name` keyword argument.
15 | """
16 | if sys.version_info < (3,):
17 | if isinstance(fn, types.FunctionType) or isinstance(fn, types.MethodType):
18 | arg_spec = inspect.getargspec(fn)
19 | else:
20 | try:
21 | arg_spec = inspect.getargspec(fn.__call__)
22 | except AttributeError:
23 | return False
24 | return (arg_name in arg_spec.args)
25 | elif sys.version_info < (3, 6):
26 | arg_spec = inspect.getfullargspec(fn)
27 | return (arg_name in arg_spec.args or
28 | arg_name in arg_spec.kwonlyargs)
29 | else:
30 | try:
31 | signature = inspect.signature(fn)
32 | except ValueError:
33 | # handling Cython
34 | signature = inspect.signature(fn.__call__)
35 | parameter = signature.parameters.get(arg_name)
36 | if parameter is None:
37 | return False
38 | return (parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
39 | inspect.Parameter.KEYWORD_ONLY))
40 |
--------------------------------------------------------------------------------
/slime/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "lime",
3 | "version": "1.0.0",
4 | "description": "",
5 | "main": "main.js",
6 | "scripts": {
7 | "build": "webpack",
8 | "watch": "webpack --watch",
9 | "start": "webpack-dev-server --hot --inline",
10 | "lint": "eslint js"
11 | },
12 | "repository": {
13 | "type": "git",
14 | "url": "git+https://github.com/marcotcr/lime.git"
15 | },
16 | "author": "Marco Tulio Ribeiro ",
17 | "license": "TODO",
18 | "bugs": {
19 | "url": "https://github.com/marcotcr/lime/issues"
20 | },
21 | "homepage": "https://github.com/marcotcr/lime#readme",
22 | "devDependencies": {
23 | "babel-cli": "^6.8.0",
24 | "babel-core": "^6.17.0",
25 | "babel-eslint": "^6.1.0",
26 | "babel-loader": "^6.2.4",
27 | "babel-polyfill": "^6.16.0",
28 | "babel-preset-es2015": "^6.0.15",
29 | "babel-preset-es2015-ie": "^6.6.2",
30 | "css-loader": "^0.23.1",
31 | "eslint": "^6.6.0",
32 | "node-libs-browser": "^0.5.3",
33 | "style-loader": "^0.13.1",
34 | "webpack": "^1.13.0",
35 | "webpack-dev-server": "^1.14.1"
36 | },
37 | "dependencies": {
38 | "d3": "^3.5.17",
39 | "lodash": "^4.11.2"
40 | },
41 | "eslintConfig": {
42 | "parser": "babel-eslint",
43 | "parserOptions": {
44 | "ecmaVersion": 6,
45 | "sourceType": "module",
46 | "ecmaFeatures": {
47 | "jsx": true
48 | }
49 | },
50 | "extends": "eslint:recommended"
51 | }
52 | }
53 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2021, Zhengze Zhou
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled python modules.
2 | *.pyc
3 |
4 | # Setuptools distribution folder.
5 | /dist/
6 |
7 | /lime/node_modules
8 |
9 | # Python egg metadata, regenerated from source files by setuptools.
10 | /*.egg-info
11 |
12 | # Unit test / coverage reports
13 | .cache
14 |
15 | # Created by https://www.gitignore.io/api/pycharm
16 |
17 | ### PyCharm ###
18 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
19 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
20 |
21 | # User-specific stuff:
22 | .idea/workspace.xml
23 | .idea/tasks.xml
24 | .idea/dictionaries
25 | .idea/vcs.xml
26 | .idea/jsLibraryMappings.xml
27 |
28 | # Sensitive or high-churn files:
29 | .idea/dataSources.ids
30 | .idea/dataSources.xml
31 | .idea/dataSources.local.xml
32 | .idea/sqlDataSources.xml
33 | .idea/dynamic.xml
34 | .idea/uiDesigner.xml
35 |
36 | # Gradle:
37 | .idea/gradle.xml
38 | .idea/libraries
39 |
40 | # Mongo Explorer plugin:
41 | .idea/mongoSettings.xml
42 |
43 | ## File-based project format:
44 | *.iws
45 |
46 | ## Plugin-specific files:
47 |
48 | # IntelliJ
49 | /out/
50 |
51 | # mpeltonen/sbt-idea plugin
52 | .idea_modules/
53 |
54 | # JIRA plugin
55 | atlassian-ide-plugin.xml
56 |
57 | # Crashlytics plugin (for Android Studio and IntelliJ)
58 | com_crashlytics_export_strings.xml
59 | crashlytics.properties
60 | crashlytics-build.properties
61 | fabric.properties
62 |
63 | ### PyCharm Patch ###
64 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
65 |
66 | # *.iml
67 | # modules.xml
68 | # .idea/misc.xml
69 | # *.ipr
70 |
71 | # Pycharm
72 | .idea
73 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # slime
2 |
3 | This repository holds code for replicating experiments in papar [S-LIME: Stabilized-LIME for Model Explanation]() to appear in [KDD2021](https://www.kdd.org/kdd2021/).
4 |
5 | It is built on the implementation of [LIME](https://github.com/marcotcr/lime) with added functionalities.
6 |
7 | ## Introduction
8 |
9 | It has been shown that post hoc explanations based on perturbations (such as LIME) exhibit large instability, posing serious challenges to the effectiveness of the method itself and harming user trust. S-LIME stands for Stabilized-LIME, which utilizes a hypothesis testing framework based on central limit theorem for determining the number of perturbation points needed to guarantee stability of the resulting explanation.
10 |
11 | ## Installation
12 |
13 | clone the repository and install using pip:
14 |
15 | ```sh
16 | git clone https://github.com/ZhengzeZhou/slime.git
17 | cd slime
18 | pip install .
19 | ```
20 |
21 | ## Usage
22 |
23 | Currently, S-LIME only support tabular data and when feature selection method is set to "lasso_path". We are woring on extending the use cases to other data types and feature selection methods.
24 |
25 | The following screenshot shows a typical usage of LIME on breasd cancer data. We can easily observe that two runs of the explanation algorithms result in different features being selected.
26 |
27 | 
28 |
29 | S-LIME is invoked by calling **explainer.slime** instead of **explainer.explain_instance**. *n_max* indicates the maximum number of sythetic samples to generate and *alpha* denotes the significance level of hypothesis testing. S-LIME explanations are guranteed to be stable under high probability.
30 |
31 | 
32 |
33 | ## Notebooks
34 |
35 | - [Breast Cancer Data](https://github.com/ZhengzeZhou/slime/blob/main/doc/notebooks/Breast%20Cancer%20Data.ipynb)
36 | - [MARS](https://github.com/ZhengzeZhou/slime/blob/main/doc/notebooks/MARS.ipynb)
37 | - [Dog Images](https://github.com/joangog/slime/blob/main/doc/notebooks/Dogs.ipynb)
38 |
39 |
40 |
--------------------------------------------------------------------------------
/slime/tests/test_generic_utils.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import sys
3 | from lime.utils.generic_utils import has_arg
4 |
5 |
6 | class TestGenericUtils(unittest.TestCase):
7 |
8 | def test_has_arg(self):
9 | # fn is callable / is not callable
10 |
11 | class FooNotCallable:
12 |
13 | def __init__(self, word):
14 | self.message = word
15 |
16 | class FooCallable:
17 |
18 | def __init__(self, word):
19 | self.message = word
20 |
21 | def __call__(self, message):
22 | return message
23 |
24 | def positional_argument_call(self, arg1):
25 | return self.message
26 |
27 | def multiple_positional_arguments_call(self, *args):
28 | res = []
29 | for a in args:
30 | res.append(a)
31 | return res
32 |
33 | def keyword_argument_call(self, filter_=True):
34 | res = self.message
35 | if filter_:
36 | res = 'KO'
37 | return res
38 |
39 | def multiple_keyword_arguments_call(self, arg1='1', arg2='2'):
40 | return self.message + arg1 + arg2
41 |
42 | def undefined_keyword_arguments_call(self, **kwargs):
43 | res = self.message
44 | for a in kwargs:
45 | res = res + a
46 | return a
47 |
48 | foo_callable = FooCallable('OK')
49 | self.assertTrue(has_arg(foo_callable, 'message'))
50 |
51 | if sys.version_info < (3,):
52 | foo_not_callable = FooNotCallable('KO')
53 | self.assertFalse(has_arg(foo_not_callable, 'message'))
54 | elif sys.version_info < (3, 6):
55 | with self.assertRaises(TypeError):
56 | foo_not_callable = FooNotCallable('KO')
57 | has_arg(foo_not_callable, 'message')
58 |
59 | # Python 2, argument in / not in valid arguments / keyword arguments
60 | if sys.version_info < (3,):
61 | self.assertFalse(has_arg(foo_callable, 'invalid_arg'))
62 | self.assertTrue(has_arg(foo_callable.positional_argument_call, 'arg1'))
63 | self.assertFalse(has_arg(foo_callable.multiple_positional_arguments_call, 'argX'))
64 | self.assertFalse(has_arg(foo_callable.keyword_argument_call, 'argX'))
65 | self.assertTrue(has_arg(foo_callable.keyword_argument_call, 'filter_'))
66 | self.assertTrue(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg2'))
67 | self.assertFalse(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg3'))
68 | self.assertFalse(has_arg(foo_callable.undefined_keyword_arguments_call, 'argX'))
69 | # Python 3, argument in / not in valid arguments / keyword arguments
70 | elif sys.version_info < (3, 6):
71 | self.assertFalse(has_arg(foo_callable, 'invalid_arg'))
72 | self.assertTrue(has_arg(foo_callable.positional_argument_call, 'arg1'))
73 | self.assertFalse(has_arg(foo_callable.multiple_positional_arguments_call, 'argX'))
74 | self.assertFalse(has_arg(foo_callable.keyword_argument_call, 'argX'))
75 | self.assertTrue(has_arg(foo_callable.keyword_argument_call, 'filter_'))
76 | self.assertTrue(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg2'))
77 | self.assertFalse(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg3'))
78 | self.assertFalse(has_arg(foo_callable.undefined_keyword_arguments_call, 'argX'))
79 | else:
80 | self.assertFalse(has_arg(foo_callable, 'invalid_arg'))
81 | self.assertTrue(has_arg(foo_callable.positional_argument_call, 'arg1'))
82 | self.assertFalse(has_arg(foo_callable.multiple_positional_arguments_call, 'argX'))
83 | self.assertFalse(has_arg(foo_callable.keyword_argument_call, 'argX'))
84 | self.assertTrue(has_arg(foo_callable.keyword_argument_call, 'filter_'))
85 | self.assertTrue(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg2'))
86 | self.assertFalse(has_arg(foo_callable.multiple_keyword_arguments_call, 'arg3'))
87 | self.assertFalse(has_arg(foo_callable.undefined_keyword_arguments_call, 'argX'))
88 | # argname is None
89 | self.assertFalse(has_arg(foo_callable, None))
90 |
91 |
92 | if __name__ == '__main__':
93 | unittest.main()
94 |
--------------------------------------------------------------------------------
/slime/js/predict_proba.js:
--------------------------------------------------------------------------------
1 | import d3 from 'd3';
2 | import {range, sortBy} from 'lodash';
3 |
4 | class PredictProba {
5 | // svg: d3 object with the svg in question
6 | // class_names: array of class names
7 | // predict_probas: array of prediction probabilities
8 | constructor(svg, class_names, predict_probas, title='Prediction probabilities') {
9 | let width = parseInt(svg.style('width'));
10 | this.names = class_names;
11 | this.names.push('Other');
12 | if (class_names.length < 10) {
13 | this.colors = d3.scale.category10().domain(this.names);
14 | this.colors_i = d3.scale.category10().domain(range(this.names.length));
15 | }
16 | else {
17 | this.colors = d3.scale.category20().domain(this.names);
18 | this.colors_i = d3.scale.category20().domain(range(this.names.length));
19 | }
20 | let [names, data] = this.map_classes(this.names, predict_probas);
21 | let bar_x = width - 125;
22 | let class_names_width = bar_x;
23 | let bar_width = width - bar_x - 32;
24 | let x_scale = d3.scale.linear().range([0, bar_width]);
25 | let bar_height = 17;
26 | let space_between_bars = 5;
27 | let bar_yshift= title === '' ? 0 : 35;
28 | let n_bars = Math.min(5, data.length);
29 | this.svg_height = n_bars * (bar_height + space_between_bars) + bar_yshift;
30 | svg.style('height', this.svg_height + 'px');
31 | let this_object = this;
32 | if (title !== '') {
33 | svg.append('text')
34 | .text(title)
35 | .attr('x', 20)
36 | .attr('y', 20);
37 | }
38 | let bar_y = i => (bar_height + space_between_bars) * i + bar_yshift;
39 | let bar = svg.append("g");
40 |
41 | for (let i of range(data.length)) {
42 | var color = this.colors(names[i]);
43 | if (names[i] == 'Other' && this.names.length > 20) {
44 | color = '#5F9EA0';
45 | }
46 | let rect = bar.append("rect");
47 | rect.attr("x", bar_x)
48 | .attr("y", bar_y(i))
49 | .attr("height", bar_height)
50 | .attr("width", x_scale(data[i]))
51 | .style("fill", color);
52 | bar.append("rect").attr("x", bar_x)
53 | .attr("y", bar_y(i))
54 | .attr("height", bar_height)
55 | .attr("width", bar_width - 1)
56 | .attr("fill-opacity", 0)
57 | .attr("stroke", "black");
58 | let text = bar.append("text");
59 | text.classed("prob_text", true);
60 | text.attr("y", bar_y(i) + bar_height - 3).attr("fill", "black").style("font", "14px tahoma, sans-serif");
61 | text = bar.append("text");
62 | text.attr("x", bar_x + x_scale(data[i]) + 5)
63 | .attr("y", bar_y(i) + bar_height - 3)
64 | .attr("fill", "black")
65 | .style("font", "14px tahoma, sans-serif")
66 | .text(data[i].toFixed(2));
67 | text = bar.append("text");
68 | text.attr("x", bar_x - 10)
69 | .attr("y", bar_y(i) + bar_height - 3)
70 | .attr("fill", "black")
71 | .attr("text-anchor", "end")
72 | .style("font", "14px tahoma, sans-serif")
73 | .text(names[i]);
74 | while (text.node().getBBox()['width'] + 1 > (class_names_width - 10)) {
75 | // TODO: ta mostrando só dois, e talvez quando hover mostrar o texto
76 | // todo
77 | let cur_text = text.text().slice(0, text.text().length - 5);
78 | text.text(cur_text + '...');
79 | if (cur_text === '') {
80 | break
81 | }
82 | }
83 | }
84 | }
85 | map_classes(class_names, predict_proba) {
86 | if (class_names.length <= 6) {
87 | return [class_names, predict_proba];
88 | }
89 | let class_dict = range(predict_proba.length).map(i => ({'name': class_names[i], 'prob': predict_proba[i], 'i' : i}));
90 | let sorted = sortBy(class_dict, d => -d.prob);
91 | let other = new Set();
92 | range(4, sorted.length).map(d => other.add(sorted[d].name));
93 | let other_prob = 0;
94 | let ret_probs = [];
95 | let ret_names = [];
96 | for (let d of range(sorted.length)) {
97 | if (other.has(sorted[d].name)) {
98 | other_prob += sorted[d].prob;
99 | }
100 | else {
101 | ret_probs.push(sorted[d].prob);
102 | ret_names.push(sorted[d].name);
103 | }
104 | };
105 | ret_names.push("Other");
106 | ret_probs.push(other_prob);
107 | return [ret_names, ret_probs];
108 | }
109 |
110 | }
111 | export default PredictProba;
112 |
113 |
114 |
--------------------------------------------------------------------------------
/slime/js/bar_chart.js:
--------------------------------------------------------------------------------
1 | import d3 from 'd3';
2 | class Barchart {
3 | // svg: d3 object with the svg in question
4 | // exp_array: list of (feature_name, weight)
5 | constructor(svg, exp_array, two_sided=true, titles=undefined, colors=['red', 'green'], show_numbers=false, bar_height=5) {
6 | let svg_width = Math.min(600, parseInt(svg.style('width')));
7 | let bar_width = two_sided ? svg_width / 2 : svg_width;
8 | if (titles === undefined) {
9 | titles = two_sided ? ['Cons', 'Pros'] : 'Pros';
10 | }
11 | if (show_numbers) {
12 | bar_width = bar_width - 30;
13 | }
14 | let x_offset = two_sided ? svg_width / 2 : 10;
15 | // 13.1 is +- the width of W, the widest letter.
16 | if (two_sided && titles.length == 2) {
17 | svg.append('text')
18 | .attr('x', svg_width / 4)
19 | .attr('y', 15)
20 | .attr('font-size', '20')
21 | .attr('text-anchor', 'middle')
22 | .style('fill', colors[0])
23 | .text(titles[0]);
24 |
25 | svg.append('text')
26 | .attr('x', svg_width / 4 * 3)
27 | .attr('y', 15)
28 | .attr('font-size', '20')
29 | .attr('text-anchor', 'middle')
30 | .style('fill', colors[1])
31 | .text(titles[1]);
32 | }
33 | else {
34 | let pos = two_sided ? svg_width / 2 : x_offset;
35 | let anchor = two_sided ? 'middle' : 'begin';
36 | svg.append('text')
37 | .attr('x', pos)
38 | .attr('y', 15)
39 | .attr('font-size', '20')
40 | .attr('text-anchor', anchor)
41 | .text(titles);
42 | }
43 | let yshift = 20;
44 | let space_between_bars = 0;
45 | let text_height = 16;
46 | let space_between_bar_and_text = 3;
47 | let total_bar_height = text_height + space_between_bar_and_text + bar_height + space_between_bars;
48 | let total_height = (total_bar_height) * exp_array.length;
49 | this.svg_height = total_height + yshift;
50 | let yscale = d3.scale.linear()
51 | .domain([0, exp_array.length])
52 | .range([yshift, yshift + total_height])
53 | let names = exp_array.map(v => v[0]);
54 | let weights = exp_array.map(v => v[1]);
55 | let max_weight = Math.max(...(weights.map(v=>Math.abs(v))));
56 | let xscale = d3.scale.linear()
57 | .domain([0,Math.max(1, max_weight)])
58 | .range([0, bar_width]);
59 |
60 | for (var i = 0; i < exp_array.length; ++i) {
61 | let name = names[i];
62 | let weight = weights[i];
63 | var size = xscale(Math.abs(weight));
64 | let to_the_right = (weight > 0 || !two_sided)
65 | let text = svg.append('text')
66 | .attr('x', to_the_right ? x_offset + 2 : x_offset - 2)
67 | .attr('y', yscale(i) + text_height)
68 | .attr('text-anchor', to_the_right ? 'begin' : 'end')
69 | .attr('font-size', '14')
70 | .text(name);
71 | while (text.node().getBBox()['width'] + 1 > bar_width) {
72 | let cur_text = text.text().slice(0, text.text().length - 5);
73 | text.text(cur_text + '...');
74 | if (text === '...') {
75 | break;
76 | }
77 | }
78 | let bar = svg.append('rect')
79 | .attr('height', bar_height)
80 | .attr('x', to_the_right ? x_offset : x_offset - size)
81 | .attr('y', text_height + yscale(i) + space_between_bar_and_text)// + bar_height)
82 | .attr('width', size)
83 | .style('fill', weight > 0 ? colors[1] : colors[0]);
84 | if (show_numbers) {
85 | let bartext = svg.append('text')
86 | .attr('x', to_the_right ? x_offset + size + 1 : x_offset - size - 1)
87 | .attr('text-anchor', (weight > 0 || !two_sided) ? 'begin' : 'end')
88 | .attr('y', bar_height + yscale(i) + text_height + space_between_bar_and_text)
89 | .attr('font-size', '10')
90 | .text(Math.abs(weight).toFixed(2));
91 | }
92 | }
93 | let line = svg.append("line")
94 | .attr("x1", x_offset)
95 | .attr("x2", x_offset)
96 | .attr("y1", bar_height + yshift)
97 | .attr("y2", Math.max(bar_height, yscale(exp_array.length)))
98 | .style("stroke-width",2)
99 | .style("stroke", "black");
100 | }
101 |
102 | }
103 | export default Barchart;
104 |
--------------------------------------------------------------------------------
/slime/wrappers/scikit_image.py:
--------------------------------------------------------------------------------
1 | import types
2 | from lime.utils.generic_utils import has_arg
3 | from skimage.segmentation import felzenszwalb, slic, quickshift
4 |
5 |
6 | class BaseWrapper(object):
7 | """Base class for LIME Scikit-Image wrapper
8 |
9 |
10 | Args:
11 | target_fn: callable function or class instance
12 | target_params: dict, parameters to pass to the target_fn
13 |
14 |
15 | 'target_params' takes parameters required to instanciate the
16 | desired Scikit-Image class/model
17 | """
18 |
19 | def __init__(self, target_fn=None, **target_params):
20 | self.target_fn = target_fn
21 | self.target_params = target_params
22 |
23 | def _check_params(self, parameters):
24 | """Checks for mistakes in 'parameters'
25 |
26 | Args :
27 | parameters: dict, parameters to be checked
28 |
29 | Raises :
30 | ValueError: if any parameter is not a valid argument for the target function
31 | or the target function is not defined
32 | TypeError: if argument parameters is not iterable
33 | """
34 | a_valid_fn = []
35 | if self.target_fn is None:
36 | if callable(self):
37 | a_valid_fn.append(self.__call__)
38 | else:
39 | raise TypeError('invalid argument: tested object is not callable,\
40 | please provide a valid target_fn')
41 | elif isinstance(self.target_fn, types.FunctionType) \
42 | or isinstance(self.target_fn, types.MethodType):
43 | a_valid_fn.append(self.target_fn)
44 | else:
45 | a_valid_fn.append(self.target_fn.__call__)
46 |
47 | if not isinstance(parameters, str):
48 | for p in parameters:
49 | for fn in a_valid_fn:
50 | if has_arg(fn, p):
51 | pass
52 | else:
53 | raise ValueError('{} is not a valid parameter'.format(p))
54 | else:
55 | raise TypeError('invalid argument: list or dictionnary expected')
56 |
57 | def set_params(self, **params):
58 | """Sets the parameters of this estimator.
59 | Args:
60 | **params: Dictionary of parameter names mapped to their values.
61 |
62 | Raises :
63 | ValueError: if any parameter is not a valid argument
64 | for the target function
65 | """
66 | self._check_params(params)
67 | self.target_params = params
68 |
69 | def filter_params(self, fn, override=None):
70 | """Filters `target_params` and return those in `fn`'s arguments.
71 | Args:
72 | fn : arbitrary function
73 | override: dict, values to override target_params
74 | Returns:
75 | result : dict, dictionary containing variables
76 | in both target_params and fn's arguments.
77 | """
78 | override = override or {}
79 | result = {}
80 | for name, value in self.target_params.items():
81 | if has_arg(fn, name):
82 | result.update({name: value})
83 | result.update(override)
84 | return result
85 |
86 |
87 | class SegmentationAlgorithm(BaseWrapper):
88 | """ Define the image segmentation function based on Scikit-Image
89 | implementation and a set of provided parameters
90 |
91 | Args:
92 | algo_type: string, segmentation algorithm among the following:
93 | 'quickshift', 'slic', 'felzenszwalb'
94 | target_params: dict, algorithm parameters (valid model paramters
95 | as define in Scikit-Image documentation)
96 | """
97 |
98 | def __init__(self, algo_type, **target_params):
99 | self.algo_type = algo_type
100 | if (self.algo_type == 'quickshift'):
101 | BaseWrapper.__init__(self, quickshift, **target_params)
102 | kwargs = self.filter_params(quickshift)
103 | self.set_params(**kwargs)
104 | elif (self.algo_type == 'felzenszwalb'):
105 | BaseWrapper.__init__(self, felzenszwalb, **target_params)
106 | kwargs = self.filter_params(felzenszwalb)
107 | self.set_params(**kwargs)
108 | elif (self.algo_type == 'slic'):
109 | BaseWrapper.__init__(self, slic, **target_params)
110 | kwargs = self.filter_params(slic)
111 | self.set_params(**kwargs)
112 |
113 | def __call__(self, *args):
114 | return self.target_fn(args[0], **self.target_params)
115 |
--------------------------------------------------------------------------------
/slime/tests/test_scikit_image.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from lime.wrappers.scikit_image import BaseWrapper
3 | from lime.wrappers.scikit_image import SegmentationAlgorithm
4 | from skimage.segmentation import quickshift
5 | from skimage.data import chelsea
6 | from skimage.util import img_as_float
7 | import numpy as np
8 |
9 |
10 | class TestBaseWrapper(unittest.TestCase):
11 |
12 | def test_base_wrapper(self):
13 |
14 | obj_with_params = BaseWrapper(a=10, b='message')
15 | obj_without_params = BaseWrapper()
16 |
17 | def foo_fn():
18 | return 'bar'
19 |
20 | obj_with_fn = BaseWrapper(foo_fn)
21 | self.assertEqual(obj_with_params.target_params, {'a': 10, 'b': 'message'})
22 | self.assertEqual(obj_without_params.target_params, {})
23 | self.assertEqual(obj_with_fn.target_fn(), 'bar')
24 |
25 | def test__check_params(self):
26 |
27 | def bar_fn(a):
28 | return str(a)
29 |
30 | class Pipo():
31 |
32 | def __init__(self):
33 | self.name = 'pipo'
34 |
35 | def __call__(self, message):
36 | return message
37 |
38 | pipo = Pipo()
39 | obj_with_valid_fn = BaseWrapper(bar_fn, a=10, b='message')
40 | obj_with_valid_callable_fn = BaseWrapper(pipo, c=10, d='message')
41 | obj_with_invalid_fn = BaseWrapper([1, 2, 3], fn_name='invalid')
42 |
43 | # target_fn is not a callable or function/method
44 | with self.assertRaises(AttributeError):
45 | obj_with_invalid_fn._check_params('fn_name')
46 |
47 | # parameters is not in target_fn args
48 | with self.assertRaises(ValueError):
49 | obj_with_valid_fn._check_params(['c'])
50 | obj_with_valid_callable_fn._check_params(['e'])
51 |
52 | # params is in target_fn args
53 | try:
54 | obj_with_valid_fn._check_params(['a'])
55 | obj_with_valid_callable_fn._check_params(['message'])
56 | except Exception:
57 | self.fail("_check_params() raised an unexpected exception")
58 |
59 | # params is not a dict or list
60 | with self.assertRaises(TypeError):
61 | obj_with_valid_fn._check_params(None)
62 | with self.assertRaises(TypeError):
63 | obj_with_valid_fn._check_params('param_name')
64 |
65 | def test_set_params(self):
66 |
67 | class Pipo():
68 |
69 | def __init__(self):
70 | self.name = 'pipo'
71 |
72 | def __call__(self, message):
73 | return message
74 | pipo = Pipo()
75 | obj = BaseWrapper(pipo)
76 |
77 | # argument is set accordingly
78 | obj.set_params(message='OK')
79 | self.assertEqual(obj.target_params, {'message': 'OK'})
80 | self.assertEqual(obj.target_fn(**obj.target_params), 'OK')
81 |
82 | # invalid argument is passed
83 | try:
84 | obj = BaseWrapper(Pipo())
85 | obj.set_params(invalid='KO')
86 | except Exception:
87 | self.assertEqual(obj.target_params, {})
88 |
89 | def test_filter_params(self):
90 |
91 | # right arguments are kept and wrong dismmissed
92 | def baz_fn(a, b, c=True):
93 | if c:
94 | return a + b
95 | else:
96 | return a
97 | obj_ = BaseWrapper(baz_fn, a=10, b=100, d=1000)
98 | self.assertEqual(obj_.filter_params(baz_fn), {'a': 10, 'b': 100})
99 |
100 | # target_params is overriden using 'override' argument
101 | self.assertEqual(obj_.filter_params(baz_fn, override={'c': False}),
102 | {'a': 10, 'b': 100, 'c': False})
103 |
104 |
105 | class TestSegmentationAlgorithm(unittest.TestCase):
106 |
107 | def test_instanciate_segmentation_algorithm(self):
108 | img = img_as_float(chelsea()[::2, ::2])
109 |
110 | # wrapped functions provide the same result
111 | fn = SegmentationAlgorithm('quickshift', kernel_size=3, max_dist=6,
112 | ratio=0.5, random_seed=133)
113 | fn_result = fn(img)
114 | original_result = quickshift(img, kernel_size=3, max_dist=6, ratio=0.5,
115 | random_seed=133)
116 |
117 | # same segments
118 | self.assertTrue(np.array_equal(fn_result, original_result))
119 |
120 | def test_instanciate_slic(self):
121 | pass
122 |
123 | def test_instanciate_felzenszwalb(self):
124 | pass
125 |
126 |
127 | if __name__ == '__main__':
128 | unittest.main()
129 |
--------------------------------------------------------------------------------
/slime/js/predicted_value.js:
--------------------------------------------------------------------------------
1 | import d3 from 'd3';
2 | import {range, sortBy} from 'lodash';
3 |
4 | class PredictedValue {
5 | // svg: d3 object with the svg in question
6 | // class_names: array of class names
7 | // predict_probas: array of prediction probabilities
8 | constructor(svg, predicted_value, min_value, max_value, title='Predicted value', log_coords = false) {
9 |
10 | if (min_value == max_value){
11 | var width_proportion = 1.0;
12 | } else {
13 | var width_proportion = (predicted_value - min_value) / (max_value - min_value);
14 | }
15 |
16 |
17 | let width = parseInt(svg.style('width'))
18 |
19 | this.color = d3.scale.category10()
20 | this.color('predicted_value')
21 | // + 2 is due to it being a float
22 | let num_digits = Math.floor(Math.max(Math.log10(Math.abs(min_value)), Math.log10(Math.abs(max_value)))) + 2
23 | num_digits = Math.max(num_digits, 3)
24 |
25 | let corner_width = 12 * num_digits;
26 | let corner_padding = 5.5 * num_digits;
27 | let bar_x = corner_width + corner_padding;
28 | let bar_width = width - corner_width * 2 - corner_padding * 2;
29 | let x_scale = d3.scale.linear().range([0, bar_width]);
30 | let bar_height = 17;
31 | let bar_yshift= title === '' ? 0 : 35;
32 | let n_bars = 1;
33 | let this_object = this;
34 | if (title !== '') {
35 | svg.append('text')
36 | .text(title)
37 | .attr('x', 20)
38 | .attr('y', 20);
39 | }
40 | let bar_y = bar_yshift;
41 | let bar = svg.append("g");
42 |
43 | //filled in bar representing predicted value in range
44 | let rect = bar.append("rect");
45 | rect.attr("x", bar_x)
46 | .attr("y", bar_y)
47 | .attr("height", bar_height)
48 | .attr("width", x_scale(width_proportion))
49 | .style("fill", this.color);
50 |
51 | //empty box representing range
52 | bar.append("rect").attr("x", bar_x)
53 | .attr("y", bar_y)
54 | .attr("height", bar_height)
55 | .attr("width",x_scale(1))
56 | .attr("fill-opacity", 0)
57 | .attr("stroke", "black");
58 | let text = bar.append("text");
59 | text.classed("prob_text", true);
60 | text.attr("y", bar_y + bar_height - 3).attr("fill", "black").style("font", "14px tahoma, sans-serif");
61 |
62 |
63 | //text for min value
64 | text = bar.append("text");
65 | text.attr("x", bar_x - corner_padding)
66 | .attr("y", bar_y + bar_height - 3)
67 | .attr("fill", "black")
68 | .attr("text-anchor", "end")
69 | .style("font", "14px tahoma, sans-serif")
70 | .text(min_value.toFixed(2));
71 |
72 | //text for range min annotation
73 | let v_adjust_min_value_annotation = text.node().getBBox().height;
74 | text = bar.append("text");
75 | text.attr("x", bar_x - corner_padding)
76 | .attr("y", bar_y + bar_height - 3 + v_adjust_min_value_annotation)
77 | .attr("fill", "black")
78 | .attr("text-anchor", "end")
79 | .style("font", "14px tahoma, sans-serif")
80 | .text("(min)");
81 |
82 |
83 | //text for predicted value
84 | // console.log('bar height: ' + bar_height)
85 | text = bar.append("text");
86 | text.text(predicted_value.toFixed(2));
87 | // let h_adjust_predicted_value_text = text.node().getBBox().width / 2;
88 | let v_adjust_predicted_value_text = text.node().getBBox().height;
89 | text.attr("x", bar_x + x_scale(width_proportion))
90 | .attr("y", bar_y + bar_height + v_adjust_predicted_value_text)
91 | .attr("fill", "black")
92 | .attr("text-anchor", "middle")
93 | .style("font", "14px tahoma, sans-serif")
94 |
95 |
96 |
97 |
98 |
99 | //text for max value
100 | text = bar.append("text");
101 | text.text(max_value.toFixed(2));
102 | // let h_adjust = text.node().getBBox().width;
103 | text.attr("x", bar_x + bar_width + corner_padding)
104 | .attr("y", bar_y + bar_height - 3)
105 | .attr("fill", "black")
106 | .attr("text-anchor", "begin")
107 | .style("font", "14px tahoma, sans-serif");
108 |
109 |
110 | //text for range max annotation
111 | let v_adjust_max_value_annotation = text.node().getBBox().height;
112 | text = bar.append("text");
113 | text.attr("x", bar_x + bar_width + corner_padding)
114 | .attr("y", bar_y + bar_height - 3 + v_adjust_min_value_annotation)
115 | .attr("fill", "black")
116 | .attr("text-anchor", "begin")
117 | .style("font", "14px tahoma, sans-serif")
118 | .text("(max)");
119 |
120 |
121 | //readjust svg size
122 | // let svg_width = width + 1 * h_adjust;
123 | // svg.style('width', svg_width + 'px');
124 |
125 | this.svg_height = n_bars * (bar_height) + bar_yshift + (2 * text.node().getBBox().height) + 10;
126 | svg.style('height', this.svg_height + 'px');
127 | if (log_coords) {
128 | console.log("svg width: " + svg_width);
129 | console.log("svg height: " + this.svg_height);
130 | console.log("bar_y: " + bar_y);
131 | console.log("bar_x: " + bar_x);
132 | console.log("Min value: " + min_value);
133 | console.log("Max value: " + max_value);
134 | console.log("Pred value: " + predicted_value);
135 | }
136 | }
137 | }
138 |
139 |
140 | export default PredictedValue;
141 |
--------------------------------------------------------------------------------
/slime/js/explanation.js:
--------------------------------------------------------------------------------
1 | import d3 from 'd3';
2 | import Barchart from './bar_chart.js';
3 | import {range, sortBy} from 'lodash';
4 | class Explanation {
5 | constructor(class_names) {
6 | this.names = class_names;
7 | if (class_names.length < 10) {
8 | this.colors = d3.scale.category10().domain(this.names);
9 | this.colors_i = d3.scale.category10().domain(range(this.names.length));
10 | }
11 | else {
12 | this.colors = d3.scale.category20().domain(this.names);
13 | this.colors_i = d3.scale.category20().domain(range(this.names.length));
14 | }
15 | }
16 | // exp: [(feature-name, weight), ...]
17 | // label: int
18 | // div: d3 selection
19 | show(exp, label, div) {
20 | let svg = div.append('svg').style('width', '100%');
21 | let colors=['#5F9EA0', this.colors_i(label)];
22 | let names = [`NOT ${this.names[label]}`, this.names[label]];
23 | if (this.names.length == 2) {
24 | colors=[this.colors_i(0), this.colors_i(1)];
25 | names = this.names;
26 | }
27 | let plot = new Barchart(svg, exp, true, names, colors, true, 10);
28 | svg.style('height', plot.svg_height + 'px');
29 | }
30 | // exp has all ocurrences of words, with start index and weight:
31 | // exp = [('word', 132, -0.13), ('word3', 111, 1.3)
32 | show_raw_text(exp, label, raw, div, opacity=true) {
33 | //let colors=['#5F9EA0', this.colors(this.exp['class'])];
34 | let colors=['#5F9EA0', this.colors_i(label)];
35 | if (this.names.length == 2) {
36 | colors=[this.colors_i(0), this.colors_i(1)];
37 | }
38 | let word_lists = [[], []];
39 | let max_weight = -1;
40 | for (let [word, start, weight] of exp) {
41 | if (weight > 0) {
42 | word_lists[1].push([start, start + word.length, weight]);
43 | }
44 | else {
45 | word_lists[0].push([start, start + word.length, -weight]);
46 | }
47 | max_weight = Math.max(max_weight, Math.abs(weight));
48 | }
49 | if (!opacity) {
50 | max_weight = 0;
51 | }
52 | this.display_raw_text(div, raw, word_lists, colors, max_weight, true);
53 | }
54 | // exp is list of (feature_name, value, weight)
55 | show_raw_tabular(exp, label, div) {
56 | div.classed('lime', true).classed('table_div', true);
57 | let colors=['#5F9EA0', this.colors_i(label)];
58 | if (this.names.length == 2) {
59 | colors=[this.colors_i(0), this.colors_i(1)];
60 | }
61 | const table = div.append('table');
62 | const thead = table.append('tr');
63 | thead.append('td').text('Feature');
64 | thead.append('td').text('Value');
65 | thead.style('color', 'black')
66 | .style('font-size', '20px');
67 | for (let [fname, value, weight] of exp) {
68 | const tr = table.append('tr');
69 | tr.style('border-style', 'hidden');
70 | tr.append('td').text(fname);
71 | tr.append('td').text(value);
72 | if (weight > 0) {
73 | tr.style('background-color', colors[1]);
74 | }
75 | else if (weight < 0) {
76 | tr.style('background-color', colors[0]);
77 | }
78 | else {
79 | tr.style('color', 'black');
80 | }
81 | }
82 | }
83 | hexToRgb(hex) {
84 | let result = /^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(hex);
85 | return result ? {
86 | r: parseInt(result[1], 16),
87 | g: parseInt(result[2], 16),
88 | b: parseInt(result[3], 16)
89 | } : null;
90 | }
91 | applyAlpha(hex, alpha) {
92 | let components = this.hexToRgb(hex);
93 | return 'rgba(' + components.r + "," + components.g + "," + components.b + "," + alpha.toFixed(3) + ")"
94 | }
95 | // sord_lists is an array of arrays, of length (colors). if with_positions is true,
96 | // word_lists is an array of [start,end] positions instead
97 | display_raw_text(div, raw_text, word_lists=[], colors=[], max_weight=1, positions=false) {
98 | div.classed('lime', true).classed('text_div', true);
99 | div.append('h3').text('Text with highlighted words');
100 | let highlight_tag = 'span';
101 | let text_span = div.append('span').style('white-space', 'pre-wrap').text(raw_text);
102 | let position_lists = word_lists;
103 | if (!positions) {
104 | position_lists = this.wordlists_to_positions(word_lists, raw_text);
105 | }
106 | let objects = []
107 | for (let i of range(position_lists.length)) {
108 | position_lists[i].map(x => objects.push({'label' : i, 'start': x[0], 'end': x[1], 'alpha': max_weight === 0 ? 1: x[2] / max_weight}));
109 | }
110 | objects = sortBy(objects, x=>x['start']);
111 | let node = text_span.node().childNodes[0];
112 | let subtract = 0;
113 | for (let obj of objects) {
114 | let word = raw_text.slice(obj.start, obj.end);
115 | let start = obj.start - subtract;
116 | let end = obj.end - subtract;
117 | let match = document.createElement(highlight_tag);
118 | match.appendChild(document.createTextNode(word));
119 | match.style.backgroundColor = this.applyAlpha(colors[obj.label], obj.alpha);
120 | let after = node.splitText(start);
121 | after.nodeValue = after.nodeValue.substring(word.length);
122 | node.parentNode.insertBefore(match, after);
123 | subtract += end;
124 | node = after;
125 | }
126 | }
127 | wordlists_to_positions(word_lists, raw_text) {
128 | let ret = []
129 | for(let words of word_lists) {
130 | if (words.length === 0) {
131 | ret.push([]);
132 | continue;
133 | }
134 | let re = new RegExp("\\b(" + words.join('|') + ")\\b",'gm')
135 | let temp;
136 | let list = [];
137 | while ((temp = re.exec(raw_text)) !== null) {
138 | list.push([temp.index, temp.index + temp[0].length]);
139 | }
140 | ret.push(list);
141 | }
142 | return ret;
143 | }
144 |
145 | }
146 | export default Explanation;
147 |
--------------------------------------------------------------------------------
/slime/submodular_pick.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import warnings
3 |
4 |
5 | class SubmodularPick(object):
6 | """Class for submodular pick
7 |
8 | Saves a representative sample of explanation objects using SP-LIME,
9 | as well as saving all generated explanations
10 |
11 | First, a collection of candidate explanations are generated
12 | (see explain_instance). From these candidates, num_exps_desired are
13 | chosen using submodular pick. (see marcotcr et al paper)."""
14 |
15 | def __init__(self,
16 | explainer,
17 | data,
18 | predict_fn,
19 | method='sample',
20 | sample_size=1000,
21 | num_exps_desired=5,
22 | num_features=10,
23 | **kwargs):
24 |
25 | """
26 | Args:
27 | data: a numpy array where each row is a single input into predict_fn
28 | predict_fn: prediction function. For classifiers, this should be a
29 | function that takes a numpy array and outputs prediction
30 | probabilities. For regressors, this takes a numpy array and
31 | returns the predictions. For ScikitClassifiers, this is
32 | `classifier.predict_proba()`. For ScikitRegressors, this
33 | is `regressor.predict()`. The prediction function needs to work
34 | on multiple feature vectors (the vectors randomly perturbed
35 | from the data_row).
36 | method: The method to use to generate candidate explanations
37 | method == 'sample' will sample the data uniformly at
38 | random. The sample size is given by sample_size. Otherwise
39 | if method == 'full' then explanations will be generated for the
40 | entire data. l
41 | sample_size: The number of instances to explain if method == 'sample'
42 | num_exps_desired: The number of explanation objects returned
43 | num_features: maximum number of features present in explanation
44 |
45 |
46 | Sets value:
47 | sp_explanations: A list of explanation objects that has a high coverage
48 | explanations: All the candidate explanations saved for potential future use.
49 | """
50 |
51 | top_labels = kwargs.get('top_labels', 1)
52 | if 'top_labels' in kwargs:
53 | del kwargs['top_labels']
54 | # Parse args
55 | if method == 'sample':
56 | if sample_size > len(data):
57 | warnings.warn("""Requested sample size larger than
58 | size of input data. Using all data""")
59 | sample_size = len(data)
60 | all_indices = np.arange(len(data))
61 | np.random.shuffle(all_indices)
62 | sample_indices = all_indices[:sample_size]
63 | elif method == 'full':
64 | sample_indices = np.arange(len(data))
65 | else:
66 | raise ValueError('Method must be \'sample\' or \'full\'')
67 |
68 | # Generate Explanations
69 | self.explanations = []
70 | for i in sample_indices:
71 | self.explanations.append(
72 | explainer.explain_instance(
73 | data[i], predict_fn, num_features=num_features,
74 | top_labels=top_labels,
75 | **kwargs))
76 | # Error handling
77 | try:
78 | num_exps_desired = int(num_exps_desired)
79 | except TypeError:
80 | return("Requested number of explanations should be an integer")
81 | if num_exps_desired > len(self.explanations):
82 | warnings.warn("""Requested number of explanations larger than
83 | total number of explanations, returning all
84 | explanations instead.""")
85 | num_exps_desired = min(num_exps_desired, len(self.explanations))
86 |
87 | # Find all the explanation model features used. Defines the dimension d'
88 | features_dict = {}
89 | feature_iter = 0
90 | for exp in self.explanations:
91 | labels = exp.available_labels() if exp.mode == 'classification' else [1]
92 | for label in labels:
93 | for feature, _ in exp.as_list(label=label):
94 | if feature not in features_dict.keys():
95 | features_dict[feature] = (feature_iter)
96 | feature_iter += 1
97 | d_prime = len(features_dict.keys())
98 |
99 | # Create the n x d' dimensional 'explanation matrix', W
100 | W = np.zeros((len(self.explanations), d_prime))
101 | for i, exp in enumerate(self.explanations):
102 | labels = exp.available_labels() if exp.mode == 'classification' else [1]
103 | for label in labels:
104 | for feature, value in exp.as_list(label):
105 | W[i, features_dict[feature]] += value
106 |
107 | # Create the global importance vector, I_j described in the paper
108 | importance = np.sum(abs(W), axis=0)**.5
109 |
110 | # Now run the SP-LIME greedy algorithm
111 | remaining_indices = set(range(len(self.explanations)))
112 | V = []
113 | for _ in range(num_exps_desired):
114 | best = 0
115 | best_ind = None
116 | current = 0
117 | for i in remaining_indices:
118 | current = np.dot(
119 | (np.sum(abs(W)[V + [i]], axis=0) > 0), importance
120 | ) # coverage function
121 | if current >= best:
122 | best = current
123 | best_ind = i
124 | V.append(best_ind)
125 | remaining_indices -= {best_ind}
126 |
127 | self.sp_explanations = [self.explanations[i] for i in V]
128 | self.V = V
129 |
--------------------------------------------------------------------------------
/slime/tests/test_lime_text.py:
--------------------------------------------------------------------------------
1 | import re
2 | import unittest
3 |
4 | import sklearn # noqa
5 | from sklearn.datasets import fetch_20newsgroups
6 | from sklearn.feature_extraction.text import TfidfVectorizer
7 | from sklearn.metrics import f1_score
8 | from sklearn.naive_bayes import MultinomialNB
9 | from sklearn.pipeline import make_pipeline
10 |
11 | import numpy as np
12 |
13 | from lime.lime_text import LimeTextExplainer
14 | from lime.lime_text import IndexedCharacters, IndexedString
15 |
16 |
17 | class TestLimeText(unittest.TestCase):
18 |
19 | def test_lime_text_explainer_good_regressor(self):
20 | categories = ['alt.atheism', 'soc.religion.christian']
21 | newsgroups_train = fetch_20newsgroups(subset='train',
22 | categories=categories)
23 | newsgroups_test = fetch_20newsgroups(subset='test',
24 | categories=categories)
25 | class_names = ['atheism', 'christian']
26 | vectorizer = TfidfVectorizer(lowercase=False)
27 | train_vectors = vectorizer.fit_transform(newsgroups_train.data)
28 | test_vectors = vectorizer.transform(newsgroups_test.data)
29 | nb = MultinomialNB(alpha=.01)
30 | nb.fit(train_vectors, newsgroups_train.target)
31 | pred = nb.predict(test_vectors)
32 | f1_score(newsgroups_test.target, pred, average='weighted')
33 | c = make_pipeline(vectorizer, nb)
34 | explainer = LimeTextExplainer(class_names=class_names)
35 | idx = 83
36 | exp = explainer.explain_instance(newsgroups_test.data[idx],
37 | c.predict_proba, num_features=6)
38 | self.assertIsNotNone(exp)
39 | self.assertEqual(6, len(exp.as_list()))
40 |
41 | def test_lime_text_tabular_equal_random_state(self):
42 | categories = ['alt.atheism', 'soc.religion.christian']
43 | newsgroups_train = fetch_20newsgroups(subset='train',
44 | categories=categories)
45 | newsgroups_test = fetch_20newsgroups(subset='test',
46 | categories=categories)
47 | class_names = ['atheism', 'christian']
48 | vectorizer = TfidfVectorizer(lowercase=False)
49 | train_vectors = vectorizer.fit_transform(newsgroups_train.data)
50 | test_vectors = vectorizer.transform(newsgroups_test.data)
51 | nb = MultinomialNB(alpha=.01)
52 | nb.fit(train_vectors, newsgroups_train.target)
53 | pred = nb.predict(test_vectors)
54 | f1_score(newsgroups_test.target, pred, average='weighted')
55 | c = make_pipeline(vectorizer, nb)
56 |
57 | explainer = LimeTextExplainer(class_names=class_names, random_state=10)
58 | exp_1 = explainer.explain_instance(newsgroups_test.data[83],
59 | c.predict_proba, num_features=6)
60 |
61 | explainer = LimeTextExplainer(class_names=class_names, random_state=10)
62 | exp_2 = explainer.explain_instance(newsgroups_test.data[83],
63 | c.predict_proba, num_features=6)
64 |
65 | self.assertTrue(exp_1.as_map() == exp_2.as_map())
66 |
67 | def test_lime_text_tabular_not_equal_random_state(self):
68 | categories = ['alt.atheism', 'soc.religion.christian']
69 | newsgroups_train = fetch_20newsgroups(subset='train',
70 | categories=categories)
71 | newsgroups_test = fetch_20newsgroups(subset='test',
72 | categories=categories)
73 | class_names = ['atheism', 'christian']
74 | vectorizer = TfidfVectorizer(lowercase=False)
75 | train_vectors = vectorizer.fit_transform(newsgroups_train.data)
76 | test_vectors = vectorizer.transform(newsgroups_test.data)
77 | nb = MultinomialNB(alpha=.01)
78 | nb.fit(train_vectors, newsgroups_train.target)
79 | pred = nb.predict(test_vectors)
80 | f1_score(newsgroups_test.target, pred, average='weighted')
81 | c = make_pipeline(vectorizer, nb)
82 |
83 | explainer = LimeTextExplainer(
84 | class_names=class_names, random_state=10)
85 | exp_1 = explainer.explain_instance(newsgroups_test.data[83],
86 | c.predict_proba, num_features=6)
87 |
88 | explainer = LimeTextExplainer(
89 | class_names=class_names, random_state=20)
90 | exp_2 = explainer.explain_instance(newsgroups_test.data[83],
91 | c.predict_proba, num_features=6)
92 |
93 | self.assertFalse(exp_1.as_map() == exp_2.as_map())
94 |
95 | def test_indexed_characters_bow(self):
96 | s = 'Please, take your time'
97 | inverse_vocab = ['P', 'l', 'e', 'a', 's', ',', ' ', 't', 'k', 'y', 'o', 'u', 'r', 'i', 'm']
98 | positions = [[0], [1], [2, 5, 11, 21], [3, 9],
99 | [4], [6], [7, 12, 17], [8, 18], [10],
100 | [13], [14], [15], [16], [19], [20]]
101 | ic = IndexedCharacters(s)
102 |
103 | self.assertTrue(np.array_equal(ic.as_np, np.array(list(s))))
104 | self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s))))
105 | self.assertTrue(ic.inverse_vocab == inverse_vocab)
106 | self.assertTrue(ic.positions == positions)
107 |
108 | def test_indexed_characters_not_bow(self):
109 | s = 'Please, take your time'
110 |
111 | ic = IndexedCharacters(s, bow=False)
112 |
113 | self.assertTrue(np.array_equal(ic.as_np, np.array(list(s))))
114 | self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s))))
115 | self.assertTrue(ic.inverse_vocab == list(s))
116 | self.assertTrue(np.array_equal(ic.positions, np.arange(len(s))))
117 |
118 | def test_indexed_string_regex(self):
119 | s = 'Please, take your time. Please'
120 | tokenized_string = np.array(
121 | ['Please', ', ', 'take', ' ', 'your', ' ', 'time', '. ', 'Please'])
122 | inverse_vocab = ['Please', 'take', 'your', 'time']
123 | start_positions = [0, 6, 8, 12, 13, 17, 18, 22, 24]
124 | positions = [[0, 8], [2], [4], [6]]
125 | indexed_string = IndexedString(s)
126 |
127 | self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string))
128 | self.assertTrue(np.array_equal(indexed_string.string_start, start_positions))
129 | self.assertTrue(indexed_string.inverse_vocab == inverse_vocab)
130 | self.assertTrue(np.array_equal(indexed_string.positions, positions))
131 |
132 | def test_indexed_string_callable(self):
133 | s = 'aabbccddaa'
134 |
135 | def tokenizer(string):
136 | return [string[i] + string[i + 1] for i in range(0, len(string) - 1, 2)]
137 |
138 | tokenized_string = np.array(['aa', 'bb', 'cc', 'dd', 'aa'])
139 | inverse_vocab = ['aa', 'bb', 'cc', 'dd']
140 | start_positions = [0, 2, 4, 6, 8]
141 | positions = [[0, 4], [1], [2], [3]]
142 | indexed_string = IndexedString(s, tokenizer)
143 |
144 | self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string))
145 | self.assertTrue(np.array_equal(indexed_string.string_start, start_positions))
146 | self.assertTrue(indexed_string.inverse_vocab == inverse_vocab)
147 | self.assertTrue(np.array_equal(indexed_string.positions, positions))
148 |
149 | def test_indexed_string_inverse_removing_tokenizer(self):
150 | s = 'This is a good movie. This, it is a great movie.'
151 |
152 | def tokenizer(string):
153 | return re.split(r'(?:\W+)|$', string)
154 |
155 | indexed_string = IndexedString(s, tokenizer)
156 |
157 | self.assertEqual(s, indexed_string.inverse_removing([]))
158 |
159 | def test_indexed_string_inverse_removing_regex(self):
160 | s = 'This is a good movie. This is a great movie'
161 | indexed_string = IndexedString(s)
162 |
163 | self.assertEqual(s, indexed_string.inverse_removing([]))
164 |
165 |
166 | if __name__ == '__main__':
167 | unittest.main()
168 |
--------------------------------------------------------------------------------
/slime/tests/test_discretize.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from unittest import TestCase
3 |
4 | import numpy as np
5 |
6 | from sklearn.datasets import load_iris
7 |
8 | from lime.discretize import QuartileDiscretizer, DecileDiscretizer, EntropyDiscretizer
9 |
10 |
11 | class TestDiscretize(TestCase):
12 |
13 | def setUp(self):
14 | iris = load_iris()
15 |
16 | self.feature_names = iris.feature_names
17 | self.x = iris.data
18 | self.y = iris.target
19 |
20 | def check_random_state_for_discretizer_class(self, DiscretizerClass):
21 | # ----------------------------------------------------------------------
22 | # -----------Check if the same random_state produces the same-----------
23 | # -------------results for different discretizer instances.-------------
24 | # ----------------------------------------------------------------------
25 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y,
26 | random_state=10)
27 | x_1 = discretizer.undiscretize(discretizer.discretize(self.x))
28 |
29 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y,
30 | random_state=10)
31 | x_2 = discretizer.undiscretize(discretizer.discretize(self.x))
32 |
33 | self.assertEqual((x_1 == x_2).sum(), x_1.shape[0] * x_1.shape[1])
34 |
35 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y,
36 | random_state=np.random.RandomState(10))
37 | x_1 = discretizer.undiscretize(discretizer.discretize(self.x))
38 |
39 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y,
40 | random_state=np.random.RandomState(10))
41 | x_2 = discretizer.undiscretize(discretizer.discretize(self.x))
42 |
43 | self.assertEqual((x_1 == x_2).sum(), x_1.shape[0] * x_1.shape[1])
44 |
45 | # ----------------------------------------------------------------------
46 | # ---------Check if two different random_state values produces----------
47 | # -------different results for different discretizers instances.--------
48 | # ----------------------------------------------------------------------
49 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y,
50 | random_state=10)
51 | x_1 = discretizer.undiscretize(discretizer.discretize(self.x))
52 |
53 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y,
54 | random_state=20)
55 | x_2 = discretizer.undiscretize(discretizer.discretize(self.x))
56 |
57 | self.assertFalse((x_1 == x_2).sum() == x_1.shape[0] * x_1.shape[1])
58 |
59 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y,
60 | random_state=np.random.RandomState(10))
61 | x_1 = discretizer.undiscretize(discretizer.discretize(self.x))
62 |
63 | discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y,
64 | random_state=np.random.RandomState(20))
65 | x_2 = discretizer.undiscretize(discretizer.discretize(self.x))
66 |
67 | self.assertFalse((x_1 == x_2).sum() == x_1.shape[0] * x_1.shape[1])
68 |
69 | def test_random_state(self):
70 | self.check_random_state_for_discretizer_class(QuartileDiscretizer)
71 |
72 | self.check_random_state_for_discretizer_class(DecileDiscretizer)
73 |
74 | self.check_random_state_for_discretizer_class(EntropyDiscretizer)
75 |
76 | def test_feature_names_1(self):
77 | self.maxDiff = None
78 | discretizer = QuartileDiscretizer(self.x, [], self.feature_names,
79 | self.y, random_state=10)
80 | self.assertDictEqual(
81 | {0: ['sepal length (cm) <= 5.10',
82 | '5.10 < sepal length (cm) <= 5.80',
83 | '5.80 < sepal length (cm) <= 6.40',
84 | 'sepal length (cm) > 6.40'],
85 | 1: ['sepal width (cm) <= 2.80',
86 | '2.80 < sepal width (cm) <= 3.00',
87 | '3.00 < sepal width (cm) <= 3.30',
88 | 'sepal width (cm) > 3.30'],
89 | 2: ['petal length (cm) <= 1.60',
90 | '1.60 < petal length (cm) <= 4.35',
91 | '4.35 < petal length (cm) <= 5.10',
92 | 'petal length (cm) > 5.10'],
93 | 3: ['petal width (cm) <= 0.30',
94 | '0.30 < petal width (cm) <= 1.30',
95 | '1.30 < petal width (cm) <= 1.80',
96 | 'petal width (cm) > 1.80']},
97 | discretizer.names)
98 |
99 | def test_feature_names_2(self):
100 | self.maxDiff = None
101 | discretizer = DecileDiscretizer(self.x, [], self.feature_names, self.y,
102 | random_state=10)
103 | self.assertDictEqual(
104 | {0: ['sepal length (cm) <= 4.80',
105 | '4.80 < sepal length (cm) <= 5.00',
106 | '5.00 < sepal length (cm) <= 5.27',
107 | '5.27 < sepal length (cm) <= 5.60',
108 | '5.60 < sepal length (cm) <= 5.80',
109 | '5.80 < sepal length (cm) <= 6.10',
110 | '6.10 < sepal length (cm) <= 6.30',
111 | '6.30 < sepal length (cm) <= 6.52',
112 | '6.52 < sepal length (cm) <= 6.90',
113 | 'sepal length (cm) > 6.90'],
114 | 1: ['sepal width (cm) <= 2.50',
115 | '2.50 < sepal width (cm) <= 2.70',
116 | '2.70 < sepal width (cm) <= 2.80',
117 | '2.80 < sepal width (cm) <= 3.00',
118 | '3.00 < sepal width (cm) <= 3.10',
119 | '3.10 < sepal width (cm) <= 3.20',
120 | '3.20 < sepal width (cm) <= 3.40',
121 | '3.40 < sepal width (cm) <= 3.61',
122 | 'sepal width (cm) > 3.61'],
123 | 2: ['petal length (cm) <= 1.40',
124 | '1.40 < petal length (cm) <= 1.50',
125 | '1.50 < petal length (cm) <= 1.70',
126 | '1.70 < petal length (cm) <= 3.90',
127 | '3.90 < petal length (cm) <= 4.35',
128 | '4.35 < petal length (cm) <= 4.64',
129 | '4.64 < petal length (cm) <= 5.00',
130 | '5.00 < petal length (cm) <= 5.32',
131 | '5.32 < petal length (cm) <= 5.80',
132 | 'petal length (cm) > 5.80'],
133 | 3: ['petal width (cm) <= 0.20',
134 | '0.20 < petal width (cm) <= 0.40',
135 | '0.40 < petal width (cm) <= 1.16',
136 | '1.16 < petal width (cm) <= 1.30',
137 | '1.30 < petal width (cm) <= 1.50',
138 | '1.50 < petal width (cm) <= 1.80',
139 | '1.80 < petal width (cm) <= 1.90',
140 | '1.90 < petal width (cm) <= 2.20',
141 | 'petal width (cm) > 2.20']},
142 | discretizer.names)
143 |
144 | def test_feature_names_3(self):
145 | self.maxDiff = None
146 | discretizer = EntropyDiscretizer(self.x, [], self.feature_names,
147 | self.y, random_state=10)
148 | self.assertDictEqual(
149 | {0: ['sepal length (cm) <= 4.85',
150 | '4.85 < sepal length (cm) <= 5.45',
151 | '5.45 < sepal length (cm) <= 5.55',
152 | '5.55 < sepal length (cm) <= 5.85',
153 | '5.85 < sepal length (cm) <= 6.15',
154 | '6.15 < sepal length (cm) <= 7.05',
155 | 'sepal length (cm) > 7.05'],
156 | 1: ['sepal width (cm) <= 2.45',
157 | '2.45 < sepal width (cm) <= 2.95',
158 | '2.95 < sepal width (cm) <= 3.05',
159 | '3.05 < sepal width (cm) <= 3.35',
160 | '3.35 < sepal width (cm) <= 3.45',
161 | '3.45 < sepal width (cm) <= 3.55',
162 | 'sepal width (cm) > 3.55'],
163 | 2: ['petal length (cm) <= 2.45',
164 | '2.45 < petal length (cm) <= 4.45',
165 | '4.45 < petal length (cm) <= 4.75',
166 | '4.75 < petal length (cm) <= 5.15',
167 | 'petal length (cm) > 5.15'],
168 | 3: ['petal width (cm) <= 0.80',
169 | '0.80 < petal width (cm) <= 1.35',
170 | '1.35 < petal width (cm) <= 1.75',
171 | '1.75 < petal width (cm) <= 1.85',
172 | 'petal width (cm) > 1.85']},
173 | discretizer.names)
174 |
175 |
176 | if __name__ == '__main__':
177 | unittest.main()
178 |
--------------------------------------------------------------------------------
/slime/discretize.py:
--------------------------------------------------------------------------------
1 | """
2 | Discretizers classes, to be used in lime_tabular
3 | """
4 | import numpy as np
5 | import sklearn
6 | import sklearn.tree
7 | import scipy
8 | from sklearn.utils import check_random_state
9 | from abc import ABCMeta, abstractmethod
10 |
11 |
12 | class BaseDiscretizer():
13 | """
14 | Abstract class - Build a class that inherits from this class to implement
15 | a custom discretizer.
16 | Method bins() is to be redefined in the child class, as it is the actual
17 | custom part of the discretizer.
18 | """
19 |
20 | __metaclass__ = ABCMeta # abstract class
21 |
22 | def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None,
23 | data_stats=None):
24 | """Initializer
25 | Args:
26 | data: numpy 2d array
27 | categorical_features: list of indices (ints) corresponding to the
28 | categorical columns. These features will not be discretized.
29 | Everything else will be considered continuous, and will be
30 | discretized.
31 | categorical_names: map from int to list of names, where
32 | categorical_names[x][y] represents the name of the yth value of
33 | column x.
34 | feature_names: list of names (strings) corresponding to the columns
35 | in the training data.
36 | data_stats: must have 'means', 'stds', 'mins' and 'maxs', use this
37 | if you don't want these values to be computed from data
38 | """
39 | self.to_discretize = ([x for x in range(data.shape[1])
40 | if x not in categorical_features])
41 | self.data_stats = data_stats
42 | self.names = {}
43 | self.lambdas = {}
44 | self.means = {}
45 | self.stds = {}
46 | self.mins = {}
47 | self.maxs = {}
48 | self.random_state = check_random_state(random_state)
49 |
50 | # To override when implementing a custom binning
51 | bins = self.bins(data, labels)
52 | bins = [np.unique(x) for x in bins]
53 |
54 | # Read the stats from data_stats if exists
55 | if data_stats:
56 | self.means = self.data_stats.get("means")
57 | self.stds = self.data_stats.get("stds")
58 | self.mins = self.data_stats.get("mins")
59 | self.maxs = self.data_stats.get("maxs")
60 |
61 | for feature, qts in zip(self.to_discretize, bins):
62 | n_bins = qts.shape[0] # Actually number of borders (= #bins-1)
63 | boundaries = np.min(data[:, feature]), np.max(data[:, feature])
64 | name = feature_names[feature]
65 |
66 | self.names[feature] = ['%s <= %.2f' % (name, qts[0])]
67 | for i in range(n_bins - 1):
68 | self.names[feature].append('%.2f < %s <= %.2f' %
69 | (qts[i], name, qts[i + 1]))
70 | self.names[feature].append('%s > %.2f' % (name, qts[n_bins - 1]))
71 |
72 | self.lambdas[feature] = lambda x, qts=qts: np.searchsorted(qts, x)
73 | discretized = self.lambdas[feature](data[:, feature])
74 |
75 | # If data stats are provided no need to compute the below set of details
76 | if data_stats:
77 | continue
78 |
79 | self.means[feature] = []
80 | self.stds[feature] = []
81 | for x in range(n_bins + 1):
82 | selection = data[discretized == x, feature]
83 | mean = 0 if len(selection) == 0 else np.mean(selection)
84 | self.means[feature].append(mean)
85 | std = 0 if len(selection) == 0 else np.std(selection)
86 | std += 0.00000000001
87 | self.stds[feature].append(std)
88 | self.mins[feature] = [boundaries[0]] + qts.tolist()
89 | self.maxs[feature] = qts.tolist() + [boundaries[1]]
90 |
91 | @abstractmethod
92 | def bins(self, data, labels):
93 | """
94 | To be overridden
95 | Returns for each feature to discretize the boundaries
96 | that form each bin of the discretizer
97 | """
98 | raise NotImplementedError("Must override bins() method")
99 |
100 | def discretize(self, data):
101 | """Discretizes the data.
102 | Args:
103 | data: numpy 2d or 1d array
104 | Returns:
105 | numpy array of same dimension, discretized.
106 | """
107 | ret = data.copy()
108 | for feature in self.lambdas:
109 | if len(data.shape) == 1:
110 | ret[feature] = int(self.lambdas[feature](ret[feature]))
111 | else:
112 | ret[:, feature] = self.lambdas[feature](
113 | ret[:, feature]).astype(int)
114 | return ret
115 |
116 | def get_undiscretize_values(self, feature, values):
117 | mins = np.array(self.mins[feature])[values]
118 | maxs = np.array(self.maxs[feature])[values]
119 |
120 | means = np.array(self.means[feature])[values]
121 | stds = np.array(self.stds[feature])[values]
122 | minz = (mins - means) / stds
123 | maxz = (maxs - means) / stds
124 | min_max_unequal = (minz != maxz)
125 |
126 | ret = minz
127 | ret[np.where(min_max_unequal)] = scipy.stats.truncnorm.rvs(
128 | minz[min_max_unequal],
129 | maxz[min_max_unequal],
130 | loc=means[min_max_unequal],
131 | scale=stds[min_max_unequal],
132 | random_state=self.random_state
133 | )
134 | return ret
135 |
136 | def undiscretize(self, data):
137 | ret = data.copy()
138 | for feature in self.means:
139 | if len(data.shape) == 1:
140 | ret[feature] = self.get_undiscretize_values(
141 | feature, ret[feature].astype(int).reshape(-1, 1)
142 | )
143 | else:
144 | ret[:, feature] = self.get_undiscretize_values(
145 | feature, ret[:, feature].astype(int)
146 | )
147 | return ret
148 |
149 |
150 | class StatsDiscretizer(BaseDiscretizer):
151 | """
152 | Class to be used to supply the data stats info when discretize_continuous is true
153 | """
154 |
155 | def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None,
156 | data_stats=None):
157 |
158 | BaseDiscretizer.__init__(self, data, categorical_features,
159 | feature_names, labels=labels,
160 | random_state=random_state,
161 | data_stats=data_stats)
162 |
163 | def bins(self, data, labels):
164 | bins_from_stats = self.data_stats.get("bins")
165 | bins = []
166 | if bins_from_stats is not None:
167 | for feature in self.to_discretize:
168 | bins_from_stats_feature = bins_from_stats.get(feature)
169 | if bins_from_stats_feature is not None:
170 | qts = np.array(bins_from_stats_feature)
171 | bins.append(qts)
172 | return bins
173 |
174 |
175 | class QuartileDiscretizer(BaseDiscretizer):
176 | def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None):
177 |
178 | BaseDiscretizer.__init__(self, data, categorical_features,
179 | feature_names, labels=labels,
180 | random_state=random_state)
181 |
182 | def bins(self, data, labels):
183 | bins = []
184 | for feature in self.to_discretize:
185 | qts = np.array(np.percentile(data[:, feature], [25, 50, 75]))
186 | bins.append(qts)
187 | return bins
188 |
189 |
190 | class DecileDiscretizer(BaseDiscretizer):
191 | def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None):
192 | BaseDiscretizer.__init__(self, data, categorical_features,
193 | feature_names, labels=labels,
194 | random_state=random_state)
195 |
196 | def bins(self, data, labels):
197 | bins = []
198 | for feature in self.to_discretize:
199 | qts = np.array(np.percentile(data[:, feature],
200 | [10, 20, 30, 40, 50, 60, 70, 80, 90]))
201 | bins.append(qts)
202 | return bins
203 |
204 |
205 | class EntropyDiscretizer(BaseDiscretizer):
206 | def __init__(self, data, categorical_features, feature_names, labels=None, random_state=None):
207 | if(labels is None):
208 | raise ValueError('Labels must be not None when using \
209 | EntropyDiscretizer')
210 | BaseDiscretizer.__init__(self, data, categorical_features,
211 | feature_names, labels=labels,
212 | random_state=random_state)
213 |
214 | def bins(self, data, labels):
215 | bins = []
216 | for feature in self.to_discretize:
217 | # Entropy splitting / at most 8 bins so max_depth=3
218 | dt = sklearn.tree.DecisionTreeClassifier(criterion='entropy',
219 | max_depth=3,
220 | random_state=self.random_state)
221 | x = np.reshape(data[:, feature], (-1, 1))
222 | dt.fit(x, labels)
223 | qts = dt.tree_.threshold[np.where(dt.tree_.children_left > -1)]
224 |
225 | if qts.shape[0] == 0:
226 | qts = np.array([np.median(data[:, feature])])
227 | else:
228 | qts = np.sort(qts)
229 |
230 | bins.append(qts)
231 |
232 | return bins
233 |
--------------------------------------------------------------------------------
/slime/explanation.py:
--------------------------------------------------------------------------------
1 | """
2 | Explanation class, with visualization functions.
3 | """
4 | from io import open
5 | import os
6 | import os.path
7 | import json
8 | import string
9 | import numpy as np
10 |
11 | from .exceptions import LimeError
12 |
13 | from sklearn.utils import check_random_state
14 |
15 |
16 | def id_generator(size=15, random_state=None):
17 | """Helper function to generate random div ids. This is useful for embedding
18 | HTML into ipython notebooks."""
19 | chars = list(string.ascii_uppercase + string.digits)
20 | return ''.join(random_state.choice(chars, size, replace=True))
21 |
22 |
23 | class DomainMapper(object):
24 | """Class for mapping features to the specific domain.
25 |
26 | The idea is that there would be a subclass for each domain (text, tables,
27 | images, etc), so that we can have a general Explanation class, and separate
28 | out the specifics of visualizing features in here.
29 | """
30 |
31 | def __init__(self):
32 | pass
33 |
34 | def map_exp_ids(self, exp, **kwargs):
35 | """Maps the feature ids to concrete names.
36 |
37 | Default behaviour is the identity function. Subclasses can implement
38 | this as they see fit.
39 |
40 | Args:
41 | exp: list of tuples [(id, weight), (id,weight)]
42 | kwargs: optional keyword arguments
43 |
44 | Returns:
45 | exp: list of tuples [(name, weight), (name, weight)...]
46 | """
47 | return exp
48 |
49 | def visualize_instance_html(self,
50 | exp,
51 | label,
52 | div_name,
53 | exp_object_name,
54 | **kwargs):
55 | """Produces html for visualizing the instance.
56 |
57 | Default behaviour does nothing. Subclasses can implement this as they
58 | see fit.
59 |
60 | Args:
61 | exp: list of tuples [(id, weight), (id,weight)]
62 | label: label id (integer)
63 | div_name: name of div object to be used for rendering(in js)
64 | exp_object_name: name of js explanation object
65 | kwargs: optional keyword arguments
66 |
67 | Returns:
68 | js code for visualizing the instance
69 | """
70 | return ''
71 |
72 |
73 | class Explanation(object):
74 | """Object returned by explainers."""
75 |
76 | def __init__(self,
77 | domain_mapper,
78 | mode='classification',
79 | class_names=None,
80 | random_state=None):
81 | """
82 |
83 | Initializer.
84 |
85 | Args:
86 | domain_mapper: must inherit from DomainMapper class
87 | type: "classification" or "regression"
88 | class_names: list of class names (only used for classification)
89 | random_state: an integer or numpy.RandomState that will be used to
90 | generate random numbers. If None, the random state will be
91 | initialized using the internal numpy seed.
92 | """
93 | self.random_state = random_state
94 | self.mode = mode
95 | self.domain_mapper = domain_mapper
96 | self.local_exp = {}
97 | self.intercept = {}
98 | self.score = {}
99 | self.local_pred = {}
100 | if mode == 'classification':
101 | self.class_names = class_names
102 | self.top_labels = None
103 | self.predict_proba = None
104 | elif mode == 'regression':
105 | self.class_names = ['negative', 'positive']
106 | self.predicted_value = None
107 | self.min_value = 0.0
108 | self.max_value = 1.0
109 | self.dummy_label = 1
110 | else:
111 | raise LimeError('Invalid explanation mode "{}". '
112 | 'Should be either "classification" '
113 | 'or "regression".'.format(mode))
114 |
115 | def available_labels(self):
116 | """
117 | Returns the list of classification labels for which we have any explanations.
118 | """
119 | try:
120 | assert self.mode == "classification"
121 | except AssertionError:
122 | raise NotImplementedError('Not supported for regression explanations.')
123 | else:
124 | ans = self.top_labels if self.top_labels else self.local_exp.keys()
125 | return list(ans)
126 |
127 | def as_list(self, label=1, **kwargs):
128 | """Returns the explanation as a list.
129 |
130 | Args:
131 | label: desired label. If you ask for a label for which an
132 | explanation wasn't computed, will throw an exception.
133 | Will be ignored for regression explanations.
134 | kwargs: keyword arguments, passed to domain_mapper
135 |
136 | Returns:
137 | list of tuples (representation, weight), where representation is
138 | given by domain_mapper. Weight is a float.
139 | """
140 | label_to_use = label if self.mode == "classification" else self.dummy_label
141 | ans = self.domain_mapper.map_exp_ids(self.local_exp[label_to_use], **kwargs)
142 | ans = [(x[0], float(x[1])) for x in ans]
143 | return ans
144 |
145 | def as_map(self):
146 | """Returns the map of explanations.
147 |
148 | Returns:
149 | Map from label to list of tuples (feature_id, weight).
150 | """
151 | return self.local_exp
152 |
153 | def as_pyplot_figure(self, label=1, **kwargs):
154 | """Returns the explanation as a pyplot figure.
155 |
156 | Will throw an error if you don't have matplotlib installed
157 | Args:
158 | label: desired label. If you ask for a label for which an
159 | explanation wasn't computed, will throw an exception.
160 | Will be ignored for regression explanations.
161 | kwargs: keyword arguments, passed to domain_mapper
162 |
163 | Returns:
164 | pyplot figure (barchart).
165 | """
166 | import matplotlib.pyplot as plt
167 | exp = self.as_list(label=label, **kwargs)
168 | fig = plt.figure()
169 | vals = [x[1] for x in exp]
170 | names = [x[0] for x in exp]
171 | vals.reverse()
172 | names.reverse()
173 | colors = ['green' if x > 0 else 'red' for x in vals]
174 | pos = np.arange(len(exp)) + .5
175 | plt.barh(pos, vals, align='center', color=colors)
176 | plt.yticks(pos, names)
177 | if self.mode == "classification":
178 | title = 'Local explanation for class %s' % self.class_names[label]
179 | else:
180 | title = 'Local explanation'
181 | plt.title(title)
182 | return fig
183 |
184 | def show_in_notebook(self,
185 | labels=None,
186 | predict_proba=True,
187 | show_predicted_value=True,
188 | **kwargs):
189 | """Shows html explanation in ipython notebook.
190 |
191 | See as_html() for parameters.
192 | This will throw an error if you don't have IPython installed"""
193 |
194 | from IPython.core.display import display, HTML
195 | display(HTML(self.as_html(labels=labels,
196 | predict_proba=predict_proba,
197 | show_predicted_value=show_predicted_value,
198 | **kwargs)))
199 |
200 | def save_to_file(self,
201 | file_path,
202 | labels=None,
203 | predict_proba=True,
204 | show_predicted_value=True,
205 | **kwargs):
206 | """Saves html explanation to file. .
207 |
208 | Params:
209 | file_path: file to save explanations to
210 |
211 | See as_html() for additional parameters.
212 |
213 | """
214 | file_ = open(file_path, 'w', encoding='utf8')
215 | file_.write(self.as_html(labels=labels,
216 | predict_proba=predict_proba,
217 | show_predicted_value=show_predicted_value,
218 | **kwargs))
219 | file_.close()
220 |
221 | def as_html(self,
222 | labels=None,
223 | predict_proba=True,
224 | show_predicted_value=True,
225 | **kwargs):
226 | """Returns the explanation as an html page.
227 |
228 | Args:
229 | labels: desired labels to show explanations for (as barcharts).
230 | If you ask for a label for which an explanation wasn't
231 | computed, will throw an exception. If None, will show
232 | explanations for all available labels. (only used for classification)
233 | predict_proba: if true, add barchart with prediction probabilities
234 | for the top classes. (only used for classification)
235 | show_predicted_value: if true, add barchart with expected value
236 | (only used for regression)
237 | kwargs: keyword arguments, passed to domain_mapper
238 |
239 | Returns:
240 | code for an html page, including javascript includes.
241 | """
242 |
243 | def jsonize(x):
244 | return json.dumps(x, ensure_ascii=False)
245 |
246 | if labels is None and self.mode == "classification":
247 | labels = self.available_labels()
248 |
249 | this_dir, _ = os.path.split(__file__)
250 | bundle = open(os.path.join(this_dir, 'bundle.js'),
251 | encoding="utf8").read()
252 |
253 | out = u'''
254 |
255 | ''' % bundle
256 | random_id = id_generator(size=15, random_state=check_random_state(self.random_state))
257 | out += u'''
258 |
259 | ''' % random_id
260 |
261 | predict_proba_js = ''
262 | if self.mode == "classification" and predict_proba:
263 | predict_proba_js = u'''
264 | var pp_div = top_div.append('div')
265 | .classed('lime predict_proba', true);
266 | var pp_svg = pp_div.append('svg').style('width', '100%%');
267 | var pp = new lime.PredictProba(pp_svg, %s, %s);
268 | ''' % (jsonize([str(x) for x in self.class_names]),
269 | jsonize(list(self.predict_proba.astype(float))))
270 |
271 | predict_value_js = ''
272 | if self.mode == "regression" and show_predicted_value:
273 | # reference self.predicted_value
274 | # (svg, predicted_value, min_value, max_value)
275 | predict_value_js = u'''
276 | var pp_div = top_div.append('div')
277 | .classed('lime predicted_value', true);
278 | var pp_svg = pp_div.append('svg').style('width', '100%%');
279 | var pp = new lime.PredictedValue(pp_svg, %s, %s, %s);
280 | ''' % (jsonize(float(self.predicted_value)),
281 | jsonize(float(self.min_value)),
282 | jsonize(float(self.max_value)))
283 |
284 | exp_js = '''var exp_div;
285 | var exp = new lime.Explanation(%s);
286 | ''' % (jsonize([str(x) for x in self.class_names]))
287 |
288 | if self.mode == "classification":
289 | for label in labels:
290 | exp = jsonize(self.as_list(label))
291 | exp_js += u'''
292 | exp_div = top_div.append('div').classed('lime explanation', true);
293 | exp.show(%s, %d, exp_div);
294 | ''' % (exp, label)
295 | else:
296 | exp = jsonize(self.as_list())
297 | exp_js += u'''
298 | exp_div = top_div.append('div').classed('lime explanation', true);
299 | exp.show(%s, %s, exp_div);
300 | ''' % (exp, self.dummy_label)
301 |
302 | raw_js = '''var raw_div = top_div.append('div');'''
303 |
304 | if self.mode == "classification":
305 | html_data = self.local_exp[labels[0]]
306 | else:
307 | html_data = self.local_exp[self.dummy_label]
308 |
309 | raw_js += self.domain_mapper.visualize_instance_html(
310 | html_data,
311 | labels[0] if self.mode == "classification" else self.dummy_label,
312 | 'raw_div',
313 | 'exp',
314 | **kwargs)
315 | out += u'''
316 |
323 | ''' % (random_id, predict_proba_js, predict_value_js, exp_js, raw_js)
324 | out += u''
325 |
326 | return out
327 |
--------------------------------------------------------------------------------
/slime/lime_base.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains abstract functionality for learning locally linear sparse model.
3 | """
4 | import numpy as np
5 | import scipy as sp
6 | import sklearn
7 | from sklearn.linear_model import Ridge
8 | from slime_lm._least_angle import lars_path
9 | from sklearn.utils import check_random_state
10 |
11 |
12 | class LimeBase(object):
13 | """Class for learning a locally linear sparse model from perturbed data"""
14 | def __init__(self,
15 | kernel_fn,
16 | verbose=False,
17 | random_state=None):
18 | """Init function
19 |
20 | Args:
21 | kernel_fn: function that transforms an array of distances into an
22 | array of proximity values (floats).
23 | verbose: if true, print local prediction values from linear model.
24 | random_state: an integer or numpy.RandomState that will be used to
25 | generate random numbers. If None, the random state will be
26 | initialized using the internal numpy seed.
27 | """
28 | self.kernel_fn = kernel_fn
29 | self.verbose = verbose
30 | self.random_state = check_random_state(random_state)
31 |
32 | @staticmethod
33 | def generate_lars_path(weighted_data, weighted_labels, testing=False, alpha=0.05):
34 | """Generates the lars path for weighted data.
35 |
36 | Args:
37 | weighted_data: data that has been weighted by kernel
38 | weighted_label: labels, weighted by kernel
39 |
40 | Returns:
41 | (alphas, coefs), both are arrays corresponding to the
42 | regularization parameter and coefficients, respectively
43 | """
44 | x_vector = weighted_data
45 | if not testing:
46 | alphas, _, coefs = lars_path(x_vector,
47 | weighted_labels,
48 | method='lasso',
49 | verbose=False,
50 | alpha=alpha)
51 | return alphas, coefs
52 | else:
53 | alphas, _, coefs, test_result = lars_path(x_vector,
54 | weighted_labels,
55 | method='lasso',
56 | verbose=False,
57 | testing=testing)
58 | return alphas, coefs, test_result
59 |
60 | def forward_selection(self, data, labels, weights, num_features):
61 | """Iteratively adds features to the model"""
62 | clf = Ridge(alpha=0, fit_intercept=True, random_state=self.random_state)
63 | used_features = []
64 | for _ in range(min(num_features, data.shape[1])):
65 | max_ = -100000000
66 | best = 0
67 | for feature in range(data.shape[1]):
68 | if feature in used_features:
69 | continue
70 | clf.fit(data[:, used_features + [feature]], labels,
71 | sample_weight=weights)
72 | score = clf.score(data[:, used_features + [feature]],
73 | labels,
74 | sample_weight=weights)
75 | if score > max_:
76 | best = feature
77 | max_ = score
78 | used_features.append(best)
79 | return np.array(used_features)
80 |
81 | def feature_selection(self, data, labels, weights, num_features, method, testing=False, alpha=0.05):
82 | """Selects features for the model. see explain_instance_with_data to
83 | understand the parameters."""
84 | if method == 'none':
85 | return np.array(range(data.shape[1]))
86 | elif method == 'forward_selection':
87 | return self.forward_selection(data, labels, weights, num_features)
88 | elif method == 'highest_weights':
89 | clf = Ridge(alpha=0.01, fit_intercept=True,
90 | random_state=self.random_state)
91 | clf.fit(data, labels, sample_weight=weights)
92 |
93 | coef = clf.coef_
94 | if sp.sparse.issparse(data):
95 | coef = sp.sparse.csr_matrix(clf.coef_)
96 | weighted_data = coef.multiply(data[0])
97 | # Note: most efficient to slice the data before reversing
98 | sdata = len(weighted_data.data)
99 | argsort_data = np.abs(weighted_data.data).argsort()
100 | # Edge case where data is more sparse than requested number of feature importances
101 | # In that case, we just pad with zero-valued features
102 | if sdata < num_features:
103 | nnz_indexes = argsort_data[::-1]
104 | indices = weighted_data.indices[nnz_indexes]
105 | num_to_pad = num_features - sdata
106 | indices = np.concatenate((indices, np.zeros(num_to_pad, dtype=indices.dtype)))
107 | indices_set = set(indices)
108 | pad_counter = 0
109 | for i in range(data.shape[1]):
110 | if i not in indices_set:
111 | indices[pad_counter + sdata] = i
112 | pad_counter += 1
113 | if pad_counter >= num_to_pad:
114 | break
115 | else:
116 | nnz_indexes = argsort_data[sdata - num_features:sdata][::-1]
117 | indices = weighted_data.indices[nnz_indexes]
118 | return indices
119 | else:
120 | weighted_data = coef * data[0]
121 | feature_weights = sorted(
122 | zip(range(data.shape[1]), weighted_data),
123 | key=lambda x: np.abs(x[1]),
124 | reverse=True)
125 | return np.array([x[0] for x in feature_weights[:num_features]])
126 | elif method == 'lasso_path':
127 | if not testing:
128 | weighted_data = ((data - np.average(data, axis=0, weights=weights))
129 | * np.sqrt(weights[:, np.newaxis]))
130 | weighted_labels = ((labels - np.average(labels, weights=weights))
131 | * np.sqrt(weights))
132 |
133 | nonzero = range(weighted_data.shape[1])
134 | _, coefs = self.generate_lars_path(weighted_data,
135 | weighted_labels)
136 | for i in range(len(coefs.T) - 1, 0, -1):
137 | nonzero = coefs.T[i].nonzero()[0]
138 | if len(nonzero) <= num_features:
139 | break
140 | used_features = nonzero
141 | return used_features
142 | else:
143 | weighted_data = ((data - np.average(data, axis=0, weights=weights))
144 | * np.sqrt(weights[:, np.newaxis]))
145 | weighted_labels = ((labels - np.average(labels, weights=weights))
146 | * np.sqrt(weights))
147 |
148 | # Xscaler = sklearn.preprocessing.StandardScaler()
149 | # Xscaler.fit(weighted_data)
150 | # weighted_data = Xscaler.transform(weighted_data)
151 |
152 | # Yscaler = sklearn.preprocessing.StandardScaler()
153 | # Yscaler.fit(weighted_labels.reshape(-1, 1))
154 | # weighted_labels = Yscaler.transform(weighted_labels.reshape(-1, 1)).ravel()
155 |
156 | nonzero = range(weighted_data.shape[1])
157 | alphas, coefs, test_result = self.generate_lars_path(weighted_data,
158 | weighted_labels,
159 | testing=True,
160 | alpha=alpha)
161 | for i in range(len(coefs.T) - 1, 0, -1):
162 | nonzero = coefs.T[i].nonzero()[0]
163 | if len(nonzero) <= num_features:
164 | break
165 | used_features = nonzero
166 | return used_features, test_result
167 | elif method == 'auto':
168 | if num_features <= 6:
169 | n_method = 'forward_selection'
170 | else:
171 | n_method = 'highest_weights'
172 | return self.feature_selection(data, labels, weights,
173 | num_features, n_method)
174 |
175 | def explain_instance_with_data(self,
176 | neighborhood_data,
177 | neighborhood_labels,
178 | distances,
179 | label,
180 | num_features,
181 | feature_selection='auto',
182 | model_regressor=None):
183 | """Takes perturbed data, labels and distances, returns explanation.
184 |
185 | Args:
186 | neighborhood_data: perturbed data, 2d array. first element is
187 | assumed to be the original data point.
188 | neighborhood_labels: corresponding perturbed labels. should have as
189 | many columns as the number of possible labels.
190 | distances: distances to original data point.
191 | label: label for which we want an explanation
192 | num_features: maximum number of features in explanation
193 | feature_selection: how to select num_features. options are:
194 | 'forward_selection': iteratively add features to the model.
195 | This is costly when num_features is high
196 | 'highest_weights': selects the features that have the highest
197 | product of absolute weight * original data point when
198 | learning with all the features
199 | 'lasso_path': chooses features based on the lasso
200 | regularization path
201 | 'none': uses all features, ignores num_features
202 | 'auto': uses forward_selection if num_features <= 6, and
203 | 'highest_weights' otherwise.
204 | model_regressor: sklearn regressor to use in explanation.
205 | Defaults to Ridge regression if None. Must have
206 | model_regressor.coef_ and 'sample_weight' as a parameter
207 | to model_regressor.fit()
208 |
209 | Returns:
210 | (intercept, exp, score, local_pred):
211 | intercept is a float.
212 | exp is a sorted list of tuples, where each tuple (x,y) corresponds
213 | to the feature id (x) and the local weight (y). The list is sorted
214 | by decreasing absolute value of y.
215 | score is the R^2 value of the returned explanation
216 | local_pred is the prediction of the explanation model on the original instance
217 | """
218 |
219 | weights = self.kernel_fn(distances)
220 | labels_column = neighborhood_labels[:, label]
221 | used_features = self.feature_selection(neighborhood_data,
222 | labels_column,
223 | weights,
224 | num_features,
225 | feature_selection)
226 | if model_regressor is None:
227 | model_regressor = Ridge(alpha=1, fit_intercept=True,
228 | random_state=self.random_state)
229 | easy_model = model_regressor
230 | easy_model.fit(neighborhood_data[:, used_features],
231 | labels_column, sample_weight=weights)
232 | prediction_score = easy_model.score(
233 | neighborhood_data[:, used_features],
234 | labels_column, sample_weight=weights)
235 |
236 | local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))
237 |
238 | if self.verbose:
239 | print('Intercept', easy_model.intercept_)
240 | print('Prediction_local', local_pred,)
241 | print('Right:', neighborhood_labels[0, label])
242 | return (easy_model.intercept_,
243 | sorted(zip(used_features, easy_model.coef_),
244 | key=lambda x: np.abs(x[1]), reverse=True),
245 | prediction_score, local_pred)
246 |
247 | def testing_explain_instance_with_data(self,
248 | neighborhood_data,
249 | neighborhood_labels,
250 | distances,
251 | label,
252 | num_features,
253 | feature_selection='lasso_path',
254 | model_regressor=None,
255 | alpha=0.05):
256 | """Takes perturbed data, labels and distances, returns explanation.
257 | This is a helper function for slime.
258 |
259 | Args:
260 | neighborhood_data: perturbed data, 2d array. first element is
261 | assumed to be the original data point.
262 | neighborhood_labels: corresponding perturbed labels. should have as
263 | many columns as the number of possible labels.
264 | distances: distances to original data point.
265 | label: label for which we want an explanation
266 | num_features: maximum number of features in explanation
267 | feature_selection: how to select num_features. options are:
268 | 'forward_selection': iteratively add features to the model.
269 | This is costly when num_features is high
270 | 'highest_weights': selects the features that have the highest
271 | product of absolute weight * original data point when
272 | learning with all the features
273 | 'lasso_path': chooses features based on the lasso
274 | regularization path
275 | 'none': uses all features, ignores num_features
276 | 'auto': uses forward_selection if num_features <= 6, and
277 | 'highest_weights' otherwise.
278 | model_regressor: sklearn regressor to use in explanation.
279 | Defaults to Ridge regression if None. Must have
280 | model_regressor.coef_ and 'sample_weight' as a parameter
281 | to model_regressor.fit()
282 | alpha: significance level of hypothesis testing.
283 |
284 | Returns:
285 | (intercept, exp, score, local_pred):
286 | intercept is a float.
287 | exp is a sorted list of tuples, where each tuple (x,y) corresponds
288 | to the feature id (x) and the local weight (y). The list is sorted
289 | by decreasing absolute value of y.
290 | score is the R^2 value of the returned explanation
291 | local_pred is the prediction of the explanation model on the original instance
292 | """
293 | weights = self.kernel_fn(distances)
294 | labels_column = neighborhood_labels[:, label]
295 | used_features, test_result = self.feature_selection(neighborhood_data,
296 | labels_column,
297 | weights,
298 | num_features,
299 | feature_selection,
300 | testing=True,
301 | alpha=alpha)
302 | if model_regressor is None:
303 | model_regressor = Ridge(alpha=1, fit_intercept=True,
304 | random_state=self.random_state)
305 | easy_model = model_regressor
306 | easy_model.fit(neighborhood_data[:, used_features],
307 | labels_column, sample_weight=weights)
308 | prediction_score = easy_model.score(
309 | neighborhood_data[:, used_features],
310 | labels_column, sample_weight=weights)
311 |
312 | local_pred = easy_model.predict(neighborhood_data[0, used_features].reshape(1, -1))
313 |
314 | if self.verbose:
315 | print('Intercept', easy_model.intercept_)
316 | print('Prediction_local', local_pred,)
317 | print('Right:', neighborhood_labels[0, label])
318 | return (easy_model.intercept_,
319 | sorted(zip(used_features, easy_model.coef_),
320 | key=lambda x: np.abs(x[1]), reverse=True),
321 | prediction_score, local_pred, used_features, test_result)
322 |
323 |
--------------------------------------------------------------------------------
/slime/lime_text.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for explaining text classifiers.
3 | """
4 | from functools import partial
5 | import itertools
6 | import json
7 | import re
8 |
9 | import numpy as np
10 | import scipy as sp
11 | import sklearn
12 | from sklearn.utils import check_random_state
13 |
14 | from . import explanation
15 | from . import lime_base
16 |
17 |
18 | class TextDomainMapper(explanation.DomainMapper):
19 | """Maps feature ids to words or word-positions"""
20 |
21 | def __init__(self, indexed_string):
22 | """Initializer.
23 |
24 | Args:
25 | indexed_string: lime_text.IndexedString, original string
26 | """
27 | self.indexed_string = indexed_string
28 |
29 | def map_exp_ids(self, exp, positions=False):
30 | """Maps ids to words or word-position strings.
31 |
32 | Args:
33 | exp: list of tuples [(id, weight), (id,weight)]
34 | positions: if True, also return word positions
35 |
36 | Returns:
37 | list of tuples (word, weight), or (word_positions, weight) if
38 | examples: ('bad', 1) or ('bad_3-6-12', 1)
39 | """
40 | if positions:
41 | exp = [('%s_%s' % (
42 | self.indexed_string.word(x[0]),
43 | '-'.join(
44 | map(str,
45 | self.indexed_string.string_position(x[0])))), x[1])
46 | for x in exp]
47 | else:
48 | exp = [(self.indexed_string.word(x[0]), x[1]) for x in exp]
49 | return exp
50 |
51 | def visualize_instance_html(self, exp, label, div_name, exp_object_name,
52 | text=True, opacity=True):
53 | """Adds text with highlighted words to visualization.
54 |
55 | Args:
56 | exp: list of tuples [(id, weight), (id,weight)]
57 | label: label id (integer)
58 | div_name: name of div object to be used for rendering(in js)
59 | exp_object_name: name of js explanation object
60 | text: if False, return empty
61 | opacity: if True, fade colors according to weight
62 | """
63 | if not text:
64 | return u''
65 | text = (self.indexed_string.raw_string()
66 | .encode('utf-8', 'xmlcharrefreplace').decode('utf-8'))
67 | text = re.sub(r'[<>&]', '|', text)
68 | exp = [(self.indexed_string.word(x[0]),
69 | self.indexed_string.string_position(x[0]),
70 | x[1]) for x in exp]
71 | all_occurrences = list(itertools.chain.from_iterable(
72 | [itertools.product([x[0]], x[1], [x[2]]) for x in exp]))
73 | all_occurrences = [(x[0], int(x[1]), x[2]) for x in all_occurrences]
74 | ret = '''
75 | %s.show_raw_text(%s, %d, %s, %s, %s);
76 | ''' % (exp_object_name, json.dumps(all_occurrences), label,
77 | json.dumps(text), div_name, json.dumps(opacity))
78 | return ret
79 |
80 |
81 | class IndexedString(object):
82 | """String with various indexes."""
83 |
84 | def __init__(self, raw_string, split_expression=r'\W+', bow=True,
85 | mask_string=None):
86 | """Initializer.
87 |
88 | Args:
89 | raw_string: string with raw text in it
90 | split_expression: Regex string or callable. If regex string, will be used with re.split.
91 | If callable, the function should return a list of tokens.
92 | bow: if True, a word is the same everywhere in the text - i.e. we
93 | will index multiple occurrences of the same word. If False,
94 | order matters, so that the same word will have different ids
95 | according to position.
96 | mask_string: If not None, replace words with this if bow=False
97 | if None, default value is UNKWORDZ
98 | """
99 | self.raw = raw_string
100 | self.mask_string = 'UNKWORDZ' if mask_string is None else mask_string
101 |
102 | if callable(split_expression):
103 | tokens = split_expression(self.raw)
104 | self.as_list = self._segment_with_tokens(self.raw, tokens)
105 | tokens = set(tokens)
106 |
107 | def non_word(string):
108 | return string not in tokens
109 |
110 | else:
111 | # with the split_expression as a non-capturing group (?:), we don't need to filter out
112 | # the separator character from the split results.
113 | splitter = re.compile(r'(%s)|$' % split_expression)
114 | self.as_list = [s for s in splitter.split(self.raw) if s]
115 | non_word = splitter.match
116 |
117 | self.as_np = np.array(self.as_list)
118 | self.string_start = np.hstack(
119 | ([0], np.cumsum([len(x) for x in self.as_np[:-1]])))
120 | vocab = {}
121 | self.inverse_vocab = []
122 | self.positions = []
123 | self.bow = bow
124 | non_vocab = set()
125 | for i, word in enumerate(self.as_np):
126 | if word in non_vocab:
127 | continue
128 | if non_word(word):
129 | non_vocab.add(word)
130 | continue
131 | if bow:
132 | if word not in vocab:
133 | vocab[word] = len(vocab)
134 | self.inverse_vocab.append(word)
135 | self.positions.append([])
136 | idx_word = vocab[word]
137 | self.positions[idx_word].append(i)
138 | else:
139 | self.inverse_vocab.append(word)
140 | self.positions.append(i)
141 | if not bow:
142 | self.positions = np.array(self.positions)
143 |
144 | def raw_string(self):
145 | """Returns the original raw string"""
146 | return self.raw
147 |
148 | def num_words(self):
149 | """Returns the number of tokens in the vocabulary for this document."""
150 | return len(self.inverse_vocab)
151 |
152 | def word(self, id_):
153 | """Returns the word that corresponds to id_ (int)"""
154 | return self.inverse_vocab[id_]
155 |
156 | def string_position(self, id_):
157 | """Returns a np array with indices to id_ (int) occurrences"""
158 | if self.bow:
159 | return self.string_start[self.positions[id_]]
160 | else:
161 | return self.string_start[[self.positions[id_]]]
162 |
163 | def inverse_removing(self, words_to_remove):
164 | """Returns a string after removing the appropriate words.
165 |
166 | If self.bow is false, replaces word with UNKWORDZ instead of removing
167 | it.
168 |
169 | Args:
170 | words_to_remove: list of ids (ints) to remove
171 |
172 | Returns:
173 | original raw string with appropriate words removed.
174 | """
175 | mask = np.ones(self.as_np.shape[0], dtype='bool')
176 | mask[self.__get_idxs(words_to_remove)] = False
177 | if not self.bow:
178 | return ''.join(
179 | [self.as_list[i] if mask[i] else self.mask_string
180 | for i in range(mask.shape[0])])
181 | return ''.join([self.as_list[v] for v in mask.nonzero()[0]])
182 |
183 | @staticmethod
184 | def _segment_with_tokens(text, tokens):
185 | """Segment a string around the tokens created by a passed-in tokenizer"""
186 | list_form = []
187 | text_ptr = 0
188 | for token in tokens:
189 | inter_token_string = []
190 | while not text[text_ptr:].startswith(token):
191 | inter_token_string.append(text[text_ptr])
192 | text_ptr += 1
193 | if text_ptr >= len(text):
194 | raise ValueError("Tokenization produced tokens that do not belong in string!")
195 | text_ptr += len(token)
196 | if inter_token_string:
197 | list_form.append(''.join(inter_token_string))
198 | list_form.append(token)
199 | if text_ptr < len(text):
200 | list_form.append(text[text_ptr:])
201 | return list_form
202 |
203 | def __get_idxs(self, words):
204 | """Returns indexes to appropriate words."""
205 | if self.bow:
206 | return list(itertools.chain.from_iterable(
207 | [self.positions[z] for z in words]))
208 | else:
209 | return self.positions[words]
210 |
211 |
212 | class IndexedCharacters(object):
213 | """String with various indexes."""
214 |
215 | def __init__(self, raw_string, bow=True, mask_string=None):
216 | """Initializer.
217 |
218 | Args:
219 | raw_string: string with raw text in it
220 | bow: if True, a char is the same everywhere in the text - i.e. we
221 | will index multiple occurrences of the same character. If False,
222 | order matters, so that the same word will have different ids
223 | according to position.
224 | mask_string: If not None, replace characters with this if bow=False
225 | if None, default value is chr(0)
226 | """
227 | self.raw = raw_string
228 | self.as_list = list(self.raw)
229 | self.as_np = np.array(self.as_list)
230 | self.mask_string = chr(0) if mask_string is None else mask_string
231 | self.string_start = np.arange(len(self.raw))
232 | vocab = {}
233 | self.inverse_vocab = []
234 | self.positions = []
235 | self.bow = bow
236 | non_vocab = set()
237 | for i, char in enumerate(self.as_np):
238 | if char in non_vocab:
239 | continue
240 | if bow:
241 | if char not in vocab:
242 | vocab[char] = len(vocab)
243 | self.inverse_vocab.append(char)
244 | self.positions.append([])
245 | idx_char = vocab[char]
246 | self.positions[idx_char].append(i)
247 | else:
248 | self.inverse_vocab.append(char)
249 | self.positions.append(i)
250 | if not bow:
251 | self.positions = np.array(self.positions)
252 |
253 | def raw_string(self):
254 | """Returns the original raw string"""
255 | return self.raw
256 |
257 | def num_words(self):
258 | """Returns the number of tokens in the vocabulary for this document."""
259 | return len(self.inverse_vocab)
260 |
261 | def word(self, id_):
262 | """Returns the word that corresponds to id_ (int)"""
263 | return self.inverse_vocab[id_]
264 |
265 | def string_position(self, id_):
266 | """Returns a np array with indices to id_ (int) occurrences"""
267 | if self.bow:
268 | return self.string_start[self.positions[id_]]
269 | else:
270 | return self.string_start[[self.positions[id_]]]
271 |
272 | def inverse_removing(self, words_to_remove):
273 | """Returns a string after removing the appropriate words.
274 |
275 | If self.bow is false, replaces word with UNKWORDZ instead of removing
276 | it.
277 |
278 | Args:
279 | words_to_remove: list of ids (ints) to remove
280 |
281 | Returns:
282 | original raw string with appropriate words removed.
283 | """
284 | mask = np.ones(self.as_np.shape[0], dtype='bool')
285 | mask[self.__get_idxs(words_to_remove)] = False
286 | if not self.bow:
287 | return ''.join(
288 | [self.as_list[i] if mask[i] else self.mask_string
289 | for i in range(mask.shape[0])])
290 | return ''.join([self.as_list[v] for v in mask.nonzero()[0]])
291 |
292 | def __get_idxs(self, words):
293 | """Returns indexes to appropriate words."""
294 | if self.bow:
295 | return list(itertools.chain.from_iterable(
296 | [self.positions[z] for z in words]))
297 | else:
298 | return self.positions[words]
299 |
300 |
301 | class LimeTextExplainer(object):
302 | """Explains text classifiers.
303 | Currently, we are using an exponential kernel on cosine distance, and
304 | restricting explanations to words that are present in documents."""
305 |
306 | def __init__(self,
307 | kernel_width=25,
308 | kernel=None,
309 | verbose=False,
310 | class_names=None,
311 | feature_selection='auto',
312 | split_expression=r'\W+',
313 | bow=True,
314 | mask_string=None,
315 | random_state=None,
316 | char_level=False):
317 | """Init function.
318 |
319 | Args:
320 | kernel_width: kernel width for the exponential kernel.
321 | kernel: similarity kernel that takes euclidean distances and kernel
322 | width as input and outputs weights in (0,1). If None, defaults to
323 | an exponential kernel.
324 | verbose: if true, print local prediction values from linear model
325 | class_names: list of class names, ordered according to whatever the
326 | classifier is using. If not present, class names will be '0',
327 | '1', ...
328 | feature_selection: feature selection method. can be
329 | 'forward_selection', 'lasso_path', 'none' or 'auto'.
330 | See function 'explain_instance_with_data' in lime_base.py for
331 | details on what each of the options does.
332 | split_expression: Regex string or callable. If regex string, will be used with re.split.
333 | If callable, the function should return a list of tokens.
334 | bow: if True (bag of words), will perturb input data by removing
335 | all occurrences of individual words or characters.
336 | Explanations will be in terms of these words. Otherwise, will
337 | explain in terms of word-positions, so that a word may be
338 | important the first time it appears and unimportant the second.
339 | Only set to false if the classifier uses word order in some way
340 | (bigrams, etc), or if you set char_level=True.
341 | mask_string: String used to mask tokens or characters if bow=False
342 | if None, will be 'UNKWORDZ' if char_level=False, chr(0)
343 | otherwise.
344 | random_state: an integer or numpy.RandomState that will be used to
345 | generate random numbers. If None, the random state will be
346 | initialized using the internal numpy seed.
347 | char_level: an boolean identifying that we treat each character
348 | as an independent occurence in the string
349 | """
350 |
351 | if kernel is None:
352 | def kernel(d, kernel_width):
353 | return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
354 |
355 | kernel_fn = partial(kernel, kernel_width=kernel_width)
356 |
357 | self.random_state = check_random_state(random_state)
358 | self.base = lime_base.LimeBase(kernel_fn, verbose,
359 | random_state=self.random_state)
360 | self.class_names = class_names
361 | self.vocabulary = None
362 | self.feature_selection = feature_selection
363 | self.bow = bow
364 | self.mask_string = mask_string
365 | self.split_expression = split_expression
366 | self.char_level = char_level
367 |
368 | def explain_instance(self,
369 | text_instance,
370 | classifier_fn,
371 | labels=(1,),
372 | top_labels=None,
373 | num_features=10,
374 | num_samples=5000,
375 | distance_metric='cosine',
376 | model_regressor=None):
377 | """Generates explanations for a prediction.
378 |
379 | First, we generate neighborhood data by randomly hiding features from
380 | the instance (see __data_labels_distance_mapping). We then learn
381 | locally weighted linear models on this neighborhood data to explain
382 | each of the classes in an interpretable way (see lime_base.py).
383 |
384 | Args:
385 | text_instance: raw text string to be explained.
386 | classifier_fn: classifier prediction probability function, which
387 | takes a list of d strings and outputs a (d, k) numpy array with
388 | prediction probabilities, where k is the number of classes.
389 | For ScikitClassifiers , this is classifier.predict_proba.
390 | labels: iterable with labels to be explained.
391 | top_labels: if not None, ignore labels and produce explanations for
392 | the K labels with highest prediction probabilities, where K is
393 | this parameter.
394 | num_features: maximum number of features present in explanation
395 | num_samples: size of the neighborhood to learn the linear model
396 | distance_metric: the distance metric to use for sample weighting,
397 | defaults to cosine similarity
398 | model_regressor: sklearn regressor to use in explanation. Defaults
399 | to Ridge regression in LimeBase. Must have model_regressor.coef_
400 | and 'sample_weight' as a parameter to model_regressor.fit()
401 | Returns:
402 | An Explanation object (see explanation.py) with the corresponding
403 | explanations.
404 | """
405 |
406 | indexed_string = (IndexedCharacters(
407 | text_instance, bow=self.bow, mask_string=self.mask_string)
408 | if self.char_level else
409 | IndexedString(text_instance, bow=self.bow,
410 | split_expression=self.split_expression,
411 | mask_string=self.mask_string))
412 | domain_mapper = TextDomainMapper(indexed_string)
413 | data, yss, distances = self.__data_labels_distances(
414 | indexed_string, classifier_fn, num_samples,
415 | distance_metric=distance_metric)
416 | if self.class_names is None:
417 | self.class_names = [str(x) for x in range(yss[0].shape[0])]
418 | ret_exp = explanation.Explanation(domain_mapper=domain_mapper,
419 | class_names=self.class_names,
420 | random_state=self.random_state)
421 | ret_exp.predict_proba = yss[0]
422 | if top_labels:
423 | labels = np.argsort(yss[0])[-top_labels:]
424 | ret_exp.top_labels = list(labels)
425 | ret_exp.top_labels.reverse()
426 | for label in labels:
427 | (ret_exp.intercept[label],
428 | ret_exp.local_exp[label],
429 | ret_exp.score[label],
430 | ret_exp.local_pred[label]) = self.base.explain_instance_with_data(
431 | data, yss, distances, label, num_features,
432 | model_regressor=model_regressor,
433 | feature_selection=self.feature_selection)
434 | return ret_exp
435 |
436 | def __data_labels_distances(self,
437 | indexed_string,
438 | classifier_fn,
439 | num_samples,
440 | distance_metric='cosine'):
441 | """Generates a neighborhood around a prediction.
442 |
443 | Generates neighborhood data by randomly removing words from
444 | the instance, and predicting with the classifier. Uses cosine distance
445 | to compute distances between original and perturbed instances.
446 | Args:
447 | indexed_string: document (IndexedString) to be explained,
448 | classifier_fn: classifier prediction probability function, which
449 | takes a string and outputs prediction probabilities. For
450 | ScikitClassifier, this is classifier.predict_proba.
451 | num_samples: size of the neighborhood to learn the linear model
452 | distance_metric: the distance metric to use for sample weighting,
453 | defaults to cosine similarity.
454 |
455 |
456 | Returns:
457 | A tuple (data, labels, distances), where:
458 | data: dense num_samples * K binary matrix, where K is the
459 | number of tokens in indexed_string. The first row is the
460 | original instance, and thus a row of ones.
461 | labels: num_samples * L matrix, where L is the number of target
462 | labels
463 | distances: cosine distance between the original instance and
464 | each perturbed instance (computed in the binary 'data'
465 | matrix), times 100.
466 | """
467 |
468 | def distance_fn(x):
469 | return sklearn.metrics.pairwise.pairwise_distances(
470 | x, x[0], metric=distance_metric).ravel() * 100
471 |
472 | doc_size = indexed_string.num_words()
473 | sample = self.random_state.randint(1, doc_size + 1, num_samples - 1)
474 | data = np.ones((num_samples, doc_size))
475 | data[0] = np.ones(doc_size)
476 | features_range = range(doc_size)
477 | inverse_data = [indexed_string.raw_string()]
478 | for i, size in enumerate(sample, start=1):
479 | inactive = self.random_state.choice(features_range, size,
480 | replace=False)
481 | data[i, inactive] = 0
482 | inverse_data.append(indexed_string.inverse_removing(inactive))
483 | labels = classifier_fn(inverse_data)
484 | distances = distance_fn(sp.sparse.csr_matrix(data))
485 | return data, labels, distances
486 |
--------------------------------------------------------------------------------
/slime/lime_image.py:
--------------------------------------------------------------------------------
1 | """
2 | Functions for explaining classifiers that use Image data.
3 | """
4 | import copy
5 | from functools import partial
6 |
7 | import numpy as np
8 | import sklearn
9 | from sklearn.utils import check_random_state
10 | from skimage.color import gray2rgb
11 | from tqdm.auto import tqdm
12 |
13 |
14 | from . import lime_base
15 | from .wrappers.scikit_image import SegmentationAlgorithm
16 |
17 |
18 | class ImageExplanation(object):
19 | def __init__(self, image, segments):
20 | """Init function.
21 |
22 | Args:
23 | image: 3d numpy array
24 | segments: 2d numpy array, with the output from skimage.segmentation
25 | """
26 | self.image = image
27 | self.segments = segments
28 | self.intercept = {}
29 | self.local_exp = {}
30 | self.local_pred = {}
31 | self.score = {}
32 |
33 | def get_image_and_mask(self, label, positive_only=True, negative_only=False, hide_rest=False,
34 | num_features=5, min_weight=0.):
35 | """Init function.
36 |
37 | Args:
38 | label: label to explain
39 | positive_only: if True, only take superpixels that positively contribute to
40 | the prediction of the label.
41 | negative_only: if True, only take superpixels that negatively contribute to
42 | the prediction of the label. If false, and so is positive_only, then both
43 | negativey and positively contributions will be taken.
44 | Both can't be True at the same time
45 | hide_rest: if True, make the non-explanation part of the return
46 | image gray
47 | num_features: number of superpixels to include in explanation
48 | min_weight: minimum weight of the superpixels to include in explanation
49 |
50 | Returns:
51 | (image, mask), where image is a 3d numpy array and mask is a 2d
52 | numpy array that can be used with
53 | skimage.segmentation.mark_boundaries
54 | """
55 | if label not in self.local_exp:
56 | raise KeyError('Label not in explanation')
57 | if positive_only & negative_only:
58 | raise ValueError("Positive_only and negative_only cannot be true at the same time.")
59 | segments = self.segments
60 | image = self.image
61 | exp = self.local_exp[label]
62 | mask = np.zeros(segments.shape, segments.dtype)
63 | if hide_rest:
64 | temp = np.zeros(self.image.shape)
65 | else:
66 | temp = self.image.copy()
67 | if positive_only:
68 | fs = [x[0] for x in exp
69 | if x[1] > 0 and x[1] > min_weight][:num_features]
70 | if negative_only:
71 | fs = [x[0] for x in exp
72 | if x[1] < 0 and abs(x[1]) > min_weight][:num_features]
73 | if positive_only or negative_only:
74 | for f in fs:
75 | temp[segments == f] = image[segments == f].copy()
76 | mask[segments == f] = 1
77 | return temp, mask
78 | else:
79 | for f, w in exp[:num_features]:
80 | if np.abs(w) < min_weight:
81 | continue
82 | c = 0 if w < 0 else 1
83 | mask[segments == f] = -1 if w < 0 else 1
84 | temp[segments == f] = image[segments == f].copy()
85 | temp[segments == f, c] = np.max(image)
86 | return temp, mask
87 |
88 |
89 | class LimeImageExplainer(object):
90 | """Explains predictions on Image (i.e. matrix) data.
91 | For numerical features, perturb them by sampling from a Normal(0,1) and
92 | doing the inverse operation of mean-centering and scaling, according to the
93 | means and stds in the training data. For categorical features, perturb by
94 | sampling according to the training distribution, and making a binary
95 | feature that is 1 when the value is the same as the instance being
96 | explained."""
97 |
98 | def __init__(self, kernel_width=.25, kernel=None, verbose=False,
99 | feature_selection='auto', random_state=None):
100 | """Init function.
101 |
102 | Args:
103 | kernel_width: kernel width for the exponential kernel.
104 | If None, defaults to sqrt(number of columns) * 0.75.
105 | kernel: similarity kernel that takes euclidean distances and kernel
106 | width as input and outputs weights in (0,1). If None, defaults to
107 | an exponential kernel.
108 | verbose: if true, print local prediction values from linear model
109 | feature_selection: feature selection method. can be
110 | 'forward_selection', 'lasso_path', 'none' or 'auto'.
111 | See function 'explain_instance_with_data' in lime_base.py for
112 | details on what each of the options does.
113 | random_state: an integer or numpy.RandomState that will be used to
114 | generate random numbers. If None, the random state will be
115 | initialized using the internal numpy seed.
116 | """
117 | kernel_width = float(kernel_width)
118 |
119 | if kernel is None:
120 | def kernel(d, kernel_width):
121 | return np.sqrt(np.exp(-(d ** 2) / kernel_width ** 2))
122 |
123 | kernel_fn = partial(kernel, kernel_width=kernel_width)
124 |
125 | self.random_state = check_random_state(random_state)
126 | self.feature_selection = feature_selection
127 | self.base = lime_base.LimeBase(kernel_fn, verbose, random_state=self.random_state)
128 |
129 | def explain_instance(self, image, classifier_fn, labels=(1,),
130 | hide_color=None,
131 | top_labels=5, num_features=100000, num_samples=1000,
132 | batch_size=10,
133 | segmentation_fn=None,
134 | distance_metric='cosine',
135 | model_regressor=None,
136 | random_seed=None,
137 | progress_bar=True):
138 | """Generates explanations for a prediction.
139 |
140 | First, we generate neighborhood data by randomly perturbing features
141 | from the instance (see __data_inverse). We then learn locally weighted
142 | linear models on this neighborhood data to explain each of the classes
143 | in an interpretable way (see lime_base.py).
144 |
145 | Args:
146 | image: 3 dimension RGB image. If this is only two dimensional,
147 | we will assume it's a grayscale image and call gray2rgb.
148 | classifier_fn: classifier prediction probability function, which
149 | takes a numpy array and outputs prediction probabilities. For
150 | ScikitClassifiers , this is classifier.predict_proba.
151 | labels: iterable with labels to be explained.
152 | hide_color: TODO
153 | top_labels: if not None, ignore labels and produce explanations for
154 | the K labels with highest prediction probabilities, where K is
155 | this parameter.
156 | num_features: maximum number of features present in explanation
157 | num_samples: size of the neighborhood to learn the linear model
158 | batch_size: TODO
159 | distance_metric: the distance metric to use for weights.
160 | model_regressor: sklearn regressor to use in explanation. Defaults
161 | to Ridge regression in LimeBase. Must have model_regressor.coef_
162 | and 'sample_weight' as a parameter to model_regressor.fit()
163 | segmentation_fn: SegmentationAlgorithm, wrapped skimage
164 | segmentation function
165 | random_seed: integer used as random seed for the segmentation
166 | algorithm. If None, a random integer, between 0 and 1000,
167 | will be generated using the internal random number generator.
168 | progress_bar: if True, show tqdm progress bar.
169 |
170 | Returns:
171 | An ImageExplanation object (see lime_image.py) with the corresponding
172 | explanations.
173 | """
174 | if len(image.shape) == 2:
175 | image = gray2rgb(image)
176 | if random_seed is None:
177 | random_seed = self.random_state.randint(0, high=1000)
178 |
179 | if segmentation_fn is None:
180 | segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4,
181 | max_dist=200, ratio=0.2,
182 | random_seed=random_seed)
183 | try:
184 | segments = segmentation_fn(image)
185 | except ValueError as e:
186 | raise e
187 |
188 | fudged_image = image.copy()
189 | if hide_color is None:
190 | for x in np.unique(segments):
191 | fudged_image[segments == x] = (
192 | np.mean(image[segments == x][:, 0]),
193 | np.mean(image[segments == x][:, 1]),
194 | np.mean(image[segments == x][:, 2]))
195 | else:
196 | fudged_image[:] = hide_color
197 |
198 | top = labels
199 |
200 | data, labels = self.data_labels(image, fudged_image, segments,
201 | classifier_fn, num_samples,
202 | batch_size=batch_size,
203 | progress_bar=progress_bar)
204 |
205 | distances = sklearn.metrics.pairwise_distances(
206 | data,
207 | data[0].reshape(1, -1),
208 | metric=distance_metric
209 | ).ravel()
210 |
211 | ret_exp = ImageExplanation(image, segments)
212 | if top_labels:
213 | top = np.argsort(labels[0])[-top_labels:]
214 | ret_exp.top_labels = list(top)
215 | ret_exp.top_labels.reverse()
216 | for label in top:
217 | (ret_exp.intercept[label],
218 | ret_exp.local_exp[label],
219 | ret_exp.score[label],
220 | ret_exp.local_pred[label]) = self.base.explain_instance_with_data(
221 | data, labels, distances, label, num_features,
222 | model_regressor=model_regressor,
223 | feature_selection=self.feature_selection)
224 | return ret_exp
225 |
226 | def testing_explain_instance(self, image, classifier_fn, labels=(1,),
227 | hide_color=None,
228 | top_labels=5, num_features=100000, num_samples=1000,
229 | batch_size=10,
230 | segmentation_fn=None,
231 | distance_metric='cosine',
232 | model_regressor=None,
233 | alpha=0.05,
234 | random_seed=None,
235 | progress_bar=True):
236 | """Generates explanations for a prediction.
237 |
238 | First, we generate neighborhood data by randomly perturbing features
239 | from the instance (see __data_inverse). We then learn locally weighted
240 | linear models on this neighborhood data to explain each of the classes
241 | in an interpretable way (see lime_base.py).
242 |
243 | Args:
244 | image: 3 dimension RGB image. If this is only two dimensional,
245 | we will assume it's a grayscale image and call gray2rgb.
246 | classifier_fn: classifier prediction probability function, which
247 | takes a numpy array and outputs prediction probabilities. For
248 | ScikitClassifiers , this is classifier.predict_proba.
249 | labels: iterable with labels to be explained.
250 | hide_color: TODO
251 | top_labels: if not None, ignore labels and produce explanations for
252 | the K labels with highest prediction probabilities, where K is
253 | this parameter.
254 | num_features: maximum number of features present in explanation
255 | num_samples: size of the neighborhood to learn the linear model
256 | batch_size: TODO
257 | distance_metric: the distance metric to use for weights.
258 | model_regressor: sklearn regressor to use in explanation. Defaults
259 | to Ridge regression in LimeBase. Must have model_regressor.coef_
260 | and 'sample_weight' as a parameter to model_regressor.fit()
261 | segmentation_fn: SegmentationAlgorithm, wrapped skimage
262 | segmentation function
263 | random_seed: integer used as random seed for the segmentation
264 | algorithm. If None, a random integer, between 0 and 1000,
265 | will be generated using the internal random number generator.
266 | progress_bar: if True, show tqdm progress bar.
267 |
268 | Returns:
269 | An ImageExplanation object (see lime_image.py) with the corresponding
270 | explanations.
271 | """
272 | if len(image.shape) == 2:
273 | image = gray2rgb(image)
274 | if random_seed is None:
275 | random_seed = self.random_state.randint(0, high=1000)
276 |
277 | if segmentation_fn is None:
278 | segmentation_fn = SegmentationAlgorithm('quickshift', kernel_size=4,
279 | max_dist=200, ratio=0.2,
280 | random_seed=random_seed)
281 | try:
282 | segments = segmentation_fn(image)
283 | except ValueError as e:
284 | raise e
285 |
286 | fudged_image = image.copy()
287 | if hide_color is None:
288 | for x in np.unique(segments):
289 | fudged_image[segments == x] = (
290 | np.mean(image[segments == x][:, 0]),
291 | np.mean(image[segments == x][:, 1]),
292 | np.mean(image[segments == x][:, 2]))
293 | else:
294 | fudged_image[:] = hide_color
295 |
296 | top = labels
297 |
298 | data, labels = self.data_labels(image, fudged_image, segments,
299 | classifier_fn, num_samples,
300 | batch_size=batch_size,
301 | progress_bar=progress_bar)
302 |
303 | distances = sklearn.metrics.pairwise_distances(
304 | data,
305 | data[0].reshape(1, -1),
306 | metric=distance_metric
307 | ).ravel()
308 |
309 | ret_exp = ImageExplanation(image, segments)
310 | if top_labels:
311 | top = np.argsort(labels[0])[-top_labels:]
312 | ret_exp.top_labels = list(top)
313 | ret_exp.top_labels.reverse()
314 | for label in top:
315 | (ret_exp.intercept[label],
316 | ret_exp.local_exp[label],
317 | ret_exp.score[label],
318 | ret_exp.local_pred[label],
319 | used_features,
320 | test_result) = self.base.testing_explain_instance_with_data(
321 | data, labels, distances, label, num_features,
322 | model_regressor=model_regressor,
323 | feature_selection=self.feature_selection,
324 | alpha=alpha)
325 | return ret_exp, test_result
326 |
327 | def data_labels(self,
328 | image,
329 | fudged_image,
330 | segments,
331 | classifier_fn,
332 | num_samples,
333 | batch_size=10,
334 | progress_bar=True):
335 | """Generates images and predictions in the neighborhood of this image.
336 |
337 | Args:
338 | image: 3d numpy array, the image
339 | fudged_image: 3d numpy array, image to replace original image when
340 | superpixel is turned off
341 | segments: segmentation of the image
342 | classifier_fn: function that takes a list of images and returns a
343 | matrix of prediction probabilities
344 | num_samples: size of the neighborhood to learn the linear model
345 | batch_size: classifier_fn will be called on batches of this size.
346 | progress_bar: if True, show tqdm progress bar.
347 |
348 | Returns:
349 | A tuple (data, labels), where:
350 | data: dense num_samples * num_superpixels
351 | labels: prediction probabilities matrix
352 | """
353 | n_features = np.unique(segments).shape[0]
354 | data = self.random_state.randint(0, 2, num_samples * n_features)\
355 | .reshape((num_samples, n_features))
356 | labels = []
357 | data[0, :] = 1
358 | imgs = []
359 | rows = tqdm(data) if progress_bar else data
360 | for row in rows:
361 | temp = copy.deepcopy(image)
362 | zeros = np.where(row == 0)[0]
363 | mask = np.zeros(segments.shape).astype(bool)
364 | for z in zeros:
365 | mask[segments == z] = True
366 | temp[mask] = fudged_image[mask]
367 | imgs.append(temp)
368 | if len(imgs) == batch_size:
369 | preds = classifier_fn(np.array(imgs))
370 | labels.extend(preds)
371 | imgs = []
372 | if len(imgs) > 0:
373 | preds = classifier_fn(np.array(imgs))
374 | labels.extend(preds)
375 | return data, np.array(labels)
376 |
377 | def slime(self,
378 | image, classifier_fn, labels=(1,),
379 | hide_color=None,
380 | top_labels=5, num_features=100000, num_samples=1000,
381 | batch_size=10,
382 | segmentation_fn=None,
383 | distance_metric='cosine',
384 | model_regressor=None,
385 | n_max=10000,
386 | alpha=0.05,
387 | tol=1e-3,
388 | random_seed=None,
389 | progress_bar=True
390 | ):
391 | """Generates explanations for a prediction with S-LIME.
392 |
393 | First, we generate neighborhood data by randomly perturbing features
394 | from the instance (see __data_inverse). We then learn locally weighted
395 | linear models on this neighborhood data to explain each of the classes
396 | in an interpretable way (see lime_base.py).
397 |
398 | Args:
399 | data_row: 1d numpy array or scipy.sparse matrix, corresponding to a row
400 | predict_fn: prediction function. For classifiers, this should be a
401 | function that takes a numpy array and outputs prediction
402 | probabilities. For regressors, this takes a numpy array and
403 | returns the predictions. For ScikitClassifiers, this is
404 | `classifier.predict_proba()`. For ScikitRegressors, this
405 | is `regressor.predict()`. The prediction function needs to work
406 | on multiple feature vectors (the vectors randomly perturbed
407 | from the data_row).
408 | labels: iterable with labels to be explained.
409 | top_labels: if not None, ignore labels and produce explanations for
410 | the K labels with highest prediction probabilities, where K is
411 | this parameter.
412 | num_features: maximum number of features present in explanation
413 | num_samples: size of the neighborhood to learn the linear model as a start
414 | distance_metric: the distance metric to use for weights.
415 | model_regressor: sklearn regressor to use in explanation. Defaults
416 | to Ridge regression in LimeBase. Must have model_regressor.coef_
417 | and 'sample_weight' as a parameter to model_regressor.fit()
418 | sampling_method: Method to sample synthetic data. Defaults to Gaussian
419 | sampling. Can also use Latin Hypercube Sampling.
420 | n_max: maximum number of sythetic samples to generate.
421 | alpha: significance level of hypothesis testing.
422 | tol: tolerence level of hypothesis testing.
423 |
424 | Returns:
425 | An Explanation object (see explanation.py) with the corresponding
426 | explanations.
427 | """
428 |
429 | while True:
430 | ret_exp, test_result = self.testing_explain_instance(image=image,
431 | classifier_fn=classifier_fn,
432 | labels=labels,
433 | hide_color=hide_color,
434 | top_labels=top_labels,
435 | num_features=num_features,
436 | num_samples=num_samples,
437 | batch_size=batch_size,
438 | segmentation_fn=segmentation_fn,
439 | distance_metric=distance_metric,
440 | model_regressor=model_regressor,
441 | alpha=alpha,
442 | random_seed=random_seed,
443 | progress_bar=progress_bar)
444 | flag = False
445 | for k in range(1, num_features): # changes num_features + 1 to num_features because it fixes bug
446 | if test_result[k][0] < -tol:
447 | flag = True
448 | break
449 | if flag and num_samples != n_max:
450 | num_samples = min(int(test_result[k][1]), 2 * num_samples)
451 | if num_samples > n_max:
452 | num_samples = n_max
453 | else:
454 | break
455 |
456 | return ret_exp
457 |
--------------------------------------------------------------------------------
/slime/tests/test_lime_tabular.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 | import collections
5 | import sklearn # noqa
6 | import sklearn.datasets
7 | import sklearn.ensemble
8 | import sklearn.linear_model # noqa
9 | from numpy.testing import assert_array_equal
10 | from sklearn.datasets import load_iris, make_classification, make_multilabel_classification
11 | from sklearn.ensemble import RandomForestClassifier
12 | from sklearn.linear_model import LinearRegression
13 | from lime.discretize import QuartileDiscretizer, DecileDiscretizer, EntropyDiscretizer
14 |
15 |
16 | try:
17 | from sklearn.model_selection import train_test_split
18 | except ImportError:
19 | # Deprecated in scikit-learn version 0.18, removed in 0.20
20 | from sklearn.cross_validation import train_test_split
21 |
22 | from lime.lime_tabular import LimeTabularExplainer
23 |
24 |
25 | class TestLimeTabular(unittest.TestCase):
26 |
27 | def setUp(self):
28 | iris = load_iris()
29 |
30 | self.feature_names = iris.feature_names
31 | self.target_names = iris.target_names
32 |
33 | (self.train,
34 | self.test,
35 | self.labels_train,
36 | self.labels_test) = train_test_split(iris.data, iris.target, train_size=0.80)
37 |
38 | def test_lime_explainer_good_regressor(self):
39 | np.random.seed(1)
40 | rf = RandomForestClassifier(n_estimators=500)
41 | rf.fit(self.train, self.labels_train)
42 | i = np.random.randint(0, self.test.shape[0])
43 |
44 | explainer = LimeTabularExplainer(self.train,
45 | mode="classification",
46 | feature_names=self.feature_names,
47 | class_names=self.target_names,
48 | discretize_continuous=True)
49 |
50 | exp = explainer.explain_instance(self.test[i],
51 | rf.predict_proba,
52 | num_features=2,
53 | model_regressor=LinearRegression())
54 |
55 | self.assertIsNotNone(exp)
56 | keys = [x[0] for x in exp.as_list()]
57 | self.assertEqual(1,
58 | sum([1 if 'petal width' in x else 0 for x in keys]),
59 | "Petal Width is a major feature")
60 | self.assertEqual(1,
61 | sum([1 if 'petal length' in x else 0 for x in keys]),
62 | "Petal Length is a major feature")
63 |
64 | def test_lime_explainer_good_regressor_synthetic_data(self):
65 | X, y = make_classification(n_samples=1000,
66 | n_features=20,
67 | n_informative=2,
68 | n_redundant=2,
69 | random_state=10)
70 |
71 | rf = RandomForestClassifier(n_estimators=500)
72 | rf.fit(X, y)
73 | instance = np.random.randint(0, X.shape[0])
74 | feature_names = ["feature" + str(i) for i in range(20)]
75 | explainer = LimeTabularExplainer(X,
76 | feature_names=feature_names,
77 | discretize_continuous=True)
78 |
79 | exp = explainer.explain_instance(X[instance], rf.predict_proba)
80 |
81 | self.assertIsNotNone(exp)
82 | self.assertEqual(10, len(exp.as_list()))
83 |
84 | def test_lime_explainer_sparse_synthetic_data(self):
85 | n_features = 20
86 | X, y = make_multilabel_classification(n_samples=100,
87 | sparse=True,
88 | n_features=n_features,
89 | n_classes=1,
90 | n_labels=2)
91 | rf = RandomForestClassifier(n_estimators=500)
92 | rf.fit(X, y)
93 | instance = np.random.randint(0, X.shape[0])
94 | feature_names = ["feature" + str(i) for i in range(n_features)]
95 | explainer = LimeTabularExplainer(X,
96 | feature_names=feature_names,
97 | discretize_continuous=True)
98 |
99 | exp = explainer.explain_instance(X[instance], rf.predict_proba)
100 |
101 | self.assertIsNotNone(exp)
102 | self.assertEqual(10, len(exp.as_list()))
103 |
104 | def test_lime_explainer_no_regressor(self):
105 | np.random.seed(1)
106 |
107 | rf = RandomForestClassifier(n_estimators=500)
108 | rf.fit(self.train, self.labels_train)
109 | i = np.random.randint(0, self.test.shape[0])
110 |
111 | explainer = LimeTabularExplainer(self.train,
112 | feature_names=self.feature_names,
113 | class_names=self.target_names,
114 | discretize_continuous=True)
115 |
116 | exp = explainer.explain_instance(self.test[i],
117 | rf.predict_proba,
118 | num_features=2)
119 | self.assertIsNotNone(exp)
120 | keys = [x[0] for x in exp.as_list()]
121 | self.assertEqual(1,
122 | sum([1 if 'petal width' in x else 0 for x in keys]),
123 | "Petal Width is a major feature")
124 | self.assertEqual(1,
125 | sum([1 if 'petal length' in x else 0 for x in keys]),
126 | "Petal Length is a major feature")
127 |
128 | def test_lime_explainer_entropy_discretizer(self):
129 | np.random.seed(1)
130 |
131 | rf = RandomForestClassifier(n_estimators=500)
132 | rf.fit(self.train, self.labels_train)
133 | i = np.random.randint(0, self.test.shape[0])
134 |
135 | explainer = LimeTabularExplainer(self.train,
136 | feature_names=self.feature_names,
137 | class_names=self.target_names,
138 | training_labels=self.labels_train,
139 | discretize_continuous=True,
140 | discretizer='entropy')
141 |
142 | exp = explainer.explain_instance(self.test[i],
143 | rf.predict_proba,
144 | num_features=2)
145 | self.assertIsNotNone(exp)
146 | keys = [x[0] for x in exp.as_list()]
147 | print(keys)
148 | self.assertEqual(1,
149 | sum([1 if 'petal width' in x else 0 for x in keys]),
150 | "Petal Width is a major feature")
151 | self.assertEqual(1,
152 | sum([1 if 'petal length' in x else 0 for x in keys]),
153 | "Petal Length is a major feature")
154 |
155 | def test_lime_tabular_explainer_equal_random_state(self):
156 | X, y = make_classification(n_samples=1000,
157 | n_features=20,
158 | n_informative=2,
159 | n_redundant=2,
160 | random_state=10)
161 |
162 | rf = RandomForestClassifier(n_estimators=500, random_state=10)
163 | rf.fit(X, y)
164 | instance = np.random.RandomState(10).randint(0, X.shape[0])
165 | feature_names = ["feature" + str(i) for i in range(20)]
166 |
167 | # ----------------------------------------------------------------------
168 | # -------------------------Quartile Discretizer-------------------------
169 | # ----------------------------------------------------------------------
170 | discretizer = QuartileDiscretizer(X, [], feature_names, y,
171 | random_state=10)
172 | explainer_1 = LimeTabularExplainer(X,
173 | feature_names=feature_names,
174 | discretize_continuous=True,
175 | discretizer=discretizer,
176 | random_state=10)
177 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
178 | num_samples=500)
179 |
180 | discretizer = QuartileDiscretizer(X, [], feature_names, y,
181 | random_state=10)
182 | explainer_2 = LimeTabularExplainer(X,
183 | feature_names=feature_names,
184 | discretize_continuous=True,
185 | discretizer=discretizer,
186 | random_state=10)
187 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
188 | num_samples=500)
189 |
190 | self.assertDictEqual(exp_1.as_map(), exp_2.as_map())
191 |
192 | # ----------------------------------------------------------------------
193 | # --------------------------Decile Discretizer--------------------------
194 | # ----------------------------------------------------------------------
195 | discretizer = DecileDiscretizer(X, [], feature_names, y,
196 | random_state=10)
197 | explainer_1 = LimeTabularExplainer(X,
198 | feature_names=feature_names,
199 | discretize_continuous=True,
200 | discretizer=discretizer,
201 | random_state=10)
202 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
203 | num_samples=500)
204 |
205 | discretizer = DecileDiscretizer(X, [], feature_names, y,
206 | random_state=10)
207 | explainer_2 = LimeTabularExplainer(X,
208 | feature_names=feature_names,
209 | discretize_continuous=True,
210 | discretizer=discretizer,
211 | random_state=10)
212 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
213 | num_samples=500)
214 |
215 | self.assertDictEqual(exp_1.as_map(), exp_2.as_map())
216 |
217 | # ----------------------------------------------------------------------
218 | # -------------------------Entropy Discretizer--------------------------
219 | # ----------------------------------------------------------------------
220 | discretizer = EntropyDiscretizer(X, [], feature_names, y,
221 | random_state=10)
222 | explainer_1 = LimeTabularExplainer(X,
223 | feature_names=feature_names,
224 | discretize_continuous=True,
225 | discretizer=discretizer,
226 | random_state=10)
227 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
228 | num_samples=500)
229 |
230 | discretizer = EntropyDiscretizer(X, [], feature_names, y,
231 | random_state=10)
232 | explainer_2 = LimeTabularExplainer(X,
233 | feature_names=feature_names,
234 | discretize_continuous=True,
235 | discretizer=discretizer,
236 | random_state=10)
237 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
238 | num_samples=500)
239 |
240 | self.assertDictEqual(exp_1.as_map(), exp_2.as_map())
241 |
242 | def test_lime_tabular_explainer_not_equal_random_state(self):
243 | X, y = make_classification(n_samples=1000,
244 | n_features=20,
245 | n_informative=2,
246 | n_redundant=2,
247 | random_state=10)
248 |
249 | rf = RandomForestClassifier(n_estimators=500, random_state=10)
250 | rf.fit(X, y)
251 | instance = np.random.RandomState(10).randint(0, X.shape[0])
252 | feature_names = ["feature" + str(i) for i in range(20)]
253 |
254 | # ----------------------------------------------------------------------
255 | # -------------------------Quartile Discretizer-------------------------
256 | # ----------------------------------------------------------------------
257 |
258 | # ---------------------------------[1]----------------------------------
259 | discretizer = QuartileDiscretizer(X, [], feature_names, y,
260 | random_state=20)
261 | explainer_1 = LimeTabularExplainer(X,
262 | feature_names=feature_names,
263 | discretize_continuous=True,
264 | discretizer=discretizer,
265 | random_state=10)
266 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
267 | num_samples=500)
268 |
269 | discretizer = QuartileDiscretizer(X, [], feature_names, y,
270 | random_state=10)
271 | explainer_2 = LimeTabularExplainer(X,
272 | feature_names=feature_names,
273 | discretize_continuous=True,
274 | discretizer=discretizer,
275 | random_state=10)
276 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
277 | num_samples=500)
278 |
279 | self.assertTrue(exp_1.as_map() != exp_2.as_map())
280 |
281 | # ---------------------------------[2]----------------------------------
282 | discretizer = QuartileDiscretizer(X, [], feature_names, y,
283 | random_state=20)
284 | explainer_1 = LimeTabularExplainer(X,
285 | feature_names=feature_names,
286 | discretize_continuous=True,
287 | discretizer=discretizer,
288 | random_state=20)
289 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
290 | num_samples=500)
291 |
292 | discretizer = QuartileDiscretizer(X, [], feature_names, y,
293 | random_state=10)
294 | explainer_2 = LimeTabularExplainer(X,
295 | feature_names=feature_names,
296 | discretize_continuous=True,
297 | discretizer=discretizer,
298 | random_state=10)
299 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
300 | num_samples=500)
301 |
302 | self.assertTrue(exp_1.as_map() != exp_2.as_map())
303 |
304 | # ---------------------------------[3]----------------------------------
305 | discretizer = QuartileDiscretizer(X, [], feature_names, y,
306 | random_state=20)
307 | explainer_1 = LimeTabularExplainer(X,
308 | feature_names=feature_names,
309 | discretize_continuous=True,
310 | discretizer=discretizer,
311 | random_state=20)
312 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
313 | num_samples=500)
314 |
315 | discretizer = QuartileDiscretizer(X, [], feature_names, y,
316 | random_state=20)
317 | explainer_2 = LimeTabularExplainer(X,
318 | feature_names=feature_names,
319 | discretize_continuous=True,
320 | discretizer=discretizer,
321 | random_state=10)
322 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
323 | num_samples=500)
324 |
325 | self.assertTrue(exp_1.as_map() != exp_2.as_map())
326 |
327 | # ---------------------------------[4]----------------------------------
328 | discretizer = QuartileDiscretizer(X, [], feature_names, y,
329 | random_state=20)
330 | explainer_1 = LimeTabularExplainer(X,
331 | feature_names=feature_names,
332 | discretize_continuous=True,
333 | discretizer=discretizer,
334 | random_state=20)
335 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
336 | num_samples=500)
337 |
338 | discretizer = QuartileDiscretizer(X, [], feature_names, y,
339 | random_state=20)
340 | explainer_2 = LimeTabularExplainer(X,
341 | feature_names=feature_names,
342 | discretize_continuous=True,
343 | discretizer=discretizer,
344 | random_state=20)
345 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
346 | num_samples=500)
347 |
348 | self.assertFalse(exp_1.as_map() != exp_2.as_map())
349 |
350 | # ----------------------------------------------------------------------
351 | # --------------------------Decile Discretizer--------------------------
352 | # ----------------------------------------------------------------------
353 |
354 | # ---------------------------------[1]----------------------------------
355 | discretizer = DecileDiscretizer(X, [], feature_names, y,
356 | random_state=20)
357 | explainer_1 = LimeTabularExplainer(X,
358 | feature_names=feature_names,
359 | discretize_continuous=True,
360 | discretizer=discretizer,
361 | random_state=10)
362 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
363 | num_samples=500)
364 |
365 | discretizer = DecileDiscretizer(X, [], feature_names, y,
366 | random_state=10)
367 | explainer_2 = LimeTabularExplainer(X,
368 | feature_names=feature_names,
369 | discretize_continuous=True,
370 | discretizer=discretizer,
371 | random_state=10)
372 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
373 | num_samples=500)
374 |
375 | self.assertTrue(exp_1.as_map() != exp_2.as_map())
376 |
377 | # ---------------------------------[2]----------------------------------
378 | discretizer = DecileDiscretizer(X, [], feature_names, y,
379 | random_state=20)
380 | explainer_1 = LimeTabularExplainer(X,
381 | feature_names=feature_names,
382 | discretize_continuous=True,
383 | discretizer=discretizer,
384 | random_state=20)
385 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
386 | num_samples=500)
387 |
388 | discretizer = DecileDiscretizer(X, [], feature_names, y,
389 | random_state=10)
390 | explainer_2 = LimeTabularExplainer(X,
391 | feature_names=feature_names,
392 | discretize_continuous=True,
393 | discretizer=discretizer,
394 | random_state=10)
395 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
396 | num_samples=500)
397 |
398 | self.assertTrue(exp_1.as_map() != exp_2.as_map())
399 |
400 | # ---------------------------------[3]----------------------------------
401 | discretizer = DecileDiscretizer(X, [], feature_names, y,
402 | random_state=20)
403 | explainer_1 = LimeTabularExplainer(X,
404 | feature_names=feature_names,
405 | discretize_continuous=True,
406 | discretizer=discretizer,
407 | random_state=20)
408 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
409 | num_samples=500)
410 |
411 | discretizer = DecileDiscretizer(X, [], feature_names, y,
412 | random_state=20)
413 | explainer_2 = LimeTabularExplainer(X,
414 | feature_names=feature_names,
415 | discretize_continuous=True,
416 | discretizer=discretizer,
417 | random_state=10)
418 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
419 | num_samples=500)
420 |
421 | self.assertTrue(exp_1.as_map() != exp_2.as_map())
422 |
423 | # ---------------------------------[4]----------------------------------
424 | discretizer = DecileDiscretizer(X, [], feature_names, y,
425 | random_state=20)
426 | explainer_1 = LimeTabularExplainer(X,
427 | feature_names=feature_names,
428 | discretize_continuous=True,
429 | discretizer=discretizer,
430 | random_state=20)
431 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
432 | num_samples=500)
433 |
434 | discretizer = DecileDiscretizer(X, [], feature_names, y,
435 | random_state=20)
436 | explainer_2 = LimeTabularExplainer(X,
437 | feature_names=feature_names,
438 | discretize_continuous=True,
439 | discretizer=discretizer,
440 | random_state=20)
441 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
442 | num_samples=500)
443 |
444 | self.assertFalse(exp_1.as_map() != exp_2.as_map())
445 |
446 | # ----------------------------------------------------------------------
447 | # --------------------------Entropy Discretizer-------------------------
448 | # ----------------------------------------------------------------------
449 |
450 | # ---------------------------------[1]----------------------------------
451 | discretizer = EntropyDiscretizer(X, [], feature_names, y,
452 | random_state=20)
453 | explainer_1 = LimeTabularExplainer(X,
454 | feature_names=feature_names,
455 | discretize_continuous=True,
456 | discretizer=discretizer,
457 | random_state=10)
458 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
459 | num_samples=500)
460 |
461 | discretizer = EntropyDiscretizer(X, [], feature_names, y,
462 | random_state=10)
463 | explainer_2 = LimeTabularExplainer(X,
464 | feature_names=feature_names,
465 | discretize_continuous=True,
466 | discretizer=discretizer,
467 | random_state=10)
468 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
469 | num_samples=500)
470 |
471 | self.assertTrue(exp_1.as_map() != exp_2.as_map())
472 |
473 | # ---------------------------------[2]----------------------------------
474 | discretizer = EntropyDiscretizer(X, [], feature_names, y,
475 | random_state=20)
476 | explainer_1 = LimeTabularExplainer(X,
477 | feature_names=feature_names,
478 | discretize_continuous=True,
479 | discretizer=discretizer,
480 | random_state=20)
481 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
482 | num_samples=500)
483 |
484 | discretizer = EntropyDiscretizer(X, [], feature_names, y,
485 | random_state=10)
486 | explainer_2 = LimeTabularExplainer(X,
487 | feature_names=feature_names,
488 | discretize_continuous=True,
489 | discretizer=discretizer,
490 | random_state=10)
491 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
492 | num_samples=500)
493 |
494 | self.assertTrue(exp_1.as_map() != exp_2.as_map())
495 |
496 | # ---------------------------------[3]----------------------------------
497 | discretizer = EntropyDiscretizer(X, [], feature_names, y,
498 | random_state=20)
499 | explainer_1 = LimeTabularExplainer(X,
500 | feature_names=feature_names,
501 | discretize_continuous=True,
502 | discretizer=discretizer,
503 | random_state=20)
504 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
505 | num_samples=500)
506 |
507 | discretizer = EntropyDiscretizer(X, [], feature_names, y,
508 | random_state=20)
509 | explainer_2 = LimeTabularExplainer(X,
510 | feature_names=feature_names,
511 | discretize_continuous=True,
512 | discretizer=discretizer,
513 | random_state=10)
514 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
515 | num_samples=500)
516 |
517 | self.assertTrue(exp_1.as_map() != exp_2.as_map())
518 |
519 | # ---------------------------------[4]----------------------------------
520 | discretizer = EntropyDiscretizer(X, [], feature_names, y,
521 | random_state=20)
522 | explainer_1 = LimeTabularExplainer(X,
523 | feature_names=feature_names,
524 | discretize_continuous=True,
525 | discretizer=discretizer,
526 | random_state=20)
527 | exp_1 = explainer_1.explain_instance(X[instance], rf.predict_proba,
528 | num_samples=500)
529 |
530 | discretizer = EntropyDiscretizer(X, [], feature_names, y,
531 | random_state=20)
532 | explainer_2 = LimeTabularExplainer(X,
533 | feature_names=feature_names,
534 | discretize_continuous=True,
535 | discretizer=discretizer,
536 | random_state=20)
537 | exp_2 = explainer_2.explain_instance(X[instance], rf.predict_proba,
538 | num_samples=500)
539 |
540 | self.assertFalse(exp_1.as_map() != exp_2.as_map())
541 |
542 | def testFeatureNamesAndCategoricalFeats(self):
543 | training_data = np.array([[0., 1.], [1., 0.]])
544 |
545 | explainer = LimeTabularExplainer(training_data=training_data)
546 | self.assertEqual(explainer.feature_names, ['0', '1'])
547 | self.assertEqual(explainer.categorical_features, [0, 1])
548 |
549 | explainer = LimeTabularExplainer(
550 | training_data=training_data,
551 | feature_names=np.array(['one', 'two'])
552 | )
553 | self.assertEqual(explainer.feature_names, ['one', 'two'])
554 |
555 | explainer = LimeTabularExplainer(
556 | training_data=training_data,
557 | categorical_features=np.array([0]),
558 | discretize_continuous=False
559 | )
560 | self.assertEqual(explainer.categorical_features, [0])
561 |
562 | def testFeatureValues(self):
563 | training_data = np.array([
564 | [0, 0, 2],
565 | [1, 1, 0],
566 | [0, 2, 2],
567 | [1, 3, 0]
568 | ])
569 |
570 | explainer = LimeTabularExplainer(
571 | training_data=training_data,
572 | categorical_features=[0, 1, 2]
573 | )
574 |
575 | self.assertEqual(set(explainer.feature_values[0]), {0, 1})
576 | self.assertEqual(set(explainer.feature_values[1]), {0, 1, 2, 3})
577 | self.assertEqual(set(explainer.feature_values[2]), {0, 2})
578 |
579 | assert_array_equal(explainer.feature_frequencies[0], np.array([.5, .5]))
580 | assert_array_equal(explainer.feature_frequencies[1], np.array([.25, .25, .25, .25]))
581 | assert_array_equal(explainer.feature_frequencies[2], np.array([.5, .5]))
582 |
583 | def test_lime_explainer_with_data_stats(self):
584 | np.random.seed(1)
585 |
586 | rf = RandomForestClassifier(n_estimators=500)
587 | rf.fit(self.train, self.labels_train)
588 | i = np.random.randint(0, self.test.shape[0])
589 |
590 | # Generate stats using a quartile descritizer
591 | descritizer = QuartileDiscretizer(self.train, [], self.feature_names, self.target_names,
592 | random_state=20)
593 |
594 | d_means = descritizer.means
595 | d_stds = descritizer.stds
596 | d_mins = descritizer.mins
597 | d_maxs = descritizer.maxs
598 | d_bins = descritizer.bins(self.train, self.target_names)
599 |
600 | # Compute feature values and frequencies of all columns
601 | cat_features = np.arange(self.train.shape[1])
602 | discretized_training_data = descritizer.discretize(self.train)
603 |
604 | feature_values = {}
605 | feature_frequencies = {}
606 | for feature in cat_features:
607 | column = discretized_training_data[:, feature]
608 | feature_count = collections.Counter(column)
609 | values, frequencies = map(list, zip(*(feature_count.items())))
610 | feature_values[feature] = values
611 | feature_frequencies[feature] = frequencies
612 |
613 | # Convert bins to list from array
614 | d_bins_revised = {}
615 | index = 0
616 | for bin in d_bins:
617 | d_bins_revised[index] = bin.tolist()
618 | index = index+1
619 |
620 | # Descritized stats
621 | data_stats = {}
622 | data_stats["means"] = d_means
623 | data_stats["stds"] = d_stds
624 | data_stats["maxs"] = d_maxs
625 | data_stats["mins"] = d_mins
626 | data_stats["bins"] = d_bins_revised
627 | data_stats["feature_values"] = feature_values
628 | data_stats["feature_frequencies"] = feature_frequencies
629 |
630 | data = np.zeros((2, len(self.feature_names)))
631 | explainer = LimeTabularExplainer(
632 | data, feature_names=self.feature_names, random_state=10,
633 | training_data_stats=data_stats, training_labels=self.target_names)
634 |
635 | exp = explainer.explain_instance(self.test[i],
636 | rf.predict_proba,
637 | num_features=2,
638 | model_regressor=LinearRegression())
639 |
640 | self.assertIsNotNone(exp)
641 | keys = [x[0] for x in exp.as_list()]
642 | self.assertEqual(1,
643 | sum([1 if 'petal width' in x else 0 for x in keys]),
644 | "Petal Width is a major feature")
645 | self.assertEqual(1,
646 | sum([1 if 'petal length' in x else 0 for x in keys]),
647 | "Petal Length is a major feature")
648 |
649 |
650 | if __name__ == '__main__':
651 | unittest.main()
652 |
--------------------------------------------------------------------------------
/slime_lm/_least_angle.py:
--------------------------------------------------------------------------------
1 | """
2 | Least Angle Regression algorithm. See the documentation on the
3 | Generalized Linear Model for a complete discussion.
4 | """
5 | # Author: Fabian Pedregosa
6 | # Alexandre Gramfort
7 | # Gael Varoquaux
8 | #
9 | # License: BSD 3 clause
10 |
11 | from math import log
12 | import sys
13 | import warnings
14 |
15 | import numpy as np
16 | from scipy import linalg, interpolate
17 | from scipy.linalg.lapack import get_lapack_funcs
18 | from scipy import stats
19 | from joblib import Parallel
20 |
21 | # mypy error: Module 'sklearn.utils' has no attribute 'arrayfuncs'
22 | from sklearn.utils import arrayfuncs, as_float_array # type: ignore
23 | from sklearn.exceptions import ConvergenceWarning
24 |
25 | SOLVE_TRIANGULAR_ARGS = {'check_finite': False}
26 |
27 |
28 | def lars_path(
29 | X,
30 | y,
31 | Xy=None,
32 | *,
33 | Gram=None,
34 | max_iter=500,
35 | alpha_min=0,
36 | method="lar",
37 | copy_X=True,
38 | eps=np.finfo(float).eps,
39 | copy_Gram=True,
40 | verbose=0,
41 | return_path=True,
42 | return_n_iter=False,
43 | positive=False,
44 | testing=False,
45 | alpha=0.05,
46 | testing_stop=False,
47 | testing_verbose=False,
48 | ):
49 | """Compute Least Angle Regression or Lasso path using LARS algorithm [1]
50 | The optimization objective for the case method='lasso' is::
51 | (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1
52 | in the case of method='lars', the objective function is only known in
53 | the form of an implicit equation (see discussion in [1])
54 | Read more in the :ref:`User Guide `.
55 | Parameters
56 | ----------
57 | X : None or array-like of shape (n_samples, n_features)
58 | Input data. Note that if X is None then the Gram matrix must be
59 | specified, i.e., cannot be None or False.
60 | y : None or array-like of shape (n_samples,)
61 | Input targets.
62 | Xy : array-like of shape (n_samples,) or (n_samples, n_targets), \
63 | default=None
64 | Xy = np.dot(X.T, y) that can be precomputed. It is useful
65 | only when the Gram matrix is precomputed.
66 | Gram : None, 'auto', array-like of shape (n_features, n_features), \
67 | default=None
68 | Precomputed Gram matrix (X' * X), if ``'auto'``, the Gram
69 | matrix is precomputed from the given X, if there are more samples
70 | than features.
71 | max_iter : int, default=500
72 | Maximum number of iterations to perform, set to infinity for no limit.
73 | alpha_min : float, default=0
74 | Minimum correlation along the path. It corresponds to the
75 | regularization parameter alpha parameter in the Lasso.
76 | method : {'lar', 'lasso'}, default='lar'
77 | Specifies the returned model. Select ``'lar'`` for Least Angle
78 | Regression, ``'lasso'`` for the Lasso.
79 | copy_X : bool, default=True
80 | If ``False``, ``X`` is overwritten.
81 | eps : float, default=np.finfo(float).eps
82 | The machine-precision regularization in the computation of the
83 | Cholesky diagonal factors. Increase this for very ill-conditioned
84 | systems. Unlike the ``tol`` parameter in some iterative
85 | optimization-based algorithms, this parameter does not control
86 | the tolerance of the optimization.
87 | copy_Gram : bool, default=True
88 | If ``False``, ``Gram`` is overwritten.
89 | verbose : int, default=0
90 | Controls output verbosity.
91 | return_path : bool, default=True
92 | If ``return_path==True`` returns the entire path, else returns only the
93 | last point of the path.
94 | return_n_iter : bool, default=False
95 | Whether to return the number of iterations.
96 | positive : bool, default=False
97 | Restrict coefficients to be >= 0.
98 | This option is only allowed with method 'lasso'. Note that the model
99 | coefficients will not converge to the ordinary-least-squares solution
100 | for small values of alpha. Only coefficients up to the smallest alpha
101 | value (``alphas_[alphas_ > 0.].min()`` when fit_path=True) reached by
102 | the stepwise Lars-Lasso algorithm are typically in congruence with the
103 | solution of the coordinate descent lasso_path function.
104 | testing : bool, default=False
105 | Whether to conduct hypothesis testing each time a new variable enters
106 | alpha : float, default=0.05
107 | Significance level of hypothesis testing. Valid only if testing is True.
108 | testing_stop : bool, default=False
109 | If set to True, stops calculating future paths when the test yields
110 | insignificant results.
111 | Only takes effect when testing is set to True.
112 | testing_verbose : bool, default=True
113 | Controls output verbosity for hypothese testing procedure.
114 | Returns
115 | -------
116 | alphas : array-like of shape (n_alphas + 1,)
117 | Maximum of covariances (in absolute value) at each iteration.
118 | ``n_alphas`` is either ``max_iter``, ``n_features`` or the
119 | number of nodes in the path with ``alpha >= alpha_min``, whichever
120 | is smaller.
121 | active : array-like of shape (n_alphas,)
122 | Indices of active variables at the end of the path.
123 | coefs : array-like of shape (n_features, n_alphas + 1)
124 | Coefficients along the path
125 | n_iter : int
126 | Number of iterations run. Returned only if return_n_iter is set
127 | to True.
128 | test_result: disctionary
129 | Contains testing results in the form of [test_stats, new_n] produced
130 | at each step. Returned only if testing is set to True.
131 | See Also
132 | --------
133 | lars_path_gram
134 | lasso_path
135 | lasso_path_gram
136 | LassoLars
137 | Lars
138 | LassoLarsCV
139 | LarsCV
140 | sklearn.decomposition.sparse_encode
141 | References
142 | ----------
143 | .. [1] "Least Angle Regression", Efron et al.
144 | http://statweb.stanford.edu/~tibs/ftp/lars.pdf
145 | .. [2] `Wikipedia entry on the Least-angle regression
146 | `_
147 | .. [3] `Wikipedia entry on the Lasso
148 | `_
149 | """
150 | if X is None and Gram is not None:
151 | raise ValueError(
152 | 'X cannot be None if Gram is not None'
153 | 'Use lars_path_gram to avoid passing X and y.'
154 | )
155 | return _lars_path_solver(
156 | X=X, y=y, Xy=Xy, Gram=Gram, n_samples=None, max_iter=max_iter,
157 | alpha_min=alpha_min, method=method, copy_X=copy_X,
158 | eps=eps, copy_Gram=copy_Gram, verbose=verbose, return_path=return_path,
159 | return_n_iter=return_n_iter, positive=positive, testing=testing,
160 | alpha=alpha, testing_stop=testing_stop, testing_verbose=testing_verbose)
161 |
162 |
163 | def lars_path_gram(
164 | Xy,
165 | Gram,
166 | *,
167 | n_samples,
168 | max_iter=500,
169 | alpha_min=0,
170 | method="lar",
171 | copy_X=True,
172 | eps=np.finfo(float).eps,
173 | copy_Gram=True,
174 | verbose=0,
175 | return_path=True,
176 | return_n_iter=False,
177 | positive=False
178 | ):
179 | """lars_path in the sufficient stats mode [1]
180 | The optimization objective for the case method='lasso' is::
181 | (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1
182 | in the case of method='lars', the objective function is only known in
183 | the form of an implicit equation (see discussion in [1])
184 | Read more in the :ref:`User Guide `.
185 | Parameters
186 | ----------
187 | Xy : array-like of shape (n_samples,) or (n_samples, n_targets)
188 | Xy = np.dot(X.T, y).
189 | Gram : array-like of shape (n_features, n_features)
190 | Gram = np.dot(X.T * X).
191 | n_samples : int or float
192 | Equivalent size of sample.
193 | max_iter : int, default=500
194 | Maximum number of iterations to perform, set to infinity for no limit.
195 | alpha_min : float, default=0
196 | Minimum correlation along the path. It corresponds to the
197 | regularization parameter alpha parameter in the Lasso.
198 | method : {'lar', 'lasso'}, default='lar'
199 | Specifies the returned model. Select ``'lar'`` for Least Angle
200 | Regression, ``'lasso'`` for the Lasso.
201 | copy_X : bool, default=True
202 | If ``False``, ``X`` is overwritten.
203 | eps : float, default=np.finfo(float).eps
204 | The machine-precision regularization in the computation of the
205 | Cholesky diagonal factors. Increase this for very ill-conditioned
206 | systems. Unlike the ``tol`` parameter in some iterative
207 | optimization-based algorithms, this parameter does not control
208 | the tolerance of the optimization.
209 | copy_Gram : bool, default=True
210 | If ``False``, ``Gram`` is overwritten.
211 | verbose : int, default=0
212 | Controls output verbosity.
213 | return_path : bool, default=True
214 | If ``return_path==True`` returns the entire path, else returns only the
215 | last point of the path.
216 | return_n_iter : bool, default=False
217 | Whether to return the number of iterations.
218 | positive : bool, default=False
219 | Restrict coefficients to be >= 0.
220 | This option is only allowed with method 'lasso'. Note that the model
221 | coefficients will not converge to the ordinary-least-squares solution
222 | for small values of alpha. Only coefficients up to the smallest alpha
223 | value (``alphas_[alphas_ > 0.].min()`` when fit_path=True) reached by
224 | the stepwise Lars-Lasso algorithm are typically in congruence with the
225 | solution of the coordinate descent lasso_path function.
226 | Returns
227 | -------
228 | alphas : array-like of shape (n_alphas + 1,)
229 | Maximum of covariances (in absolute value) at each iteration.
230 | ``n_alphas`` is either ``max_iter``, ``n_features`` or the
231 | number of nodes in the path with ``alpha >= alpha_min``, whichever
232 | is smaller.
233 | active : array-like of shape (n_alphas,)
234 | Indices of active variables at the end of the path.
235 | coefs : array-like of shape (n_features, n_alphas + 1)
236 | Coefficients along the path
237 | n_iter : int
238 | Number of iterations run. Returned only if return_n_iter is set
239 | to True.
240 | See Also
241 | --------
242 | lars_path
243 | lasso_path
244 | lasso_path_gram
245 | LassoLars
246 | Lars
247 | LassoLarsCV
248 | LarsCV
249 | sklearn.decomposition.sparse_encode
250 | References
251 | ----------
252 | .. [1] "Least Angle Regression", Efron et al.
253 | http://statweb.stanford.edu/~tibs/ftp/lars.pdf
254 | .. [2] `Wikipedia entry on the Least-angle regression
255 | `_
256 | .. [3] `Wikipedia entry on the Lasso
257 | `_
258 | """
259 | return _lars_path_solver(
260 | X=None, y=None, Xy=Xy, Gram=Gram, n_samples=n_samples,
261 | max_iter=max_iter, alpha_min=alpha_min, method=method,
262 | copy_X=copy_X, eps=eps, copy_Gram=copy_Gram,
263 | verbose=verbose, return_path=return_path,
264 | return_n_iter=return_n_iter, positive=positive)
265 |
266 |
267 | def _lars_path_solver(
268 | X,
269 | y,
270 | Xy=None,
271 | Gram=None,
272 | n_samples=None,
273 | max_iter=500,
274 | alpha_min=0,
275 | method="lar",
276 | copy_X=True,
277 | eps=np.finfo(float).eps,
278 | copy_Gram=True,
279 | verbose=0,
280 | return_path=True,
281 | return_n_iter=False,
282 | positive=False,
283 | testing=False,
284 | alpha=0.05,
285 | testing_stop=False,
286 | testing_verbose=False,
287 | ):
288 | """Compute Least Angle Regression or Lasso path using LARS algorithm [1]
289 | The optimization objective for the case method='lasso' is::
290 | (1 / (2 * n_samples)) * ||y - Xw||^2_2 + alpha * ||w||_1
291 | in the case of method='lars', the objective function is only known in
292 | the form of an implicit equation (see discussion in [1])
293 | Read more in the :ref:`User Guide `.
294 | Parameters
295 | ----------
296 | X : None or ndarray of shape (n_samples, n_features)
297 | Input data. Note that if X is None then Gram must be specified,
298 | i.e., cannot be None or False.
299 | y : None or ndarray of shape (n_samples,)
300 | Input targets.
301 | Xy : array-like of shape (n_samples,) or (n_samples, n_targets), \
302 | default=None
303 | `Xy = np.dot(X.T, y)` that can be precomputed. It is useful
304 | only when the Gram matrix is precomputed.
305 | Gram : None, 'auto' or array-like of shape (n_features, n_features), \
306 | default=None
307 | Precomputed Gram matrix `(X' * X)`, if ``'auto'``, the Gram
308 | matrix is precomputed from the given X, if there are more samples
309 | than features.
310 | n_samples : int or float, default=None
311 | Equivalent size of sample. If `None`, it will be `n_samples`.
312 | max_iter : int, default=500
313 | Maximum number of iterations to perform, set to infinity for no limit.
314 | alpha_min : float, default=0
315 | Minimum correlation along the path. It corresponds to the
316 | regularization parameter alpha parameter in the Lasso.
317 | method : {'lar', 'lasso'}, default='lar'
318 | Specifies the returned model. Select ``'lar'`` for Least Angle
319 | Regression, ``'lasso'`` for the Lasso.
320 | copy_X : bool, default=True
321 | If ``False``, ``X`` is overwritten.
322 | eps : float, default=np.finfo(float).eps
323 | The machine-precision regularization in the computation of the
324 | Cholesky diagonal factors. Increase this for very ill-conditioned
325 | systems. Unlike the ``tol`` parameter in some iterative
326 | optimization-based algorithms, this parameter does not control
327 | the tolerance of the optimization.
328 | copy_Gram : bool, default=True
329 | If ``False``, ``Gram`` is overwritten.
330 | verbose : int, default=0
331 | Controls output verbosity.
332 | return_path : bool, default=True
333 | If ``return_path==True`` returns the entire path, else returns only the
334 | last point of the path.
335 | return_n_iter : bool, default=False
336 | Whether to return the number of iterations.
337 | positive : bool, default=False
338 | Restrict coefficients to be >= 0.
339 | This option is only allowed with method 'lasso'. Note that the model
340 | coefficients will not converge to the ordinary-least-squares solution
341 | for small values of alpha. Only coefficients up to the smallest alpha
342 | value (``alphas_[alphas_ > 0.].min()`` when fit_path=True) reached by
343 | the stepwise Lars-Lasso algorithm are typically in congruence with the
344 | solution of the coordinate descent lasso_path function.
345 | testing : bool, default=False
346 | Whether to conduct hypothesis testing each time a new variable enters
347 | alpha : float, default=0.05
348 | Significance level of hypothesis testing. Valid only if testing is True.
349 | testing_stop : bool, default=False
350 | If set to True, stops calculating future paths when the test yields
351 | insignificant results.
352 | Only takes effect when testing is set to True.
353 | testing_verbose : bool, default=True
354 | Controls output verbosity for hypothese testing procedure.
355 | Returns
356 | -------
357 | alphas : array-like of shape (n_alphas + 1,)
358 | Maximum of covariances (in absolute value) at each iteration.
359 | ``n_alphas`` is either ``max_iter``, ``n_features`` or the
360 | number of nodes in the path with ``alpha >= alpha_min``, whichever
361 | is smaller.
362 | active : array-like of shape (n_alphas,)
363 | Indices of active variables at the end of the path.
364 | coefs : array-like of shape (n_features, n_alphas + 1)
365 | Coefficients along the path
366 | n_iter : int
367 | Number of iterations run. Returned only if return_n_iter is set
368 | to True.
369 | test_result: dictionary
370 | Contains testing results in the form of [test_stats, new_n] produced
371 | at each step. Returned only if testing is set to True.
372 | See Also
373 | --------
374 | lasso_path
375 | LassoLars
376 | Lars
377 | LassoLarsCV
378 | LarsCV
379 | sklearn.decomposition.sparse_encode
380 | References
381 | ----------
382 | .. [1] "Least Angle Regression", Efron et al.
383 | http://statweb.stanford.edu/~tibs/ftp/lars.pdf
384 | .. [2] `Wikipedia entry on the Least-angle regression
385 | `_
386 | .. [3] `Wikipedia entry on the Lasso
387 | `_
388 | """
389 | if method == "lar" and positive:
390 | raise ValueError(
391 | "Positive constraint not supported for 'lar' " "coding method."
392 | )
393 |
394 | n_samples = n_samples if n_samples is not None else y.size
395 |
396 | if Xy is None:
397 | Cov = np.dot(X.T, y)
398 | else:
399 | Cov = Xy.copy()
400 |
401 | if Gram is None or Gram is False:
402 | Gram = None
403 | if X is None:
404 | raise ValueError('X and Gram cannot both be unspecified.')
405 | elif isinstance(Gram, str) and Gram == 'auto' or Gram is True:
406 | if Gram is True or X.shape[0] > X.shape[1]:
407 | Gram = np.dot(X.T, X)
408 | else:
409 | Gram = None
410 | elif copy_Gram:
411 | Gram = Gram.copy()
412 |
413 | if Gram is None:
414 | n_features = X.shape[1]
415 | else:
416 | n_features = Cov.shape[0]
417 | if Gram.shape != (n_features, n_features):
418 | raise ValueError('The shapes of the inputs Gram and Xy'
419 | ' do not match.')
420 |
421 | if copy_X and X is not None and Gram is None:
422 | # force copy. setting the array to be fortran-ordered
423 | # speeds up the calculation of the (partial) Gram matrix
424 | # and allows to easily swap columns
425 | X = X.copy('F')
426 |
427 | max_features = min(max_iter, n_features)
428 |
429 | dtypes = set(a.dtype for a in (X, y, Xy, Gram) if a is not None)
430 | if len(dtypes) == 1:
431 | # use the precision level of input data if it is consistent
432 | return_dtype = next(iter(dtypes))
433 | else:
434 | # fallback to double precision otherwise
435 | return_dtype = np.float64
436 |
437 | if return_path:
438 | coefs = np.zeros((max_features + 1, n_features), dtype=return_dtype)
439 | alphas = np.zeros(max_features + 1, dtype=return_dtype)
440 | else:
441 | coef, prev_coef = (np.zeros(n_features, dtype=return_dtype),
442 | np.zeros(n_features, dtype=return_dtype))
443 | alpha, prev_alpha = (np.array([0.], dtype=return_dtype),
444 | np.array([0.], dtype=return_dtype))
445 | # above better ideas?
446 |
447 | n_iter, n_active = 0, 0
448 | active, indices = list(), np.arange(n_features)
449 | # holds the sign of covariance
450 | sign_active = np.empty(max_features, dtype=np.int8)
451 | drop = False
452 |
453 | # will hold the cholesky factorization. Only lower part is
454 | # referenced.
455 | if Gram is None:
456 | L = np.empty((max_features, max_features), dtype=X.dtype)
457 | swap, nrm2 = linalg.get_blas_funcs(('swap', 'nrm2'), (X,))
458 | else:
459 | L = np.empty((max_features, max_features), dtype=Gram.dtype)
460 | swap, nrm2 = linalg.get_blas_funcs(('swap', 'nrm2'), (Cov,))
461 | solve_cholesky, = get_lapack_funcs(('potrs',), (L,))
462 |
463 | if verbose:
464 | if verbose > 1:
465 | print("Step\t\tAdded\t\tDropped\t\tActive set size\t\tC")
466 | else:
467 | sys.stdout.write('.')
468 | sys.stdout.flush()
469 |
470 | tiny32 = np.finfo(np.float32).tiny # to avoid division by 0 warning
471 | equality_tolerance = np.finfo(np.float32).eps
472 |
473 | residual = y - 0
474 | coef = np.zeros(n_features)
475 | test_result = {}
476 |
477 | if Gram is not None:
478 | Gram_copy = Gram.copy()
479 | Cov_copy = Cov.copy()
480 |
481 | z_score = stats.norm.ppf(1 - alpha)
482 | while True:
483 | if not testing:
484 | if Cov.size:
485 | if positive:
486 | C_idx = np.argmax(Cov)
487 | else:
488 | C_idx = np.argmax(np.abs(Cov))
489 |
490 | C_ = Cov[C_idx]
491 |
492 | if positive:
493 | C = C_
494 | else:
495 | C = np.fabs(C_)
496 | else:
497 | C = 0.
498 | else:
499 | # not implemented when if positive is set to True
500 | if Cov.size:
501 | if positive:
502 | C_idx = np.argmax(Cov)
503 | else:
504 | C_idx = np.argmax(np.abs(Cov))
505 | if Cov.size > 1:
506 | C_idx_second = np.abs(Cov).argsort()[-2]
507 |
508 | x1 = X.T[n_active + C_idx]
509 | x2 = X.T[n_active + C_idx_second]
510 |
511 | residual = y - np.dot(X[:, :n_active], coef[active])
512 | u = np.array([np.dot(x1, residual), np.dot(x2, residual)]) / len(y)
513 | cov = np.cov(x1 * residual, x2 * residual)
514 |
515 | new_n = len(y)
516 | if u[0] >= 0 and u[1] >= 0:
517 | test_stats = u[0] - u[1] - z_score * np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y))
518 | if test_stats < 0:
519 | z_alpha = (u[0] - u[1]) / np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y))
520 | new_n = new_n * (z_score / z_alpha) ** 2
521 | elif u[0] >= 0 and u[1] < 0:
522 | test_stats = u[0] + u[1] - z_score * np.sqrt(2 * (cov[0][0] + cov[1][1] + cov[0][1] + cov[1][0]) / len(y))
523 | if test_stats < 0:
524 | z_alpha = (u[0] + u[1]) / np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y))
525 | new_n = new_n * (z_score / z_alpha) ** 2
526 | elif u[0] < 0 and u[1] >= 0:
527 | test_stats = -(u[0] + u[1] + z_score * np.sqrt(2 * (cov[0][0] + cov[1][1] + cov[0][1] + cov[1][0]) / len(y)))
528 | if test_stats < 0:
529 | z_alpha = (-u[0] - u[1]) / np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y))
530 | new_n = new_n * (z_score / z_alpha) ** 2
531 | else:
532 | test_stats = -(u[0] - u[1] + z_score * np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y)))
533 | if test_stats < 0:
534 | z_alpha = (-u[0] + u[1]) / np.sqrt(2 * (cov[0][0] + cov[1][1] - cov[0][1] - cov[1][0]) / len(y))
535 | new_n = new_n * (z_score / z_alpha) ** 2
536 |
537 | test_result[n_active + 1] = [test_stats, new_n]
538 |
539 | if testing_verbose:
540 | print("Selecting " + str(n_active + 1) + "th varieble: ")
541 | print("Correlations: " + str(np.round(u, 4)))
542 | print("Test statistics: " + str(round(test_stats, 4)))
543 |
544 | if testing_stop:
545 | if test_stats < 0:
546 | if testing_verbose:
547 | print("Not enough samples!")
548 | return alphas, active, coefs.T, test_result
549 | else:
550 | test_result[n_active + 1] = [0, 0]
551 |
552 | C_ = Cov[C_idx]
553 |
554 | if positive:
555 | C = C_
556 | else:
557 | C = np.fabs(C_)
558 | else:
559 | C = 0.
560 |
561 | if return_path:
562 | alpha = alphas[n_iter, np.newaxis]
563 | coef = coefs[n_iter]
564 | prev_alpha = alphas[n_iter - 1, np.newaxis]
565 | prev_coef = coefs[n_iter - 1]
566 |
567 | alpha[0] = C / n_samples
568 | if alpha[0] <= alpha_min + equality_tolerance: # early stopping
569 | if abs(alpha[0] - alpha_min) > equality_tolerance:
570 | # interpolation factor 0 <= ss < 1
571 | if n_iter > 0:
572 | # In the first iteration, all alphas are zero, the formula
573 | # below would make ss a NaN
574 | ss = ((prev_alpha[0] - alpha_min) /
575 | (prev_alpha[0] - alpha[0]))
576 | coef[:] = prev_coef + ss * (coef - prev_coef)
577 | alpha[0] = alpha_min
578 | if return_path:
579 | coefs[n_iter] = coef
580 | break
581 |
582 | if n_iter >= max_iter or n_active >= n_features:
583 | break
584 | if not drop:
585 |
586 | ##########################################################
587 | # Append x_j to the Cholesky factorization of (Xa * Xa') #
588 | # #
589 | # ( L 0 ) #
590 | # L -> ( ) , where L * w = Xa' x_j #
591 | # ( w z ) and z = ||x_j|| #
592 | # #
593 | ##########################################################
594 |
595 | if positive:
596 | sign_active[n_active] = np.ones_like(C_)
597 | else:
598 | sign_active[n_active] = np.sign(C_)
599 | m, n = n_active, C_idx + n_active
600 |
601 | Cov[C_idx], Cov[0] = swap(Cov[C_idx], Cov[0])
602 | indices[n], indices[m] = indices[m], indices[n]
603 | Cov_not_shortened = Cov
604 | Cov = Cov[1:] # remove Cov[0]
605 |
606 | if Gram is None:
607 | X.T[n], X.T[m] = swap(X.T[n], X.T[m])
608 | c = nrm2(X.T[n_active]) ** 2
609 | L[n_active, :n_active] = \
610 | np.dot(X.T[n_active], X.T[:n_active].T)
611 | else:
612 | # swap does only work inplace if matrix is fortran
613 | # contiguous ...
614 | Gram[m], Gram[n] = swap(Gram[m], Gram[n])
615 | Gram[:, m], Gram[:, n] = swap(Gram[:, m], Gram[:, n])
616 | c = Gram[n_active, n_active]
617 | L[n_active, :n_active] = Gram[n_active, :n_active]
618 |
619 | # Update the cholesky decomposition for the Gram matrix
620 | if n_active:
621 | linalg.solve_triangular(L[:n_active, :n_active],
622 | L[n_active, :n_active],
623 | trans=0, lower=1,
624 | overwrite_b=True,
625 | **SOLVE_TRIANGULAR_ARGS)
626 |
627 | v = np.dot(L[n_active, :n_active], L[n_active, :n_active])
628 | diag = max(np.sqrt(np.abs(c - v)), eps)
629 | L[n_active, n_active] = diag
630 |
631 | if diag < 1e-7:
632 | # The system is becoming too ill-conditioned.
633 | # We have degenerate vectors in our active set.
634 | # We'll 'drop for good' the last regressor added.
635 |
636 | # Note: this case is very rare. It is no longer triggered by
637 | # the test suite. The `equality_tolerance` margin added in 0.16
638 | # to get early stopping to work consistently on all versions of
639 | # Python including 32 bit Python under Windows seems to make it
640 | # very difficult to trigger the 'drop for good' strategy.
641 | warnings.warn('Regressors in active set degenerate. '
642 | 'Dropping a regressor, after %i iterations, '
643 | 'i.e. alpha=%.3e, '
644 | 'with an active set of %i regressors, and '
645 | 'the smallest cholesky pivot element being %.3e.'
646 | ' Reduce max_iter or increase eps parameters.'
647 | % (n_iter, alpha, n_active, diag),
648 | ConvergenceWarning)
649 |
650 | # XXX: need to figure a 'drop for good' way
651 | Cov = Cov_not_shortened
652 | Cov[0] = 0
653 | Cov[C_idx], Cov[0] = swap(Cov[C_idx], Cov[0])
654 | continue
655 |
656 | active.append(indices[n_active])
657 | n_active += 1
658 |
659 | if verbose > 1:
660 | print("%s\t\t%s\t\t%s\t\t%s\t\t%s" % (n_iter, active[-1], '',
661 | n_active, C))
662 |
663 | if method == 'lasso' and n_iter > 0 and prev_alpha[0] < alpha[0]:
664 | # alpha is increasing. This is because the updates of Cov are
665 | # bringing in too much numerical error that is greater than
666 | # than the remaining correlation with the
667 | # regressors. Time to bail out
668 | warnings.warn('Early stopping the lars path, as the residues '
669 | 'are small and the current value of alpha is no '
670 | 'longer well controlled. %i iterations, alpha=%.3e, '
671 | 'previous alpha=%.3e, with an active set of %i '
672 | 'regressors.'
673 | % (n_iter, alpha, prev_alpha, n_active),
674 | ConvergenceWarning)
675 | break
676 |
677 | # least squares solution
678 | least_squares, _ = solve_cholesky(L[:n_active, :n_active],
679 | sign_active[:n_active],
680 | lower=True)
681 |
682 | if least_squares.size == 1 and least_squares == 0:
683 | # This happens because sign_active[:n_active] = 0
684 | least_squares[...] = 1
685 | AA = 1.
686 | else:
687 | # is this really needed ?
688 | AA = 1. / np.sqrt(np.sum(least_squares * sign_active[:n_active]))
689 |
690 | if not np.isfinite(AA):
691 | # L is too ill-conditioned
692 | i = 0
693 | L_ = L[:n_active, :n_active].copy()
694 | while not np.isfinite(AA):
695 | L_.flat[::n_active + 1] += (2 ** i) * eps
696 | least_squares, _ = solve_cholesky(
697 | L_, sign_active[:n_active], lower=True)
698 | tmp = max(np.sum(least_squares * sign_active[:n_active]),
699 | eps)
700 | AA = 1. / np.sqrt(tmp)
701 | i += 1
702 | least_squares *= AA
703 |
704 | if Gram is None:
705 | # equiangular direction of variables in the active set
706 | eq_dir = np.dot(X.T[:n_active].T, least_squares)
707 | # correlation between each unactive variables and
708 | # eqiangular vector
709 | corr_eq_dir = np.dot(X.T[n_active:], eq_dir)
710 | else:
711 | # if huge number of features, this takes 50% of time, I
712 | # think could be avoided if we just update it using an
713 | # orthogonal (QR) decomposition of X
714 | corr_eq_dir = np.dot(Gram[:n_active, n_active:].T,
715 | least_squares)
716 |
717 | g1 = arrayfuncs.min_pos((C - Cov) / (AA - corr_eq_dir + tiny32))
718 | if positive:
719 | gamma_ = min(g1, C / AA)
720 | else:
721 | g2 = arrayfuncs.min_pos((C + Cov) / (AA + corr_eq_dir + tiny32))
722 | gamma_ = min(g1, g2, C / AA)
723 |
724 | # TODO: better names for these variables: z
725 | drop = False
726 | z = -coef[active] / (least_squares + tiny32)
727 | z_pos = arrayfuncs.min_pos(z)
728 | if z_pos < gamma_:
729 | # some coefficients have changed sign
730 | idx = np.where(z == z_pos)[0][::-1]
731 |
732 | # update the sign, important for LAR
733 | sign_active[idx] = -sign_active[idx]
734 |
735 | if method == 'lasso':
736 | gamma_ = z_pos
737 | drop = True
738 |
739 | n_iter += 1
740 |
741 | if return_path:
742 | if n_iter >= coefs.shape[0]:
743 | del coef, alpha, prev_alpha, prev_coef
744 | # resize the coefs and alphas array
745 | add_features = 2 * max(1, (max_features - n_active))
746 | coefs = np.resize(coefs, (n_iter + add_features, n_features))
747 | coefs[-add_features:] = 0
748 | alphas = np.resize(alphas, n_iter + add_features)
749 | alphas[-add_features:] = 0
750 | coef = coefs[n_iter]
751 | prev_coef = coefs[n_iter - 1]
752 | else:
753 | # mimic the effect of incrementing n_iter on the array references
754 | prev_coef = coef
755 | prev_alpha[0] = alpha[0]
756 | coef = np.zeros_like(coef)
757 |
758 | coef[active] = prev_coef[active] + gamma_ * least_squares
759 |
760 | # update correlations
761 | Cov -= gamma_ * corr_eq_dir
762 |
763 | # See if any coefficient has changed sign
764 | if drop and method == 'lasso':
765 |
766 | # handle the case when idx is not length of 1
767 | for ii in idx:
768 | arrayfuncs.cholesky_delete(L[:n_active, :n_active], ii)
769 |
770 | n_active -= 1
771 | # handle the case when idx is not length of 1
772 | drop_idx = [active.pop(ii) for ii in idx]
773 |
774 | if Gram is None:
775 | # propagate dropped variable
776 | for ii in idx:
777 | for i in range(ii, n_active):
778 | X.T[i], X.T[i + 1] = swap(X.T[i], X.T[i + 1])
779 | # yeah this is stupid
780 | indices[i], indices[i + 1] = indices[i + 1], indices[i]
781 |
782 | # TODO: this could be updated
783 | residual = y - np.dot(X[:, :n_active], coef[active])
784 | temp = np.dot(X.T[n_active], residual)
785 |
786 | Cov = np.r_[temp, Cov]
787 | else:
788 | for ii in idx:
789 | for i in range(ii, n_active):
790 | indices[i], indices[i + 1] = indices[i + 1], indices[i]
791 | Gram[i], Gram[i + 1] = swap(Gram[i], Gram[i + 1])
792 | Gram[:, i], Gram[:, i + 1] = swap(Gram[:, i],
793 | Gram[:, i + 1])
794 |
795 | # Cov_n = Cov_j + x_j * X + increment(betas) TODO:
796 | # will this still work with multiple drops ?
797 |
798 | # recompute covariance. Probably could be done better
799 | # wrong as Xy is not swapped with the rest of variables
800 |
801 | # TODO: this could be updated
802 | temp = Cov_copy[drop_idx] - np.dot(Gram_copy[drop_idx], coef)
803 | Cov = np.r_[temp, Cov]
804 |
805 | sign_active = np.delete(sign_active, idx)
806 | sign_active = np.append(sign_active, 0.) # just to maintain size
807 | if verbose > 1:
808 | print("%s\t\t%s\t\t%s\t\t%s\t\t%s" % (n_iter, '', drop_idx,
809 | n_active, abs(temp)))
810 |
811 | if return_path:
812 | # resize coefs in case of early stop
813 | alphas = alphas[:n_iter + 1]
814 | coefs = coefs[:n_iter + 1]
815 |
816 | if return_n_iter:
817 | return alphas, active, coefs.T, n_iter
818 | else:
819 | if testing:
820 | return alphas, active, coefs.T, test_result
821 | else:
822 | return alphas, active, coefs.T
823 | else:
824 | if return_n_iter:
825 | return alpha, active, coef, n_iter
826 | else:
827 | return alpha, active, coef
--------------------------------------------------------------------------------