├── .github
├── pull_request_template.md
└── workflows
│ └── build.yml
├── .gitignore
├── .pylintrc
├── CITATION.cff
├── LICENSE
├── README.md
├── codecov.yml
├── demo
├── __init__.py
└── feature_overview
│ ├── __init__.py
│ ├── feature_overview.ipynb
│ ├── healthcare.png
│ ├── missing_embeddings.py
│ ├── no_missing_embeddings.py
│ └── paper_example_image.png
├── example_pipelines
├── __init__.py
├── _pipelines.py
├── adult_complex
│ ├── __init__.py
│ ├── adult_complex.png
│ ├── adult_complex.py
│ ├── adult_test.csv
│ └── adult_train.csv
├── adult_simple
│ ├── __init__.py
│ ├── adult_simple.ipynb
│ ├── adult_simple.png
│ └── adult_simple.py
├── compas
│ ├── __init__.py
│ ├── compas.png
│ ├── compas.py
│ ├── compas_test.csv
│ └── compas_train.csv
└── healthcare
│ ├── __init__.py
│ ├── _gensim_wrapper.py
│ ├── custom_monkeypatching
│ ├── __init__.py
│ └── patch_healthcare_utils.py
│ ├── healthcare.png
│ ├── healthcare.py
│ ├── healthcare_utils.py
│ ├── histories.csv
│ └── patients.csv
├── experiments
├── __init__.py
└── performance
│ ├── __init__.py
│ ├── _benchmark_utils.py
│ ├── _empty_inspection.py
│ └── performance_benchmarks.ipynb
├── mlinspect
├── __init__.py
├── _inspector_result.py
├── _pipeline_inspector.py
├── backends
│ ├── __init__.py
│ ├── _all_backends.py
│ ├── _backend.py
│ ├── _backend_utils.py
│ ├── _iter_creation.py
│ ├── _pandas_backend.py
│ └── _sklearn_backend.py
├── checks
│ ├── __init__.py
│ ├── _check.py
│ ├── _no_bias_introduced_for.py
│ ├── _no_illegal_features.py
│ └── _similar_removal_probabilities_for.py
├── inspections
│ ├── __init__.py
│ ├── _arg_capturing.py
│ ├── _column_propagation.py
│ ├── _completeness_of_columns.py
│ ├── _count_distinct_of_columns.py
│ ├── _histogram_for_columns.py
│ ├── _inspection.py
│ ├── _inspection_input.py
│ ├── _inspection_result.py
│ ├── _intersectional_histogram_for_columns.py
│ ├── _lineage.py
│ └── _materialize_first_output_rows.py
├── instrumentation
│ ├── __init__.py
│ ├── _call_capture_transformer.py
│ ├── _dag_node.py
│ └── _pipeline_executor.py
├── monkeypatching
│ ├── README.md
│ ├── __init__.py
│ ├── _mlinspect_ndarray.py
│ ├── _monkey_patching_utils.py
│ ├── _patch_numpy.py
│ ├── _patch_pandas.py
│ ├── _patch_sklearn.py
│ └── _patch_statsmodels.py
├── testing
│ ├── __init__.py
│ ├── _random_annotation_testing_inspection.py
│ └── _testing_helper_utils.py
├── utils
│ ├── __init__.py
│ └── _utils.py
└── visualisation
│ ├── __init__.py
│ └── _visualisation.py
├── requirements
├── requirements.dev.txt
└── requirements.txt
├── setup.cfg
├── setup.py
└── test
├── __init__.py
├── backends
├── __init__.py
├── test_pandas_backend.py
└── test_sklearn_backend.py
├── checks
├── __init__.py
├── test_no_bias_introduced_for.py
├── test_no_illegal_features.py
└── test_similar_removal_probablities_for.py
├── demo
├── __init__.py
└── feature_overview
│ ├── __init__.py
│ ├── test_feature_overview.py
│ ├── test_missing_embeddings.py
│ └── test_no_missing_embeddings.py
├── example_pipelines
├── __init__.py
├── test_adult_complex.py
├── test_adult_simple.py
├── test_compas.py
├── test_healthcare.py
└── test_patch_healthcare_utils.py
├── experiments
├── __init__.py
└── performance
│ ├── __init__.py
│ ├── test_benchmark_utils.py
│ └── test_performance_benchmarks.py
├── inspections
├── __init__.py
├── test_arg_capturing.py
├── test_column_propagation.py
├── test_completeness_of_columns.py
├── test_count_distinct_of_columns.py
├── test_histogram_for_columns.py
├── test_intersectional_histogram_for_columns.py
├── test_lineage.py
└── test_materialize_first_output_rows.py
├── instrumentation
├── __init__.py
└── test_pipeline_executor.py
├── monkeypatching
├── __init__.py
├── test_patch_numpy.py
├── test_patch_pandas.py
├── test_patch_sklearn.py
└── test_patch_statsmodels.py
├── test_pipeline_inspector.py
├── utils
├── __init__.py
└── test_utils.py
└── visualisation
├── __init__.py
└── test_visualisation.py
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | *Issue #, if available:*
2 |
3 | *Description of changes:*
4 |
5 |
6 | By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.
7 |
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: Lint and Test
2 |
3 | on: [ push, pull_request ]
4 |
5 | jobs:
6 | python:
7 | runs-on: ubuntu-latest
8 | strategy:
9 | matrix:
10 | python-version: [ '3.10' ]
11 |
12 | steps:
13 | - uses: actions/checkout@v2
14 | - name: Set up Python ${{ matrix.python-version }}
15 | uses: actions/setup-python@v2
16 | with:
17 | python-version: ${{ matrix.python-version }}
18 |
19 | - name: Setup Graphviz
20 | uses: ts-graphviz/setup-graphviz@v1
21 |
22 | - name: Cache dependencies
23 | uses: actions/cache@v2
24 | with:
25 | path: ${{ env.pythonLocation }}
26 | key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('requirements/requirements.txt') }}-${{ hashFiles('requirements/requirements.dev.txt') }}
27 |
28 | - name: Install dependencies
29 | env:
30 | SETUPTOOLS_USE_DISTUTILS: stdlib
31 | run: |
32 | python -m pip install --upgrade pip
33 | pip install --upgrade --upgrade-strategy eager ".[dev]"
34 |
35 | - name: Unit Tests
36 | run: python -m pytest
37 |
38 | - name: Upload Coverage Report
39 | uses: codecov/codecov-action@v1
40 | with:
41 | token: ${{ secrets.CODECOV_TOKEN }}
42 | files: ./coverage.xml
43 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # PyCharm
132 | .idea
133 |
134 | # MacOS
135 | .DS_Store
136 |
137 | # VS Code
138 | .vscode
139 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Grafberger"
5 | given-names: "Stefan"
6 | orcid: "https://orcid.org/0000-0002-9884-9517"
7 | - family-names: "Groth"
8 | given-names: "Paul"
9 | orcid: "https://orcid.org/0000-0003-0183-6910"
10 | - family-names: "Stoyanovich"
11 | given-names: "Julia"
12 | - family-names: "Schelter"
13 | given-names: "Sebastian"
14 | title: "Data Distribution Debugging in Machine Learning Pipelines"
15 | doi: 10.1007/s00778-021-00726-w
16 | url: "https://github.com/stefan-grafberger/mlinspect"
17 | preferred-citation:
18 | type: article
19 | authors:
20 | - family-names: "Grafberger"
21 | given-names: "Stefan"
22 | orcid: "https://orcid.org/0000-0002-9884-9517"
23 | - family-names: "Groth"
24 | given-names: "Paul"
25 | orcid: "https://orcid.org/0000-0003-0183-6910"
26 | - family-names: "Stoyanovich"
27 | given-names: "Julia"
28 | - family-names: "Schelter"
29 | given-names: "Sebastian"
30 | title: "Data Distribution Debugging in Machine Learning Pipelines"
31 | doi: 10.1007/s00778-021-00726-w
32 | date-released: 2022-01-31
33 |
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | mlinspect
2 | ================================
3 |
4 | [](https://github.com/stefan-grafberger/mlinspect)
5 | [](https://github.com/stefan-grafberger/mlinspect/blob/master/LICENSE)
6 | [](https://github.com/stefan-grafberger/mlinspect/actions/workflows/build.yml)
7 | [](https://codecov.io/gh/stefan-grafberger/mlinspect)
8 |
9 | Inspect ML Pipelines in Python in the form of a DAG
10 |
11 | ## Run mlinspect locally
12 |
13 | Prerequisite: Python 3.10
14 |
15 | 1. Clone this repository
16 | 2. Set up the environment
17 |
18 | `cd mlinspect`
19 | `python -m venv venv`
20 | `source venv/bin/activate`
21 |
22 | 3. If you want to use the visualisation functions we provide, install graphviz which can not be installed via pip
23 |
24 | `Linux: ` `apt-get install graphviz`
25 | `MAC OS: ` `brew install graphviz`
26 |
27 | 4. Install pip dependencies
28 |
29 | `SETUPTOOLS_USE_DISTUTILS=stdlib pip install -e .[dev]`
30 |
31 | 5. To ensure everything works, you can run the tests (without graphviz, the visualisation test will fail)
32 |
33 | `python setup.py test`
34 |
35 | ## How to use mlinspect
36 | mlinspect makes it easy to analyze your pipeline and automatically check for common issues.
37 | ```python
38 | from mlinspect import PipelineInspector
39 | from mlinspect.inspections import MaterializeFirstOutputRows
40 | from mlinspect.checks import NoBiasIntroducedFor
41 |
42 | IPYNB_PATH = ...
43 |
44 | inspector_result = PipelineInspector\
45 | .on_pipeline_from_ipynb_file(IPYNB_PATH)\
46 | .add_required_inspection(MaterializeFirstOutputRows(5))\
47 | .add_check(NoBiasIntroducedFor(['race']))\
48 | .execute()
49 |
50 | extracted_dag = inspector_result.dag
51 | dag_node_to_inspection_results = inspector_result.dag_node_to_inspection_results
52 | check_to_check_results = inspector_result.check_to_check_results
53 | ```
54 |
55 | ## Detailed Example
56 | We prepared a [demo notebook](demo/feature_overview/feature_overview.ipynb) to showcase mlinspect and its features.
57 |
58 | ## Supported libraries and API functions
59 | mlinspect already supports a selection of API functions from `pandas` and `scikit-learn`. Extending mlinspect to support more and more API functions and libraries will be an ongoing effort. However, mlinspect won't just crash when it encounters functions it doesn't recognize yet. For more information, please see [here](mlinspect/monkeypatching/README.md).
60 |
61 | ## Notes
62 | * For debugging in PyCharm, set the pytest flag `--no-cov` ([Link](https://stackoverflow.com/questions/34870962/how-to-debug-py-test-in-pycharm-when-coverage-is-enabled))
63 |
64 | ## Publications
65 | * [Stefan Grafberger, Paul Groth, Julia Stoyanovich, Sebastian Schelter (2022). Data Distribution Debugging in Machine Learning Pipelines. The VLDB Journal — The International Journal on Very Large Data Bases (Special Issue on Data Science for Responsible Data Management).](https://stefan-grafberger.com/mlinspect-journal.pdf)
66 | * [Stefan Grafberger, Shubha Guha, Julia Stoyanovich, Sebastian Schelter (2021). mlinspect: a Data Distribution Debugger for Machine Learning Pipelines. ACM SIGMOD (demo).](https://stefan-grafberger.com/mlinspect-demo.pdf)
67 | * [Stefan Grafberger, Julia Stoyanovich, Sebastian Schelter (2020). Lightweight Inspection of Data Preprocessing in Native Machine Learning Pipelines. Conference on Innovative Data Systems Research (CIDR).](https://stefan-grafberger.com/mlinspect-cidr.pdf)
68 |
69 | ## License
70 | This library is licensed under the Apache 2.0 License.
71 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | status:
3 | project:
4 | default:
5 | target: auto
6 | threshold: 2%
7 | patch: off
8 |
--------------------------------------------------------------------------------
/demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/demo/__init__.py
--------------------------------------------------------------------------------
/demo/feature_overview/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/demo/feature_overview/__init__.py
--------------------------------------------------------------------------------
/demo/feature_overview/healthcare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/demo/feature_overview/healthcare.png
--------------------------------------------------------------------------------
/demo/feature_overview/missing_embeddings.py:
--------------------------------------------------------------------------------
1 | """
2 | The MissingEmbedding Inspection
3 | """
4 | import dataclasses
5 | from typing import Iterable, List
6 |
7 | from mlinspect import FunctionInfo
8 | from mlinspect.inspections import Inspection, InspectionInputUnaryOperator
9 |
10 |
11 | @dataclasses.dataclass(frozen=True, eq=True)
12 | class MissingEmbeddingsInfo:
13 | """
14 | Info about potentially missing embeddings
15 | """
16 | missing_embedding_count: int
17 | missing_embeddings_examples: List[str]
18 |
19 |
20 | class MissingEmbeddings(Inspection):
21 | """
22 | A simple example inspection
23 | """
24 |
25 | def __init__(self, example_threshold=10):
26 | self._is_embedding_operator = False
27 | self._missing_embedding_count = 0
28 | self._missing_embeddings_examples = []
29 | self.example_threshold = example_threshold
30 |
31 | def visit_operator(self, inspection_input) -> Iterable[any]:
32 | """
33 | Visit an operator
34 | """
35 | # pylint: disable=too-many-branches, too-many-statements
36 | if isinstance(inspection_input, InspectionInputUnaryOperator) and \
37 | inspection_input.operator_context.function_info == \
38 | FunctionInfo('example_pipelines.healthcare.healthcare_utils', 'MyW2VTransformer'):
39 | # TODO: Are there existing word embedding transformers for sklearn we can use this for?
40 | self._is_embedding_operator = True
41 | for row in inspection_input.row_iterator:
42 | # Count missing embeddings
43 | embedding_array = row.output[0]
44 | is_zero_vector = not embedding_array.any()
45 | if is_zero_vector:
46 | self._missing_embedding_count += 1
47 | if len(self._missing_embeddings_examples) < self.example_threshold:
48 | self._missing_embeddings_examples.append(row.input[0])
49 | yield None
50 | else:
51 | for _ in inspection_input.row_iterator:
52 | yield None
53 |
54 | def get_operator_annotation_after_visit(self) -> any:
55 | if self._is_embedding_operator:
56 | assert self._missing_embedding_count is not None # May only be called after the operator visit is finished
57 | result = MissingEmbeddingsInfo(self._missing_embedding_count, self._missing_embeddings_examples)
58 | self._missing_embedding_count = 0
59 | self._is_embedding_operator = False
60 | self._missing_embeddings_examples = []
61 | return result
62 | return None
63 |
64 | @property
65 | def inspection_id(self):
66 | return self.example_threshold
67 |
--------------------------------------------------------------------------------
/demo/feature_overview/no_missing_embeddings.py:
--------------------------------------------------------------------------------
1 | """
2 | The NoMissingEmbeddings Check
3 | """
4 | import collections
5 | import dataclasses
6 | from typing import Iterable, Dict
7 |
8 | from demo.feature_overview.missing_embeddings import MissingEmbeddings, MissingEmbeddingsInfo
9 | from mlinspect import DagNode
10 | from mlinspect.checks import Check, CheckStatus, CheckResult
11 | from mlinspect.inspections import Inspection, InspectionResult
12 |
13 |
14 | ILLEGAL_FEATURES = {"race", "gender", "age"}
15 |
16 |
17 | @dataclasses.dataclass
18 | class NoMissingEmbeddingsResult(CheckResult):
19 | """
20 | Does the pipeline use illegal features?
21 | """
22 | dag_node_to_missing_embeddings: Dict[DagNode, MissingEmbeddingsInfo]
23 |
24 |
25 | class NoMissingEmbeddings(Check):
26 | """
27 | Does the model get sensitive attributes like race as feature?
28 | """
29 | # pylint: disable=unnecessary-pass, too-few-public-methods
30 |
31 | def __init__(self, example_threshold=10):
32 | self.example_threshold = example_threshold
33 |
34 | @property
35 | def check_id(self):
36 | """The id of the Constraints"""
37 | return self.example_threshold
38 |
39 | @property
40 | def required_inspections(self) -> Iterable[Inspection]:
41 | """The id of the check"""
42 | return [MissingEmbeddings(self.example_threshold)]
43 |
44 | def evaluate(self, inspection_result: InspectionResult) -> CheckResult:
45 | """Evaluate the check"""
46 | dag_node_to_missing_embeddings = {}
47 | for dag_node, dag_node_inspection_result in inspection_result.dag_node_to_inspection_results.items():
48 | if MissingEmbeddings(self.example_threshold) in dag_node_inspection_result:
49 | missing_embedding_info = dag_node_inspection_result[MissingEmbeddings(self.example_threshold)]
50 | assert missing_embedding_info is None or isinstance(missing_embedding_info, MissingEmbeddingsInfo)
51 | if missing_embedding_info is not None and missing_embedding_info.missing_embedding_count > 0:
52 | dag_node_to_missing_embeddings[dag_node] = missing_embedding_info
53 | if dag_node_to_missing_embeddings:
54 | description = "Missing embeddings were found!"
55 | result = NoMissingEmbeddingsResult(self, CheckStatus.FAILURE, description, dag_node_to_missing_embeddings)
56 | else:
57 | result = NoMissingEmbeddingsResult(self, CheckStatus.SUCCESS, None, collections.OrderedDict())
58 | return result
59 |
--------------------------------------------------------------------------------
/demo/feature_overview/paper_example_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/demo/feature_overview/paper_example_image.png
--------------------------------------------------------------------------------
/example_pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Packages and classes we want to expose to users
3 | """
4 | from ._pipelines import ADULT_SIMPLE_PY, ADULT_SIMPLE_IPYNB, ADULT_SIMPLE_PNG, ADULT_COMPLEX_PY, ADULT_COMPLEX_PNG, \
5 | COMPAS_PY, COMPAS_PNG, HEALTHCARE_PY, HEALTHCARE_PNG
6 |
7 | __all__ = [
8 | 'ADULT_SIMPLE_PY', 'ADULT_SIMPLE_IPYNB', 'ADULT_SIMPLE_PNG',
9 | 'ADULT_COMPLEX_PY', 'ADULT_COMPLEX_PNG',
10 | 'COMPAS_PY', 'COMPAS_PNG',
11 | 'HEALTHCARE_PY', 'HEALTHCARE_PNG',
12 | ]
13 |
--------------------------------------------------------------------------------
/example_pipelines/_pipelines.py:
--------------------------------------------------------------------------------
1 | """
2 | Some useful utils for the project
3 | """
4 | import os
5 |
6 | from mlinspect.utils import get_project_root
7 |
8 | ADULT_SIMPLE_PY = os.path.join(str(get_project_root()), "example_pipelines", "adult_simple", "adult_simple.py")
9 | ADULT_SIMPLE_IPYNB = os.path.join(str(get_project_root()), "example_pipelines", "adult_simple", "adult_simple.ipynb")
10 | ADULT_SIMPLE_PNG = os.path.join(str(get_project_root()), "example_pipelines", "adult_simple", "adult_simple.png")
11 |
12 | ADULT_COMPLEX_PY = os.path.join(str(get_project_root()), "example_pipelines", "adult_complex", "adult_complex.py")
13 | ADULT_COMPLEX_PNG = os.path.join(str(get_project_root()), "example_pipelines", "adult_complex", "adult_complex.png")
14 |
15 | COMPAS_PY = os.path.join(str(get_project_root()), "example_pipelines", "compas", "compas.py")
16 | COMPAS_PNG = os.path.join(str(get_project_root()), "example_pipelines", "compas", "compas.png")
17 |
18 | HEALTHCARE_PY = os.path.join(str(get_project_root()), "example_pipelines", "healthcare", "healthcare.py")
19 | HEALTHCARE_PNG = os.path.join(str(get_project_root()), "example_pipelines", "healthcare", "healthcare.png")
20 |
--------------------------------------------------------------------------------
/example_pipelines/adult_complex/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/adult_complex/__init__.py
--------------------------------------------------------------------------------
/example_pipelines/adult_complex/adult_complex.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/adult_complex/adult_complex.png
--------------------------------------------------------------------------------
/example_pipelines/adult_complex/adult_complex.py:
--------------------------------------------------------------------------------
1 | """
2 | An example pipeline
3 | """
4 | import os
5 |
6 | import pandas as pd
7 | import numpy as np
8 | from sklearn import preprocessing
9 | from sklearn.compose import ColumnTransformer
10 | from sklearn.impute import SimpleImputer
11 | from sklearn.pipeline import Pipeline
12 | from sklearn.preprocessing import OneHotEncoder, StandardScaler
13 | from sklearn.tree import DecisionTreeClassifier
14 |
15 | from mlinspect.utils import get_project_root
16 |
17 | train_file = os.path.join(str(get_project_root()), "example_pipelines", "adult_complex", "adult_train.csv")
18 | train_data = pd.read_csv(train_file, na_values='?', index_col=0)
19 | test_file = os.path.join(str(get_project_root()), "example_pipelines", "adult_complex", "adult_test.csv")
20 | test_data = pd.read_csv(test_file, na_values='?', index_col=0)
21 |
22 | train_labels = preprocessing.label_binarize(train_data['income-per-year'], classes=['>50K', '<=50K'])
23 | test_labels = preprocessing.label_binarize(test_data['income-per-year'], classes=['>50K', '<=50K'])
24 |
25 | nested_categorical_feature_transformation = Pipeline([
26 | ('impute', SimpleImputer(missing_values=np.nan, strategy='most_frequent')),
27 | ('encode', OneHotEncoder(handle_unknown='ignore'))
28 | ])
29 |
30 | nested_feature_transformation = ColumnTransformer(transformers=[
31 | ('categorical', nested_categorical_feature_transformation, ['education', 'workclass']),
32 | ('numeric', StandardScaler(), ['age', 'hours-per-week'])
33 | ])
34 |
35 | nested_income_pipeline = Pipeline([
36 | ('features', nested_feature_transformation),
37 | ('classifier', DecisionTreeClassifier())])
38 |
39 | nested_income_pipeline.fit(train_data, train_labels)
40 |
41 | print(nested_income_pipeline.score(test_data, test_labels))
42 |
--------------------------------------------------------------------------------
/example_pipelines/adult_simple/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/adult_simple/__init__.py
--------------------------------------------------------------------------------
/example_pipelines/adult_simple/adult_simple.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stdout",
10 | "output_type": "stream",
11 | "text": [
12 | "pipeline start\n",
13 | "pipeline finished\n"
14 | ]
15 | }
16 | ],
17 | "source": [
18 | "\"\"\"\n",
19 | "An example pipeline\n",
20 | "\"\"\"\n",
21 | "import os\n",
22 | "import pandas as pd\n",
23 | "\n",
24 | "from sklearn import compose, preprocessing, tree, pipeline\n",
25 | "from mlinspect.utils import get_project_root\n",
26 | "\n",
27 | "print('pipeline start')\n",
28 | "train_file = os.path.join(str(get_project_root()), \"example_pipelines\", \"adult_complex\", \"adult_train.csv\")\n",
29 | "raw_data = pd.read_csv(train_file, na_values='?', index_col=0)\n",
30 | "\n",
31 | "data = raw_data.dropna()\n",
32 | "\n",
33 | "labels = preprocessing.label_binarize(data['income-per-year'], classes=['>50K', '<=50K'])\n",
34 | "\n",
35 | "feature_transformation = compose.ColumnTransformer(transformers=[\n",
36 | " ('categorical', preprocessing.OneHotEncoder(handle_unknown='ignore'), ['education', 'workclass']),\n",
37 | " ('numeric', preprocessing.StandardScaler(), ['age', 'hours-per-week'])\n",
38 | "])\n",
39 | "\n",
40 | "\n",
41 | "income_pipeline = pipeline.Pipeline([\n",
42 | " ('features', feature_transformation),\n",
43 | " ('classifier', tree.DecisionTreeClassifier())])\n",
44 | "\n",
45 | "income_pipeline.fit(data, labels)\n",
46 | "\n",
47 | "\n",
48 | "print('pipeline finished')"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": []
57 | }
58 | ],
59 | "metadata": {
60 | "kernelspec": {
61 | "display_name": "Python 3",
62 | "language": "python",
63 | "name": "python3"
64 | },
65 | "language_info": {
66 | "codemirror_mode": {
67 | "name": "ipython",
68 | "version": 3
69 | },
70 | "file_extension": ".py",
71 | "mimetype": "text/x-python",
72 | "name": "python",
73 | "nbconvert_exporter": "python",
74 | "pygments_lexer": "ipython3",
75 | "version": "3.8.5"
76 | }
77 | },
78 | "nbformat": 4,
79 | "nbformat_minor": 2
80 | }
81 |
--------------------------------------------------------------------------------
/example_pipelines/adult_simple/adult_simple.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/adult_simple/adult_simple.png
--------------------------------------------------------------------------------
/example_pipelines/adult_simple/adult_simple.py:
--------------------------------------------------------------------------------
1 | """
2 | An example pipeline
3 | """
4 | import os
5 | import pandas as pd
6 |
7 | from sklearn import compose, preprocessing, tree, pipeline
8 | from mlinspect.utils import get_project_root
9 |
10 | print('pipeline start')
11 | train_file = os.path.join(str(get_project_root()), "example_pipelines", "adult_complex", "adult_train.csv")
12 | raw_data = pd.read_csv(train_file, na_values='?', index_col=0)
13 |
14 | data = raw_data.dropna()
15 |
16 | labels = preprocessing.label_binarize(data['income-per-year'], classes=['>50K', '<=50K'])
17 |
18 | feature_transformation = compose.ColumnTransformer(transformers=[
19 | ('categorical', preprocessing.OneHotEncoder(handle_unknown='ignore'), ['education', 'workclass']),
20 | ('numeric', preprocessing.StandardScaler(), ['age', 'hours-per-week'])
21 | ])
22 |
23 |
24 | income_pipeline = pipeline.Pipeline([
25 | ('features', feature_transformation),
26 | ('classifier', tree.DecisionTreeClassifier())])
27 |
28 | income_pipeline.fit(data, labels)
29 |
30 |
31 | print('pipeline finished')
32 |
--------------------------------------------------------------------------------
/example_pipelines/compas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/compas/__init__.py
--------------------------------------------------------------------------------
/example_pipelines/compas/compas.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/compas/compas.png
--------------------------------------------------------------------------------
/example_pipelines/compas/compas.py:
--------------------------------------------------------------------------------
1 | """
2 | An example pipeline
3 | """
4 | import os
5 |
6 | import pandas as pd
7 | from sklearn.compose import ColumnTransformer
8 | from sklearn.impute import SimpleImputer
9 | from sklearn.linear_model import LogisticRegression
10 | from sklearn.pipeline import Pipeline
11 | from sklearn.preprocessing import OneHotEncoder, KBinsDiscretizer, label_binarize
12 |
13 | from mlinspect.utils import get_project_root
14 |
15 | train_file = os.path.join(str(get_project_root()), "example_pipelines", "compas", "compas_train.csv")
16 | train_data = pd.read_csv(train_file, na_values='?', index_col=0)
17 | test_file = os.path.join(str(get_project_root()), "example_pipelines", "compas", "compas_test.csv")
18 | test_data = pd.read_csv(test_file, na_values='?', index_col=0)
19 |
20 | train_data = train_data[
21 | ['sex', 'dob', 'age', 'c_charge_degree', 'race', 'score_text', 'priors_count', 'days_b_screening_arrest',
22 | 'decile_score', 'is_recid', 'two_year_recid', 'c_jail_in', 'c_jail_out']]
23 | test_data = test_data[
24 | ['sex', 'dob', 'age', 'c_charge_degree', 'race', 'score_text', 'priors_count', 'days_b_screening_arrest',
25 | 'decile_score', 'is_recid', 'two_year_recid', 'c_jail_in', 'c_jail_out']]
26 |
27 | train_data = train_data[(train_data['days_b_screening_arrest'] <= 30) & (train_data['days_b_screening_arrest'] >= -30)]
28 | train_data = train_data[train_data['is_recid'] != -1]
29 | train_data = train_data[train_data['c_charge_degree'] != "O"]
30 | train_data = train_data[train_data['score_text'] != 'N/A']
31 |
32 | train_data = train_data.replace('Medium', "Low")
33 | test_data = test_data.replace('Medium', "Low")
34 |
35 | train_labels = label_binarize(train_data['score_text'], classes=['High', 'Low'])
36 | test_labels = label_binarize(test_data['score_text'], classes=['High', 'Low'])
37 |
38 | impute1_and_onehot = Pipeline([('imputer1', SimpleImputer(strategy='most_frequent')),
39 | ('onehot', OneHotEncoder(handle_unknown='ignore'))])
40 | impute2_and_bin = Pipeline([('imputer2', SimpleImputer(strategy='mean')),
41 | ('discretizer', KBinsDiscretizer(n_bins=4, encode='ordinal', strategy='uniform'))])
42 |
43 | featurizer = ColumnTransformer(transformers=[
44 | ('impute1_and_onehot', impute1_and_onehot, ['is_recid']),
45 | ('impute2_and_bin', impute2_and_bin, ['age'])
46 | ])
47 |
48 | pipeline = Pipeline([
49 | ('features', featurizer),
50 | ('classifier', LogisticRegression())
51 | ])
52 |
53 | pipeline.fit(train_data, train_labels.ravel())
54 | print(pipeline.score(test_data, test_labels.ravel()))
55 |
--------------------------------------------------------------------------------
/example_pipelines/healthcare/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/healthcare/__init__.py
--------------------------------------------------------------------------------
/example_pipelines/healthcare/_gensim_wrapper.py:
--------------------------------------------------------------------------------
1 | # pylint: disable-all
2 | # Author: Chinmaya Pancholi
3 | # Copyright (C) 2017 Radim Rehurek
4 | # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html
5 | """Scikit learn interface for :class:`~gensim.models.word2vec.Word2Vec`.
6 | Follows scikit-learn API conventions to facilitate using gensim along with scikit-learn.
7 | """
8 | import numpy as np
9 | import six
10 | from sklearn.base import TransformerMixin, BaseEstimator
11 | from sklearn.exceptions import NotFittedError
12 |
13 | from gensim import models
14 |
15 |
16 | class W2VTransformer(TransformerMixin, BaseEstimator):
17 | """Base Word2Vec module, wraps :class:`~gensim.models.word2vec.Word2Vec`.
18 | For more information please have a look to `Tomas Mikolov, Kai Chen, Greg Corrado, Jeffrey Dean: "Efficient
19 | Estimation of Word Representations in Vector Space" `_.
20 | """
21 | def __init__(self, size=100, alpha=0.025, window=5, min_count=5, max_vocab_size=None, sample=1e-3, seed=1,
22 | workers=3, min_alpha=0.0001, sg=0, hs=0, negative=5, cbow_mean=1, hashfxn=hash, iter=5, null_word=0,
23 | trim_rule=None, sorted_vocab=1, batch_words=10000):
24 | """
25 | Parameters
26 | ----------
27 | size : int
28 | Dimensionality of the feature vectors.
29 | alpha : float
30 | The initial learning rate.
31 | window : int
32 | The maximum distance between the current and predicted word within a sentence.
33 | min_count : int
34 | Ignores all words with total frequency lower than this.
35 | max_vocab_size : int
36 | Limits the RAM during vocabulary building; if there are more unique
37 | words than this, then prune the infrequent ones. Every 10 million word types need about 1GB of RAM.
38 | Set to `None` for no limit.
39 | sample : float
40 | The threshold for configuring which higher-frequency words are randomly downsampled,
41 | useful range is (0, 1e-5).
42 | seed : int
43 | Seed for the random number generator. Initial vectors for each word are seeded with a hash of
44 | the concatenation of word + `str(seed)`. Note that for a fully deterministically-reproducible run,
45 | you must also limit the model to a single worker thread (`workers=1`), to eliminate ordering jitter
46 | from OS thread scheduling. (In Python 3, reproducibility between interpreter launches also requires
47 | use of the `PYTHONHASHSEED` environment variable to control hash randomization).
48 | workers : int
49 | Use these many worker threads to train the model (=faster training with multicore machines).
50 | min_alpha : float
51 | Learning rate will linearly drop to `min_alpha` as training progresses.
52 | sg : int {1, 0}
53 | Defines the training algorithm. If 1, CBOW is used, otherwise, skip-gram is employed.
54 | hs : int {1,0}
55 | If 1, hierarchical softmax will be used for model training.
56 | If set to 0, and `negative` is non-zero, negative sampling will be used.
57 | negative : int
58 | If > 0, negative sampling will be used, the int for negative specifies how many "noise words"
59 | should be drawn (usually between 5-20).
60 | If set to 0, no negative sampling is used.
61 | cbow_mean : int {1,0}
62 | If 0, use the sum of the context word vectors. If 1, use the mean, only applies when cbow is used.
63 | hashfxn : callable (object -> int), optional
64 | A hashing function. Used to create an initial random reproducible vector by hashing the random seed.
65 | iter : int
66 | Number of iterations (epochs) over the corpus.
67 | null_word : int {1, 0}
68 | If 1, a null pseudo-word will be created for padding when using concatenative L1 (run-of-words)
69 | trim_rule : function
70 | Vocabulary trimming rule, specifies whether certain words should remain in the vocabulary,
71 | be trimmed away, or handled using the default (discard if word count < min_count).
72 | Can be None (min_count will be used, look to :func:`~gensim.utils.keep_vocab_item`),
73 | or a callable that accepts parameters (word, count, min_count) and returns either
74 | :attr:`gensim.utils.RULE_DISCARD`, :attr:`gensim.utils.RULE_KEEP` or :attr:`gensim.utils.RULE_DEFAULT`.
75 | Note: The rule, if given, is only used to prune vocabulary during build_vocab() and is not stored as part
76 | of the model.
77 | sorted_vocab : int {1,0}
78 | If 1, sort the vocabulary by descending frequency before assigning word indexes.
79 | batch_words : int
80 | Target size (in words) for batches of examples passed to worker threads (and
81 | thus cython routines).(Larger batches will be passed if individual
82 | texts are longer than 10000 words, but the standard cython code truncates to that maximum.)
83 | """
84 | self.gensim_model = None
85 | self.size = size
86 | self.alpha = alpha
87 | self.window = window
88 | self.min_count = min_count
89 | self.max_vocab_size = max_vocab_size
90 | self.sample = sample
91 | self.seed = seed
92 | self.workers = workers
93 | self.min_alpha = min_alpha
94 | self.sg = sg
95 | self.hs = hs
96 | self.negative = negative
97 | self.cbow_mean = int(cbow_mean)
98 | self.hashfxn = hashfxn
99 | self.iter = iter
100 | self.null_word = null_word
101 | self.trim_rule = trim_rule
102 | self.sorted_vocab = sorted_vocab
103 | self.batch_words = batch_words
104 |
105 | def fit(self, X, y=None):
106 | """Fit the model according to the given training data.
107 | Parameters
108 | ----------
109 | X : iterable of iterables of str
110 | The input corpus. X can be simply a list of lists of tokens, but for larger corpora,
111 | consider an iterable that streams the sentences directly from disk/network.
112 | See :class:`~gensim.models.word2vec.BrownCorpus`, :class:`~gensim.models.word2vec.Text8Corpus`
113 | or :class:`~gensim.models.word2vec.LineSentence` in :mod:`~gensim.models.word2vec` module for such examples.
114 | Returns
115 | -------
116 | :class:`~gensim.sklearn_api.w2vmodel.W2VTransformer`
117 | The trained model.
118 | """
119 | self.gensim_model = models.Word2Vec(
120 | sentences=X, vector_size=self.size, alpha=self.alpha,
121 | window=self.window, min_count=self.min_count, max_vocab_size=self.max_vocab_size,
122 | sample=self.sample, seed=self.seed, workers=self.workers, min_alpha=self.min_alpha,
123 | sg=self.sg, hs=self.hs, negative=self.negative, cbow_mean=self.cbow_mean,
124 | hashfxn=self.hashfxn, null_word=self.null_word, trim_rule=self.trim_rule,
125 | sorted_vocab=self.sorted_vocab, batch_words=self.batch_words
126 | )
127 | return self
128 |
129 | def transform(self, words):
130 | """Get the word vectors the input words.
131 | Parameters
132 | ----------
133 | words : {iterable of str, str}
134 | Word or a collection of words to be transformed.
135 | Returns
136 | -------
137 | np.ndarray of shape [`len(words)`, `size`]
138 | A 2D array where each row is the vector of one word.
139 | """
140 | if self.gensim_model is None:
141 | raise NotFittedError(
142 | "This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method."
143 | )
144 |
145 | # The input as array of array
146 | if isinstance(words, six.string_types):
147 | words = [words]
148 | vectors = [self.gensim_model.wv[word] for word in words]
149 | return np.reshape(np.array(vectors), (len(words), self.size))
150 |
151 | def partial_fit(self, X):
152 | raise NotImplementedError(
153 | "'partial_fit' has not been implemented for W2VTransformer. "
154 | "However, the model can be updated with a fixed vocabulary using Gensim API call."
155 | )
156 |
--------------------------------------------------------------------------------
/example_pipelines/healthcare/custom_monkeypatching/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/healthcare/custom_monkeypatching/__init__.py
--------------------------------------------------------------------------------
/example_pipelines/healthcare/custom_monkeypatching/patch_healthcare_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Monkey patching for healthcare_utils
3 | """
4 | import gorilla
5 |
6 | from example_pipelines.healthcare import healthcare_utils
7 | from example_pipelines.healthcare import _gensim_wrapper
8 | from mlinspect.backends._sklearn_backend import SklearnBackend
9 | from mlinspect.inspections._inspection_input import OperatorContext, FunctionInfo, OperatorType
10 | from mlinspect.instrumentation._dag_node import DagNode, BasicCodeLocation, DagNodeDetails
11 | from mlinspect.instrumentation._pipeline_executor import singleton
12 | from mlinspect.monkeypatching._monkey_patching_utils import add_dag_node, \
13 | get_input_info, execute_patched_func_no_op_id, get_optional_code_info_or_none
14 | from mlinspect.monkeypatching._mlinspect_ndarray import MlinspectNdarray
15 |
16 |
17 | class SklearnMyW2VTransformerPatching:
18 | """ Patches for healthcare_utils.MyW2VTransformer"""
19 |
20 | # pylint: disable=too-few-public-methods
21 |
22 | @gorilla.patch(_gensim_wrapper.W2VTransformer, name='__init__', settings=gorilla.Settings(allow_hit=True))
23 | def patched__init__(self, *, size=100, alpha=0.025, window=5, min_count=5, max_vocab_size=None, sample=1e-3, seed=1,
24 | workers=3, min_alpha=0.0001, sg=0, hs=0, negative=5, cbow_mean=1, hashfxn=hash, iter=5,
25 | null_word=0, trim_rule=None, sorted_vocab=1, batch_words=10000,
26 | mlinspect_caller_filename=None, mlinspect_lineno=None,
27 | mlinspect_optional_code_reference=None, mlinspect_optional_source_code=None,
28 | mlinspect_fit_transform_active=False):
29 | """ Patch for ('example_pipelines.healthcare.healthcare_utils', 'MyW2VTransformer') """
30 | # pylint: disable=no-method-argument, attribute-defined-outside-init, too-many-locals, redefined-builtin,
31 | # pylint: disable=invalid-name
32 | original = gorilla.get_original_attribute(_gensim_wrapper.W2VTransformer, '__init__')
33 |
34 | self.mlinspect_caller_filename = mlinspect_caller_filename
35 | self.mlinspect_lineno = mlinspect_lineno
36 | self.mlinspect_optional_code_reference = mlinspect_optional_code_reference
37 | self.mlinspect_optional_source_code = mlinspect_optional_source_code
38 | self.mlinspect_fit_transform_active = mlinspect_fit_transform_active
39 |
40 | self.mlinspect_non_data_func_args = {'size': size, 'alpha': alpha, 'window': window,
41 | 'min_count': min_count, 'max_vocab_size': max_vocab_size, 'sample': sample,
42 | 'seed': seed, 'workers': workers, 'min_alpha': min_alpha, 'sg': sg,
43 | 'hs': hs, 'negative': negative, 'cbow_mean': cbow_mean, 'iter': iter,
44 | 'null_word': null_word, 'trim_rule': trim_rule,
45 | 'sorted_vocab': sorted_vocab, 'batch_words': batch_words}
46 |
47 | def execute_inspections(_, caller_filename, lineno, optional_code_reference, optional_source_code):
48 | """ Execute inspections, add DAG node """
49 | original(self, hashfxn=hashfxn, **self.mlinspect_non_data_func_args)
50 |
51 | self.mlinspect_caller_filename = caller_filename
52 | self.mlinspect_lineno = lineno
53 | self.mlinspect_optional_code_reference = optional_code_reference
54 | self.mlinspect_optional_source_code = optional_source_code
55 |
56 | return execute_patched_func_no_op_id(original, execute_inspections, self, hashfxn=hashfxn,
57 | **self.mlinspect_non_data_func_args)
58 |
59 | @gorilla.patch(healthcare_utils.MyW2VTransformer, name='fit_transform', settings=gorilla.Settings(allow_hit=True))
60 | def patched_fit_transform(self, *args, **kwargs):
61 | """ Patch for ('example_pipelines.healthcare.healthcare_utils.MyW2VTransformer', 'fit_transform') """
62 | # pylint: disable=no-method-argument
63 | self.mlinspect_fit_transform_active = True # pylint: disable=attribute-defined-outside-init
64 | original = gorilla.get_original_attribute(healthcare_utils.MyW2VTransformer, 'fit_transform')
65 | function_info = FunctionInfo('example_pipelines.healthcare.healthcare_utils', 'MyW2VTransformer')
66 | input_info = get_input_info(args[0], self.mlinspect_caller_filename, self.mlinspect_lineno, function_info,
67 | self.mlinspect_optional_code_reference, self.mlinspect_optional_source_code)
68 |
69 | operator_context = OperatorContext(OperatorType.TRANSFORMER, function_info)
70 | input_infos = SklearnBackend.before_call(operator_context, [input_info.annotated_dfobject])
71 | result = original(self, input_infos[0].result_data, *args[1:], **kwargs)
72 | backend_result = SklearnBackend.after_call(operator_context,
73 | input_infos,
74 | result,
75 | self.mlinspect_non_data_func_args)
76 | new_return_value = backend_result.annotated_dfobject.result_data
77 | assert isinstance(new_return_value, MlinspectNdarray)
78 | dag_node = DagNode(singleton.get_next_op_id(),
79 | BasicCodeLocation(self.mlinspect_caller_filename, self.mlinspect_lineno),
80 | operator_context,
81 | DagNodeDetails("Word2Vec: fit_transform", ['array']),
82 | get_optional_code_info_or_none(self.mlinspect_optional_code_reference,
83 | self.mlinspect_optional_source_code))
84 | add_dag_node(dag_node, [input_info.dag_node], backend_result)
85 | self.mlinspect_fit_transform_active = False # pylint: disable=attribute-defined-outside-init
86 | return new_return_value
87 |
88 | @gorilla.patch(healthcare_utils.MyW2VTransformer, name='transform', settings=gorilla.Settings(allow_hit=True))
89 | def patched_transform(self, *args, **kwargs):
90 | """ Patch for ('example_pipelines.healthcare.healthcare_utils.MyW2VTransformer', 'transform') """
91 | # pylint: disable=no-method-argument
92 | original = gorilla.get_original_attribute(healthcare_utils.MyW2VTransformer, 'transform')
93 | if not self.mlinspect_fit_transform_active:
94 | function_info = FunctionInfo('example_pipelines.healthcare.healthcare_utils', 'MyW2VTransformer')
95 | input_info = get_input_info(args[0], self.mlinspect_caller_filename, self.mlinspect_lineno, function_info,
96 | self.mlinspect_optional_code_reference, self.mlinspect_optional_source_code)
97 |
98 | operator_context = OperatorContext(OperatorType.TRANSFORMER, function_info)
99 | input_infos = SklearnBackend.before_call(operator_context, [input_info.annotated_dfobject])
100 | result = original(self, input_infos[0].result_data, *args[1:], **kwargs)
101 | backend_result = SklearnBackend.after_call(operator_context,
102 | input_infos,
103 | result,
104 | self.mlinspect_non_data_func_args)
105 | new_return_value = backend_result.annotated_dfobject.result_data
106 | assert isinstance(new_return_value, MlinspectNdarray)
107 | dag_node = DagNode(singleton.get_next_op_id(),
108 | BasicCodeLocation(self.mlinspect_caller_filename, self.mlinspect_lineno),
109 | operator_context,
110 | DagNodeDetails("Word2Vec: transform", ['array']),
111 | get_optional_code_info_or_none(self.mlinspect_optional_code_reference,
112 | self.mlinspect_optional_source_code))
113 | add_dag_node(dag_node, [input_info.dag_node], backend_result)
114 | else:
115 | new_return_value = original(self, *args, **kwargs)
116 | return new_return_value
117 |
--------------------------------------------------------------------------------
/example_pipelines/healthcare/healthcare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/healthcare/healthcare.png
--------------------------------------------------------------------------------
/example_pipelines/healthcare/healthcare.py:
--------------------------------------------------------------------------------
1 | """Predicting which patients are at a higher risk of complications"""
2 | import warnings
3 | import os
4 | import pandas as pd
5 | from scikeras.wrappers import KerasClassifier
6 | from sklearn.compose import ColumnTransformer
7 | from sklearn.impute import SimpleImputer
8 | from sklearn.model_selection import train_test_split
9 | from sklearn.pipeline import Pipeline
10 | from sklearn.preprocessing import OneHotEncoder, StandardScaler
11 | from example_pipelines.healthcare.healthcare_utils import MyW2VTransformer, \
12 | create_model
13 | from mlinspect.utils import get_project_root
14 |
15 | # FutureWarning: Sklearn 0.24 made a change that breaks remainder='drop', that change will be fixed
16 | # in an upcoming version: https://github.com/scikit-learn/scikit-learn/pull/19263
17 | warnings.filterwarnings('ignore')
18 |
19 | COUNTIES_OF_INTEREST = ['county2', 'county3']
20 |
21 | patients = pd.read_csv(os.path.join(str(get_project_root()), "example_pipelines", "healthcare",
22 | "patients.csv"), na_values='?')
23 | histories = pd.read_csv(os.path.join(str(get_project_root()), "example_pipelines", "healthcare",
24 | "histories.csv"), na_values='?')
25 |
26 | data = patients.merge(histories, on=['ssn'])
27 | complications = data.groupby('age_group') \
28 | .agg(mean_complications=('complications', 'mean'))
29 | data = data.merge(complications, on=['age_group'])
30 | data['label'] = data['complications'] > 1.2 * data['mean_complications']
31 | data = data[['smoker', 'last_name', 'county', 'num_children', 'race', 'income', 'label']]
32 | data = data[data['county'].isin(COUNTIES_OF_INTEREST)]
33 |
34 | impute_and_one_hot_encode = Pipeline([
35 | ('impute', SimpleImputer(strategy='most_frequent')),
36 | ('encode', OneHotEncoder(sparse=False, handle_unknown='ignore'))
37 | ])
38 | featurisation = ColumnTransformer(transformers=[
39 | ("impute_and_one_hot_encode", impute_and_one_hot_encode, ['smoker', 'county', 'race']),
40 | ('word2vec', MyW2VTransformer(min_count=2), ['last_name']),
41 | ('numeric', StandardScaler(), ['num_children', 'income']),
42 | ], remainder='drop')
43 | neural_net = KerasClassifier(model=create_model, epochs=10, batch_size=1, verbose=0,
44 | hidden_layer_sizes=(9, 9,), loss="binary_crossentropy")
45 | pipeline = Pipeline([
46 | ('features', featurisation),
47 | ('learner', neural_net)])
48 |
49 | train_data, test_data = train_test_split(data)
50 | model = pipeline.fit(train_data, train_data['label'])
51 | print(f"Mean accuracy: {model.score(test_data, test_data['label'])}")
52 |
--------------------------------------------------------------------------------
/example_pipelines/healthcare/healthcare_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Some useful utils for the project
3 | """
4 | import numpy
5 | from keras import Input
6 | from sklearn.exceptions import NotFittedError
7 | from tensorflow.keras.models import Sequential
8 | from tensorflow.keras.layers import Dense
9 |
10 | from example_pipelines.healthcare._gensim_wrapper import W2VTransformer
11 |
12 |
13 | class MyW2VTransformer(W2VTransformer):
14 | """Some custom w2v transformer."""
15 | # pylint: disable-all
16 |
17 | def partial_fit(self, X):
18 | super().partial_fit([X])
19 |
20 | def fit(self, X, y=None):
21 | X = X.iloc[:, 0].tolist()
22 | return super().fit([X], y)
23 |
24 | def transform(self, words):
25 | words = words.iloc[:, 0].tolist()
26 | if self.gensim_model is None:
27 | raise NotFittedError(
28 | "This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method."
29 | )
30 |
31 | # The input as array of array
32 | vectors = []
33 | for word in words:
34 | if word in self.gensim_model.wv:
35 | vectors.append(self.gensim_model.wv[word])
36 | else:
37 | vectors.append(numpy.zeros(self.size))
38 | return numpy.reshape(numpy.array(vectors), (len(words), self.size))
39 |
40 |
41 | def create_model(meta, hidden_layer_sizes):
42 | n_features_in_ = meta["n_features_in_"]
43 | n_classes_ = meta["n_classes_"]
44 | model = Sequential()
45 | model.add(Input(shape=(n_features_in_,)))
46 | for hidden_layer_size in hidden_layer_sizes:
47 | model.add(Dense(hidden_layer_size, activation="relu"))
48 | model.add(Dense(1, activation="sigmoid"))
49 | return model
50 |
--------------------------------------------------------------------------------
/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/experiments/__init__.py
--------------------------------------------------------------------------------
/experiments/performance/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/experiments/performance/__init__.py
--------------------------------------------------------------------------------
/experiments/performance/_empty_inspection.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple empty inspection
3 | """
4 | from typing import Iterable
5 |
6 | from mlinspect.inspections._inspection import Inspection
7 |
8 |
9 | class EmptyInspection(Inspection):
10 | """
11 | An empty inspection for performance experiments
12 | """
13 |
14 | def __init__(self, inspection_id):
15 | self._id = inspection_id
16 |
17 | @property
18 | def inspection_id(self):
19 | return self._id
20 |
21 | def visit_operator(self, inspection_input) -> Iterable[any]:
22 | """
23 | Visit an operator
24 | """
25 | for _ in inspection_input.row_iterator:
26 | yield None
27 |
28 | def get_operator_annotation_after_visit(self) -> any:
29 | return None
30 |
--------------------------------------------------------------------------------
/mlinspect/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Packages and classes we want to expose to users
3 | """
4 | from ._pipeline_inspector import PipelineInspector
5 | from ._inspector_result import InspectorResult
6 | from .inspections._inspection_input import OperatorContext, FunctionInfo, OperatorType
7 | from .instrumentation._dag_node import DagNode, BasicCodeLocation, DagNodeDetails, OptionalCodeInfo, CodeReference
8 |
9 | __all__ = [
10 | 'utils',
11 | 'inspections',
12 | 'checks',
13 | 'visualisation',
14 | 'PipelineInspector', 'InspectorResult',
15 | 'DagNode', 'OperatorType',
16 | 'BasicCodeLocation', 'OperatorContext', 'DagNodeDetails', 'OptionalCodeInfo', 'FunctionInfo', 'CodeReference'
17 | ]
18 |
--------------------------------------------------------------------------------
/mlinspect/_inspector_result.py:
--------------------------------------------------------------------------------
1 | """
2 | Data class used as result of the PipelineExecutor
3 | """
4 | import dataclasses
5 | from typing import Dict
6 |
7 | import networkx
8 |
9 | from mlinspect.checks._check import Check, CheckResult
10 | from mlinspect.inspections._inspection import Inspection
11 |
12 |
13 | @dataclasses.dataclass
14 | class InspectorResult:
15 | """
16 | The class the PipelineExecutor returns
17 | """
18 | dag: networkx.DiGraph
19 | dag_node_to_inspection_results: Dict[any, Dict[Inspection, any]] # First any is DagNode
20 | check_to_check_results: Dict[Check, CheckResult]
21 |
--------------------------------------------------------------------------------
/mlinspect/_pipeline_inspector.py:
--------------------------------------------------------------------------------
1 | """
2 | User-facing API for inspecting the pipeline
3 | """
4 | from typing import Iterable, Dict, List
5 |
6 | from pandas import DataFrame
7 |
8 | from mlinspect.inspections._inspection import Inspection
9 | from .checks._check import Check, CheckResult
10 | from ._inspector_result import InspectorResult
11 | from .instrumentation._pipeline_executor import singleton
12 |
13 |
14 | class PipelineInspectorBuilder:
15 | """
16 | The fluent API builder to build an inspection run
17 | """
18 |
19 | def __init__(self, notebook_path: str or None = None,
20 | python_path: str or None = None,
21 | python_code: str or None = None
22 | ) -> None:
23 | self.track_code_references = True
24 | self.monkey_patching_modules = []
25 | self.notebook_path = notebook_path
26 | self.python_path = python_path
27 | self.python_code = python_code
28 | self.inspections = []
29 | self.checks = []
30 |
31 | def add_required_inspection(self, inspection: Inspection):
32 | """
33 | Add an analyzer
34 | """
35 | self.inspections.append(inspection)
36 | return self
37 |
38 | def add_required_inspections(self, inspections: Iterable[Inspection]):
39 | """
40 | Add a list of inspections
41 | """
42 | self.inspections.extend(inspections)
43 | return self
44 |
45 | def add_check(self, check: Check):
46 | """
47 | Add an analyzer
48 | """
49 | self.checks.append(check)
50 | return self
51 |
52 | def add_checks(self, checks: Iterable[Check]):
53 | """
54 | Add a list of inspections
55 | """
56 | self.checks.extend(checks)
57 | return self
58 |
59 | def set_code_reference_tracking(self, track_code_references: bool):
60 | """
61 | Set whether to track code references. The default is tracking them.
62 | """
63 | self.track_code_references = track_code_references
64 | return self
65 |
66 | def add_custom_monkey_patching_modules(self, module_list: List):
67 | """
68 | Add additional monkey patching modules.
69 | """
70 | self.monkey_patching_modules.extend(module_list)
71 | return self
72 |
73 | def add_custom_monkey_patching_module(self, module: any):
74 | """
75 | Add an additional monkey patching module.
76 | """
77 | self.monkey_patching_modules.append(module)
78 | return self
79 |
80 | def execute(self) -> InspectorResult:
81 | """
82 | Instrument and execute the pipeline
83 | """
84 | return singleton.run(notebook_path=self.notebook_path,
85 | python_path=self.python_path,
86 | python_code=self.python_code,
87 | inspections=self.inspections,
88 | checks=self.checks,
89 | custom_monkey_patching=self.monkey_patching_modules)
90 |
91 |
92 | class PipelineInspector:
93 | """
94 | The entry point to the fluent API to build an inspection run
95 | """
96 | @staticmethod
97 | def on_pipeline_from_py_file(path: str) -> PipelineInspectorBuilder:
98 | """Inspect a pipeline from a .py file."""
99 | return PipelineInspectorBuilder(python_path=path)
100 |
101 | @staticmethod
102 | def on_pipeline_from_ipynb_file(path: str) -> PipelineInspectorBuilder:
103 | """Inspect a pipeline from a .ipynb file."""
104 | return PipelineInspectorBuilder(notebook_path=path)
105 |
106 | @staticmethod
107 | def on_pipeline_from_string(code: str) -> PipelineInspectorBuilder:
108 | """Inspect a pipeline from a string."""
109 | return PipelineInspectorBuilder(python_code=code)
110 |
111 | @staticmethod
112 | def check_results_as_data_frame(check_to_check_results: Dict[Check, CheckResult]) -> DataFrame:
113 | """
114 | Get a pandas DataFrame with an overview of the CheckResults
115 | """
116 | check_names = []
117 | status = []
118 | descriptions = []
119 | for check_result in check_to_check_results.values():
120 | check_names.append(check_result.check)
121 | status.append(check_result.status)
122 | descriptions.append(check_result.description)
123 | return DataFrame(zip(check_names, status, descriptions), columns=["check_name", "status", "description"])
124 |
--------------------------------------------------------------------------------
/mlinspect/backends/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/mlinspect/backends/__init__.py
--------------------------------------------------------------------------------
/mlinspect/backends/_all_backends.py:
--------------------------------------------------------------------------------
1 | """
2 | Get all available backends
3 | """
4 | from typing import List
5 |
6 | from ._backend import Backend
7 | from ._pandas_backend import PandasBackend
8 | from ._sklearn_backend import SklearnBackend
9 |
10 |
11 | def get_all_backends() -> List[Backend]:
12 | """Get the list of all currently available backends"""
13 | return [PandasBackend(), SklearnBackend()]
14 |
--------------------------------------------------------------------------------
/mlinspect/backends/_backend.py:
--------------------------------------------------------------------------------
1 | """
2 | The Interface for the different instrumentation backends
3 | """
4 | import abc
5 | import dataclasses
6 | from types import MappingProxyType
7 | from typing import List, Dict
8 |
9 | from mlinspect.inspections import Inspection
10 |
11 |
12 | @dataclasses.dataclass(frozen=True)
13 | class AnnotatedDfObject:
14 | """ A dataframe-like object and its annotations """
15 | result_data: any
16 | result_annotation: any
17 |
18 |
19 | @dataclasses.dataclass(frozen=True)
20 | class BackendResult:
21 | """ The annotated dataframe and the annotations for the current DAG node """
22 | annotated_dfobject: AnnotatedDfObject
23 | dag_node_annotation: Dict[Inspection, any]
24 | optional_second_annotated_dfobject: AnnotatedDfObject = None
25 | optional_second_dag_node_annotation: Dict[Inspection, any] = None
26 |
27 |
28 | class Backend(metaclass=abc.ABCMeta):
29 | """
30 | The Interface for the different instrumentation backends
31 | """
32 |
33 | @staticmethod
34 | @abc.abstractmethod
35 | def before_call(operator_context, input_infos: List[AnnotatedDfObject]) \
36 | -> List[AnnotatedDfObject]:
37 | """The value or module a function may be called on"""
38 | # pylint: disable=too-many-arguments, unused-argument
39 | raise NotImplementedError
40 |
41 |
42 | @staticmethod
43 | @abc.abstractmethod
44 | def after_call(operator_context, input_infos: List[AnnotatedDfObject], return_value,
45 | non_data_function_args: Dict[str, any] = MappingProxyType({})) -> BackendResult:
46 | """The return value of some function"""
47 | # pylint: disable=too-many-arguments, unused-argument
48 | raise NotImplementedError
49 |
--------------------------------------------------------------------------------
/mlinspect/backends/_backend_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Some utility functions the different instrumentation backends
3 | """
4 | # pylint: disable=unnecessary-dunder-call
5 | import itertools
6 |
7 | import numpy
8 | from pandas import DataFrame, Series
9 | from scipy.sparse import csr_matrix
10 |
11 | from ._backend import AnnotatedDfObject
12 | from ..inspections._inspection_input import ColumnInfo
13 | from ..monkeypatching._mlinspect_ndarray import MlinspectNdarray
14 |
15 |
16 | def get_annotation_rows(input_annotations, inspection_index):
17 | """
18 | In the pandas backend, we store annotations in a data frame, for the sklearn transformers lists are enough
19 | """
20 | if isinstance(input_annotations, DataFrame):
21 | annotations_for_inspection = input_annotations.iloc[:, inspection_index]
22 | assert isinstance(annotations_for_inspection, Series)
23 | else:
24 | annotations_for_inspection = input_annotations[inspection_index]
25 | assert isinstance(annotations_for_inspection, list)
26 | annotation_rows = annotations_for_inspection.__iter__()
27 | return annotation_rows
28 |
29 |
30 | def build_annotation_df_from_iters(inspections, annotation_iterators):
31 | """
32 | Build the annotations dataframe
33 | """
34 | annotation_iterators = itertools.zip_longest(*annotation_iterators)
35 | inspection_names = [str(inspection) for inspection in inspections]
36 | annotations_df = DataFrame(annotation_iterators, columns=inspection_names)
37 | return annotations_df
38 |
39 |
40 | def build_annotation_list_from_iters(annotation_iterators):
41 | """
42 | Build the annotations dataframe
43 | """
44 | annotation_lists = [list(iterator) for iterator in annotation_iterators]
45 | return list(annotation_lists)
46 |
47 |
48 | def get_iterator_for_type(data, np_nditer_with_refs=False, columns=None):
49 | """
50 | Create an efficient iterator for the data.
51 | Automatically detects the data type and fails if it cannot handle that data type.
52 | """
53 | if isinstance(data, DataFrame):
54 | iterator = get_df_row_iterator(data)
55 | elif isinstance(data, numpy.ndarray):
56 | # TODO: Measure performance impact of np_nditer_with_refs. To support arbitrary pipelines, remove this
57 | # or check the type of the standard iterator. It seems the nditer variant is faster but does not always work
58 | iterator = get_numpy_array_row_iterator(data, np_nditer_with_refs, columns)
59 | elif isinstance(data, Series):
60 | iterator = get_series_row_iterator(data, columns)
61 | elif isinstance(data, csr_matrix):
62 | iterator = get_csr_row_iterator(data, columns)
63 | elif isinstance(data, list):
64 | iterator = get_list_row_iterator(data, columns)
65 | else:
66 | raise NotImplementedError(f"TODO: Support type {type(data)}!")
67 | return iterator
68 |
69 |
70 | def create_wrapper_with_annotations(annotations_df, return_value) -> AnnotatedDfObject:
71 | """
72 | Create a wrapper based on the data type of the return value and store the annotations in it.
73 | """
74 | if isinstance(return_value, numpy.ndarray):
75 | return_value = MlinspectNdarray(return_value)
76 | new_return_value = AnnotatedDfObject(return_value, annotations_df)
77 | elif isinstance(return_value, DataFrame):
78 | # Remove index columns that may have been created
79 | if "mlinspect_index" in return_value.columns:
80 | return_value = return_value.drop("mlinspect_index", axis=1)
81 | elif "mlinspect_index_x" in return_value.columns:
82 | return_value = return_value.drop(["mlinspect_index_x", "mlinspect_index_y"], axis=1)
83 | assert "mlinspect_index" not in return_value.columns
84 | assert "mlinspect_index_x" not in return_value.columns
85 |
86 | new_return_value = AnnotatedDfObject(return_value, annotations_df)
87 | elif isinstance(return_value, (Series, csr_matrix)):
88 | return_value.annotations = annotations_df
89 | new_return_value = AnnotatedDfObject(return_value, annotations_df)
90 | elif return_value is None:
91 | new_return_value = AnnotatedDfObject(None, annotations_df)
92 | else:
93 | raise NotImplementedError(f"A type that is still unsupported was found: {return_value}")
94 | return new_return_value
95 |
96 |
97 | def get_df_row_iterator(dataframe):
98 | """
99 | Create an efficient iterator for the data frame rows.
100 | The implementation is inspired by the implementation of the pandas DataFrame.itertuple method
101 | """
102 | # Performance tips:
103 | # https://stackoverflow.com/questions/16476924/how-to-iterate-over-rows-in-a-dataframe-in-pandas
104 | arrays = []
105 | column_info = ColumnInfo(list(dataframe.columns.values))
106 | arrays.extend(dataframe.iloc[:, k] for k in range(0, len(dataframe.columns)))
107 |
108 | return column_info, map(tuple, zip(*arrays))
109 |
110 |
111 | def get_series_row_iterator(series, columns=None):
112 | """
113 | Create an efficient iterator for the data frame rows.
114 | The implementation is inspired by the implementation of the pandas DataFrame.itertuple method
115 | """
116 | if columns:
117 | column_info = ColumnInfo(columns)
118 | elif series.name:
119 | column_info = ColumnInfo([series.name])
120 | else:
121 | column_info = ColumnInfo(["array"])
122 | numpy_iterator = series.__iter__()
123 |
124 | return column_info, map(tuple, zip(numpy_iterator))
125 |
126 |
127 | def get_numpy_array_row_iterator(nparray, nditer=False, columns=None):
128 | """
129 | Create an efficient iterator for the data frame rows.
130 | The implementation is inspired by the implementation of the pandas DataFrame.itertuple method
131 | """
132 | if columns:
133 | column_info = ColumnInfo(columns)
134 | else:
135 | column_info = ColumnInfo(["array"])
136 | if nditer is True:
137 | numpy_iterator = numpy.nditer(nparray, ["refs_ok"])
138 | else:
139 | numpy_iterator = nparray.__iter__()
140 |
141 | return column_info, map(tuple, zip(numpy_iterator))
142 |
143 |
144 | def get_list_row_iterator(list_data, columns=None):
145 | """
146 | Create an efficient iterator for the data frame rows.
147 | The implementation is inspired by the implementation of the pandas DataFrame.itertuple method
148 | """
149 | if columns:
150 | column_info = ColumnInfo(columns)
151 | else:
152 | column_info = ColumnInfo(["array"])
153 | numpy_iterator = list_data.__iter__()
154 |
155 | return column_info, map(tuple, zip(numpy_iterator))
156 |
157 |
158 | def get_csr_row_iterator(csr, columns=None):
159 | """
160 | Create an efficient iterator for csr rows.
161 | The implementation is inspired by the implementation of the pandas DataFrame.itertuple method
162 | """
163 | # TODO: Maybe there is a way to use sparse rows that is faster
164 | # However, this is the fastest way I discovered so far
165 | np_array = csr.toarray()
166 | if columns:
167 | column_info = ColumnInfo(columns)
168 | else:
169 | column_info = ColumnInfo(["array"])
170 | numpy_iterator = np_array.__iter__()
171 |
172 | return column_info, map(tuple, zip(numpy_iterator))
173 |
--------------------------------------------------------------------------------
/mlinspect/backends/_sklearn_backend.py:
--------------------------------------------------------------------------------
1 | """
2 | The scikit-learn backend
3 | """
4 | from types import MappingProxyType
5 | from typing import List, Dict
6 |
7 | import pandas
8 |
9 | from ._backend import Backend, AnnotatedDfObject, BackendResult
10 | from ._iter_creation import iter_input_annotation_output_sink_op, iter_input_annotation_output_nary_op
11 | from ._pandas_backend import execute_inspection_visits_unary_operator, store_inspection_outputs, \
12 | execute_inspection_visits_data_source
13 | from .. import OperatorType
14 | from ..instrumentation._pipeline_executor import singleton
15 |
16 |
17 | class SklearnBackend(Backend):
18 | """
19 | The scikit-learn backend
20 | """
21 |
22 | @staticmethod
23 | def before_call(operator_context, input_infos: List[AnnotatedDfObject]):
24 | """The value or module a function may be called on"""
25 | # pylint: disable=too-many-arguments
26 | if operator_context.operator == OperatorType.TRAIN_TEST_SPLIT:
27 | pandas_df = input_infos[0].result_data
28 | assert isinstance(pandas_df, pandas.DataFrame)
29 | pandas_df['mlinspect_index'] = range(0, len(pandas_df))
30 | return input_infos
31 |
32 | @staticmethod
33 | def after_call(operator_context, input_infos: List[AnnotatedDfObject], return_value,
34 | non_data_function_args: Dict[str, any] = MappingProxyType({})) \
35 | -> BackendResult:
36 | """The return value of some function"""
37 | # pylint: disable=too-many-arguments
38 | if operator_context.operator == OperatorType.DATA_SOURCE:
39 | return_value = execute_inspection_visits_data_source(operator_context, return_value, non_data_function_args)
40 | elif operator_context.operator == OperatorType.TRAIN_TEST_SPLIT:
41 | train_data, test_data = return_value
42 | train_return_value = execute_inspection_visits_unary_operator(operator_context,
43 | input_infos[0].result_data,
44 | input_infos[0].result_annotation,
45 | train_data,
46 | True, non_data_function_args)
47 | test_return_value = execute_inspection_visits_unary_operator(operator_context,
48 | input_infos[0].result_data,
49 | input_infos[0].result_annotation,
50 | test_data,
51 | True, non_data_function_args)
52 | input_infos[0].result_data.drop("mlinspect_index", axis=1, inplace=True)
53 | train_data.drop("mlinspect_index", axis=1, inplace=True)
54 | test_data.drop("mlinspect_index", axis=1, inplace=True)
55 | return_value = BackendResult(train_return_value.annotated_dfobject,
56 | train_return_value.dag_node_annotation,
57 | test_return_value.annotated_dfobject,
58 | test_return_value.dag_node_annotation)
59 | elif operator_context.operator in {OperatorType.PROJECTION, OperatorType.PROJECTION_MODIFY,
60 | OperatorType.TRANSFORMER, OperatorType.TRAIN_DATA, OperatorType.TRAIN_LABELS,
61 | OperatorType.TEST_DATA, OperatorType.TEST_LABELS}:
62 | return_value = execute_inspection_visits_unary_operator(operator_context, input_infos[0].result_data,
63 | input_infos[0].result_annotation, return_value,
64 | False, non_data_function_args)
65 | elif operator_context.operator == OperatorType.ESTIMATOR:
66 | return_value = execute_inspection_visits_sink_op(operator_context,
67 | input_infos[0].result_data,
68 | input_infos[0].result_annotation,
69 | input_infos[1].result_data,
70 | input_infos[1].result_annotation,
71 | non_data_function_args)
72 | elif operator_context.operator == OperatorType.SCORE:
73 | return_value = execute_inspection_visits_nary_op(operator_context,
74 | input_infos,
75 | return_value,
76 | non_data_function_args)
77 | elif operator_context.operator == OperatorType.CONCATENATION:
78 | return_value = execute_inspection_visits_nary_op(operator_context, input_infos, return_value,
79 | non_data_function_args)
80 | else:
81 | raise NotImplementedError(f"SklearnBackend doesn't know any operations of type "
82 | f"'{operator_context.operator}' yet!")
83 | return return_value
84 |
85 |
86 | # -------------------------------------------------------
87 | # Execute inspections functions
88 | # -------------------------------------------------------
89 |
90 | def execute_inspection_visits_sink_op(operator_context, data, data_annotation, target,
91 | target_annotation, non_data_function_args) -> BackendResult:
92 | """ Execute inspections """
93 | # pylint: disable=too-many-arguments
94 | inspection_count = len(singleton.inspections)
95 | iterators_for_inspections = iter_input_annotation_output_sink_op(inspection_count,
96 | data,
97 | data_annotation,
98 | target,
99 | target_annotation,
100 | operator_context,
101 | non_data_function_args)
102 | annotation_iterators = execute_visits(iterators_for_inspections)
103 | return_value = store_inspection_outputs(annotation_iterators, None)
104 | return return_value
105 |
106 |
107 | def execute_inspection_visits_nary_op(operator_context, annotated_dfs: List[AnnotatedDfObject],
108 | return_value_df, non_data_function_args) -> BackendResult:
109 | """Execute inspections"""
110 | # pylint: disable=too-many-arguments
111 | inspection_count = len(singleton.inspections)
112 | iterators_for_inspections = iter_input_annotation_output_nary_op(inspection_count,
113 | annotated_dfs,
114 | return_value_df,
115 | operator_context,
116 | non_data_function_args)
117 | annotation_iterators = execute_visits(iterators_for_inspections)
118 | return_value = store_inspection_outputs(annotation_iterators, return_value_df)
119 | return return_value
120 |
121 |
122 | def execute_visits(iterators_for_inspections):
123 | """
124 | After creating the iterators we need depending on the operator type, we need to execute the
125 | generic inspection visits
126 | """
127 | annotation_iterators = []
128 | for inspection_index, inspection in enumerate(singleton.inspections):
129 | iterator_for_inspection = iterators_for_inspections[inspection_index]
130 | annotations_iterator = inspection.visit_operator(iterator_for_inspection)
131 | annotation_iterators.append(annotations_iterator)
132 | return annotation_iterators
133 |
--------------------------------------------------------------------------------
/mlinspect/checks/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Packages and classes we want to expose to users
3 | """
4 | from ._no_bias_introduced_for import NoBiasIntroducedFor, NoBiasIntroducedForResult, BiasDistributionChange
5 | from ._no_illegal_features import NoIllegalFeatures, NoIllegalFeaturesResult
6 | from ._similar_removal_probabilities_for import SimilarRemovalProbabilitiesFor, SimilarRemovalProbabilitiesForResult, \
7 | RemovalProbabilities
8 | from ._check import Check, CheckResult, CheckStatus
9 |
10 | __all__ = [
11 | # General classes
12 | 'Check',
13 | 'CheckResult',
14 | 'CheckStatus',
15 | # Native checks
16 | 'NoBiasIntroducedFor', 'NoBiasIntroducedForResult', 'BiasDistributionChange',
17 | 'NoIllegalFeatures', 'NoIllegalFeaturesResult',
18 | 'SimilarRemovalProbabilitiesFor', 'SimilarRemovalProbabilitiesForResult', 'RemovalProbabilities'
19 | ]
20 |
--------------------------------------------------------------------------------
/mlinspect/checks/_check.py:
--------------------------------------------------------------------------------
1 | """
2 | The Interface for the Checks
3 | """
4 | from __future__ import annotations
5 |
6 | import abc
7 | import dataclasses
8 | from enum import Enum
9 | from typing import Iterable
10 |
11 | from mlinspect.inspections._inspection import Inspection
12 | from mlinspect.inspections._inspection_result import InspectionResult
13 |
14 |
15 | class CheckStatus(Enum):
16 | """
17 | The result of the check
18 | """
19 | SUCCESS = "Success"
20 | FAILURE = "Failure"
21 |
22 |
23 | @dataclasses.dataclass
24 | class CheckResult:
25 | """
26 | Does this check cause an error or a warning if it fails?
27 | """
28 | check: Check
29 | status: CheckStatus
30 | description: str or None
31 |
32 |
33 | class Check(metaclass=abc.ABCMeta):
34 | """
35 | Checks like no_bias_introduced
36 | """
37 | # pylint: disable=unnecessary-pass, too-few-public-methods
38 |
39 | @property
40 | def check_id(self):
41 | """The id of the Check"""
42 | return None
43 |
44 | @property
45 | @abc.abstractmethod
46 | def required_inspections(self) -> Iterable[Inspection]:
47 | """Inspections required to evaluate this check"""
48 | raise NotImplementedError
49 |
50 | @abc.abstractmethod
51 | def evaluate(self, inspection_result: InspectionResult) -> CheckResult:
52 | """Evaluate the check"""
53 | raise NotImplementedError
54 |
55 | def __eq__(self, other):
56 | """Checks must implement equals"""
57 | return (isinstance(other, self.__class__) and
58 | self.check_id == other.check_id)
59 |
60 | def __hash__(self):
61 | """Checks must be hashable"""
62 | return hash((self.__class__.__name__, self.check_id))
63 |
64 | def __repr__(self):
65 | """Checks must have a str representation"""
66 | return f"{self.__class__.__name__}({self.check_id or ''})"
67 |
--------------------------------------------------------------------------------
/mlinspect/checks/_no_illegal_features.py:
--------------------------------------------------------------------------------
1 | """
2 | The Interface for the Constraints
3 | """
4 | from __future__ import annotations
5 |
6 | import dataclasses
7 | from typing import List, Iterable
8 |
9 | from mlinspect.checks._check import Check, CheckStatus, CheckResult
10 | from mlinspect.inspections._inspection import Inspection
11 | from mlinspect.inspections._inspection_input import OperatorType
12 | from mlinspect.inspections._inspection_result import InspectionResult
13 |
14 | ILLEGAL_FEATURES = {"race", "gender", "age"}
15 |
16 |
17 | @dataclasses.dataclass
18 | class NoIllegalFeaturesResult(CheckResult):
19 | """
20 | Does the pipeline use illegal features?
21 | """
22 | illegal_features: List[str]
23 |
24 |
25 | class NoIllegalFeatures(Check):
26 | """
27 | Does the model get sensitive attributes like race as feature?
28 | """
29 | # pylint: disable=unnecessary-pass, too-few-public-methods
30 |
31 | def __init__(self, additional_illegal_feature_names=None):
32 | if additional_illegal_feature_names is None:
33 | additional_illegal_feature_names = []
34 | self.additional_illegal_feature_names = additional_illegal_feature_names
35 |
36 | @property
37 | def required_inspections(self) -> Iterable[Inspection]:
38 | """The inspections required for the check"""
39 | return []
40 |
41 | @property
42 | def check_id(self):
43 | """The id of the Constraints"""
44 | return tuple(self.additional_illegal_feature_names)
45 |
46 | def evaluate(self, inspection_result: InspectionResult) -> CheckResult:
47 | """Evaluate the check"""
48 | # TODO: Make this robust and add extensive testing
49 | dag = inspection_result.dag
50 | train_data_nodes = [node for node in dag.nodes if node.operator_info.operator == OperatorType.TRAIN_DATA]
51 | used_columns = []
52 | for train_data_node in train_data_nodes:
53 | used_columns.extend(self.get_used_columns(dag, train_data_node))
54 | forbidden_columns = {*ILLEGAL_FEATURES, *self.additional_illegal_feature_names}
55 | used_illegal_columns = list(set(used_columns).intersection(forbidden_columns))
56 | if used_illegal_columns:
57 | description = f"Used illegal columns: {used_illegal_columns}"
58 | result = NoIllegalFeaturesResult(self, CheckStatus.FAILURE, description, used_illegal_columns)
59 | else:
60 | result = NoIllegalFeaturesResult(self, CheckStatus.SUCCESS, None, [])
61 | return result
62 |
63 | def get_used_columns(self, dag, current_node):
64 | """
65 | Get the output column of the current dag node. If the current dag node is, e.g., a concatenation,
66 | check the parents of the current dag node.
67 | """
68 | columns = current_node.details.columns
69 | if columns is not None and columns != ["array"]:
70 | result = columns
71 | else:
72 | parent_columns = []
73 | for parent in dag.predecessors(current_node):
74 | parent_columns.extend(self.get_used_columns(dag, parent))
75 | result = parent_columns
76 | return result
77 |
--------------------------------------------------------------------------------
/mlinspect/inspections/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Packages and classes we want to expose to users
3 | """
4 | from ._arg_capturing import ArgumentCapturing
5 | from ._completeness_of_columns import CompletenessOfColumns
6 | from ._count_distinct_of_columns import CountDistinctOfColumns
7 | from ._inspection import Inspection
8 | from ._inspection_result import InspectionResult
9 | from ._inspection_input import InspectionInputUnaryOperator, InspectionInputDataSource, InspectionInputSinkOperator, \
10 | InspectionInputNAryOperator
11 | from ._histogram_for_columns import HistogramForColumns
12 | from ._intersectional_histogram_for_columns import IntersectionalHistogramForColumns
13 | from ._lineage import RowLineage
14 | from ._materialize_first_output_rows import MaterializeFirstOutputRows
15 | from ._column_propagation import ColumnPropagation
16 |
17 | __all__ = [
18 | # For defining custom inspections
19 | 'Inspection', 'InspectionResult',
20 | 'InspectionInputUnaryOperator', 'InspectionInputDataSource', 'InspectionInputSinkOperator',
21 | 'InspectionInputNAryOperator',
22 | # Native inspections
23 | 'HistogramForColumns',
24 | 'ColumnPropagation',
25 | 'IntersectionalHistogramForColumns',
26 | 'RowLineage',
27 | 'MaterializeFirstOutputRows',
28 | 'CompletenessOfColumns',
29 | 'CountDistinctOfColumns',
30 | 'ArgumentCapturing'
31 | ]
32 |
--------------------------------------------------------------------------------
/mlinspect/inspections/_arg_capturing.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple inspection to capture important function call arguments like estimator hyperparameters
3 | """
4 | from typing import Iterable
5 |
6 | from ._inspection import Inspection
7 |
8 |
9 | class ArgumentCapturing(Inspection):
10 | """
11 | A simple inspection to capture important function call arguments like estimator hyperparameters
12 | """
13 |
14 | def __init__(self):
15 | self._captured_arguments = None
16 |
17 | @property
18 | def inspection_id(self):
19 | return None
20 |
21 | def visit_operator(self, inspection_input) -> Iterable[any]:
22 | """
23 | Visit an operator
24 | """
25 | self._captured_arguments = inspection_input.non_data_function_args
26 |
27 | for _ in inspection_input.row_iterator:
28 | yield None
29 |
30 | def get_operator_annotation_after_visit(self) -> any:
31 | captured_args = self._captured_arguments
32 | self._captured_arguments = None
33 | return captured_args
34 |
--------------------------------------------------------------------------------
/mlinspect/inspections/_completeness_of_columns.py:
--------------------------------------------------------------------------------
1 | """
2 | A inspection to compute the ratio of non-values in output columns
3 | """
4 | from typing import Iterable
5 |
6 | import pandas
7 |
8 | from mlinspect.inspections._inspection import Inspection
9 | from mlinspect.inspections._inspection_input import OperatorType, InspectionInputSinkOperator
10 |
11 |
12 | class CompletenessOfColumns(Inspection):
13 | """
14 | An inspection to compute the completeness of columns
15 | """
16 |
17 | def __init__(self, columns):
18 | self._present_column_names = []
19 | self._null_value_counts = []
20 | self._total_counts = []
21 | self._operator_type = None
22 | self.columns = columns
23 |
24 | @property
25 | def inspection_id(self):
26 | return tuple(self.columns)
27 |
28 | def visit_operator(self, inspection_input) -> Iterable[any]:
29 | """
30 | Visit an operator
31 | """
32 | # pylint: disable=too-many-branches, too-many-statements, too-many-locals
33 | self._present_column_names = []
34 | self._null_value_counts = []
35 | self._total_counts = []
36 | self._operator_type = inspection_input.operator_context.operator
37 |
38 | if not isinstance(inspection_input, InspectionInputSinkOperator):
39 | present_columns_index = []
40 | for column_name in self.columns:
41 | column_present = column_name in inspection_input.output_columns.fields
42 | if column_present:
43 | column_index = inspection_input.output_columns.get_index_of_column(column_name)
44 | present_columns_index.append(column_index)
45 | self._present_column_names.append(column_name)
46 | self._null_value_counts.append(0)
47 | self._total_counts.append(0)
48 | for row in inspection_input.row_iterator:
49 | for present_column_index, column_index in enumerate(present_columns_index):
50 | column_value = row.output[column_index]
51 | is_null = pandas.isna(column_value)
52 | self._null_value_counts[present_column_index] += int(is_null)
53 | self._total_counts[present_column_index] += 1
54 | yield None
55 | else:
56 | for _ in inspection_input.row_iterator:
57 | yield None
58 |
59 | def get_operator_annotation_after_visit(self) -> any:
60 | assert self._operator_type
61 | if self._operator_type is not OperatorType.ESTIMATOR:
62 | completeness_results = {}
63 | for column_index, column_name in enumerate(self._present_column_names):
64 | null_value_count = self._null_value_counts[column_index]
65 | total_count = self._total_counts[column_index]
66 | completeness = (total_count - null_value_count) / total_count
67 | completeness_results[column_name] = completeness
68 | return completeness_results
69 | self._operator_type = None
70 | return None
71 |
--------------------------------------------------------------------------------
/mlinspect/inspections/_count_distinct_of_columns.py:
--------------------------------------------------------------------------------
1 | """
2 | An inspection to compute the number of distinct values in output columns
3 | """
4 | from typing import Iterable
5 |
6 | from mlinspect.inspections._inspection import Inspection
7 | from mlinspect.inspections._inspection_input import OperatorType, InspectionInputSinkOperator
8 |
9 |
10 | class CountDistinctOfColumns(Inspection):
11 | """
12 | An inspection to compute the number of distinct values of columns
13 | """
14 |
15 | def __init__(self, columns):
16 | self._present_column_names = []
17 | self._distinct_value_sets = []
18 | self._operator_type = None
19 | self.columns = columns
20 |
21 | @property
22 | def inspection_id(self):
23 | return tuple(self.columns)
24 |
25 | def visit_operator(self, inspection_input) -> Iterable[any]:
26 | """
27 | Visit an operator
28 | """
29 | # pylint: disable=too-many-branches, too-many-statements, too-many-locals
30 | self._present_column_names = []
31 | self._distinct_value_sets = []
32 | self._operator_type = inspection_input.operator_context.operator
33 |
34 | if not isinstance(inspection_input, InspectionInputSinkOperator):
35 | present_columns_index = []
36 | for column_name in self.columns:
37 | column_present = column_name in inspection_input.output_columns.fields
38 | if column_present:
39 | column_index = inspection_input.output_columns.get_index_of_column(column_name)
40 | present_columns_index.append(column_index)
41 | self._present_column_names.append(column_name)
42 | self._distinct_value_sets.append(set())
43 | for row in inspection_input.row_iterator:
44 | for present_column_index, column_index in enumerate(present_columns_index):
45 | column_value = row.output[column_index]
46 | self._distinct_value_sets[present_column_index].add(column_value)
47 | yield None
48 | else:
49 | for _ in inspection_input.row_iterator:
50 | yield None
51 |
52 | def get_operator_annotation_after_visit(self) -> any:
53 | assert self._operator_type
54 | if self._operator_type is not OperatorType.ESTIMATOR:
55 | completeness_results = {}
56 | for column_index, column_name in enumerate(self._present_column_names):
57 | distinct_value_count = len(self._distinct_value_sets[column_index])
58 | completeness_results[column_name] = distinct_value_count
59 | del self._distinct_value_sets
60 | return completeness_results
61 | self._operator_type = None
62 | return None
63 |
--------------------------------------------------------------------------------
/mlinspect/inspections/_histogram_for_columns.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple inspection to compute histograms of sensitive groups in the data
3 | """
4 | from typing import Iterable
5 |
6 | from mlinspect.inspections._inspection import Inspection
7 | from mlinspect.inspections._inspection_input import InspectionInputDataSource, \
8 | InspectionInputUnaryOperator, InspectionInputNAryOperator, OperatorType, FunctionInfo
9 |
10 |
11 | class HistogramForColumns(Inspection):
12 | """
13 | An inspection to compute group membership histograms for multiple columns
14 | """
15 |
16 | def __init__(self, sensitive_columns):
17 | self._histogram_op_output = None
18 | self._operator_type = None
19 | self.sensitive_columns = sensitive_columns
20 |
21 | @property
22 | def inspection_id(self):
23 | return tuple(self.sensitive_columns)
24 |
25 | def visit_operator(self, inspection_input) -> Iterable[any]:
26 | """
27 | Visit an operator
28 | """
29 | # pylint: disable=too-many-branches, too-many-statements, too-many-locals, too-many-nested-blocks
30 | current_count = - 1
31 |
32 | histogram_maps = []
33 | for _ in self.sensitive_columns:
34 | histogram_maps.append({})
35 |
36 | self._operator_type = inspection_input.operator_context.operator
37 |
38 | if isinstance(inspection_input, InspectionInputUnaryOperator):
39 | sensitive_columns_present = []
40 | sensitive_columns_index = []
41 | for column in self.sensitive_columns:
42 | column_present = column in inspection_input.input_columns.fields
43 | sensitive_columns_present.append(column_present)
44 | column_index = inspection_input.input_columns.get_index_of_column(column)
45 | sensitive_columns_index.append(column_index)
46 | if inspection_input.operator_context.function_info == FunctionInfo('sklearn.impute._base', 'SimpleImputer'):
47 | for row in inspection_input.row_iterator:
48 | current_count += 1
49 | column_values = []
50 | for check_index, _ in enumerate(self.sensitive_columns):
51 | if sensitive_columns_present[check_index]:
52 | column_value = row.output[0][sensitive_columns_index[check_index]]
53 | else:
54 | column_value = row.annotation[check_index]
55 | column_values.append(column_value)
56 | group_count = histogram_maps[check_index].get(column_value, 0)
57 | group_count += 1
58 | histogram_maps[check_index][column_value] = group_count
59 | yield column_values
60 | else:
61 | for row in inspection_input.row_iterator:
62 | current_count += 1
63 | column_values = []
64 | for check_index, _ in enumerate(self.sensitive_columns):
65 | if sensitive_columns_present[check_index]:
66 | column_value = row.input[sensitive_columns_index[check_index]]
67 | else:
68 | column_value = row.annotation[check_index]
69 | column_values.append(column_value)
70 | group_count = histogram_maps[check_index].get(column_value, 0)
71 | group_count += 1
72 | histogram_maps[check_index][column_value] = group_count
73 | yield column_values
74 | elif isinstance(inspection_input, InspectionInputDataSource):
75 | sensitive_columns_present = []
76 | sensitive_columns_index = []
77 | for column in self.sensitive_columns:
78 | column_present = column in inspection_input.output_columns.fields
79 | sensitive_columns_present.append(column_present)
80 | column_index = inspection_input.output_columns.get_index_of_column(column)
81 | sensitive_columns_index.append(column_index)
82 | for row in inspection_input.row_iterator:
83 | current_count += 1
84 | column_values = []
85 | for check_index, _ in enumerate(self.sensitive_columns):
86 | if sensitive_columns_present[check_index]:
87 | column_value = row.output[sensitive_columns_index[check_index]]
88 | column_values.append(column_value)
89 | group_count = histogram_maps[check_index].get(column_value, 0)
90 | group_count += 1
91 | histogram_maps[check_index][column_value] = group_count
92 | else:
93 | column_values.append(None)
94 | yield column_values
95 | elif isinstance(inspection_input, InspectionInputNAryOperator):
96 | sensitive_columns_present = []
97 | sensitive_columns_index = []
98 | for column in self.sensitive_columns:
99 | column_present = column in inspection_input.output_columns.fields
100 | sensitive_columns_present.append(column_present)
101 | column_index = inspection_input.output_columns.get_index_of_column(column)
102 | sensitive_columns_index.append(column_index)
103 | for row in inspection_input.row_iterator:
104 | current_count += 1
105 | column_values = []
106 | for check_index, _ in enumerate(self.sensitive_columns):
107 | if sensitive_columns_present[check_index]:
108 | column_value = row.output[sensitive_columns_index[check_index]]
109 | column_values.append(column_value)
110 | group_count = histogram_maps[check_index].get(column_value, 0)
111 | group_count += 1
112 | histogram_maps[check_index][column_value] = group_count
113 | else:
114 | if sensitive_columns_present[check_index]:
115 | column_value = row.output[sensitive_columns_index[check_index]]
116 | else:
117 | column_value_candidates = [annotation[check_index] for annotation in row.annotation
118 | if annotation[check_index] is not None]
119 | if len(column_value_candidates) >= 1:
120 | column_value = column_value_candidates[0]
121 | else:
122 | column_value = None
123 | column_values.append(column_value)
124 | group_count = histogram_maps[check_index].get(column_value, 0)
125 | group_count += 1
126 | histogram_maps[check_index][column_value] = group_count
127 | yield column_values
128 | else:
129 | for _ in inspection_input.row_iterator:
130 | yield None
131 |
132 | self._histogram_op_output = {}
133 | for check_index, column in enumerate(self.sensitive_columns):
134 | self._histogram_op_output[column] = histogram_maps[check_index]
135 |
136 | def get_operator_annotation_after_visit(self) -> any:
137 | assert self._operator_type
138 | if self._operator_type is not OperatorType.ESTIMATOR:
139 | result = self._histogram_op_output
140 | self._histogram_op_output = None
141 | self._operator_type = None
142 | return result
143 | self._operator_type = None
144 | return None
145 |
--------------------------------------------------------------------------------
/mlinspect/inspections/_inspection.py:
--------------------------------------------------------------------------------
1 | """
2 | The Interface for the Inspection
3 | """
4 | import abc
5 | from typing import Union, Iterable
6 |
7 | from mlinspect.inspections._inspection_input import InspectionInputDataSource, \
8 | InspectionInputUnaryOperator, InspectionInputNAryOperator, InspectionInputSinkOperator
9 |
10 |
11 | class Inspection(metaclass=abc.ABCMeta):
12 | """
13 | The Interface for the Inspections
14 | """
15 |
16 | @property
17 | def inspection_id(self):
18 | """The id of the inspection"""
19 | return None
20 |
21 | @abc.abstractmethod
22 | def visit_operator(self, inspection_input: Union[InspectionInputDataSource, InspectionInputUnaryOperator,
23 | InspectionInputNAryOperator, InspectionInputSinkOperator])\
24 | -> Iterable[any]:
25 | """Visit an operator in the DAG"""
26 | raise NotImplementedError
27 |
28 | @abc.abstractmethod
29 | def get_operator_annotation_after_visit(self) -> any:
30 | """Get the output to be included in the DAG"""
31 | raise NotImplementedError
32 |
33 | def __eq__(self, other):
34 | """Inspections must implement equals"""
35 | return (isinstance(other, self.__class__) and
36 | self.inspection_id == other.inspection_id)
37 |
38 | def __hash__(self):
39 | """Inspections must be hashable"""
40 | return hash((self.__class__.__name__, self.inspection_id))
41 |
42 | def __repr__(self):
43 | """Inspections must have a str representation"""
44 | return f"{self.__class__.__name__}({self.inspection_id})"
45 |
--------------------------------------------------------------------------------
/mlinspect/inspections/_inspection_input.py:
--------------------------------------------------------------------------------
1 | """
2 | Data classes used as input for the inspections
3 | """
4 | import dataclasses
5 | from enum import Enum
6 | from typing import Tuple, List, Iterable, Dict
7 |
8 |
9 | @dataclasses.dataclass(frozen=True)
10 | class ColumnInfo:
11 | """
12 | A class we use to efficiently pass pandas/sklearn rows
13 | """
14 | fields: List[str]
15 |
16 | def get_index_of_column(self, column_name):
17 | """
18 | Get the values index for some column
19 | """
20 | if column_name in self.fields:
21 | return self.fields.index(column_name)
22 | return None
23 |
24 | def __eq__(self, other):
25 | return (isinstance(other, ColumnInfo) and
26 | self.fields == other.fields)
27 |
28 |
29 | @dataclasses.dataclass(frozen=True)
30 | class FunctionInfo:
31 | """
32 | Contains the function name and its path
33 | """
34 | module: str
35 | function_name: str
36 |
37 |
38 | class OperatorType(Enum):
39 | """
40 | The different operator types in our DAG
41 | """
42 | DATA_SOURCE = "Data Source"
43 | MISSING_OP = "Encountered unsupported operation! Fallback: Data Source"
44 | SELECTION = "Selection"
45 | PROJECTION = "Projection"
46 | PROJECTION_MODIFY = "Projection (Modify)"
47 | TRANSFORMER = "Transformer"
48 | CONCATENATION = "Concatenation"
49 | ESTIMATOR = "Estimator"
50 | SCORE = "Score"
51 | TRAIN_DATA = "Train Data"
52 | TRAIN_LABELS = "Train Labels"
53 | TEST_DATA = "Test Data"
54 | TEST_LABELS = "Test Labels"
55 | JOIN = "Join"
56 | GROUP_BY_AGG = "Groupby and Aggregate"
57 | TRAIN_TEST_SPLIT = "Train Test Split"
58 |
59 |
60 | @dataclasses.dataclass(frozen=True)
61 | class OperatorContext:
62 | """
63 | Additional context for the inspection. Contains, most importantly, the operator type.
64 | """
65 | operator: OperatorType
66 | function_info: FunctionInfo or None
67 |
68 |
69 | @dataclasses.dataclass(frozen=True)
70 | class InspectionRowDataSource:
71 | """
72 | Wrapper class for the only operator without a parent: a Data Source
73 | """
74 | output: Tuple
75 |
76 |
77 | @dataclasses.dataclass(frozen=True)
78 | class InspectionInputDataSource:
79 | """
80 | Additional context for the inspection. Contains, most importantly, the operator type.
81 | """
82 | operator_context: OperatorContext
83 | output_columns: ColumnInfo
84 | row_iterator: Iterable[InspectionRowDataSource]
85 | non_data_function_args: Dict[str, any]
86 |
87 |
88 | @dataclasses.dataclass(frozen=True)
89 | class InspectionRowUnaryOperator:
90 | """
91 | Wrapper class for the operators with one parent like Selections and Projections
92 | """
93 | input: Tuple
94 | annotation: any
95 | output: Tuple
96 |
97 |
98 | @dataclasses.dataclass(frozen=True)
99 | class InspectionInputUnaryOperator:
100 | """
101 | Additional context for the inspection. Contains, most importantly, the operator type.
102 | """
103 | operator_context: OperatorContext
104 | input_columns: ColumnInfo
105 | output_columns: ColumnInfo
106 | row_iterator: Iterable[InspectionRowUnaryOperator]
107 | non_data_function_args: Dict[str, any]
108 |
109 |
110 | @dataclasses.dataclass(frozen=True)
111 | class InspectionRowNAryOperator:
112 | """
113 | Wrapper class for the operators with multiple parents like Concatenations
114 | """
115 | inputs: Tuple[Tuple]
116 | annotation: Tuple[any]
117 | output: Tuple
118 |
119 |
120 | @dataclasses.dataclass(frozen=True)
121 | class InspectionInputNAryOperator:
122 | """
123 | Additional context for the inspection. Contains, most importantly, the operator type.
124 | """
125 | operator_context: OperatorContext
126 | inputs_columns: List[ColumnInfo]
127 | output_columns: ColumnInfo
128 | row_iterator: Iterable[InspectionRowNAryOperator]
129 | non_data_function_args: Dict[str, any]
130 |
131 |
132 | @dataclasses.dataclass(frozen=True)
133 | class InspectionRowSinkOperator:
134 | """
135 | Wrapper class for operators like Estimators that only get fitted
136 | """
137 | input: Tuple[Tuple]
138 | annotation: any
139 |
140 |
141 | @dataclasses.dataclass(frozen=True)
142 | class InspectionInputSinkOperator:
143 | """
144 | Additional context for the inspection. Contains, most importantly, the operator type.
145 | """
146 | operator_context: OperatorContext
147 | inputs_columns: List[ColumnInfo]
148 | row_iterator: Iterable[InspectionRowSinkOperator]
149 | non_data_function_args: Dict[str, any]
150 |
--------------------------------------------------------------------------------
/mlinspect/inspections/_inspection_result.py:
--------------------------------------------------------------------------------
1 | """
2 | Data class used as result of the PipelineExecutor
3 | """
4 | import dataclasses
5 | from typing import Dict
6 |
7 | import networkx
8 |
9 | from mlinspect.instrumentation._dag_node import DagNode
10 | from mlinspect.inspections._inspection import Inspection
11 |
12 |
13 | @dataclasses.dataclass
14 | class InspectionResult:
15 | """
16 | The class the PipelineExecutor returns
17 | """
18 | dag: networkx.DiGraph
19 | dag_node_to_inspection_results: Dict[DagNode, Dict[Inspection, any]]
20 |
--------------------------------------------------------------------------------
/mlinspect/inspections/_intersectional_histogram_for_columns.py:
--------------------------------------------------------------------------------
1 | """
2 | An inspection to compute histograms of intersectional group memberships
3 | """
4 | from typing import Iterable, List
5 |
6 | from mlinspect.inspections._inspection import Inspection
7 | from mlinspect.inspections._inspection_input import InspectionInputDataSource, \
8 | InspectionInputUnaryOperator, InspectionInputNAryOperator, OperatorType, FunctionInfo
9 |
10 |
11 | class IntersectionalHistogramForColumns(Inspection):
12 | """
13 | An inspection to compute intersectional group memberships
14 | """
15 |
16 | def __init__(self, sensitive_columns: List[str]):
17 | self._histogram_op_output = None
18 | self._operator_type = None
19 | self.sensitive_columns = sensitive_columns
20 |
21 | @property
22 | def inspection_id(self):
23 | return tuple(self.sensitive_columns)
24 |
25 | def visit_operator(self, inspection_input) -> Iterable[any]:
26 | """
27 | Visit an operator
28 | """
29 | # pylint: disable=too-many-branches, too-many-statements, too-many-locals
30 | current_count = - 1
31 |
32 | histogram_map = {}
33 |
34 | self._operator_type = inspection_input.operator_context.operator
35 |
36 | if isinstance(inspection_input, InspectionInputUnaryOperator):
37 | sensitive_columns_present = []
38 | sensitive_columns_index = []
39 | for column in self.sensitive_columns:
40 | column_present = column in inspection_input.input_columns.fields
41 | sensitive_columns_present.append(column_present)
42 | column_index = inspection_input.input_columns.get_index_of_column(column)
43 | sensitive_columns_index.append(column_index)
44 | if inspection_input.operator_context.function_info == FunctionInfo('sklearn.impute._base', 'SimpleImputer'):
45 | for row in inspection_input.row_iterator:
46 | current_count += 1
47 | column_values = []
48 | for check_index, _ in enumerate(self.sensitive_columns):
49 | if sensitive_columns_present[check_index]:
50 | column_value = row.output[0][sensitive_columns_index[check_index]]
51 | else:
52 | column_value = row.annotation[check_index]
53 | column_values.append(column_value)
54 | update_histograms(column_values, histogram_map)
55 | yield column_values
56 | else:
57 | for row in inspection_input.row_iterator:
58 | current_count += 1
59 | column_values = []
60 | for check_index, _ in enumerate(self.sensitive_columns):
61 | if sensitive_columns_present[check_index]:
62 | column_value = row.input[sensitive_columns_index[check_index]]
63 | else:
64 | column_value = row.annotation[check_index]
65 | column_values.append(column_value)
66 | update_histograms(column_values, histogram_map)
67 | yield column_values
68 | elif isinstance(inspection_input, InspectionInputDataSource):
69 | sensitive_columns_present = []
70 | sensitive_columns_index = []
71 | for column in self.sensitive_columns:
72 | column_present = column in inspection_input.output_columns.fields
73 | sensitive_columns_present.append(column_present)
74 | column_index = inspection_input.output_columns.get_index_of_column(column)
75 | sensitive_columns_index.append(column_index)
76 | for row in inspection_input.row_iterator:
77 | current_count += 1
78 | column_values = []
79 | for check_index, _ in enumerate(self.sensitive_columns):
80 | if sensitive_columns_present[check_index]:
81 | column_value = row.output[sensitive_columns_index[check_index]]
82 | column_values.append(column_value)
83 | else:
84 | column_values.append(None)
85 | update_histograms(column_values, histogram_map)
86 | yield column_values
87 | elif isinstance(inspection_input, InspectionInputNAryOperator):
88 | sensitive_columns_present = []
89 | sensitive_columns_index = []
90 | for column in self.sensitive_columns:
91 | column_present = column in inspection_input.output_columns.fields
92 | sensitive_columns_present.append(column_present)
93 | column_index = inspection_input.output_columns.get_index_of_column(column)
94 | sensitive_columns_index.append(column_index)
95 | for row in inspection_input.row_iterator:
96 | current_count += 1
97 | column_values = []
98 | for check_index, _ in enumerate(self.sensitive_columns):
99 | if sensitive_columns_present[check_index]:
100 | column_value = row.output[sensitive_columns_index[check_index]]
101 | column_values.append(column_value)
102 | else:
103 | column_values.append(None)
104 | update_histograms(column_values, histogram_map)
105 | yield column_values
106 | else:
107 | for _ in inspection_input.row_iterator:
108 | yield None
109 |
110 | self._histogram_op_output = histogram_map
111 |
112 | def get_operator_annotation_after_visit(self) -> any:
113 | assert self._operator_type
114 | if self._operator_type is not OperatorType.ESTIMATOR:
115 | result = self._histogram_op_output
116 | self._histogram_op_output = None
117 | self._operator_type = None
118 | return result
119 | self._operator_type = None
120 | return None
121 |
122 |
123 | def update_histograms(column_values, histogram_map):
124 | """Update the histograms with the intersectional information"""
125 | value_tuple = tuple(column_values)
126 | group_count = histogram_map.get(value_tuple, 0)
127 | group_count += 1
128 | histogram_map[value_tuple] = group_count
129 |
--------------------------------------------------------------------------------
/mlinspect/inspections/_lineage.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple inspection for lineage tracking
3 | """
4 | import dataclasses
5 | from typing import Iterable, List
6 |
7 | from pandas import DataFrame, Series
8 |
9 | from mlinspect.inspections._inspection import Inspection
10 | from mlinspect.inspections._inspection_input import InspectionInputUnaryOperator, \
11 | InspectionInputSinkOperator, InspectionInputDataSource, InspectionInputNAryOperator, OperatorType
12 |
13 |
14 | @dataclasses.dataclass(frozen=True)
15 | class LineageId:
16 | """
17 | A lineage id class
18 | """
19 | operator_id: int
20 | row_id: int
21 |
22 |
23 | class RowLineage(Inspection):
24 | """
25 | A simple inspection for row-level lineage tracking
26 | """
27 | # TODO: Add an option to pass a list of lineage ids to this inspection. Then it materializes all related tuples.
28 | # To do this efficiently, we do not want to do expensive membership tests. We can collect all base LineageIds
29 | # in a set and then it is enough to check for set memberships in InspectionInputDataSource inspection inputs.
30 | # This set membership can be used as a 'materialize' flag we use as annotation. Then we simply need to check this
31 | # flag to check whether to materialize rows.
32 | # pylint: disable=too-many-instance-attributes
33 |
34 | ALL_ROWS = -1
35 |
36 | def __init__(self, row_count: int, operator_type_restriction: List[OperatorType] = None):
37 | self.row_count = row_count
38 | if operator_type_restriction is not None:
39 | self.operator_type_restriction = set(operator_type_restriction)
40 | self._inspection_id = (self.row_count, *self.operator_type_restriction)
41 | else:
42 | self.operator_type_restriction = None
43 | self._inspection_id = self.row_count
44 | self._operator_count = -1
45 | self._op_output = None
46 | self._op_lineage = None
47 | self._output_columns = None
48 | self._is_sink = False
49 | self._materialize_for_this_operator = None
50 |
51 | def visit_operator(self, inspection_input) -> Iterable[any]:
52 | """Visit an operator, generate row index number annotations and check whether they get propagated correctly"""
53 | # pylint: disable=too-many-branches, too-many-statements
54 | self._operator_count += 1
55 | self._op_output = []
56 | self._op_lineage = []
57 | current_count = -1
58 | self._materialize_for_this_operator = (self.operator_type_restriction is None) or \
59 | (inspection_input.operator_context.operator
60 | in self.operator_type_restriction)
61 |
62 | if not isinstance(inspection_input, InspectionInputSinkOperator):
63 | self._output_columns = inspection_input.output_columns.fields
64 | else:
65 | self._is_sink = True
66 |
67 | if isinstance(inspection_input, InspectionInputDataSource):
68 | if self._materialize_for_this_operator and self.row_count == RowLineage.ALL_ROWS:
69 | for row in inspection_input.row_iterator:
70 | current_count += 1
71 | annotation = {LineageId(self._operator_count, current_count)}
72 | self._op_output.append(row.output)
73 | self._op_lineage.append(annotation)
74 | yield annotation
75 | elif self._materialize_for_this_operator:
76 | for row in inspection_input.row_iterator:
77 | current_count += 1
78 | annotation = {LineageId(self._operator_count, current_count)}
79 | if current_count < self.row_count:
80 | self._op_output.append(row.output)
81 | self._op_lineage.append(annotation)
82 | yield annotation
83 | else:
84 | for _ in inspection_input.row_iterator:
85 | current_count += 1
86 | annotation = {LineageId(self._operator_count, current_count)}
87 | yield annotation
88 | elif isinstance(inspection_input, InspectionInputNAryOperator):
89 | if self._materialize_for_this_operator and self.row_count == RowLineage.ALL_ROWS:
90 | for row in inspection_input.row_iterator:
91 | current_count += 1
92 |
93 | annotation = set.union(*row.annotation)
94 | self._op_output.append(row.output)
95 | self._op_lineage.append(annotation)
96 | yield annotation
97 | elif self._materialize_for_this_operator:
98 | for row in inspection_input.row_iterator:
99 | current_count += 1
100 |
101 | annotation = set.union(*row.annotation)
102 | if current_count < self.row_count:
103 | self._op_output.append(row.output)
104 | self._op_lineage.append(annotation)
105 | yield annotation
106 | else:
107 | for row in inspection_input.row_iterator:
108 | current_count += 1
109 | annotation = set.union(*row.annotation)
110 | yield annotation
111 | elif isinstance(inspection_input, InspectionInputSinkOperator):
112 | if self._materialize_for_this_operator and self.row_count == RowLineage.ALL_ROWS:
113 | for row in inspection_input.row_iterator:
114 | current_count += 1
115 |
116 | annotation = set.union(*row.annotation)
117 | self._op_lineage.append(annotation)
118 | yield annotation
119 | elif self._materialize_for_this_operator:
120 | for row in inspection_input.row_iterator:
121 | current_count += 1
122 |
123 | annotation = set.union(*row.annotation)
124 | if current_count < self.row_count:
125 | self._op_lineage.append(annotation)
126 | yield annotation
127 | else:
128 | for row in inspection_input.row_iterator:
129 | current_count += 1
130 | annotation = set.union(*row.annotation)
131 | yield annotation
132 | elif isinstance(inspection_input, InspectionInputUnaryOperator):
133 | if self._materialize_for_this_operator and self.row_count == RowLineage.ALL_ROWS:
134 | for row in inspection_input.row_iterator:
135 | current_count += 1
136 | annotation = row.annotation
137 | self._op_output.append(row.output)
138 | self._op_lineage.append(annotation)
139 | yield annotation
140 | elif self._materialize_for_this_operator:
141 | for row in inspection_input.row_iterator:
142 | current_count += 1
143 | annotation = row.annotation
144 |
145 | if current_count < self.row_count:
146 | self._op_output.append(row.output)
147 | self._op_lineage.append(annotation)
148 | yield annotation
149 | else:
150 | for row in inspection_input.row_iterator:
151 | current_count += 1
152 | annotation = row.annotation
153 | yield annotation
154 | else:
155 | assert False
156 |
157 | def get_operator_annotation_after_visit(self) -> any:
158 | if not self._materialize_for_this_operator:
159 | result = None
160 | elif not self._is_sink:
161 | assert self._op_lineage
162 | result = DataFrame(self._op_output, columns=self._output_columns)
163 | result["mlinspect_lineage"] = self._op_lineage
164 | else:
165 | assert self._op_lineage
166 | lineage_series = Series(self._op_lineage)
167 | result = DataFrame(lineage_series, columns=["mlinspect_lineage"])
168 | self._op_output = None
169 | self._op_lineage = None
170 | self._output_columns = None
171 | self._is_sink = False
172 | return result
173 |
174 | @property
175 | def inspection_id(self):
176 | return self._inspection_id
177 |
--------------------------------------------------------------------------------
/mlinspect/inspections/_materialize_first_output_rows.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple inspection to materialise operator outputs
3 | """
4 | from typing import Iterable
5 |
6 | from pandas import DataFrame
7 |
8 | from ._inspection import Inspection
9 | from ._inspection_input import InspectionInputSinkOperator, OperatorType
10 |
11 |
12 | class MaterializeFirstOutputRows(Inspection):
13 | """
14 | A simple example analyzer
15 | """
16 |
17 | def __init__(self, row_count: int):
18 | self.row_count = row_count
19 | self._analyzer_id = self.row_count
20 | self._first_rows_op_output = None
21 | self._operator_type = None
22 | self._output_columns = None
23 |
24 | @property
25 | def inspection_id(self):
26 | return self._analyzer_id
27 |
28 | def visit_operator(self, inspection_input) -> Iterable[any]:
29 | """
30 | Visit an operator
31 | """
32 | current_count = - 1
33 | operator_output = []
34 | self._operator_type = inspection_input.operator_context.operator
35 |
36 | if not isinstance(inspection_input, InspectionInputSinkOperator):
37 | self._output_columns = inspection_input.output_columns.fields
38 | for row in inspection_input.row_iterator:
39 | current_count += 1
40 | if current_count < self.row_count:
41 | operator_output.append(row.output)
42 | yield None
43 | else:
44 | for _ in inspection_input.row_iterator:
45 | yield None
46 |
47 | self._first_rows_op_output = operator_output
48 |
49 | def get_operator_annotation_after_visit(self) -> any:
50 | assert self._operator_type
51 | if self._operator_type is not OperatorType.ESTIMATOR:
52 | assert self._first_rows_op_output and self._output_columns is not None # Visit must be finished
53 | result = DataFrame(self._first_rows_op_output, columns=self._output_columns)
54 | self._first_rows_op_output = None
55 | self._operator_type = None
56 | self._output_columns = None
57 | return result
58 | self._operator_type = None
59 | self._output_columns = None
60 | return None
61 |
--------------------------------------------------------------------------------
/mlinspect/instrumentation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/mlinspect/instrumentation/__init__.py
--------------------------------------------------------------------------------
/mlinspect/instrumentation/_call_capture_transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | Inserts function call capturing into the DAG
3 | """
4 | import ast
5 |
6 |
7 | class CallCaptureTransformer(ast.NodeTransformer):
8 | """
9 | ast.NodeTransformer to replace calls with captured calls
10 | """
11 |
12 | def visit_Call(self, node):
13 | """
14 | Instrument all function calls
15 | """
16 | # pylint: disable=invalid-name
17 | ast.NodeTransformer.generic_visit(self, node)
18 | self.call_add_set_code_reference(node)
19 | return node
20 |
21 | def visit_Subscript(self, node):
22 | """
23 | Instrument all subscript calls
24 | """
25 | # pylint: disable=invalid-name
26 | ast.NodeTransformer.generic_visit(self, node)
27 | if isinstance(node.ctx, ast.Store):
28 | # Needed to get the parent assign node for subscript assigns.
29 | # Without this, "pandas_df['baz'] = baz + 1" would only be "pandas_df['baz']"
30 | code_reference_from_node = node.parents[0]
31 | else:
32 | code_reference_from_node = node
33 | self.subscript_add_set_code_reference(node, code_reference_from_node)
34 | return node
35 |
36 | @staticmethod
37 | def call_add_set_code_reference(node):
38 | """
39 | When a method is called, capture the arguments of the method before executing it
40 | """
41 | # We need to use a keyword argument call to capture stuff because the eval order and because
42 | # keyword arguments may contain function calls.
43 | # https://stackoverflow.com/questions/17948369/is-it-safe-to-rely-on-python-function-arguments-evaluation-order
44 | # Here we can consider instrumenting only functions we patch based on the name
45 | # But the detection based on the static function name is unreliable, so we will skip this for now
46 | kwargs = node.keywords
47 | call_node = CallCaptureTransformer.create_set_code_reference_node_call(node, kwargs)
48 | new_kwargs_node = ast.keyword(value=call_node, arg=None)
49 | node.keywords = [new_kwargs_node]
50 |
51 | @staticmethod
52 | def create_set_code_reference_node_call(node, kwargs):
53 | """
54 | Create the set_code_reference function call ast node that then gets inserted into the AST
55 | """
56 | call_node = ast.Call(func=ast.Name(id='set_code_reference_call', ctx=ast.Load()),
57 | args=[ast.Constant(n=node.lineno, kind=None),
58 | ast.Constant(n=node.col_offset, kind=None),
59 | ast.Constant(n=node.end_lineno, kind=None),
60 | ast.Constant(n=node.end_col_offset, kind=None)],
61 | keywords=kwargs)
62 | return call_node
63 |
64 | @staticmethod
65 | def create_set_code_reference_node_subscript(node, kwargs):
66 | """
67 | Create the set_code_reference function call ast node that then gets inserted into the AST
68 | """
69 | call_node = ast.Call(func=ast.Name(id='set_code_reference_subscript', ctx=ast.Load()),
70 | args=[ast.Constant(n=node.lineno, kind=None),
71 | ast.Constant(n=node.col_offset, kind=None),
72 | ast.Constant(n=node.end_lineno, kind=None),
73 | ast.Constant(n=node.end_col_offset, kind=None),
74 | kwargs],
75 | keywords=[])
76 | return call_node
77 |
78 | @staticmethod
79 | def subscript_add_set_code_reference(node, code_reference_from_node):
80 | """
81 | When the __getitem__ method of some object is called, capture the arguments of the method before executing it
82 | """
83 | subscript_arg = node.slice
84 | call_node = CallCaptureTransformer.create_set_code_reference_node_subscript(code_reference_from_node,
85 | subscript_arg)
86 | node.slice = call_node
87 |
--------------------------------------------------------------------------------
/mlinspect/instrumentation/_dag_node.py:
--------------------------------------------------------------------------------
1 | """
2 | The Nodes used in the DAG as nodes for the networkx.DiGraph
3 | """
4 | import dataclasses
5 | from typing import List
6 |
7 | from mlinspect.inspections._inspection_input import OperatorContext
8 |
9 |
10 | @dataclasses.dataclass(frozen=True)
11 | class CodeReference:
12 | """
13 | Identifies a function call in the user pipeline code
14 | """
15 | lineno: int
16 | col_offset: int
17 | end_lineno: int
18 | end_col_offset: int
19 |
20 |
21 | @dataclasses.dataclass
22 | class BasicCodeLocation:
23 | """
24 | Basic information that can be collected even if `set_code_reference_tracking` is disabled
25 | """
26 | caller_filename: str
27 | lineno: int
28 |
29 |
30 | @dataclasses.dataclass
31 | class OptionalCodeInfo:
32 | """
33 | The additional information collected by mlinspect if `set_code_reference_tracking` is enabled
34 | """
35 | code_reference: CodeReference
36 | source_code: str
37 |
38 |
39 | @dataclasses.dataclass
40 | class DagNodeDetails:
41 | """
42 | Additional info about the DAG node
43 | """
44 | description: str or None = None
45 | columns: List[str] = None
46 |
47 |
48 | @dataclasses.dataclass
49 | class DagNode:
50 | """
51 | A DAG Node
52 | """
53 |
54 | node_id: int
55 | code_location: BasicCodeLocation
56 | operator_info: OperatorContext
57 | details: DagNodeDetails
58 | optional_code_info: OptionalCodeInfo or None = None
59 |
60 | def __hash__(self):
61 | return hash(self.node_id)
62 |
--------------------------------------------------------------------------------
/mlinspect/monkeypatching/README.md:
--------------------------------------------------------------------------------
1 | # Support for different libraries and API functions
2 |
3 | ## Handling of unknown functions
4 | * Extending mlinspect to support more and more API functions and libraries will be an ongoing effort. External contributions are very welcome!
5 | * However, mlinspect doesn't just crash when it encounters unknown functions.
6 | * mlinspect just ignores functions it doesn't recognize. If a function it does recognize encounters the input from a relevant unknown function, it will create a `MISSING_OP` node for a single or multiple unknown function calls. The inspections also get to see this unknown input, from their perspective it's just a new data source.
7 | * Example:
8 | ```python
9 | import networkx
10 | from inspect import cleandoc
11 | from testfixtures import compare
12 | from mlinspect import OperatorType, OperatorContext, FunctionInfo, PipelineInspector, CodeReference, DagNode, BasicCodeLocation, DagNodeDetails, \
13 | OptionalCodeInfo
14 |
15 |
16 | test_code = cleandoc("""
17 | from inspect import cleandoc
18 | import pandas
19 | from mlinspect.testing._testing_helper_utils import black_box_df_op
20 |
21 | df = black_box_df_op()
22 | df = df.dropna()
23 | """)
24 |
25 | extracted_dag = PipelineInspector.on_pipeline_from_string(test_code).execute().dag
26 |
27 | expected_dag = networkx.DiGraph()
28 | expected_missing_op = DagNode(-1,
29 | BasicCodeLocation("", 5),
30 | OperatorContext(OperatorType.MISSING_OP, None),
31 | DagNodeDetails('Warning! Operator :5 (df.dropna()) encountered a '
32 | 'DataFrame resulting from an operation without mlinspect support!',
33 | ['A']),
34 | OptionalCodeInfo(CodeReference(5, 5, 5, 16), 'df.dropna()'))
35 | expected_select = DagNode(0,
36 | BasicCodeLocation("", 5),
37 | OperatorContext(OperatorType.SELECTION, FunctionInfo('pandas.core.frame', 'dropna')),
38 | DagNodeDetails('dropna', ['A']),
39 | OptionalCodeInfo(CodeReference(5, 5, 5, 16), 'df.dropna()'))
40 | expected_dag.add_edge(expected_missing_op, expected_select)
41 | compare(networkx.to_dict_of_dicts(extracted_dag), networkx.to_dict_of_dicts(expected_dag))
42 | ```
43 |
44 | ## Pandas
45 | * The implementation can be found mainly [here](./_patch_pandas.py)
46 | * The [tests](../../test/monkeypatching/test_patch_pandas.py) are probably more useful to look at
47 | * Currently supported functions:
48 |
49 | | Function Call | Operator
50 | | ------------- |:-------------:|
51 | | `('pandas.io.parsers', 'read_csv')` | Data Source |
52 | | `('pandas.core.frame', 'DataFrame')` | Data Source |
53 | | `('pandas.core.series', 'Series')` | Data Source |
54 | | `('pandas.core.frame', '__getitem__')`, arg type: strings | Projection|
55 | | `('pandas.core.frame', '__getitem__')`, arg type: series | Selection |
56 | | `('pandas.core.frame', 'dropna')` | Selection |
57 | | `('pandas.core.frame', 'replace')` | Projection (Mod) |
58 | | `('pandas.core.frame', '__setitem__')` | Projection (Mod) |
59 | | `('pandas.core.frame', 'merge')` | Join |
60 | | `('pandas.core.frame', 'groupby')` | Nothing (until a following agg call) |
61 | | `('pandas.core.groupbygeneric', 'agg')` | Groupby/Agg |
62 |
63 | ## Sklearn
64 | * The implementation can be found mainly [here](./_patch_sklearn.py)
65 | * The [tests](../../test/monkeypatching/test_patch_sklearn.py) are probably more useful to look at
66 | * Currently supported functions:
67 |
68 | | Function Call | Operator
69 | | ------------- |:-------------:|
70 | | `('sklearn.compose._column_transformer', 'ColumnTransformer')`, column selection | Projection |
71 | | `('sklearn.preprocessing._label', 'label_binarize')` | Projection (Mod) |
72 | | `('sklearn.compose._column_transformer', 'ColumnTransformer')`, concatenation | Concatenation |
73 | | `('sklearn.model_selection._split', 'train_test_split')` | Split (Train/Test)
74 | | `('sklearn.preprocessing._encoders', 'OneHotEncoder')`, arg type: strings | Transformer |
75 | | `('sklearn.preprocessing._data', 'StandardScaler')` | Transformer |
76 | | `('sklearn.impute._base’, 'SimpleImputer')` | Transformer |
77 | | `('sklearn.feature_extraction.text’, 'HashingVectorizer')` | Transformer |
78 | | `('sklearn.preprocessing._discretization', 'KBinsDiscretizer')` | Transformer |
79 | | `('sklearn.preprocessing_function_transformer','FunctionTransformer')` | Transformer |
80 | | `('sklearn.tree._classes', 'DecisionTreeClassifier')` | Estimator |
81 | | `('sklearn.linear_model._stochastic_gradient', 'SGDClassifier')` | Estimator |
82 | | `('tensorflow.python.keras.wrappers.scikit_learn', 'KerasClassifier')` | Estimator |
83 | | `('sklearn.linear_model._logistic', 'LogisticRegression')` | Estimator |
84 |
85 |
86 | ## Numpy
87 | * The implementation can be found mainly [here](./_patch_numpy.py)
88 | * The [tests](../../test/monkeypatching/test_patch_numpy.py) are probably more useful to look at
89 | * Currently supported functions:
90 |
91 | | Function Call | Operator
92 | | ------------- |:-------------:|
93 | | `('numpy.random', 'random')` | Data Source |
94 |
95 | ## Statsmodels
96 | * The implementation can be found mainly [here](./_patch_statsmodels.py)
97 | * The [tests](../../test/monkeypatching/test_patch_statsmodels.py) are probably more useful to look at
98 | * Currently supported functions:
99 |
100 | | Function Call | Operator
101 | | ------------- |:-------------:|
102 | | `('statsmodels.datasets', 'get_rdataset')` | Data Source |
103 | | `('statsmodels.api', 'add_constant')` | Projection (Mod) |
104 | | `('statsmodel.api', 'OLS')`, numpy syntax | Estimator |
105 |
--------------------------------------------------------------------------------
/mlinspect/monkeypatching/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/mlinspect/monkeypatching/__init__.py
--------------------------------------------------------------------------------
/mlinspect/monkeypatching/_mlinspect_ndarray.py:
--------------------------------------------------------------------------------
1 | """
2 | Monkey patching for numpy
3 | """
4 | import numpy
5 |
6 |
7 | class MlinspectNdarray(numpy.ndarray):
8 | """
9 | A wrapper for numpy ndarrays to store our additional annotations.
10 | See https://docs.scipy.org/doc/numpy-1.13.0/user/basics.subclassing.html
11 | """
12 |
13 | def __new__(cls, input_array, _mlinspect_dag_node=None, _mlinspect_annotation=None):
14 | # Input array is an already formed ndarray instance
15 | # We first cast to be our class type
16 | obj = numpy.asarray(input_array).view(cls)
17 | # add the new attribute to the created instance
18 | obj._mlinspect_dag_node = _mlinspect_dag_node
19 | obj._mlinspect_annotation = _mlinspect_annotation
20 | # Finally, we must return the newly created object:
21 | return obj
22 |
23 | def __array_finalize__(self, obj):
24 | # pylint: disable=attribute-defined-outside-init
25 | # see InfoArray.__array_finalize__ for comments
26 | if obj is None:
27 | return
28 | self._mlinspect_dag_node = getattr(obj, '_mlinspect_dag_node', None)
29 | self._mlinspect_annotation = getattr(obj, '_mlinspect_annotation', None)
30 |
31 | def ravel(self, order='C'):
32 | # pylint: disable=no-member
33 | result = super().ravel(order)
34 | assert isinstance(result, MlinspectNdarray)
35 | result._mlinspect_dag_node = self._mlinspect_dag_node # pylint: disable=protected-access
36 | result._mlinspect_annotation = self._mlinspect_annotation # pylint: disable=protected-access
37 | return result
38 |
--------------------------------------------------------------------------------
/mlinspect/monkeypatching/_patch_numpy.py:
--------------------------------------------------------------------------------
1 | """
2 | Monkey patching for numpy
3 | """
4 | import gorilla
5 | from numpy import random
6 |
7 | from mlinspect import DagNode, BasicCodeLocation, DagNodeDetails
8 | from mlinspect.backends._sklearn_backend import SklearnBackend
9 | from mlinspect.inspections._inspection_input import OperatorContext, FunctionInfo, OperatorType
10 | from mlinspect.monkeypatching._monkey_patching_utils import execute_patched_func, add_dag_node, \
11 | get_optional_code_info_or_none
12 |
13 |
14 | @gorilla.patches(random)
15 | class NumpyRandomPatching:
16 | """ Patches for sklearn """
17 |
18 | # pylint: disable=too-few-public-methods,no-self-argument
19 |
20 | @gorilla.name('random')
21 | @gorilla.settings(allow_hit=True)
22 | def patched_random(*args, **kwargs):
23 | """ Patch for ('numpy.random', 'random') """
24 | # pylint: disable=no-method-argument
25 | original = gorilla.get_original_attribute(random, 'random')
26 |
27 | def execute_inspections(op_id, caller_filename, lineno, optional_code_reference, optional_source_code):
28 | """ Execute inspections, add DAG node """
29 | function_info = FunctionInfo('numpy.random', 'random')
30 | operator_context = OperatorContext(OperatorType.DATA_SOURCE, function_info)
31 | input_infos = SklearnBackend.before_call(operator_context, [])
32 | result = original(*args, **kwargs)
33 | backend_result = SklearnBackend.after_call(operator_context, input_infos, result)
34 |
35 | dag_node = DagNode(op_id,
36 | BasicCodeLocation(caller_filename, lineno),
37 | operator_context,
38 | DagNodeDetails("random", ['array']),
39 | get_optional_code_info_or_none(optional_code_reference, optional_source_code))
40 | add_dag_node(dag_node, [], backend_result)
41 | new_return_value = backend_result.annotated_dfobject.result_data
42 | return new_return_value
43 |
44 | return execute_patched_func(original, execute_inspections, *args, **kwargs)
45 |
--------------------------------------------------------------------------------
/mlinspect/monkeypatching/_patch_statsmodels.py:
--------------------------------------------------------------------------------
1 | """
2 | Monkey patching for numpy
3 | """
4 | import gorilla
5 | from statsmodels import api
6 | from statsmodels.api import datasets
7 |
8 | from mlinspect import DagNode, BasicCodeLocation, DagNodeDetails
9 | from mlinspect.backends._pandas_backend import PandasBackend
10 | from mlinspect.backends._sklearn_backend import SklearnBackend
11 | from mlinspect.inspections._inspection_input import OperatorContext, FunctionInfo, OperatorType
12 | from mlinspect.instrumentation._pipeline_executor import singleton
13 | from mlinspect.monkeypatching._monkey_patching_utils import execute_patched_func, add_dag_node, \
14 | get_optional_code_info_or_none, get_input_info, execute_patched_func_no_op_id, add_train_data_node, \
15 | add_train_label_node
16 |
17 |
18 | @gorilla.patches(api)
19 | class StatsmodelApiPatching:
20 | """ Patches for statsmodel """
21 | # pylint: disable=too-few-public-methods,no-self-argument
22 |
23 | @gorilla.name('add_constant')
24 | @gorilla.settings(allow_hit=True)
25 | def patched_random(*args, **kwargs):
26 | """ Patch for ('statsmodel.api', 'add_constant') """
27 | # pylint: disable=no-method-argument
28 | original = gorilla.get_original_attribute(api, 'add_constant')
29 |
30 | def execute_inspections(op_id, caller_filename, lineno, optional_code_reference, optional_source_code):
31 | """ Execute inspections, add DAG node """
32 | function_info = FunctionInfo('statsmodel.api', 'add_constant')
33 | input_info = get_input_info(args[0], caller_filename, lineno, function_info, optional_code_reference,
34 | optional_source_code)
35 |
36 | operator_context = OperatorContext(OperatorType.PROJECTION_MODIFY, function_info)
37 | input_infos = SklearnBackend.before_call(operator_context, [input_info.annotated_dfobject])
38 | result = original(input_infos[0].result_data, *args[1:], **kwargs)
39 | backend_result = SklearnBackend.after_call(operator_context,
40 | input_infos,
41 | result)
42 | new_return_value = backend_result.annotated_dfobject.result_data
43 |
44 | dag_node = DagNode(op_id,
45 | BasicCodeLocation(caller_filename, lineno),
46 | operator_context,
47 | DagNodeDetails("Adds const column", ["array"]),
48 | get_optional_code_info_or_none(optional_code_reference, optional_source_code))
49 | add_dag_node(dag_node, [input_info.dag_node], backend_result)
50 |
51 | return new_return_value
52 | return execute_patched_func(original, execute_inspections, *args, **kwargs)
53 |
54 |
55 | @gorilla.patches(datasets)
56 | class StatsmodelsDatasetPatching:
57 | """ Patches for pandas """
58 |
59 | # pylint: disable=too-few-public-methods,no-self-argument
60 |
61 | @gorilla.name('get_rdataset')
62 | @gorilla.settings(allow_hit=True)
63 | def patched_read_csv(*args, **kwargs):
64 | """ Patch for ('statsmodels.datasets', 'get_rdataset') """
65 | # pylint: disable=no-method-argument
66 | original = gorilla.get_original_attribute(datasets, 'get_rdataset')
67 |
68 | def execute_inspections(op_id, caller_filename, lineno, optional_code_reference, optional_source_code):
69 | """ Execute inspections, add DAG node """
70 | function_info = FunctionInfo('statsmodels.datasets', 'get_rdataset')
71 |
72 | operator_context = OperatorContext(OperatorType.DATA_SOURCE, function_info)
73 | input_infos = PandasBackend.before_call(operator_context, [])
74 | result = original(*args, **kwargs)
75 | backend_result = PandasBackend.after_call(operator_context,
76 | input_infos,
77 | result.data)
78 | result.data = backend_result.annotated_dfobject.result_data
79 | dag_node = DagNode(op_id,
80 | BasicCodeLocation(caller_filename, lineno),
81 | operator_context,
82 | DagNodeDetails(result.title, list(result.data.columns)),
83 | get_optional_code_info_or_none(optional_code_reference, optional_source_code))
84 | add_dag_node(dag_node, [], backend_result)
85 | return result
86 |
87 | return execute_patched_func(original, execute_inspections, *args, **kwargs)
88 |
89 |
90 | @gorilla.patches(api.OLS)
91 | class StatsmodelsOlsPatching:
92 | """ Patches for statsmodel OLS"""
93 |
94 | # pylint: disable=too-few-public-methods
95 |
96 | @gorilla.name('__init__')
97 | @gorilla.settings(allow_hit=True)
98 | def patched__init__(self, *args, **kwargs):
99 | """ Patch for ('statsmodel.api', 'OLS') """
100 | # pylint: disable=no-method-argument, attribute-defined-outside-init, too-many-locals
101 | original = gorilla.get_original_attribute(api.OLS, '__init__')
102 |
103 | def execute_inspections(_, caller_filename, lineno, optional_code_reference, optional_source_code):
104 | """ Execute inspections, add DAG node """
105 | original(self, *args, **kwargs)
106 |
107 | self.mlinspect_caller_filename = caller_filename
108 | self.mlinspect_lineno = lineno
109 | self.mlinspect_optional_code_reference = optional_code_reference
110 | self.mlinspect_optional_source_code = optional_source_code
111 |
112 | return execute_patched_func_no_op_id(original, execute_inspections, *args, **kwargs)
113 |
114 | @gorilla.name('fit')
115 | @gorilla.settings(allow_hit=True)
116 | def patched_fit(self, *args, **kwargs):
117 | """ Patch for ('statsmodel.api.OLS', 'fit') """
118 | # pylint: disable=no-method-argument, too-many-locals
119 | original = gorilla.get_original_attribute(api.OLS, 'fit')
120 | function_info = FunctionInfo('statsmodel.api.OLS', 'fit')
121 |
122 | # Train data
123 | # pylint: disable=no-member
124 | data_backend_result, train_data_node, train_data_result = add_train_data_node(self,
125 | self.data.exog,
126 | function_info)
127 | self.data.exog = train_data_result
128 | # pylint: disable=no-member
129 | label_backend_result, train_labels_node, train_labels_result = add_train_label_node(self, self.data.endog,
130 | function_info)
131 | self.data.endog = train_labels_result
132 |
133 | # Estimator
134 | operator_context = OperatorContext(OperatorType.ESTIMATOR, function_info)
135 | input_dfs = [data_backend_result.annotated_dfobject, label_backend_result.annotated_dfobject]
136 | input_infos = SklearnBackend.before_call(operator_context, input_dfs)
137 | result = original(self, *args, **kwargs)
138 | estimator_backend_result = SklearnBackend.after_call(operator_context,
139 | input_infos,
140 | None)
141 |
142 | dag_node = DagNode(singleton.get_next_op_id(),
143 | BasicCodeLocation(self.mlinspect_caller_filename, self.mlinspect_lineno),
144 | operator_context,
145 | DagNodeDetails("Decision Tree", []),
146 | get_optional_code_info_or_none(self.mlinspect_optional_code_reference,
147 | self.mlinspect_optional_source_code))
148 | add_dag_node(dag_node, [train_data_node, train_labels_node], estimator_backend_result)
149 | return result
150 |
--------------------------------------------------------------------------------
/mlinspect/testing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/mlinspect/testing/__init__.py
--------------------------------------------------------------------------------
/mlinspect/testing/_random_annotation_testing_inspection.py:
--------------------------------------------------------------------------------
1 | """
2 | A simple analyzer for testing annotation propagation
3 | """
4 | import random
5 | from typing import Iterable
6 |
7 | from mlinspect.inspections._inspection import Inspection
8 | from mlinspect.inspections._inspection_input import InspectionRowUnaryOperator, InspectionRowNAryOperator, \
9 | InspectionRowSinkOperator
10 |
11 |
12 | class RandomAnnotationTestingInspection(Inspection):
13 | """
14 | A simple analyzer for testing annotation propagation
15 | """
16 |
17 | def __init__(self, row_count: int):
18 | self.operator_count = 0
19 | self.row_count = row_count
20 | self._analyzer_id = self.row_count
21 | self._operator_output = None
22 | self.rows_to_random_numbers_operator_0 = {}
23 |
24 | def visit_operator(self, inspection_input) -> Iterable[any]:
25 | """Visit an operator, generate random number annotations and check whether they get propagated correctly"""
26 | # pylint: disable=too-many-branches
27 | operator_output = []
28 | current_count = - 1
29 |
30 | if self.operator_count == 0:
31 | for row in inspection_input.row_iterator:
32 | current_count += 1
33 | if current_count < self.row_count:
34 | random_number = random.randint(0, 10000)
35 | output_tuple = row.output
36 | self.rows_to_random_numbers_operator_0[output_tuple] = random_number
37 | operator_output.append(row.output)
38 | yield random_number
39 | else:
40 | yield None
41 | elif self.operator_count == 1:
42 | filtered_rows = 0
43 | for row in inspection_input.row_iterator:
44 | current_count += 1
45 | assert isinstance(row, InspectionRowUnaryOperator) # This analyzer is really only for testing
46 | annotation = row.annotation
47 | if current_count < self.row_count:
48 | output_tuple = row.output
49 | if output_tuple in self.rows_to_random_numbers_operator_0:
50 | random_number = self.rows_to_random_numbers_operator_0[output_tuple]
51 | assert annotation == random_number # Test whether the annotation got propagated correctly
52 | else:
53 | filtered_rows += 1
54 | assert filtered_rows != self.row_count # If all rows got filtered, this test is useless
55 | operator_output.append(row.output)
56 | yield annotation
57 | else:
58 | for row in inspection_input.row_iterator:
59 | assert isinstance(row, (InspectionRowUnaryOperator, InspectionRowNAryOperator,
60 | InspectionRowSinkOperator))
61 | if isinstance(row, InspectionRowUnaryOperator):
62 | annotation = row.annotation
63 | else:
64 | annotation = row.annotation[0]
65 | yield annotation
66 | self.operator_count += 1
67 | self._operator_output = operator_output
68 |
69 | def get_operator_annotation_after_visit(self) -> any:
70 | assert self._operator_output or self.operator_count > 1
71 | result = self._operator_output
72 | self._operator_output = None
73 | return result
74 |
75 | @property
76 | def inspection_id(self):
77 | return self._analyzer_id
78 |
--------------------------------------------------------------------------------
/mlinspect/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Packages and classes we want to expose to users
3 | """
4 | from ._utils import get_project_root
5 |
6 | __all__ = [
7 | 'get_project_root',
8 | ]
9 |
--------------------------------------------------------------------------------
/mlinspect/utils/_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Some useful utils for the project
3 | """
4 | from pathlib import Path
5 |
6 |
7 | def get_project_root() -> Path:
8 | """Returns the project root folder."""
9 | return Path(__file__).parent.parent.parent
10 |
--------------------------------------------------------------------------------
/mlinspect/visualisation/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Packages and classes we want to expose to users
3 | """
4 | from ._visualisation import get_dag_as_pretty_string, save_fig_to_path
5 |
6 | __all__ = [
7 | 'get_dag_as_pretty_string',
8 | 'save_fig_to_path',
9 | ]
10 |
--------------------------------------------------------------------------------
/mlinspect/visualisation/_visualisation.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions to visualise the extracted DAG
3 | """
4 | from inspect import cleandoc
5 |
6 | import networkx
7 | from networkx.drawing.nx_agraph import to_agraph
8 |
9 | from mlinspect import DagNode
10 |
11 |
12 | def save_fig_to_path(extracted_dag, filename):
13 | """
14 | Create a figure of the extracted DAG and save it with some filename
15 | """
16 |
17 | def get_new_node_label(node: DagNode):
18 | label = cleandoc(f"""
19 | {node.node_id}: {node.operator_info.operator.value} (L{node.code_location.lineno})
20 | {node.details.description or ""}
21 | """)
22 | return label
23 |
24 | # noinspection PyTypeChecker
25 | extracted_dag = networkx.relabel_nodes(extracted_dag, get_new_node_label)
26 |
27 | agraph = to_agraph(extracted_dag)
28 | agraph.layout('dot')
29 | agraph.draw(filename)
30 |
31 |
32 | def get_dag_as_pretty_string(extracted_dag):
33 | """
34 | Create a figure of the extracted DAG and save it with some filename
35 | """
36 |
37 | def get_new_node_label(node: DagNode):
38 | description = ""
39 | if node.details.description:
40 | description = f"({node.details.description})"
41 |
42 | label = f"{node.operator_info.operator.value}{description}"
43 | return label
44 |
45 | # noinspection PyTypeChecker
46 | extracted_dag = networkx.relabel_nodes(extracted_dag, get_new_node_label)
47 |
48 | agraph = to_agraph(extracted_dag)
49 | return agraph.to_string()
50 |
--------------------------------------------------------------------------------
/requirements/requirements.dev.txt:
--------------------------------------------------------------------------------
1 | pylint==2.17.7
2 | pytest==8.0.0
3 | pytest-pylint==0.21.0
4 | pytest-runner==5.2
5 | pytest-cov==2.10.1
6 | pytest-pycharm==0.7.0
7 | pytest-mock==3.3.1
8 | gensim==4.3.0
9 | keras==2.15.0
10 | jupyter==1.0.0
11 | importnb==0.6.2
12 | seaborn==0.11.0
13 |
--------------------------------------------------------------------------------
/requirements/requirements.txt:
--------------------------------------------------------------------------------
1 | scikit-learn==1.3.2
2 | protobuf==3.20.3
3 | pandas==1.5.3
4 | numpy==1.23.5
5 | six==1.15.0
6 | nbformat==5.0.8
7 | nbconvert==6.4.5
8 | ipython==7.25.0
9 | astpretty==2.0.0
10 | astmonkey==0.3.6
11 | networkx==2.5
12 | more-itertools==8.6.0
13 | pygraphviz==1.7
14 | testfixtures==6.17.1
15 | matplotlib==3.4.2
16 | gorilla==0.4.0
17 | astunparse==1.6.3
18 | setuptools==57.0.0
19 | scipy==1.11.3
20 | statsmodels==0.14.1
21 | tensorflow==2.15.0
22 | scikeras==0.12
23 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [aliases]
2 | test=pytest
3 |
4 | [tool:pytest]
5 | addopts = --pylint --pylint-rcfile=./pylintrc --cov=mlinspect --cov-report=xml
6 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | For 'python setup.py develop' and 'python setup.py test'
3 | """
4 | import os
5 | from setuptools import setup, find_packages
6 |
7 | ROOT = os.path.dirname(__file__)
8 |
9 | with open(os.path.join(ROOT, "requirements", "requirements.txt"), encoding="utf-8") as f:
10 | required = f.read().splitlines()
11 |
12 | with open(os.path.join(ROOT, "requirements", "requirements.dev.txt"), encoding="utf-8") as f:
13 | test_required = f.read().splitlines()
14 |
15 | with open("README.md", "r", encoding="utf-8") as fh:
16 | long_description = fh.read()
17 |
18 | setup(
19 | name="mlinspect",
20 | version="0.0.1.dev0",
21 | description="Inspect ML Pipelines in the form of a DAG",
22 | author='Stefan Grafberger',
23 | author_email='stefangrafberger@gmail.com',
24 | long_description=long_description,
25 | long_description_content_type="text/markdown",
26 | packages=find_packages(),
27 | include_package_data=True,
28 | install_requires=required,
29 | tests_require=test_required,
30 | extras_require={'dev': test_required},
31 | license='Apache License 2.0',
32 | python_requires='==3.10.*',
33 | classifiers=[
34 | 'License :: OSI Approved :: Apache Software License',
35 | 'Programming Language :: Python :: 3 :: Only',
36 | 'Programming Language :: Python :: 3.10'
37 | ]
38 | )
39 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/__init__.py
--------------------------------------------------------------------------------
/test/backends/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/backends/__init__.py
--------------------------------------------------------------------------------
/test/backends/test_pandas_backend.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the PipelineExecutor works
3 | """
4 | from mlinspect.testing._testing_helper_utils import get_pandas_read_csv_and_dropna_code, \
5 | run_random_annotation_testing_analyzer, run_row_index_annotation_testing_analyzer, run_multiple_test_analyzers
6 |
7 |
8 | def test_pandas_backend_random_annotation_propagation():
9 | """
10 | Tests whether the pandas backend works
11 | """
12 | code = get_pandas_read_csv_and_dropna_code()
13 | random_annotation_analyzer_result = run_random_annotation_testing_analyzer(code)
14 | assert len(random_annotation_analyzer_result) == 3
15 |
16 |
17 | def test_pandas_backend_row_index_annotation_propagation():
18 | """
19 | Tests whether the pandas backend works
20 | """
21 | code = get_pandas_read_csv_and_dropna_code()
22 | lineage_result = run_row_index_annotation_testing_analyzer(code)
23 | assert len(lineage_result) == 3
24 |
25 |
26 | def test_pandas_backend_annotation_propagation_multiple_analyzers():
27 | """
28 | Tests whether the pandas backend works
29 | """
30 | code = get_pandas_read_csv_and_dropna_code()
31 |
32 | dag_node_to_inspection_results, analyzers = run_multiple_test_analyzers(code)
33 |
34 | for inspection_result in dag_node_to_inspection_results.values():
35 | for analyzer in analyzers:
36 | assert analyzer in inspection_result
37 |
--------------------------------------------------------------------------------
/test/backends/test_sklearn_backend.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the PipelineExecutor works
3 | """
4 |
5 | from mlinspect.testing._testing_helper_utils import run_random_annotation_testing_analyzer, \
6 | run_row_index_annotation_testing_analyzer, run_multiple_test_analyzers
7 | from example_pipelines import ADULT_SIMPLE_PY
8 |
9 |
10 | def test_sklearn_backend_random_annotation_propagation():
11 | """
12 | Tests whether the sklearn backend works
13 | """
14 | with open(ADULT_SIMPLE_PY, encoding="utf-8") as file:
15 | code = file.read()
16 |
17 | random_annotation_analyzer_result = run_random_annotation_testing_analyzer(code)
18 | assert len(random_annotation_analyzer_result) == 12
19 |
20 |
21 | def test_sklearn_backend_row_index_annotation_propagation():
22 | """
23 | Tests whether the sklearn backend works
24 | """
25 | with open(ADULT_SIMPLE_PY, encoding="utf-8") as file:
26 | code = file.read()
27 | lineage_result = run_row_index_annotation_testing_analyzer(code)
28 | assert len(lineage_result) == 12
29 |
30 |
31 | def test_sklearn_backend_annotation_propagation_multiple_analyzers():
32 | """
33 | Tests whether the sklearn backend works
34 | """
35 | with open(ADULT_SIMPLE_PY, encoding="utf-8") as file:
36 | code = file.read()
37 |
38 | dag_node_to_inspection_results, analyzers = run_multiple_test_analyzers(code)
39 |
40 | for inspection_result in dag_node_to_inspection_results.values():
41 | for analyzer in analyzers:
42 | assert analyzer in inspection_result
43 |
--------------------------------------------------------------------------------
/test/checks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/checks/__init__.py
--------------------------------------------------------------------------------
/test/checks/test_no_bias_introduced_for.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether NoMissingEmbeddings works
3 | """
4 | import math
5 | from inspect import cleandoc
6 |
7 | from pandas import DataFrame
8 | from testfixtures import compare
9 |
10 | from mlinspect import DagNode, BasicCodeLocation, OperatorContext, OperatorType, FunctionInfo, DagNodeDetails, \
11 | OptionalCodeInfo
12 | from mlinspect._pipeline_inspector import PipelineInspector
13 | from mlinspect.checks import CheckStatus, NoBiasIntroducedFor, \
14 | NoBiasIntroducedForResult
15 | from mlinspect.checks._no_bias_introduced_for import BiasDistributionChange
16 | from mlinspect.instrumentation._dag_node import CodeReference
17 |
18 |
19 | def test_no_bias_introduced_for_merge():
20 | """
21 | Tests whether RowLineage works for joins
22 | """
23 | test_code = cleandoc("""
24 | import pandas as pd
25 |
26 | df_a = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'], 'B': [1, 2, 4, 5, 7]})
27 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, 5], 'C': [1, 5, 4, 11, None]})
28 | df_merged = df_a.merge(df_b, on='B')
29 | """)
30 |
31 | inspector_result = PipelineInspector \
32 | .on_pipeline_from_string(test_code) \
33 | .add_check(NoBiasIntroducedFor(['A'])) \
34 | .execute()
35 |
36 | check_result = inspector_result.check_to_check_results[NoBiasIntroducedFor(['A'])]
37 | expected_result = get_expected_check_result_merge()
38 | compare(check_result, expected_result)
39 |
40 |
41 | def test_no_bias_introduced_simple_imputer():
42 | """
43 | Tests whether RowLineage works for joins
44 | """
45 | test_code = cleandoc("""
46 | import pandas as pd
47 | from sklearn.impute import SimpleImputer
48 | import numpy as np
49 |
50 | df = pd.DataFrame({'A': ['cat_a', np.nan, 'cat_a', 'cat_c']})
51 | imputer = SimpleImputer(missing_values=np.nan, strategy='most_frequent')
52 | imputed_data = imputer.fit_transform(df)
53 | """)
54 |
55 | inspector_result = PipelineInspector \
56 | .on_pipeline_from_string(test_code) \
57 | .add_check(NoBiasIntroducedFor(['A'])) \
58 | .execute()
59 |
60 | check_result = inspector_result.check_to_check_results[NoBiasIntroducedFor(['A'])]
61 | expected_result = get_expected_check_result_simple_imputer()
62 | compare(check_result, expected_result)
63 |
64 |
65 | def get_expected_check_result_merge():
66 | """ Expected result for the code snippet in test_no_bias_introduced_for_merge"""
67 | failing_dag_node = DagNode(2,
68 | BasicCodeLocation('', 5),
69 | OperatorContext(OperatorType.JOIN, FunctionInfo('pandas.core.frame', 'merge')),
70 | DagNodeDetails("on 'B'", ['A', 'B', 'C']),
71 | OptionalCodeInfo(CodeReference(5, 12, 5, 36), "df_a.merge(df_b, on='B')"))
72 |
73 | change_df = DataFrame({'sensitive_column_value': ['cat_a', 'cat_b', 'cat_c'],
74 | 'count_before': [2, 2, 1],
75 | 'count_after': [2, 1, 1],
76 | 'ratio_before': [0.4, 0.4, 0.2],
77 | 'ratio_after': [0.5, 0.25, 0.25],
78 | 'relative_ratio_change': [(0.5 - 0.4) / 0.4, (.25 - 0.4) / 0.4, (0.25 - 0.2) / 0.2]})
79 | expected_distribution_change = BiasDistributionChange(failing_dag_node, False, (.25 - 0.4) / 0.4, change_df)
80 | expected_dag_node_to_change = {failing_dag_node: {'A': expected_distribution_change}}
81 | failure_message = 'A Join causes a min_relative_ratio_change of \'A\' by -0.37500000000000006, a value below the ' \
82 | 'configured minimum threshold -0.3!'
83 | expected_result = NoBiasIntroducedForResult(NoBiasIntroducedFor(['A']), CheckStatus.FAILURE, failure_message,
84 | expected_dag_node_to_change)
85 | return expected_result
86 |
87 |
88 | def get_expected_check_result_simple_imputer():
89 | """ Expected result for the code snippet in test_no_bias_introduced_for_simple_imputer"""
90 | imputer_dag_node = DagNode(1,
91 | BasicCodeLocation('', 6),
92 | OperatorContext(OperatorType.TRANSFORMER,
93 | FunctionInfo('sklearn.impute._base', 'SimpleImputer')),
94 | DagNodeDetails('Simple Imputer: fit_transform', ['A']),
95 | OptionalCodeInfo(CodeReference(6, 10, 6, 72),
96 | "SimpleImputer(missing_values=np.nan, strategy='most_frequent')"))
97 |
98 | change_df = DataFrame({'sensitive_column_value': ['cat_a', 'cat_c', math.nan],
99 | 'count_before': [2, 1, 1],
100 | 'count_after': [3, 1, 0],
101 | 'ratio_before': [0.5, 0.25, 0.25],
102 | 'ratio_after': [0.75, 0.25, 0.],
103 | 'relative_ratio_change': [0.5, 0., -1.]})
104 | expected_distribution_change = BiasDistributionChange(imputer_dag_node, True, 0., change_df)
105 | expected_dag_node_to_change = {imputer_dag_node: {'A': expected_distribution_change}}
106 | expected_result = NoBiasIntroducedForResult(NoBiasIntroducedFor(['A']), CheckStatus.SUCCESS, None,
107 | expected_dag_node_to_change)
108 | return expected_result
109 |
--------------------------------------------------------------------------------
/test/checks/test_no_illegal_features.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether NoMissingEmbeddings works
3 | """
4 | from inspect import cleandoc
5 |
6 | from testfixtures import compare, SequenceComparison, StringComparison
7 |
8 | from mlinspect._pipeline_inspector import PipelineInspector
9 | from mlinspect.checks import NoIllegalFeatures, CheckStatus, NoIllegalFeaturesResult
10 |
11 |
12 | def test_no_illegal_features():
13 | """
14 | Tests whether NoIllegalFeatures works for joins
15 | """
16 | test_code = cleandoc("""
17 | import pandas as pd
18 | from sklearn.preprocessing import label_binarize, StandardScaler, OneHotEncoder
19 | from sklearn.compose import ColumnTransformer
20 | from sklearn.pipeline import Pipeline
21 | from sklearn.tree import DecisionTreeClassifier
22 |
23 | data = pd.DataFrame({'age': [1, 2, 10, 5], 'B': ['cat_a', 'cat_b', 'cat_a', 'cat_c'],
24 | 'C': ['cat_a', 'cat_b', 'cat_a', 'cat_c'], 'target': ['no', 'no', 'yes', 'yes']})
25 |
26 | column_transformer = ColumnTransformer(transformers=[
27 | ('numeric', StandardScaler(), ['age']),
28 | ('categorical', OneHotEncoder(sparse=False), ['B', 'C'])
29 | ])
30 |
31 | income_pipeline = Pipeline([
32 | ('features', column_transformer),
33 | ('classifier', DecisionTreeClassifier())])
34 |
35 | labels = label_binarize(data['target'], classes=['no', 'yes'])
36 | income_pipeline.fit(data, labels)
37 | """)
38 |
39 | inspector_result = PipelineInspector \
40 | .on_pipeline_from_string(test_code) \
41 | .add_check(NoIllegalFeatures(['C'])) \
42 | .execute()
43 |
44 | check_result = inspector_result.check_to_check_results[NoIllegalFeatures(['C'])]
45 | # pylint: disable=anomalous-backslash-in-string
46 | expected_result = NoIllegalFeaturesResult(NoIllegalFeatures(['C']), CheckStatus.FAILURE,
47 | StringComparison("Used illegal columns\: .*"),
48 | SequenceComparison('C', 'age', ordered=False))
49 | compare(check_result, expected_result)
50 |
--------------------------------------------------------------------------------
/test/demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/demo/__init__.py
--------------------------------------------------------------------------------
/test/demo/feature_overview/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/demo/feature_overview/__init__.py
--------------------------------------------------------------------------------
/test/demo/feature_overview/test_feature_overview.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the healthcare demo works
3 | """
4 | import os
5 |
6 | from importnb import Notebook
7 | import matplotlib
8 |
9 | from mlinspect.utils import get_project_root
10 |
11 |
12 | DEMO_NB_FILE = os.path.join(str(get_project_root()), "demo", "feature_overview", "feature_overview.ipynb")
13 |
14 |
15 | def test_demo_nb():
16 | """
17 | Tests whether the demo notebook works
18 | """
19 | matplotlib.use("template") # Disable plt.show when executing nb as part of this test
20 | Notebook.load(DEMO_NB_FILE)
21 |
--------------------------------------------------------------------------------
/test/demo/feature_overview/test_missing_embeddings.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether MissingEmbeddings works
3 | """
4 | from inspect import cleandoc
5 |
6 | from testfixtures import compare
7 |
8 | from demo.feature_overview.missing_embeddings import MissingEmbeddings, MissingEmbeddingsInfo
9 | from example_pipelines.healthcare import custom_monkeypatching
10 | from mlinspect._pipeline_inspector import PipelineInspector
11 |
12 |
13 | def test_missing_embeddings():
14 | """
15 | Tests whether MissingEmbeddings works for joins
16 | """
17 | test_code = cleandoc("""
18 | import pandas as pd
19 | from example_pipelines.healthcare.healthcare_utils import MyW2VTransformer
20 |
21 | df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c']})
22 | word_to_vec = MyW2VTransformer(min_count=2, size=2, workers=1)
23 | encoded_data = word_to_vec.fit_transform(df)
24 | """)
25 |
26 | inspector_result = PipelineInspector \
27 | .on_pipeline_from_string(test_code) \
28 | .add_required_inspection(MissingEmbeddings(10)) \
29 | .add_custom_monkey_patching_module(custom_monkeypatching) \
30 | .execute()
31 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
32 |
33 | missing_embeddings_output = inspection_results[0][MissingEmbeddings(10)]
34 | expected_result = None
35 | compare(missing_embeddings_output, expected_result)
36 |
37 | missing_embeddings_output = inspection_results[1][MissingEmbeddings(10)]
38 | expected_result = MissingEmbeddingsInfo(2, ['cat_b', 'cat_c'])
39 | compare(missing_embeddings_output, expected_result)
40 |
--------------------------------------------------------------------------------
/test/demo/feature_overview/test_no_missing_embeddings.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether NoMissingEmbeddings works
3 | """
4 | from inspect import cleandoc
5 |
6 | from testfixtures import compare
7 |
8 | from demo.feature_overview.missing_embeddings import MissingEmbeddingsInfo
9 | from demo.feature_overview.no_missing_embeddings import NoMissingEmbeddings, NoMissingEmbeddingsResult
10 | from example_pipelines.healthcare import custom_monkeypatching
11 | from mlinspect import DagNode, BasicCodeLocation, OperatorContext, OperatorType, FunctionInfo, DagNodeDetails, \
12 | OptionalCodeInfo
13 | from mlinspect._pipeline_inspector import PipelineInspector
14 | from mlinspect.checks import CheckStatus
15 | from mlinspect.instrumentation._dag_node import CodeReference
16 |
17 |
18 | def test_no_missing_embeddings():
19 | """
20 | Tests whether NoMissingEmbeddings works for joins
21 | """
22 | test_code = cleandoc("""
23 | import pandas as pd
24 | from example_pipelines.healthcare.healthcare_utils import MyW2VTransformer
25 |
26 | df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c']})
27 | word_to_vec = MyW2VTransformer(min_count=2, size=2, workers=1)
28 | encoded_data = word_to_vec.fit_transform(df)
29 | """)
30 |
31 | inspector_result = PipelineInspector \
32 | .on_pipeline_from_string(test_code) \
33 | .add_check(NoMissingEmbeddings()) \
34 | .add_custom_monkey_patching_module(custom_monkeypatching) \
35 | .execute()
36 |
37 | check_result = inspector_result.check_to_check_results[NoMissingEmbeddings()]
38 | expected_failed_dag_node_with_result = {
39 | DagNode(1,
40 | BasicCodeLocation('', 5),
41 | OperatorContext(OperatorType.TRANSFORMER,
42 | FunctionInfo('example_pipelines.healthcare.healthcare_utils', 'MyW2VTransformer')),
43 | DagNodeDetails('Word2Vec: fit_transform', ['array']),
44 | OptionalCodeInfo(CodeReference(5, 14, 5, 62), 'MyW2VTransformer(min_count=2, size=2, workers=1)'))
45 | : MissingEmbeddingsInfo(2, ['cat_b', 'cat_c'])}
46 | expected_result = NoMissingEmbeddingsResult(NoMissingEmbeddings(10), CheckStatus.FAILURE,
47 | 'Missing embeddings were found!', expected_failed_dag_node_with_result)
48 | compare(check_result, expected_result)
49 |
--------------------------------------------------------------------------------
/test/example_pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/example_pipelines/__init__.py
--------------------------------------------------------------------------------
/test/example_pipelines/test_adult_complex.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the adult_easy test pipeline works
3 | """
4 | import ast
5 |
6 | from mlinspect.testing._testing_helper_utils import run_and_assert_all_op_outputs_inspected
7 | from example_pipelines import ADULT_COMPLEX_PY, ADULT_COMPLEX_PNG
8 |
9 |
10 | def test_py_pipeline_runs():
11 | """
12 | Tests whether the .py version of the pipeline works
13 | """
14 | with open(ADULT_COMPLEX_PY, encoding="utf-8") as file:
15 | text = file.read()
16 | parsed_ast = ast.parse(text)
17 | exec(compile(parsed_ast, filename="", mode="exec"))
18 |
19 |
20 | def test_instrumented_py_pipeline_runs():
21 | """
22 | Tests whether the pipeline works with instrumentation
23 | """
24 | dag = run_and_assert_all_op_outputs_inspected(ADULT_COMPLEX_PY, ["race"], ADULT_COMPLEX_PNG)
25 | assert len(dag) == 24
26 |
--------------------------------------------------------------------------------
/test/example_pipelines/test_adult_simple.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the adult_easy test pipeline works
3 | """
4 | import ast
5 |
6 | import nbformat
7 | from nbconvert import PythonExporter
8 |
9 | from mlinspect.testing._testing_helper_utils import run_and_assert_all_op_outputs_inspected
10 | from example_pipelines import ADULT_SIMPLE_PY, ADULT_SIMPLE_IPYNB, ADULT_SIMPLE_PNG
11 |
12 |
13 | def test_py_pipeline_runs():
14 | """
15 | Tests whether the .py version of the pipeline works
16 | """
17 | with open(ADULT_SIMPLE_PY, encoding="utf-8") as file:
18 | text = file.read()
19 | parsed_ast = ast.parse(text)
20 | exec(compile(parsed_ast, filename="", mode="exec"))
21 |
22 |
23 | def test_nb_pipeline_runs():
24 | """
25 | Tests whether the .ipynb version of the pipeline works
26 | """
27 | with open(ADULT_SIMPLE_IPYNB, encoding="utf-8") as file:
28 | notebook = nbformat.reads(file.read(), nbformat.NO_CONVERT)
29 | exporter = PythonExporter()
30 |
31 | code, _ = exporter.from_notebook_node(notebook)
32 | parsed_ast = ast.parse(code)
33 | exec(compile(parsed_ast, filename="", mode="exec"))
34 |
35 |
36 | def test_instrumented_py_pipeline_runs():
37 | """
38 | Tests whether the pipeline works with instrumentation
39 | """
40 | dag = run_and_assert_all_op_outputs_inspected(ADULT_SIMPLE_PY, ["race"], ADULT_SIMPLE_PNG)
41 | assert len(dag) == 12
42 |
--------------------------------------------------------------------------------
/test/example_pipelines/test_compas.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the adult_easy test pipeline works
3 | """
4 | import ast
5 | from mlinspect.testing._testing_helper_utils import run_and_assert_all_op_outputs_inspected
6 | from example_pipelines import COMPAS_PY, COMPAS_PNG
7 |
8 |
9 | def test_py_pipeline_runs():
10 | """
11 | Tests whether the .py version of the pipeline works
12 | """
13 | with open(COMPAS_PY, encoding="utf-8") as file:
14 | text = file.read()
15 | parsed_ast = ast.parse(text)
16 | exec(compile(parsed_ast, filename="", mode="exec"))
17 |
18 |
19 | def test_instrumented_py_pipeline_runs():
20 | """
21 | Tests whether the pipeline works with instrumentation
22 | """
23 | dag = run_and_assert_all_op_outputs_inspected(COMPAS_PY, ['sex', 'race'], COMPAS_PNG)
24 | assert len(dag) == 39
25 |
--------------------------------------------------------------------------------
/test/example_pipelines/test_healthcare.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the healthcare demo works
3 | """
4 | import ast
5 | import pandas as pd
6 | from scikeras.wrappers import KerasClassifier
7 | from sklearn.preprocessing import StandardScaler, label_binarize
8 |
9 | from example_pipelines.healthcare import custom_monkeypatching
10 | from example_pipelines.healthcare.healthcare_utils import create_model, MyW2VTransformer
11 | from example_pipelines import HEALTHCARE_PY, HEALTHCARE_PNG
12 | from mlinspect.testing._testing_helper_utils import run_and_assert_all_op_outputs_inspected
13 |
14 |
15 | def test_my_word_to_vec_transformer():
16 | """
17 | Tests whether MyW2VTransformer works
18 | """
19 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c']})
20 | word_to_vec = MyW2VTransformer(min_count=2, size=2, workers=1)
21 | encoded_data = word_to_vec.fit_transform(pandas_df)
22 | assert encoded_data.shape == (4, 2)
23 |
24 |
25 | def test_my_keras_classifier():
26 | """
27 | Tests whether MyKerasClassifier works
28 | """
29 | pandas_df = pd.DataFrame({'A': [0, 1, 2, 3], 'B': [0, 1, 2, 3], 'target': ['no', 'no', 'yes', 'yes']})
30 |
31 | train = StandardScaler().fit_transform(pandas_df[['A', 'B']])
32 | target = label_binarize(pandas_df[['target']], classes=['no', 'yes'])
33 |
34 | clf = KerasClassifier(build_fn=create_model, epochs=2, batch_size=1, verbose=0,
35 | hidden_layer_sizes=(9, 9,), loss="binary_crossentropy")
36 | clf.fit(train, target)
37 |
38 | test_predict = clf.predict([[0., 0.], [0.6, 0.6]])
39 | assert test_predict.shape == (2, 1)
40 |
41 |
42 | def test_py_pipeline_runs():
43 | """
44 | Tests whether the pipeline works without instrumentation
45 | """
46 | with open(HEALTHCARE_PY, encoding="utf-8") as file:
47 | healthcare_code = file.read()
48 | parsed_ast = ast.parse(healthcare_code)
49 | exec(compile(parsed_ast, filename="", mode="exec"))
50 |
51 |
52 | def test_instrumented_py_pipeline_runs():
53 | """
54 | Tests whether the pipeline works with instrumentation
55 | """
56 | dag = run_and_assert_all_op_outputs_inspected(HEALTHCARE_PY, ["age_group", "race"], HEALTHCARE_PNG,
57 | [custom_monkeypatching])
58 | assert len(dag) == 37
59 |
--------------------------------------------------------------------------------
/test/experiments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/experiments/__init__.py
--------------------------------------------------------------------------------
/test/experiments/performance/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/experiments/performance/__init__.py
--------------------------------------------------------------------------------
/test/experiments/performance/test_benchmark_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the benchmark utils work
3 | """
4 |
5 | from experiments.performance._benchmark_utils import do_op_instrumentation_benchmarks, OperatorBenchmarkType, \
6 | do_op_inspections_benchmarks, do_full_pipeline_benchmarks, PipelineBenchmarkType
7 |
8 |
9 | def test_instrumentation_benchmarks():
10 | """
11 | Tests whether the pipeline works with instrumentation
12 | """
13 | for op_type in OperatorBenchmarkType:
14 | benchmark_results = do_op_instrumentation_benchmarks(100, op_type)
15 |
16 | assert benchmark_results["no mlinspect"]
17 | assert benchmark_results["no inspection"]
18 | assert benchmark_results["one inspection"]
19 | assert benchmark_results["two inspections"]
20 | assert benchmark_results["three inspections"]
21 |
22 |
23 | def test_inspection_benchmarks():
24 | """
25 | Tests whether the pipeline works with instrumentation
26 | """
27 | for op_type in OperatorBenchmarkType:
28 | benchmark_results = do_op_inspections_benchmarks(100, op_type)
29 |
30 | assert benchmark_results["empty inspection"]
31 | assert benchmark_results["MaterializeFirstOutputRows(10)"]
32 | assert benchmark_results["RowLineage(10)"]
33 | assert benchmark_results["HistogramForColumns(['group_col_1'])"]
34 | assert benchmark_results["HistogramForColumns(['group_col_1', 'group_col_2', 'group_col_3'])"]
35 |
36 |
37 | def test_full_pipeline_benchmarks():
38 | """
39 | Tests whether the pipeline works with instrumentation
40 | """
41 | for pipeline in PipelineBenchmarkType:
42 | benchmark_results = do_full_pipeline_benchmarks(pipeline)
43 |
44 | assert benchmark_results["no mlinspect"]
45 | assert benchmark_results["no inspection"]
46 | assert benchmark_results["one inspection"]
47 | assert benchmark_results["two inspections"]
48 | assert benchmark_results["three inspections"]
49 |
--------------------------------------------------------------------------------
/test/experiments/performance/test_performance_benchmarks.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the performance benchmark notebook works
3 | """
4 | import os
5 |
6 | import matplotlib
7 | from importnb import Notebook
8 |
9 | from mlinspect.utils import get_project_root
10 |
11 | EXPERIMENT_NB_FILE = os.path.join(str(get_project_root()), "experiments", "performance", "performance_benchmarks.ipynb")
12 |
13 |
14 | def test_experiment_nb():
15 | """
16 | Tests whether the experiment notebook works
17 | """
18 | matplotlib.use("template") # Disable plt.show when executing nb as part of this test
19 | Notebook.load(EXPERIMENT_NB_FILE)
20 |
--------------------------------------------------------------------------------
/test/inspections/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/inspections/__init__.py
--------------------------------------------------------------------------------
/test/inspections/test_column_propagation.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether ColumnPropagation works
3 | """
4 | from inspect import cleandoc
5 |
6 | import pandas
7 | from pandas import DataFrame
8 |
9 | from mlinspect._pipeline_inspector import PipelineInspector
10 | from mlinspect.inspections import ColumnPropagation
11 |
12 |
13 | def test_propagation_merge():
14 | """
15 | Tests whether ColumnPropagation works for joins
16 | """
17 | test_code = cleandoc("""
18 | import pandas as pd
19 |
20 | df_a = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'], 'B': [1, 2, 4, 5, 7]})
21 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, 5], 'C': [1, 5, 4, 11, None]})
22 | df_merged = df_a.merge(df_b, on='B')
23 | """)
24 |
25 | inspector_result = PipelineInspector \
26 | .on_pipeline_from_string(test_code) \
27 | .add_required_inspection(ColumnPropagation(["A"], 2)) \
28 | .execute()
29 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
30 |
31 | propagation_output = inspection_results[0][ColumnPropagation(["A"], 2)]
32 | expected_df = DataFrame([['cat_a', 1, 'cat_a'], ['cat_b', 2, 'cat_b']], columns=['A', 'B', 'mlinspect_A'])
33 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True))
34 |
35 | propagation_output = inspection_results[1][ColumnPropagation(["A"], 2)]
36 | expected_df = DataFrame([[1, 1., None], [2, 5., None]], columns=['B', 'C', 'mlinspect_A'])
37 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True))
38 |
39 | propagation_output = inspection_results[2][ColumnPropagation(["A"], 2)]
40 | expected_df = DataFrame([['cat_a', 1, 1., 'cat_a'], ['cat_b', 2, 5., 'cat_b']],
41 | columns=['A', 'B', 'C', 'mlinspect_A'])
42 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True))
43 |
44 |
45 | def test_propagation_projection():
46 | """
47 | Tests whether ColumnPropagation works for projections
48 | """
49 | test_code = cleandoc("""
50 | import pandas as pd
51 |
52 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'],
53 | 'B': [1, 2, 4, 5, 7], 'C': [2, 2, 10, 5, 7]})
54 | pandas_df = pandas_df[['B', 'C']]
55 | pandas_df = pandas_df[['C']]
56 | """)
57 |
58 | inspector_result = PipelineInspector \
59 | .on_pipeline_from_string(test_code) \
60 | .add_required_inspection(ColumnPropagation(["A"], 2)) \
61 | .execute()
62 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
63 |
64 | propagation_output = inspection_results[0][ColumnPropagation(["A"], 2)]
65 | expected_df = DataFrame([['cat_a', 1, 2, 'cat_a'], ['cat_b', 2, 2, 'cat_b']], columns=['A', 'B', 'C', 'mlinspect_A'])
66 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True))
67 |
68 | propagation_output = inspection_results[1][ColumnPropagation(["A"], 2)]
69 | expected_df = DataFrame([[1, 2, 'cat_a'], [2, 2, 'cat_b']], columns=['B', 'C', 'mlinspect_A'])
70 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True))
71 |
72 | propagation_output = inspection_results[2][ColumnPropagation(["A"], 2)]
73 | expected_df = DataFrame([[2, 'cat_a'], [2, 'cat_b']], columns=['C', 'mlinspect_A'])
74 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True))
75 |
76 |
77 | def test_propagation_score():
78 | """
79 | Tests whether ColumnPropagation works for projections
80 | """
81 | test_code = cleandoc("""
82 | import pandas as pd
83 | from sklearn.preprocessing import label_binarize, StandardScaler
84 | from sklearn.tree import DecisionTreeClassifier
85 | import numpy as np
86 |
87 | df = pd.DataFrame({'A': [0, 1, 2, 3], 'B': [0, 1, 2, 3], 'cat_col': ['cat_a', 'cat_b', 'cat_a', 'cat_a'],
88 | 'target': ['no', 'no', 'yes', 'yes']})
89 |
90 | train = StandardScaler().fit_transform(df[['A', 'B']])
91 | target = label_binarize(df['target'], classes=['no', 'yes'])
92 |
93 | clf = DecisionTreeClassifier()
94 | clf = clf.fit(train, target)
95 |
96 | test_df = pd.DataFrame({'A': [0., 0.6], 'B': [0., 0.6], 'cat_col': ['cat_a', 'cat_b'],
97 | 'target': ['no', 'yes']})
98 | test_labels = label_binarize(test_df['target'], classes=['no', 'yes'])
99 | test_score = clf.score(test_df[['A', 'B']], test_labels)
100 | assert test_score == 1.0
101 | """)
102 |
103 | inspector_result = PipelineInspector \
104 | .on_pipeline_from_string(test_code) \
105 | .add_required_inspection(ColumnPropagation(["cat_col"], 2)) \
106 | .execute()
107 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
108 |
109 | propagation_output = inspection_results[14][ColumnPropagation(["cat_col"], 2)]
110 | expected_df = DataFrame([[0, 'cat_a'], [1, 'cat_b']], columns=['array', 'mlinspect_cat_col'])
111 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True))
112 |
--------------------------------------------------------------------------------
/test/inspections/test_completeness_of_columns.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether CompletenessOfColumns works
3 | """
4 | from inspect import cleandoc
5 |
6 | from testfixtures import compare
7 |
8 | from mlinspect._pipeline_inspector import PipelineInspector
9 | from mlinspect.inspections import CompletenessOfColumns
10 |
11 |
12 | def test_completeness_merge():
13 | """
14 | Tests whether CompletenessOfColumns works for joins
15 | """
16 | test_code = cleandoc("""
17 | import numpy as np
18 | import pandas as pd
19 |
20 | df_a = pd.DataFrame({'A': ['cat_a', None, 'cat_a', 'cat_c', None], 'B': [1, 2, 4, 5, 7]})
21 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, np.nan], 'C': [1, 5, 4, 11, None]})
22 | df_merged = df_a.merge(df_b, on='B')
23 | """)
24 |
25 | inspector_result = PipelineInspector \
26 | .on_pipeline_from_string(test_code) \
27 | .add_required_inspection(CompletenessOfColumns(['A', 'B'])) \
28 | .execute()
29 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
30 |
31 | completeness_output = inspection_results[0][CompletenessOfColumns(['A', 'B'])]
32 | expected_completeness = {'A': 0.6, 'B': 1.0}
33 | compare(completeness_output, expected_completeness)
34 |
35 | completeness_output = inspection_results[1][CompletenessOfColumns(['A', 'B'])]
36 | expected_completeness = {'B': 0.8}
37 | compare(completeness_output, expected_completeness)
38 |
39 | completeness_output = inspection_results[2][CompletenessOfColumns(['A', 'B'])]
40 | expected_completeness = {'A': 2/3, 'B': 1.0}
41 | compare(completeness_output, expected_completeness)
42 |
43 |
44 | def test_completeness_projection():
45 | """
46 | Tests whether CompletenessOfColumns works for projections
47 | """
48 | test_code = cleandoc("""
49 | import pandas as pd
50 | import numpy as np
51 |
52 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', None, 'cat_c', 'cat_b'],
53 | 'B': [1, None, np.nan, None, 7], 'C': [2, 2, 10, 5, 7]})
54 | pandas_df = pandas_df[['B', 'C']]
55 | pandas_df = pandas_df[['C']]
56 | """)
57 |
58 | inspector_result = PipelineInspector \
59 | .on_pipeline_from_string(test_code) \
60 | .add_required_inspection(CompletenessOfColumns(['A', 'B'])) \
61 | .execute()
62 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
63 |
64 | completeness_output = inspection_results[0][CompletenessOfColumns(['A', 'B'])]
65 | expected_completeness = {'A': 0.8, 'B': 0.4}
66 | compare(completeness_output, expected_completeness)
67 |
68 | completeness_output = inspection_results[1][CompletenessOfColumns(['A', 'B'])]
69 | expected_completeness = {'B': 0.4}
70 | compare(completeness_output, expected_completeness)
71 |
72 | completeness_output = inspection_results[2][CompletenessOfColumns(['A', 'B'])]
73 | expected_completeness = {}
74 | compare(completeness_output, expected_completeness)
75 |
--------------------------------------------------------------------------------
/test/inspections/test_count_distinct_of_columns.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether CountDistinctOfColumns works
3 | """
4 | from inspect import cleandoc
5 |
6 | from testfixtures import compare
7 |
8 | from mlinspect._pipeline_inspector import PipelineInspector
9 | from mlinspect.inspections import CountDistinctOfColumns
10 |
11 |
12 | def test_count_distinct_merge():
13 | """
14 | Tests whether CountDistinctOfColumns works for joins
15 | """
16 | test_code = cleandoc("""
17 | import numpy as np
18 | import pandas as pd
19 |
20 | df_a = pd.DataFrame({'A': ['cat_a', None, 'cat_a', 'cat_c', None], 'B': [1, 2, 4, 5, 7]})
21 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, np.nan], 'C': [1, 5, 4, 11, None]})
22 | df_merged = df_a.merge(df_b, on='B')
23 | """)
24 |
25 | inspector_result = PipelineInspector \
26 | .on_pipeline_from_string(test_code) \
27 | .add_required_inspection(CountDistinctOfColumns(['A', 'B'])) \
28 | .execute()
29 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
30 |
31 | count_distinct_output = inspection_results[0][CountDistinctOfColumns(['A', 'B'])]
32 | expected_count_distinct = {'A': 3, 'B': 5}
33 | compare(count_distinct_output, expected_count_distinct)
34 |
35 | count_distinct_output = inspection_results[1][CountDistinctOfColumns(['A', 'B'])]
36 | expected_count_distinct = {'B': 5}
37 | compare(count_distinct_output, expected_count_distinct)
38 |
39 | count_distinct_output = inspection_results[2][CountDistinctOfColumns(['A', 'B'])]
40 | expected_count_distinct = {'A': 2, 'B': 3}
41 | compare(count_distinct_output, expected_count_distinct)
42 |
43 |
44 | def test_count_distinct_projection():
45 | """
46 | Tests whether CountDistinctOfColumns works for projections
47 | """
48 | test_code = cleandoc("""
49 | import pandas as pd
50 | import numpy as np
51 |
52 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', None, 'cat_c', 'cat_b'],
53 | 'B': [1, None, np.nan, None, 7], 'C': [2, 2, 10, 5, 7]})
54 | pandas_df = pandas_df[['B', 'C']]
55 | pandas_df = pandas_df[['C']]
56 | """)
57 |
58 | inspector_result = PipelineInspector \
59 | .on_pipeline_from_string(test_code) \
60 | .add_required_inspection(CountDistinctOfColumns(['A', 'B'])) \
61 | .execute()
62 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
63 |
64 | count_distinct_output = inspection_results[0][CountDistinctOfColumns(['A', 'B'])]
65 | expected_count_distinct = {'A': 4, 'B': 5}
66 | compare(count_distinct_output, expected_count_distinct)
67 |
68 | count_distinct_output = inspection_results[1][CountDistinctOfColumns(['A', 'B'])]
69 | expected_count_distinct = {'B': 5}
70 | compare(count_distinct_output, expected_count_distinct)
71 |
72 | count_distinct_output = inspection_results[2][CountDistinctOfColumns(['A', 'B'])]
73 | expected_count_distinct = {}
74 | compare(count_distinct_output, expected_count_distinct)
75 |
--------------------------------------------------------------------------------
/test/inspections/test_histogram_for_columns.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether HistogramForColumns works
3 | """
4 | from inspect import cleandoc
5 |
6 | from testfixtures import compare
7 |
8 | from mlinspect._pipeline_inspector import PipelineInspector
9 | from mlinspect.inspections import HistogramForColumns
10 |
11 |
12 | def test_histogram_merge():
13 | """
14 | Tests whether HistogramForColumns works for joins
15 | """
16 | test_code = cleandoc("""
17 | import pandas as pd
18 |
19 | df_a = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'], 'B': [1, 2, 4, 5, 7]})
20 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, 5], 'C': [1, 5, 4, 11, None]})
21 | df_merged = df_a.merge(df_b, on='B')
22 | """)
23 |
24 | inspector_result = PipelineInspector \
25 | .on_pipeline_from_string(test_code) \
26 | .add_required_inspection(HistogramForColumns(["A"])) \
27 | .execute()
28 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
29 |
30 | histogram_output = inspection_results[0][HistogramForColumns(["A"])]
31 | expected_histogram = {'A': {'cat_a': 2, 'cat_b': 2, 'cat_c': 1}}
32 | compare(histogram_output, expected_histogram)
33 |
34 | histogram_output = inspection_results[1][HistogramForColumns(["A"])]
35 | expected_histogram = {'A': {}}
36 | compare(histogram_output, expected_histogram)
37 |
38 | histogram_output = inspection_results[2][HistogramForColumns(["A"])]
39 | expected_histogram = {'A': {'cat_a': 2, 'cat_b': 1, 'cat_c': 1}}
40 | compare(histogram_output, expected_histogram)
41 |
42 |
43 | def test_histogram_projection():
44 | """
45 | Tests whether HistogramForColumns works for projections
46 | """
47 | test_code = cleandoc("""
48 | import pandas as pd
49 |
50 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'],
51 | 'B': [1, 2, 4, 5, 7], 'C': [2, 2, 10, 5, 7]})
52 | pandas_df = pandas_df[['B', 'C']]
53 | pandas_df = pandas_df[['C']]
54 | """)
55 |
56 | inspector_result = PipelineInspector \
57 | .on_pipeline_from_string(test_code) \
58 | .add_required_inspection(HistogramForColumns(["A"])) \
59 | .execute()
60 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
61 |
62 | histogram_output = inspection_results[0][HistogramForColumns(["A"])]
63 | expected_histogram = {'A': {'cat_a': 2, 'cat_b': 2, 'cat_c': 1}}
64 | compare(histogram_output, expected_histogram)
65 |
66 | histogram_output = inspection_results[1][HistogramForColumns(["A"])]
67 | expected_histogram = {'A': {'cat_a': 2, 'cat_b': 2, 'cat_c': 1}}
68 | compare(histogram_output, expected_histogram)
69 |
70 | histogram_output = inspection_results[2][HistogramForColumns(["A"])]
71 | expected_histogram = {'A': {'cat_a': 2, 'cat_b': 2, 'cat_c': 1}}
72 | compare(histogram_output, expected_histogram)
73 |
74 |
75 | def test_histogram_score():
76 | """
77 | Tests whether HistogramForColumns works for projections
78 | """
79 | test_code = cleandoc("""
80 | import pandas as pd
81 | from sklearn.preprocessing import label_binarize, StandardScaler
82 | from sklearn.tree import DecisionTreeClassifier
83 | import numpy as np
84 |
85 | df = pd.DataFrame({'A': [0, 1, 2, 3], 'B': [0, 1, 2, 3], 'target': ['no', 'no', 'yes', 'yes']})
86 |
87 | train = StandardScaler().fit_transform(df[['A', 'B']])
88 | target = label_binarize(df['target'], classes=['no', 'yes'])
89 |
90 | clf = DecisionTreeClassifier()
91 | clf = clf.fit(train, target)
92 |
93 | test_df = pd.DataFrame({'A': [0., 0.6], 'B': [0., 0.6], 'target': ['no', 'yes']})
94 | test_labels = label_binarize(test_df['target'], classes=['no', 'yes'])
95 | test_score = clf.score(test_df[['A', 'B']], test_labels)
96 | assert test_score == 1.0
97 | """)
98 |
99 | inspector_result = PipelineInspector \
100 | .on_pipeline_from_string(test_code) \
101 | .add_required_inspection(HistogramForColumns(["target"])) \
102 | .execute()
103 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
104 |
105 | histogram_output = inspection_results[14][HistogramForColumns(["target"])]
106 | expected_histogram = {'target': {'no': 1, 'yes': 1}}
107 | compare(histogram_output, expected_histogram)
108 |
--------------------------------------------------------------------------------
/test/inspections/test_intersectional_histogram_for_columns.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether IntersectionalHistogramForColumns works
3 | """
4 | from inspect import cleandoc
5 |
6 | from testfixtures import compare
7 |
8 | from mlinspect._pipeline_inspector import PipelineInspector
9 | from mlinspect.inspections import IntersectionalHistogramForColumns
10 |
11 |
12 | def test_histogram_merge():
13 | """
14 | Tests whether IntersectionalHistogramForColumns works for joins
15 | """
16 | test_code = cleandoc("""
17 | import pandas as pd
18 |
19 | df_a = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_b', 'cat_b'],
20 | 'B': [1, 2, 4, 5, 7],
21 | 'C': [True, False, True, True, True]})
22 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, 5], 'D': [1, 5, 4, 11, None]})
23 | df_merged = df_a.merge(df_b, on='B')
24 | """)
25 |
26 | inspector_result = PipelineInspector \
27 | .on_pipeline_from_string(test_code) \
28 | .add_required_inspection(IntersectionalHistogramForColumns(["A", "C"])) \
29 | .execute()
30 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
31 |
32 | histogram_output = inspection_results[0][IntersectionalHistogramForColumns(["A", "C"])]
33 | expected_histogram = {('cat_a', True): 2, ('cat_b', False): 1, ('cat_b', True): 2}
34 | compare(histogram_output, expected_histogram)
35 |
36 | histogram_output = inspection_results[1][IntersectionalHistogramForColumns(["A", "C"])]
37 | expected_histogram = {(None, None): 5}
38 | compare(histogram_output, expected_histogram)
39 |
40 | histogram_output = inspection_results[2][IntersectionalHistogramForColumns(["A", "C"])]
41 | expected_histogram = {('cat_a', True): 2, ('cat_b', False): 1, ('cat_b', True): 1}
42 | compare(histogram_output, expected_histogram)
43 |
44 |
45 | def test_histogram_projection():
46 | """
47 | Tests whether IntersectionalHistogramForColumns works for projections
48 | """
49 | test_code = cleandoc("""
50 | import pandas as pd
51 |
52 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'],
53 | 'B': [1, 2, 4, 5, 7], 'C': [True, False, True, True, True]})
54 | pandas_df = pandas_df[['B', 'C']]
55 | pandas_df = pandas_df[['C']]
56 | """)
57 |
58 | inspector_result = PipelineInspector \
59 | .on_pipeline_from_string(test_code) \
60 | .add_required_inspection(IntersectionalHistogramForColumns(["A", "C"])) \
61 | .execute()
62 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
63 |
64 | histogram_output = inspection_results[0][IntersectionalHistogramForColumns(["A", "C"])]
65 | expected_histogram = {('cat_a', True): 2, ('cat_b', False): 1, ('cat_c', True): 1, ('cat_b', True): 1}
66 | compare(histogram_output, expected_histogram)
67 |
68 | histogram_output = inspection_results[1][IntersectionalHistogramForColumns(["A", "C"])]
69 | expected_histogram = {('cat_a', True): 2, ('cat_b', False): 1, ('cat_c', True): 1, ('cat_b', True): 1}
70 | compare(histogram_output, expected_histogram)
71 |
72 | histogram_output = inspection_results[2][IntersectionalHistogramForColumns(["A", "C"])]
73 | expected_histogram = {('cat_a', True): 2, ('cat_b', False): 1, ('cat_c', True): 1, ('cat_b', True): 1}
74 | compare(histogram_output, expected_histogram)
75 |
--------------------------------------------------------------------------------
/test/inspections/test_lineage.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether RowLineage works
3 | """
4 | from inspect import cleandoc
5 |
6 | import pandas
7 | import numpy as np
8 | from pandas import DataFrame
9 |
10 | from mlinspect import OperatorType
11 | from mlinspect._pipeline_inspector import PipelineInspector
12 | from mlinspect.inspections import RowLineage
13 | from mlinspect.inspections._lineage import LineageId
14 |
15 |
16 | def test_row_lineage_merge():
17 | """
18 | Tests whether RowLineage works for joins
19 | """
20 | test_code = cleandoc("""
21 | import pandas as pd
22 |
23 | df_a = pd.DataFrame({'A': [0, 2, 4, 8, 5], 'B': [1, 2, 4, 5, 7]})
24 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, 5], 'C': [1, 5, 4, 11, None]})
25 | df_merged = df_a.merge(df_b, on='B')
26 | """)
27 |
28 | inspector_result = PipelineInspector \
29 | .on_pipeline_from_string(test_code) \
30 | .add_required_inspection(RowLineage(2)) \
31 | .execute()
32 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
33 |
34 | lineage_output = inspection_results[0][RowLineage(2)]
35 | expected_lineage_df = DataFrame([[0, 1, {LineageId(0, 0)}],
36 | [2, 2, {LineageId(0, 1)}]],
37 | columns=['A', 'B', 'mlinspect_lineage'])
38 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True))
39 |
40 | lineage_output = inspection_results[1][RowLineage(2)]
41 | expected_lineage_df = DataFrame([[1, 1., {LineageId(1, 0)}],
42 | [2, 5., {LineageId(1, 1)}]],
43 | columns=['B', 'C', 'mlinspect_lineage'])
44 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True))
45 |
46 | lineage_output = inspection_results[2][RowLineage(2)]
47 | expected_lineage_df = DataFrame([[0, 1, 1., {LineageId(0, 0), LineageId(1, 0)}],
48 | [2, 2, 5., {LineageId(0, 1), LineageId(1, 1)}]],
49 | columns=['A', 'B', 'C', 'mlinspect_lineage'])
50 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True))
51 |
52 |
53 | def test_row_lineage_concat():
54 | """
55 | Tests whether RowLineage works for concats
56 | """
57 | test_code = cleandoc("""
58 | import pandas as pd
59 | from sklearn.preprocessing import StandardScaler, OneHotEncoder
60 | from sklearn.compose import ColumnTransformer
61 |
62 | df = pd.DataFrame({'A': [1, 2, 10, 5], 'B': ['cat_a', 'cat_b', 'cat_a', 'cat_c']})
63 | column_transformer = ColumnTransformer(transformers=[
64 | ('numeric', StandardScaler(), ['A']),
65 | ('categorical', OneHotEncoder(sparse=False), ['B'])
66 | ])
67 | encoded_data = column_transformer.fit_transform(df)
68 | """)
69 |
70 | inspector_result = PipelineInspector \
71 | .on_pipeline_from_string(test_code) \
72 | .add_required_inspection(RowLineage(2)) \
73 | .execute()
74 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
75 |
76 | lineage_output = inspection_results[0][RowLineage(2)]
77 | expected_lineage_df = DataFrame([[1, 'cat_a', {LineageId(0, 0)}],
78 | [2, 'cat_b', {LineageId(0, 1)}]],
79 | columns=['A', 'B', 'mlinspect_lineage'])
80 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True))
81 |
82 | lineage_output = inspection_results[1][RowLineage(2)]
83 | expected_lineage_df = DataFrame([[1, {LineageId(0, 0)}],
84 | [2, {LineageId(0, 1)}]],
85 | columns=['A', 'mlinspect_lineage'])
86 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True))
87 |
88 | lineage_output = inspection_results[2][RowLineage(2)]
89 | expected_lineage_df = DataFrame([[np.array([-1.0]), {LineageId(0, 0)}],
90 | [np.array([-0.7142857142857143]), {LineageId(0, 1)}]],
91 | columns=['array', 'mlinspect_lineage'])
92 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True))
93 |
94 | lineage_output = inspection_results[3][RowLineage(2)]
95 | expected_lineage_df = DataFrame([['cat_a', {LineageId(0, 0)}],
96 | ['cat_b', {LineageId(0, 1)}]],
97 | columns=['B', 'mlinspect_lineage'])
98 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True))
99 |
100 | lineage_output = inspection_results[4][RowLineage(2)]
101 | expected_lineage_df = DataFrame([[np.array([1., 0., 0.]), {LineageId(0, 0)}],
102 | [np.array([0., 1., 0.]), {LineageId(0, 1)}]],
103 | columns=['array', 'mlinspect_lineage'])
104 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True))
105 |
106 | lineage_output = inspection_results[5][RowLineage(2)]
107 | expected_lineage_df = DataFrame([[np.array([-1.0, 1., 0., 0.]), {LineageId(0, 0)}],
108 | [np.array([-0.7142857142857143, 0., 1., 0.]), {LineageId(0, 1)}]],
109 | columns=['array', 'mlinspect_lineage'])
110 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True))
111 |
112 |
113 | def test_all_rows_for_op_type():
114 | """
115 | Tests whether RowLineage works for materialising all data from specific operators
116 | """
117 | test_code = cleandoc("""
118 | import pandas as pd
119 |
120 | df_a = pd.DataFrame({'A': [0, 2], 'B': [1, 2]})
121 | df_b = pd.DataFrame({'B': [1, 2], 'C': [1, 5]})
122 | df_merged = df_a.merge(df_b, on='B')
123 | """)
124 | row_lineage = RowLineage(RowLineage.ALL_ROWS, [OperatorType.DATA_SOURCE])
125 | inspector_result = PipelineInspector \
126 | .on_pipeline_from_string(test_code) \
127 | .add_required_inspection(row_lineage) \
128 | .execute()
129 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values())
130 |
131 | lineage_output = inspection_results[0][row_lineage]
132 | expected_lineage_df = DataFrame([[0, 1, {LineageId(0, 0)}],
133 | [2, 2, {LineageId(0, 1)}]],
134 | columns=['A', 'B', 'mlinspect_lineage'])
135 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True))
136 |
137 | lineage_output = inspection_results[1][row_lineage]
138 | expected_lineage_df = DataFrame([[1, 1, {LineageId(1, 0)}],
139 | [2, 5, {LineageId(1, 1)}]],
140 | columns=['B', 'C', 'mlinspect_lineage'])
141 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True))
142 |
143 | lineage_output = inspection_results[2][row_lineage]
144 | assert lineage_output is None
145 |
--------------------------------------------------------------------------------
/test/instrumentation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/instrumentation/__init__.py
--------------------------------------------------------------------------------
/test/monkeypatching/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/monkeypatching/__init__.py
--------------------------------------------------------------------------------
/test/monkeypatching/test_patch_numpy.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the monkey patching works for all patched numpy methods
3 | """
4 | from inspect import cleandoc
5 |
6 | import pandas
7 | from pandas import DataFrame
8 | from testfixtures import compare
9 |
10 | from mlinspect import OperatorContext, FunctionInfo, OperatorType
11 | from mlinspect.inspections._lineage import RowLineage, LineageId
12 | from mlinspect.instrumentation import _pipeline_executor
13 | from mlinspect.instrumentation._dag_node import DagNode, CodeReference, BasicCodeLocation, DagNodeDetails, \
14 | OptionalCodeInfo
15 |
16 |
17 | def test_numpy_random():
18 | """
19 | Tests whether the monkey patching of ('numpy.random', 'random') works
20 | """
21 | test_code = cleandoc("""
22 | import numpy as np
23 | np.random.seed(42)
24 | test = np.random.random(100)
25 | assert len(test) == 100
26 | """)
27 |
28 | inspector_result = _pipeline_executor.singleton.run(python_code=test_code, track_code_references=True,
29 | inspections=[RowLineage(2)])
30 | extracted_node: DagNode = list(inspector_result.dag.nodes)[0]
31 |
32 | expected_node = DagNode(0,
33 | BasicCodeLocation("", 3),
34 | OperatorContext(OperatorType.DATA_SOURCE, FunctionInfo('numpy.random', 'random')),
35 | DagNodeDetails('random', ['array']),
36 | OptionalCodeInfo(CodeReference(3, 7, 3, 28), "np.random.random(100)"))
37 | compare(extracted_node, expected_node)
38 |
39 | inspection_results_data_source = inspector_result.dag_node_to_inspection_results[extracted_node]
40 | lineage_output = inspection_results_data_source[RowLineage(2)]
41 | expected_lineage_df = DataFrame([[0.5, {LineageId(0, 0)}],
42 | [0.5, {LineageId(0, 1)}]],
43 | columns=['array', 'mlinspect_lineage'])
44 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True),
45 | atol=1)
46 |
--------------------------------------------------------------------------------
/test/test_pipeline_inspector.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the fluent API works
3 | """
4 |
5 | import networkx
6 | from testfixtures import compare
7 |
8 | from example_pipelines.healthcare import custom_monkeypatching
9 | from example_pipelines import ADULT_SIMPLE_PY, ADULT_SIMPLE_IPYNB, HEALTHCARE_PY
10 | from mlinspect import PipelineInspector, OperatorType
11 | from mlinspect.checks import CheckStatus, NoBiasIntroducedFor, NoIllegalFeatures
12 | from mlinspect.inspections import HistogramForColumns, MaterializeFirstOutputRows
13 | from mlinspect.testing._testing_helper_utils import get_expected_dag_adult_easy
14 |
15 |
16 | def test_inspector_adult_easy_py_pipeline():
17 | """
18 | Tests whether the .py version of the inspector works
19 | """
20 | inspector_result = PipelineInspector\
21 | .on_pipeline_from_py_file(ADULT_SIMPLE_PY)\
22 | .add_required_inspection(MaterializeFirstOutputRows(5))\
23 | .add_check(NoBiasIntroducedFor(['race']))\
24 | .add_check(NoIllegalFeatures())\
25 | .execute()
26 | extracted_dag = inspector_result.dag
27 | expected_dag = get_expected_dag_adult_easy(ADULT_SIMPLE_PY)
28 | compare(networkx.to_dict_of_dicts(extracted_dag), networkx.to_dict_of_dicts(expected_dag))
29 |
30 | assert HistogramForColumns(['race']) in list(inspector_result.dag_node_to_inspection_results.values())[0]
31 | check_to_check_results = inspector_result.check_to_check_results
32 | assert check_to_check_results[NoBiasIntroducedFor(['race'])].status == CheckStatus.SUCCESS
33 | assert check_to_check_results[NoIllegalFeatures()].status == CheckStatus.FAILURE
34 |
35 |
36 | def test_inspector_adult_easy_py_pipeline_without_inspections():
37 | """
38 | Tests whether the .py version of the inspector works
39 | """
40 | inspector_result = PipelineInspector\
41 | .on_pipeline_from_py_file(ADULT_SIMPLE_PY)\
42 | .execute()
43 | extracted_dag = inspector_result.dag
44 | expected_dag = get_expected_dag_adult_easy(ADULT_SIMPLE_PY)
45 | compare(networkx.to_dict_of_dicts(extracted_dag), networkx.to_dict_of_dicts(expected_dag))
46 |
47 |
48 | def test_inspector_adult_easy_ipynb_pipeline():
49 | """
50 | Tests whether the .ipynb version of the inspector works
51 | """
52 | inspector_result = PipelineInspector\
53 | .on_pipeline_from_ipynb_file(ADULT_SIMPLE_IPYNB)\
54 | .add_required_inspection(MaterializeFirstOutputRows(5)) \
55 | .add_check(NoBiasIntroducedFor(['race'])) \
56 | .add_check(NoIllegalFeatures()) \
57 | .execute()
58 | extracted_dag = inspector_result.dag
59 | expected_dag = get_expected_dag_adult_easy(ADULT_SIMPLE_IPYNB, 6)
60 | compare(networkx.to_dict_of_dicts(extracted_dag), networkx.to_dict_of_dicts(expected_dag))
61 |
62 | assert HistogramForColumns(['race']) in list(inspector_result.dag_node_to_inspection_results.values())[0]
63 | check_to_check_results = inspector_result.check_to_check_results
64 | assert check_to_check_results[NoBiasIntroducedFor(['race'])].status == CheckStatus.SUCCESS
65 | assert check_to_check_results[NoIllegalFeatures()].status == CheckStatus.FAILURE
66 |
67 |
68 | def test_inspector_adult_easy_str_pipeline():
69 | """
70 | Tests whether the str version of the inspector works
71 | """
72 | with open(ADULT_SIMPLE_PY, encoding="utf-8") as file:
73 | code = file.read()
74 |
75 | inspector_result = PipelineInspector\
76 | .on_pipeline_from_string(code)\
77 | .add_required_inspection(MaterializeFirstOutputRows(5)) \
78 | .add_check(NoBiasIntroducedFor(['race'])) \
79 | .add_check(NoIllegalFeatures()) \
80 | .execute()
81 | extracted_dag = inspector_result.dag
82 | expected_dag = get_expected_dag_adult_easy("")
83 | compare(networkx.to_dict_of_dicts(extracted_dag), networkx.to_dict_of_dicts(expected_dag))
84 |
85 | assert HistogramForColumns(['race']) in list(inspector_result.dag_node_to_inspection_results.values())[0]
86 | check_to_check_results = inspector_result.check_to_check_results
87 | assert check_to_check_results[NoBiasIntroducedFor(['race'])].status == CheckStatus.SUCCESS
88 | assert check_to_check_results[NoIllegalFeatures()].status == CheckStatus.FAILURE
89 |
90 |
91 | def test_inspector_additional_module():
92 | """
93 | Tests whether the str version of the inspector works
94 | """
95 | inspector_result = PipelineInspector \
96 | .on_pipeline_from_py_file(HEALTHCARE_PY) \
97 | .add_required_inspection(MaterializeFirstOutputRows(5)) \
98 | .add_custom_monkey_patching_module(custom_monkeypatching) \
99 | .execute()
100 |
101 | assert_healthcare_pipeline_output_complete(inspector_result)
102 |
103 |
104 | def test_inspector_additional_modules():
105 | """
106 | Tests whether the str version of the inspector works
107 | """
108 | inspector_result = PipelineInspector \
109 | .on_pipeline_from_py_file(HEALTHCARE_PY) \
110 | .add_required_inspection(MaterializeFirstOutputRows(5)) \
111 | .add_custom_monkey_patching_modules([custom_monkeypatching]) \
112 | .execute()
113 |
114 | assert_healthcare_pipeline_output_complete(inspector_result)
115 |
116 |
117 | def assert_healthcare_pipeline_output_complete(inspector_result):
118 | """ Assert that the healthcare DAG was extracted completely """
119 | for dag_node, inspection_result in inspector_result.dag_node_to_inspection_results.items():
120 | assert dag_node.operator_info.operator != OperatorType.MISSING_OP
121 | assert MaterializeFirstOutputRows(5) in inspection_result
122 | if dag_node.operator_info.operator is not OperatorType.ESTIMATOR:
123 | assert inspection_result[MaterializeFirstOutputRows(5)] is not None
124 | else:
125 | assert inspection_result[MaterializeFirstOutputRows(5)] is None
126 | assert len(inspector_result.dag) == 37
127 |
--------------------------------------------------------------------------------
/test/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/utils/__init__.py
--------------------------------------------------------------------------------
/test/utils/test_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the utils work
3 | """
4 | from pathlib import Path
5 |
6 | from mlinspect.utils._utils import get_project_root
7 |
8 |
9 | def test_get_project_root():
10 | """
11 | Tests whether get_project_root works
12 | """
13 | assert get_project_root() == Path(__file__).parent.parent.parent
14 |
--------------------------------------------------------------------------------
/test/visualisation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/visualisation/__init__.py
--------------------------------------------------------------------------------
/test/visualisation/test_visualisation.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests whether the visualisation of the resulting DAG works
3 | """
4 | import os
5 |
6 | from mlinspect.utils import get_project_root
7 | from mlinspect.visualisation import save_fig_to_path, get_dag_as_pretty_string
8 | from mlinspect.testing._testing_helper_utils import get_expected_dag_adult_easy
9 |
10 |
11 | def test_save_fig_to_path():
12 | """
13 | Tests whether the .py version of the inspector works
14 | """
15 | extracted_dag = get_expected_dag_adult_easy("")
16 |
17 | filename = os.path.join(str(get_project_root()), "example_pipelines", "adult_simple", "adult_simple.png")
18 | save_fig_to_path(extracted_dag, filename)
19 |
20 | assert os.path.isfile(filename)
21 |
22 |
23 | def test_get_dag_as_pretty_string():
24 | """
25 | Tests whether the .py version of the inspector works
26 | """
27 | extracted_dag = get_expected_dag_adult_easy("")
28 |
29 | pretty_string = get_dag_as_pretty_string(extracted_dag)
30 |
31 | print(pretty_string)
32 |
--------------------------------------------------------------------------------