├── .github ├── pull_request_template.md └── workflows │ └── build.yml ├── .gitignore ├── .pylintrc ├── CITATION.cff ├── LICENSE ├── README.md ├── codecov.yml ├── demo ├── __init__.py └── feature_overview │ ├── __init__.py │ ├── feature_overview.ipynb │ ├── healthcare.png │ ├── missing_embeddings.py │ ├── no_missing_embeddings.py │ └── paper_example_image.png ├── example_pipelines ├── __init__.py ├── _pipelines.py ├── adult_complex │ ├── __init__.py │ ├── adult_complex.png │ ├── adult_complex.py │ ├── adult_test.csv │ └── adult_train.csv ├── adult_simple │ ├── __init__.py │ ├── adult_simple.ipynb │ ├── adult_simple.png │ └── adult_simple.py ├── compas │ ├── __init__.py │ ├── compas.png │ ├── compas.py │ ├── compas_test.csv │ └── compas_train.csv └── healthcare │ ├── __init__.py │ ├── _gensim_wrapper.py │ ├── custom_monkeypatching │ ├── __init__.py │ └── patch_healthcare_utils.py │ ├── healthcare.png │ ├── healthcare.py │ ├── healthcare_utils.py │ ├── histories.csv │ └── patients.csv ├── experiments ├── __init__.py └── performance │ ├── __init__.py │ ├── _benchmark_utils.py │ ├── _empty_inspection.py │ └── performance_benchmarks.ipynb ├── mlinspect ├── __init__.py ├── _inspector_result.py ├── _pipeline_inspector.py ├── backends │ ├── __init__.py │ ├── _all_backends.py │ ├── _backend.py │ ├── _backend_utils.py │ ├── _iter_creation.py │ ├── _pandas_backend.py │ └── _sklearn_backend.py ├── checks │ ├── __init__.py │ ├── _check.py │ ├── _no_bias_introduced_for.py │ ├── _no_illegal_features.py │ └── _similar_removal_probabilities_for.py ├── inspections │ ├── __init__.py │ ├── _arg_capturing.py │ ├── _column_propagation.py │ ├── _completeness_of_columns.py │ ├── _count_distinct_of_columns.py │ ├── _histogram_for_columns.py │ ├── _inspection.py │ ├── _inspection_input.py │ ├── _inspection_result.py │ ├── _intersectional_histogram_for_columns.py │ ├── _lineage.py │ └── _materialize_first_output_rows.py ├── instrumentation │ ├── __init__.py │ ├── _call_capture_transformer.py │ ├── _dag_node.py │ └── _pipeline_executor.py ├── monkeypatching │ ├── README.md │ ├── __init__.py │ ├── _mlinspect_ndarray.py │ ├── _monkey_patching_utils.py │ ├── _patch_numpy.py │ ├── _patch_pandas.py │ ├── _patch_sklearn.py │ └── _patch_statsmodels.py ├── testing │ ├── __init__.py │ ├── _random_annotation_testing_inspection.py │ └── _testing_helper_utils.py ├── utils │ ├── __init__.py │ └── _utils.py └── visualisation │ ├── __init__.py │ └── _visualisation.py ├── requirements ├── requirements.dev.txt └── requirements.txt ├── setup.cfg ├── setup.py └── test ├── __init__.py ├── backends ├── __init__.py ├── test_pandas_backend.py └── test_sklearn_backend.py ├── checks ├── __init__.py ├── test_no_bias_introduced_for.py ├── test_no_illegal_features.py └── test_similar_removal_probablities_for.py ├── demo ├── __init__.py └── feature_overview │ ├── __init__.py │ ├── test_feature_overview.py │ ├── test_missing_embeddings.py │ └── test_no_missing_embeddings.py ├── example_pipelines ├── __init__.py ├── test_adult_complex.py ├── test_adult_simple.py ├── test_compas.py ├── test_healthcare.py └── test_patch_healthcare_utils.py ├── experiments ├── __init__.py └── performance │ ├── __init__.py │ ├── test_benchmark_utils.py │ └── test_performance_benchmarks.py ├── inspections ├── __init__.py ├── test_arg_capturing.py ├── test_column_propagation.py ├── test_completeness_of_columns.py ├── test_count_distinct_of_columns.py ├── test_histogram_for_columns.py ├── test_intersectional_histogram_for_columns.py ├── test_lineage.py └── test_materialize_first_output_rows.py ├── instrumentation ├── __init__.py └── test_pipeline_executor.py ├── monkeypatching ├── __init__.py ├── test_patch_numpy.py ├── test_patch_pandas.py ├── test_patch_sklearn.py └── test_patch_statsmodels.py ├── test_pipeline_inspector.py ├── utils ├── __init__.py └── test_utils.py └── visualisation ├── __init__.py └── test_visualisation.py /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | *Issue #, if available:* 2 | 3 | *Description of changes:* 4 | 5 | 6 | By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license. 7 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Lint and Test 2 | 3 | on: [ push, pull_request ] 4 | 5 | jobs: 6 | python: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: [ '3.10' ] 11 | 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: ${{ matrix.python-version }} 18 | 19 | - name: Setup Graphviz 20 | uses: ts-graphviz/setup-graphviz@v1 21 | 22 | - name: Cache dependencies 23 | uses: actions/cache@v2 24 | with: 25 | path: ${{ env.pythonLocation }} 26 | key: ${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('requirements/requirements.txt') }}-${{ hashFiles('requirements/requirements.dev.txt') }} 27 | 28 | - name: Install dependencies 29 | env: 30 | SETUPTOOLS_USE_DISTUTILS: stdlib 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install --upgrade --upgrade-strategy eager ".[dev]" 34 | 35 | - name: Unit Tests 36 | run: python -m pytest 37 | 38 | - name: Upload Coverage Report 39 | uses: codecov/codecov-action@v1 40 | with: 41 | token: ${{ secrets.CODECOV_TOKEN }} 42 | files: ./coverage.xml 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # PyCharm 132 | .idea 133 | 134 | # MacOS 135 | .DS_Store 136 | 137 | # VS Code 138 | .vscode 139 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Grafberger" 5 | given-names: "Stefan" 6 | orcid: "https://orcid.org/0000-0002-9884-9517" 7 | - family-names: "Groth" 8 | given-names: "Paul" 9 | orcid: "https://orcid.org/0000-0003-0183-6910" 10 | - family-names: "Stoyanovich" 11 | given-names: "Julia" 12 | - family-names: "Schelter" 13 | given-names: "Sebastian" 14 | title: "Data Distribution Debugging in Machine Learning Pipelines" 15 | doi: 10.1007/s00778-021-00726-w 16 | url: "https://github.com/stefan-grafberger/mlinspect" 17 | preferred-citation: 18 | type: article 19 | authors: 20 | - family-names: "Grafberger" 21 | given-names: "Stefan" 22 | orcid: "https://orcid.org/0000-0002-9884-9517" 23 | - family-names: "Groth" 24 | given-names: "Paul" 25 | orcid: "https://orcid.org/0000-0003-0183-6910" 26 | - family-names: "Stoyanovich" 27 | given-names: "Julia" 28 | - family-names: "Schelter" 29 | given-names: "Sebastian" 30 | title: "Data Distribution Debugging in Machine Learning Pipelines" 31 | doi: 10.1007/s00778-021-00726-w 32 | date-released: 2022-01-31 33 | 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | mlinspect 2 | ================================ 3 | 4 | [![mlinspect](https://img.shields.io/badge/🔎-mlinspect-green)](https://github.com/stefan-grafberger/mlinspect) 5 | [![GitHub license](https://img.shields.io/badge/License-Apache%202.0-yellowgreen.svg)](https://github.com/stefan-grafberger/mlinspect/blob/master/LICENSE) 6 | [![Build Status](https://github.com/stefan-grafberger/mlinspect/actions/workflows/build.yml/badge.svg)](https://github.com/stefan-grafberger/mlinspect/actions/workflows/build.yml) 7 | [![codecov](https://codecov.io/gh/stefan-grafberger/mlinspect/branch/master/graph/badge.svg?token=KTMNPBV1ZZ)](https://codecov.io/gh/stefan-grafberger/mlinspect) 8 | 9 | Inspect ML Pipelines in Python in the form of a DAG 10 | 11 | ## Run mlinspect locally 12 | 13 | Prerequisite: Python 3.10 14 | 15 | 1. Clone this repository 16 | 2. Set up the environment 17 | 18 | `cd mlinspect`
19 | `python -m venv venv`
20 | `source venv/bin/activate`
21 | 22 | 3. If you want to use the visualisation functions we provide, install graphviz which can not be installed via pip 23 | 24 | `Linux: ` `apt-get install graphviz`
25 | `MAC OS: ` `brew install graphviz`
26 | 27 | 4. Install pip dependencies 28 | 29 | `SETUPTOOLS_USE_DISTUTILS=stdlib pip install -e .[dev]`
30 | 31 | 5. To ensure everything works, you can run the tests (without graphviz, the visualisation test will fail) 32 | 33 | `python setup.py test`
34 | 35 | ## How to use mlinspect 36 | mlinspect makes it easy to analyze your pipeline and automatically check for common issues. 37 | ```python 38 | from mlinspect import PipelineInspector 39 | from mlinspect.inspections import MaterializeFirstOutputRows 40 | from mlinspect.checks import NoBiasIntroducedFor 41 | 42 | IPYNB_PATH = ... 43 | 44 | inspector_result = PipelineInspector\ 45 | .on_pipeline_from_ipynb_file(IPYNB_PATH)\ 46 | .add_required_inspection(MaterializeFirstOutputRows(5))\ 47 | .add_check(NoBiasIntroducedFor(['race']))\ 48 | .execute() 49 | 50 | extracted_dag = inspector_result.dag 51 | dag_node_to_inspection_results = inspector_result.dag_node_to_inspection_results 52 | check_to_check_results = inspector_result.check_to_check_results 53 | ``` 54 | 55 | ## Detailed Example 56 | We prepared a [demo notebook](demo/feature_overview/feature_overview.ipynb) to showcase mlinspect and its features. 57 | 58 | ## Supported libraries and API functions 59 | mlinspect already supports a selection of API functions from `pandas` and `scikit-learn`. Extending mlinspect to support more and more API functions and libraries will be an ongoing effort. However, mlinspect won't just crash when it encounters functions it doesn't recognize yet. For more information, please see [here](mlinspect/monkeypatching/README.md). 60 | 61 | ## Notes 62 | * For debugging in PyCharm, set the pytest flag `--no-cov` ([Link](https://stackoverflow.com/questions/34870962/how-to-debug-py-test-in-pycharm-when-coverage-is-enabled)) 63 | 64 | ## Publications 65 | * [Stefan Grafberger, Paul Groth, Julia Stoyanovich, Sebastian Schelter (2022). Data Distribution Debugging in Machine Learning Pipelines. The VLDB Journal — The International Journal on Very Large Data Bases (Special Issue on Data Science for Responsible Data Management).](https://stefan-grafberger.com/mlinspect-journal.pdf) 66 | * [Stefan Grafberger, Shubha Guha, Julia Stoyanovich, Sebastian Schelter (2021). mlinspect: a Data Distribution Debugger for Machine Learning Pipelines. ACM SIGMOD (demo).](https://stefan-grafberger.com/mlinspect-demo.pdf) 67 | * [Stefan Grafberger, Julia Stoyanovich, Sebastian Schelter (2020). Lightweight Inspection of Data Preprocessing in Native Machine Learning Pipelines. Conference on Innovative Data Systems Research (CIDR).](https://stefan-grafberger.com/mlinspect-cidr.pdf) 68 | 69 | ## License 70 | This library is licensed under the Apache 2.0 License. 71 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | target: auto 6 | threshold: 2% 7 | patch: off 8 | -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/demo/__init__.py -------------------------------------------------------------------------------- /demo/feature_overview/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/demo/feature_overview/__init__.py -------------------------------------------------------------------------------- /demo/feature_overview/healthcare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/demo/feature_overview/healthcare.png -------------------------------------------------------------------------------- /demo/feature_overview/missing_embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MissingEmbedding Inspection 3 | """ 4 | import dataclasses 5 | from typing import Iterable, List 6 | 7 | from mlinspect import FunctionInfo 8 | from mlinspect.inspections import Inspection, InspectionInputUnaryOperator 9 | 10 | 11 | @dataclasses.dataclass(frozen=True, eq=True) 12 | class MissingEmbeddingsInfo: 13 | """ 14 | Info about potentially missing embeddings 15 | """ 16 | missing_embedding_count: int 17 | missing_embeddings_examples: List[str] 18 | 19 | 20 | class MissingEmbeddings(Inspection): 21 | """ 22 | A simple example inspection 23 | """ 24 | 25 | def __init__(self, example_threshold=10): 26 | self._is_embedding_operator = False 27 | self._missing_embedding_count = 0 28 | self._missing_embeddings_examples = [] 29 | self.example_threshold = example_threshold 30 | 31 | def visit_operator(self, inspection_input) -> Iterable[any]: 32 | """ 33 | Visit an operator 34 | """ 35 | # pylint: disable=too-many-branches, too-many-statements 36 | if isinstance(inspection_input, InspectionInputUnaryOperator) and \ 37 | inspection_input.operator_context.function_info == \ 38 | FunctionInfo('example_pipelines.healthcare.healthcare_utils', 'MyW2VTransformer'): 39 | # TODO: Are there existing word embedding transformers for sklearn we can use this for? 40 | self._is_embedding_operator = True 41 | for row in inspection_input.row_iterator: 42 | # Count missing embeddings 43 | embedding_array = row.output[0] 44 | is_zero_vector = not embedding_array.any() 45 | if is_zero_vector: 46 | self._missing_embedding_count += 1 47 | if len(self._missing_embeddings_examples) < self.example_threshold: 48 | self._missing_embeddings_examples.append(row.input[0]) 49 | yield None 50 | else: 51 | for _ in inspection_input.row_iterator: 52 | yield None 53 | 54 | def get_operator_annotation_after_visit(self) -> any: 55 | if self._is_embedding_operator: 56 | assert self._missing_embedding_count is not None # May only be called after the operator visit is finished 57 | result = MissingEmbeddingsInfo(self._missing_embedding_count, self._missing_embeddings_examples) 58 | self._missing_embedding_count = 0 59 | self._is_embedding_operator = False 60 | self._missing_embeddings_examples = [] 61 | return result 62 | return None 63 | 64 | @property 65 | def inspection_id(self): 66 | return self.example_threshold 67 | -------------------------------------------------------------------------------- /demo/feature_overview/no_missing_embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | The NoMissingEmbeddings Check 3 | """ 4 | import collections 5 | import dataclasses 6 | from typing import Iterable, Dict 7 | 8 | from demo.feature_overview.missing_embeddings import MissingEmbeddings, MissingEmbeddingsInfo 9 | from mlinspect import DagNode 10 | from mlinspect.checks import Check, CheckStatus, CheckResult 11 | from mlinspect.inspections import Inspection, InspectionResult 12 | 13 | 14 | ILLEGAL_FEATURES = {"race", "gender", "age"} 15 | 16 | 17 | @dataclasses.dataclass 18 | class NoMissingEmbeddingsResult(CheckResult): 19 | """ 20 | Does the pipeline use illegal features? 21 | """ 22 | dag_node_to_missing_embeddings: Dict[DagNode, MissingEmbeddingsInfo] 23 | 24 | 25 | class NoMissingEmbeddings(Check): 26 | """ 27 | Does the model get sensitive attributes like race as feature? 28 | """ 29 | # pylint: disable=unnecessary-pass, too-few-public-methods 30 | 31 | def __init__(self, example_threshold=10): 32 | self.example_threshold = example_threshold 33 | 34 | @property 35 | def check_id(self): 36 | """The id of the Constraints""" 37 | return self.example_threshold 38 | 39 | @property 40 | def required_inspections(self) -> Iterable[Inspection]: 41 | """The id of the check""" 42 | return [MissingEmbeddings(self.example_threshold)] 43 | 44 | def evaluate(self, inspection_result: InspectionResult) -> CheckResult: 45 | """Evaluate the check""" 46 | dag_node_to_missing_embeddings = {} 47 | for dag_node, dag_node_inspection_result in inspection_result.dag_node_to_inspection_results.items(): 48 | if MissingEmbeddings(self.example_threshold) in dag_node_inspection_result: 49 | missing_embedding_info = dag_node_inspection_result[MissingEmbeddings(self.example_threshold)] 50 | assert missing_embedding_info is None or isinstance(missing_embedding_info, MissingEmbeddingsInfo) 51 | if missing_embedding_info is not None and missing_embedding_info.missing_embedding_count > 0: 52 | dag_node_to_missing_embeddings[dag_node] = missing_embedding_info 53 | if dag_node_to_missing_embeddings: 54 | description = "Missing embeddings were found!" 55 | result = NoMissingEmbeddingsResult(self, CheckStatus.FAILURE, description, dag_node_to_missing_embeddings) 56 | else: 57 | result = NoMissingEmbeddingsResult(self, CheckStatus.SUCCESS, None, collections.OrderedDict()) 58 | return result 59 | -------------------------------------------------------------------------------- /demo/feature_overview/paper_example_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/demo/feature_overview/paper_example_image.png -------------------------------------------------------------------------------- /example_pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Packages and classes we want to expose to users 3 | """ 4 | from ._pipelines import ADULT_SIMPLE_PY, ADULT_SIMPLE_IPYNB, ADULT_SIMPLE_PNG, ADULT_COMPLEX_PY, ADULT_COMPLEX_PNG, \ 5 | COMPAS_PY, COMPAS_PNG, HEALTHCARE_PY, HEALTHCARE_PNG 6 | 7 | __all__ = [ 8 | 'ADULT_SIMPLE_PY', 'ADULT_SIMPLE_IPYNB', 'ADULT_SIMPLE_PNG', 9 | 'ADULT_COMPLEX_PY', 'ADULT_COMPLEX_PNG', 10 | 'COMPAS_PY', 'COMPAS_PNG', 11 | 'HEALTHCARE_PY', 'HEALTHCARE_PNG', 12 | ] 13 | -------------------------------------------------------------------------------- /example_pipelines/_pipelines.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some useful utils for the project 3 | """ 4 | import os 5 | 6 | from mlinspect.utils import get_project_root 7 | 8 | ADULT_SIMPLE_PY = os.path.join(str(get_project_root()), "example_pipelines", "adult_simple", "adult_simple.py") 9 | ADULT_SIMPLE_IPYNB = os.path.join(str(get_project_root()), "example_pipelines", "adult_simple", "adult_simple.ipynb") 10 | ADULT_SIMPLE_PNG = os.path.join(str(get_project_root()), "example_pipelines", "adult_simple", "adult_simple.png") 11 | 12 | ADULT_COMPLEX_PY = os.path.join(str(get_project_root()), "example_pipelines", "adult_complex", "adult_complex.py") 13 | ADULT_COMPLEX_PNG = os.path.join(str(get_project_root()), "example_pipelines", "adult_complex", "adult_complex.png") 14 | 15 | COMPAS_PY = os.path.join(str(get_project_root()), "example_pipelines", "compas", "compas.py") 16 | COMPAS_PNG = os.path.join(str(get_project_root()), "example_pipelines", "compas", "compas.png") 17 | 18 | HEALTHCARE_PY = os.path.join(str(get_project_root()), "example_pipelines", "healthcare", "healthcare.py") 19 | HEALTHCARE_PNG = os.path.join(str(get_project_root()), "example_pipelines", "healthcare", "healthcare.png") 20 | -------------------------------------------------------------------------------- /example_pipelines/adult_complex/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/adult_complex/__init__.py -------------------------------------------------------------------------------- /example_pipelines/adult_complex/adult_complex.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/adult_complex/adult_complex.png -------------------------------------------------------------------------------- /example_pipelines/adult_complex/adult_complex.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example pipeline 3 | """ 4 | import os 5 | 6 | import pandas as pd 7 | import numpy as np 8 | from sklearn import preprocessing 9 | from sklearn.compose import ColumnTransformer 10 | from sklearn.impute import SimpleImputer 11 | from sklearn.pipeline import Pipeline 12 | from sklearn.preprocessing import OneHotEncoder, StandardScaler 13 | from sklearn.tree import DecisionTreeClassifier 14 | 15 | from mlinspect.utils import get_project_root 16 | 17 | train_file = os.path.join(str(get_project_root()), "example_pipelines", "adult_complex", "adult_train.csv") 18 | train_data = pd.read_csv(train_file, na_values='?', index_col=0) 19 | test_file = os.path.join(str(get_project_root()), "example_pipelines", "adult_complex", "adult_test.csv") 20 | test_data = pd.read_csv(test_file, na_values='?', index_col=0) 21 | 22 | train_labels = preprocessing.label_binarize(train_data['income-per-year'], classes=['>50K', '<=50K']) 23 | test_labels = preprocessing.label_binarize(test_data['income-per-year'], classes=['>50K', '<=50K']) 24 | 25 | nested_categorical_feature_transformation = Pipeline([ 26 | ('impute', SimpleImputer(missing_values=np.nan, strategy='most_frequent')), 27 | ('encode', OneHotEncoder(handle_unknown='ignore')) 28 | ]) 29 | 30 | nested_feature_transformation = ColumnTransformer(transformers=[ 31 | ('categorical', nested_categorical_feature_transformation, ['education', 'workclass']), 32 | ('numeric', StandardScaler(), ['age', 'hours-per-week']) 33 | ]) 34 | 35 | nested_income_pipeline = Pipeline([ 36 | ('features', nested_feature_transformation), 37 | ('classifier', DecisionTreeClassifier())]) 38 | 39 | nested_income_pipeline.fit(train_data, train_labels) 40 | 41 | print(nested_income_pipeline.score(test_data, test_labels)) 42 | -------------------------------------------------------------------------------- /example_pipelines/adult_simple/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/adult_simple/__init__.py -------------------------------------------------------------------------------- /example_pipelines/adult_simple/adult_simple.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "pipeline start\n", 13 | "pipeline finished\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "\"\"\"\n", 19 | "An example pipeline\n", 20 | "\"\"\"\n", 21 | "import os\n", 22 | "import pandas as pd\n", 23 | "\n", 24 | "from sklearn import compose, preprocessing, tree, pipeline\n", 25 | "from mlinspect.utils import get_project_root\n", 26 | "\n", 27 | "print('pipeline start')\n", 28 | "train_file = os.path.join(str(get_project_root()), \"example_pipelines\", \"adult_complex\", \"adult_train.csv\")\n", 29 | "raw_data = pd.read_csv(train_file, na_values='?', index_col=0)\n", 30 | "\n", 31 | "data = raw_data.dropna()\n", 32 | "\n", 33 | "labels = preprocessing.label_binarize(data['income-per-year'], classes=['>50K', '<=50K'])\n", 34 | "\n", 35 | "feature_transformation = compose.ColumnTransformer(transformers=[\n", 36 | " ('categorical', preprocessing.OneHotEncoder(handle_unknown='ignore'), ['education', 'workclass']),\n", 37 | " ('numeric', preprocessing.StandardScaler(), ['age', 'hours-per-week'])\n", 38 | "])\n", 39 | "\n", 40 | "\n", 41 | "income_pipeline = pipeline.Pipeline([\n", 42 | " ('features', feature_transformation),\n", 43 | " ('classifier', tree.DecisionTreeClassifier())])\n", 44 | "\n", 45 | "income_pipeline.fit(data, labels)\n", 46 | "\n", 47 | "\n", 48 | "print('pipeline finished')" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [] 57 | } 58 | ], 59 | "metadata": { 60 | "kernelspec": { 61 | "display_name": "Python 3", 62 | "language": "python", 63 | "name": "python3" 64 | }, 65 | "language_info": { 66 | "codemirror_mode": { 67 | "name": "ipython", 68 | "version": 3 69 | }, 70 | "file_extension": ".py", 71 | "mimetype": "text/x-python", 72 | "name": "python", 73 | "nbconvert_exporter": "python", 74 | "pygments_lexer": "ipython3", 75 | "version": "3.8.5" 76 | } 77 | }, 78 | "nbformat": 4, 79 | "nbformat_minor": 2 80 | } 81 | -------------------------------------------------------------------------------- /example_pipelines/adult_simple/adult_simple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/adult_simple/adult_simple.png -------------------------------------------------------------------------------- /example_pipelines/adult_simple/adult_simple.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example pipeline 3 | """ 4 | import os 5 | import pandas as pd 6 | 7 | from sklearn import compose, preprocessing, tree, pipeline 8 | from mlinspect.utils import get_project_root 9 | 10 | print('pipeline start') 11 | train_file = os.path.join(str(get_project_root()), "example_pipelines", "adult_complex", "adult_train.csv") 12 | raw_data = pd.read_csv(train_file, na_values='?', index_col=0) 13 | 14 | data = raw_data.dropna() 15 | 16 | labels = preprocessing.label_binarize(data['income-per-year'], classes=['>50K', '<=50K']) 17 | 18 | feature_transformation = compose.ColumnTransformer(transformers=[ 19 | ('categorical', preprocessing.OneHotEncoder(handle_unknown='ignore'), ['education', 'workclass']), 20 | ('numeric', preprocessing.StandardScaler(), ['age', 'hours-per-week']) 21 | ]) 22 | 23 | 24 | income_pipeline = pipeline.Pipeline([ 25 | ('features', feature_transformation), 26 | ('classifier', tree.DecisionTreeClassifier())]) 27 | 28 | income_pipeline.fit(data, labels) 29 | 30 | 31 | print('pipeline finished') 32 | -------------------------------------------------------------------------------- /example_pipelines/compas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/compas/__init__.py -------------------------------------------------------------------------------- /example_pipelines/compas/compas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/compas/compas.png -------------------------------------------------------------------------------- /example_pipelines/compas/compas.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example pipeline 3 | """ 4 | import os 5 | 6 | import pandas as pd 7 | from sklearn.compose import ColumnTransformer 8 | from sklearn.impute import SimpleImputer 9 | from sklearn.linear_model import LogisticRegression 10 | from sklearn.pipeline import Pipeline 11 | from sklearn.preprocessing import OneHotEncoder, KBinsDiscretizer, label_binarize 12 | 13 | from mlinspect.utils import get_project_root 14 | 15 | train_file = os.path.join(str(get_project_root()), "example_pipelines", "compas", "compas_train.csv") 16 | train_data = pd.read_csv(train_file, na_values='?', index_col=0) 17 | test_file = os.path.join(str(get_project_root()), "example_pipelines", "compas", "compas_test.csv") 18 | test_data = pd.read_csv(test_file, na_values='?', index_col=0) 19 | 20 | train_data = train_data[ 21 | ['sex', 'dob', 'age', 'c_charge_degree', 'race', 'score_text', 'priors_count', 'days_b_screening_arrest', 22 | 'decile_score', 'is_recid', 'two_year_recid', 'c_jail_in', 'c_jail_out']] 23 | test_data = test_data[ 24 | ['sex', 'dob', 'age', 'c_charge_degree', 'race', 'score_text', 'priors_count', 'days_b_screening_arrest', 25 | 'decile_score', 'is_recid', 'two_year_recid', 'c_jail_in', 'c_jail_out']] 26 | 27 | train_data = train_data[(train_data['days_b_screening_arrest'] <= 30) & (train_data['days_b_screening_arrest'] >= -30)] 28 | train_data = train_data[train_data['is_recid'] != -1] 29 | train_data = train_data[train_data['c_charge_degree'] != "O"] 30 | train_data = train_data[train_data['score_text'] != 'N/A'] 31 | 32 | train_data = train_data.replace('Medium', "Low") 33 | test_data = test_data.replace('Medium', "Low") 34 | 35 | train_labels = label_binarize(train_data['score_text'], classes=['High', 'Low']) 36 | test_labels = label_binarize(test_data['score_text'], classes=['High', 'Low']) 37 | 38 | impute1_and_onehot = Pipeline([('imputer1', SimpleImputer(strategy='most_frequent')), 39 | ('onehot', OneHotEncoder(handle_unknown='ignore'))]) 40 | impute2_and_bin = Pipeline([('imputer2', SimpleImputer(strategy='mean')), 41 | ('discretizer', KBinsDiscretizer(n_bins=4, encode='ordinal', strategy='uniform'))]) 42 | 43 | featurizer = ColumnTransformer(transformers=[ 44 | ('impute1_and_onehot', impute1_and_onehot, ['is_recid']), 45 | ('impute2_and_bin', impute2_and_bin, ['age']) 46 | ]) 47 | 48 | pipeline = Pipeline([ 49 | ('features', featurizer), 50 | ('classifier', LogisticRegression()) 51 | ]) 52 | 53 | pipeline.fit(train_data, train_labels.ravel()) 54 | print(pipeline.score(test_data, test_labels.ravel())) 55 | -------------------------------------------------------------------------------- /example_pipelines/healthcare/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/healthcare/__init__.py -------------------------------------------------------------------------------- /example_pipelines/healthcare/_gensim_wrapper.py: -------------------------------------------------------------------------------- 1 | # pylint: disable-all 2 | # Author: Chinmaya Pancholi 3 | # Copyright (C) 2017 Radim Rehurek 4 | # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html 5 | """Scikit learn interface for :class:`~gensim.models.word2vec.Word2Vec`. 6 | Follows scikit-learn API conventions to facilitate using gensim along with scikit-learn. 7 | """ 8 | import numpy as np 9 | import six 10 | from sklearn.base import TransformerMixin, BaseEstimator 11 | from sklearn.exceptions import NotFittedError 12 | 13 | from gensim import models 14 | 15 | 16 | class W2VTransformer(TransformerMixin, BaseEstimator): 17 | """Base Word2Vec module, wraps :class:`~gensim.models.word2vec.Word2Vec`. 18 | For more information please have a look to `Tomas Mikolov, Kai Chen, Greg Corrado, Jeffrey Dean: "Efficient 19 | Estimation of Word Representations in Vector Space" `_. 20 | """ 21 | def __init__(self, size=100, alpha=0.025, window=5, min_count=5, max_vocab_size=None, sample=1e-3, seed=1, 22 | workers=3, min_alpha=0.0001, sg=0, hs=0, negative=5, cbow_mean=1, hashfxn=hash, iter=5, null_word=0, 23 | trim_rule=None, sorted_vocab=1, batch_words=10000): 24 | """ 25 | Parameters 26 | ---------- 27 | size : int 28 | Dimensionality of the feature vectors. 29 | alpha : float 30 | The initial learning rate. 31 | window : int 32 | The maximum distance between the current and predicted word within a sentence. 33 | min_count : int 34 | Ignores all words with total frequency lower than this. 35 | max_vocab_size : int 36 | Limits the RAM during vocabulary building; if there are more unique 37 | words than this, then prune the infrequent ones. Every 10 million word types need about 1GB of RAM. 38 | Set to `None` for no limit. 39 | sample : float 40 | The threshold for configuring which higher-frequency words are randomly downsampled, 41 | useful range is (0, 1e-5). 42 | seed : int 43 | Seed for the random number generator. Initial vectors for each word are seeded with a hash of 44 | the concatenation of word + `str(seed)`. Note that for a fully deterministically-reproducible run, 45 | you must also limit the model to a single worker thread (`workers=1`), to eliminate ordering jitter 46 | from OS thread scheduling. (In Python 3, reproducibility between interpreter launches also requires 47 | use of the `PYTHONHASHSEED` environment variable to control hash randomization). 48 | workers : int 49 | Use these many worker threads to train the model (=faster training with multicore machines). 50 | min_alpha : float 51 | Learning rate will linearly drop to `min_alpha` as training progresses. 52 | sg : int {1, 0} 53 | Defines the training algorithm. If 1, CBOW is used, otherwise, skip-gram is employed. 54 | hs : int {1,0} 55 | If 1, hierarchical softmax will be used for model training. 56 | If set to 0, and `negative` is non-zero, negative sampling will be used. 57 | negative : int 58 | If > 0, negative sampling will be used, the int for negative specifies how many "noise words" 59 | should be drawn (usually between 5-20). 60 | If set to 0, no negative sampling is used. 61 | cbow_mean : int {1,0} 62 | If 0, use the sum of the context word vectors. If 1, use the mean, only applies when cbow is used. 63 | hashfxn : callable (object -> int), optional 64 | A hashing function. Used to create an initial random reproducible vector by hashing the random seed. 65 | iter : int 66 | Number of iterations (epochs) over the corpus. 67 | null_word : int {1, 0} 68 | If 1, a null pseudo-word will be created for padding when using concatenative L1 (run-of-words) 69 | trim_rule : function 70 | Vocabulary trimming rule, specifies whether certain words should remain in the vocabulary, 71 | be trimmed away, or handled using the default (discard if word count < min_count). 72 | Can be None (min_count will be used, look to :func:`~gensim.utils.keep_vocab_item`), 73 | or a callable that accepts parameters (word, count, min_count) and returns either 74 | :attr:`gensim.utils.RULE_DISCARD`, :attr:`gensim.utils.RULE_KEEP` or :attr:`gensim.utils.RULE_DEFAULT`. 75 | Note: The rule, if given, is only used to prune vocabulary during build_vocab() and is not stored as part 76 | of the model. 77 | sorted_vocab : int {1,0} 78 | If 1, sort the vocabulary by descending frequency before assigning word indexes. 79 | batch_words : int 80 | Target size (in words) for batches of examples passed to worker threads (and 81 | thus cython routines).(Larger batches will be passed if individual 82 | texts are longer than 10000 words, but the standard cython code truncates to that maximum.) 83 | """ 84 | self.gensim_model = None 85 | self.size = size 86 | self.alpha = alpha 87 | self.window = window 88 | self.min_count = min_count 89 | self.max_vocab_size = max_vocab_size 90 | self.sample = sample 91 | self.seed = seed 92 | self.workers = workers 93 | self.min_alpha = min_alpha 94 | self.sg = sg 95 | self.hs = hs 96 | self.negative = negative 97 | self.cbow_mean = int(cbow_mean) 98 | self.hashfxn = hashfxn 99 | self.iter = iter 100 | self.null_word = null_word 101 | self.trim_rule = trim_rule 102 | self.sorted_vocab = sorted_vocab 103 | self.batch_words = batch_words 104 | 105 | def fit(self, X, y=None): 106 | """Fit the model according to the given training data. 107 | Parameters 108 | ---------- 109 | X : iterable of iterables of str 110 | The input corpus. X can be simply a list of lists of tokens, but for larger corpora, 111 | consider an iterable that streams the sentences directly from disk/network. 112 | See :class:`~gensim.models.word2vec.BrownCorpus`, :class:`~gensim.models.word2vec.Text8Corpus` 113 | or :class:`~gensim.models.word2vec.LineSentence` in :mod:`~gensim.models.word2vec` module for such examples. 114 | Returns 115 | ------- 116 | :class:`~gensim.sklearn_api.w2vmodel.W2VTransformer` 117 | The trained model. 118 | """ 119 | self.gensim_model = models.Word2Vec( 120 | sentences=X, vector_size=self.size, alpha=self.alpha, 121 | window=self.window, min_count=self.min_count, max_vocab_size=self.max_vocab_size, 122 | sample=self.sample, seed=self.seed, workers=self.workers, min_alpha=self.min_alpha, 123 | sg=self.sg, hs=self.hs, negative=self.negative, cbow_mean=self.cbow_mean, 124 | hashfxn=self.hashfxn, null_word=self.null_word, trim_rule=self.trim_rule, 125 | sorted_vocab=self.sorted_vocab, batch_words=self.batch_words 126 | ) 127 | return self 128 | 129 | def transform(self, words): 130 | """Get the word vectors the input words. 131 | Parameters 132 | ---------- 133 | words : {iterable of str, str} 134 | Word or a collection of words to be transformed. 135 | Returns 136 | ------- 137 | np.ndarray of shape [`len(words)`, `size`] 138 | A 2D array where each row is the vector of one word. 139 | """ 140 | if self.gensim_model is None: 141 | raise NotFittedError( 142 | "This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method." 143 | ) 144 | 145 | # The input as array of array 146 | if isinstance(words, six.string_types): 147 | words = [words] 148 | vectors = [self.gensim_model.wv[word] for word in words] 149 | return np.reshape(np.array(vectors), (len(words), self.size)) 150 | 151 | def partial_fit(self, X): 152 | raise NotImplementedError( 153 | "'partial_fit' has not been implemented for W2VTransformer. " 154 | "However, the model can be updated with a fixed vocabulary using Gensim API call." 155 | ) 156 | -------------------------------------------------------------------------------- /example_pipelines/healthcare/custom_monkeypatching/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/healthcare/custom_monkeypatching/__init__.py -------------------------------------------------------------------------------- /example_pipelines/healthcare/custom_monkeypatching/patch_healthcare_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Monkey patching for healthcare_utils 3 | """ 4 | import gorilla 5 | 6 | from example_pipelines.healthcare import healthcare_utils 7 | from example_pipelines.healthcare import _gensim_wrapper 8 | from mlinspect.backends._sklearn_backend import SklearnBackend 9 | from mlinspect.inspections._inspection_input import OperatorContext, FunctionInfo, OperatorType 10 | from mlinspect.instrumentation._dag_node import DagNode, BasicCodeLocation, DagNodeDetails 11 | from mlinspect.instrumentation._pipeline_executor import singleton 12 | from mlinspect.monkeypatching._monkey_patching_utils import add_dag_node, \ 13 | get_input_info, execute_patched_func_no_op_id, get_optional_code_info_or_none 14 | from mlinspect.monkeypatching._mlinspect_ndarray import MlinspectNdarray 15 | 16 | 17 | class SklearnMyW2VTransformerPatching: 18 | """ Patches for healthcare_utils.MyW2VTransformer""" 19 | 20 | # pylint: disable=too-few-public-methods 21 | 22 | @gorilla.patch(_gensim_wrapper.W2VTransformer, name='__init__', settings=gorilla.Settings(allow_hit=True)) 23 | def patched__init__(self, *, size=100, alpha=0.025, window=5, min_count=5, max_vocab_size=None, sample=1e-3, seed=1, 24 | workers=3, min_alpha=0.0001, sg=0, hs=0, negative=5, cbow_mean=1, hashfxn=hash, iter=5, 25 | null_word=0, trim_rule=None, sorted_vocab=1, batch_words=10000, 26 | mlinspect_caller_filename=None, mlinspect_lineno=None, 27 | mlinspect_optional_code_reference=None, mlinspect_optional_source_code=None, 28 | mlinspect_fit_transform_active=False): 29 | """ Patch for ('example_pipelines.healthcare.healthcare_utils', 'MyW2VTransformer') """ 30 | # pylint: disable=no-method-argument, attribute-defined-outside-init, too-many-locals, redefined-builtin, 31 | # pylint: disable=invalid-name 32 | original = gorilla.get_original_attribute(_gensim_wrapper.W2VTransformer, '__init__') 33 | 34 | self.mlinspect_caller_filename = mlinspect_caller_filename 35 | self.mlinspect_lineno = mlinspect_lineno 36 | self.mlinspect_optional_code_reference = mlinspect_optional_code_reference 37 | self.mlinspect_optional_source_code = mlinspect_optional_source_code 38 | self.mlinspect_fit_transform_active = mlinspect_fit_transform_active 39 | 40 | self.mlinspect_non_data_func_args = {'size': size, 'alpha': alpha, 'window': window, 41 | 'min_count': min_count, 'max_vocab_size': max_vocab_size, 'sample': sample, 42 | 'seed': seed, 'workers': workers, 'min_alpha': min_alpha, 'sg': sg, 43 | 'hs': hs, 'negative': negative, 'cbow_mean': cbow_mean, 'iter': iter, 44 | 'null_word': null_word, 'trim_rule': trim_rule, 45 | 'sorted_vocab': sorted_vocab, 'batch_words': batch_words} 46 | 47 | def execute_inspections(_, caller_filename, lineno, optional_code_reference, optional_source_code): 48 | """ Execute inspections, add DAG node """ 49 | original(self, hashfxn=hashfxn, **self.mlinspect_non_data_func_args) 50 | 51 | self.mlinspect_caller_filename = caller_filename 52 | self.mlinspect_lineno = lineno 53 | self.mlinspect_optional_code_reference = optional_code_reference 54 | self.mlinspect_optional_source_code = optional_source_code 55 | 56 | return execute_patched_func_no_op_id(original, execute_inspections, self, hashfxn=hashfxn, 57 | **self.mlinspect_non_data_func_args) 58 | 59 | @gorilla.patch(healthcare_utils.MyW2VTransformer, name='fit_transform', settings=gorilla.Settings(allow_hit=True)) 60 | def patched_fit_transform(self, *args, **kwargs): 61 | """ Patch for ('example_pipelines.healthcare.healthcare_utils.MyW2VTransformer', 'fit_transform') """ 62 | # pylint: disable=no-method-argument 63 | self.mlinspect_fit_transform_active = True # pylint: disable=attribute-defined-outside-init 64 | original = gorilla.get_original_attribute(healthcare_utils.MyW2VTransformer, 'fit_transform') 65 | function_info = FunctionInfo('example_pipelines.healthcare.healthcare_utils', 'MyW2VTransformer') 66 | input_info = get_input_info(args[0], self.mlinspect_caller_filename, self.mlinspect_lineno, function_info, 67 | self.mlinspect_optional_code_reference, self.mlinspect_optional_source_code) 68 | 69 | operator_context = OperatorContext(OperatorType.TRANSFORMER, function_info) 70 | input_infos = SklearnBackend.before_call(operator_context, [input_info.annotated_dfobject]) 71 | result = original(self, input_infos[0].result_data, *args[1:], **kwargs) 72 | backend_result = SklearnBackend.after_call(operator_context, 73 | input_infos, 74 | result, 75 | self.mlinspect_non_data_func_args) 76 | new_return_value = backend_result.annotated_dfobject.result_data 77 | assert isinstance(new_return_value, MlinspectNdarray) 78 | dag_node = DagNode(singleton.get_next_op_id(), 79 | BasicCodeLocation(self.mlinspect_caller_filename, self.mlinspect_lineno), 80 | operator_context, 81 | DagNodeDetails("Word2Vec: fit_transform", ['array']), 82 | get_optional_code_info_or_none(self.mlinspect_optional_code_reference, 83 | self.mlinspect_optional_source_code)) 84 | add_dag_node(dag_node, [input_info.dag_node], backend_result) 85 | self.mlinspect_fit_transform_active = False # pylint: disable=attribute-defined-outside-init 86 | return new_return_value 87 | 88 | @gorilla.patch(healthcare_utils.MyW2VTransformer, name='transform', settings=gorilla.Settings(allow_hit=True)) 89 | def patched_transform(self, *args, **kwargs): 90 | """ Patch for ('example_pipelines.healthcare.healthcare_utils.MyW2VTransformer', 'transform') """ 91 | # pylint: disable=no-method-argument 92 | original = gorilla.get_original_attribute(healthcare_utils.MyW2VTransformer, 'transform') 93 | if not self.mlinspect_fit_transform_active: 94 | function_info = FunctionInfo('example_pipelines.healthcare.healthcare_utils', 'MyW2VTransformer') 95 | input_info = get_input_info(args[0], self.mlinspect_caller_filename, self.mlinspect_lineno, function_info, 96 | self.mlinspect_optional_code_reference, self.mlinspect_optional_source_code) 97 | 98 | operator_context = OperatorContext(OperatorType.TRANSFORMER, function_info) 99 | input_infos = SklearnBackend.before_call(operator_context, [input_info.annotated_dfobject]) 100 | result = original(self, input_infos[0].result_data, *args[1:], **kwargs) 101 | backend_result = SklearnBackend.after_call(operator_context, 102 | input_infos, 103 | result, 104 | self.mlinspect_non_data_func_args) 105 | new_return_value = backend_result.annotated_dfobject.result_data 106 | assert isinstance(new_return_value, MlinspectNdarray) 107 | dag_node = DagNode(singleton.get_next_op_id(), 108 | BasicCodeLocation(self.mlinspect_caller_filename, self.mlinspect_lineno), 109 | operator_context, 110 | DagNodeDetails("Word2Vec: transform", ['array']), 111 | get_optional_code_info_or_none(self.mlinspect_optional_code_reference, 112 | self.mlinspect_optional_source_code)) 113 | add_dag_node(dag_node, [input_info.dag_node], backend_result) 114 | else: 115 | new_return_value = original(self, *args, **kwargs) 116 | return new_return_value 117 | -------------------------------------------------------------------------------- /example_pipelines/healthcare/healthcare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/example_pipelines/healthcare/healthcare.png -------------------------------------------------------------------------------- /example_pipelines/healthcare/healthcare.py: -------------------------------------------------------------------------------- 1 | """Predicting which patients are at a higher risk of complications""" 2 | import warnings 3 | import os 4 | import pandas as pd 5 | from scikeras.wrappers import KerasClassifier 6 | from sklearn.compose import ColumnTransformer 7 | from sklearn.impute import SimpleImputer 8 | from sklearn.model_selection import train_test_split 9 | from sklearn.pipeline import Pipeline 10 | from sklearn.preprocessing import OneHotEncoder, StandardScaler 11 | from example_pipelines.healthcare.healthcare_utils import MyW2VTransformer, \ 12 | create_model 13 | from mlinspect.utils import get_project_root 14 | 15 | # FutureWarning: Sklearn 0.24 made a change that breaks remainder='drop', that change will be fixed 16 | # in an upcoming version: https://github.com/scikit-learn/scikit-learn/pull/19263 17 | warnings.filterwarnings('ignore') 18 | 19 | COUNTIES_OF_INTEREST = ['county2', 'county3'] 20 | 21 | patients = pd.read_csv(os.path.join(str(get_project_root()), "example_pipelines", "healthcare", 22 | "patients.csv"), na_values='?') 23 | histories = pd.read_csv(os.path.join(str(get_project_root()), "example_pipelines", "healthcare", 24 | "histories.csv"), na_values='?') 25 | 26 | data = patients.merge(histories, on=['ssn']) 27 | complications = data.groupby('age_group') \ 28 | .agg(mean_complications=('complications', 'mean')) 29 | data = data.merge(complications, on=['age_group']) 30 | data['label'] = data['complications'] > 1.2 * data['mean_complications'] 31 | data = data[['smoker', 'last_name', 'county', 'num_children', 'race', 'income', 'label']] 32 | data = data[data['county'].isin(COUNTIES_OF_INTEREST)] 33 | 34 | impute_and_one_hot_encode = Pipeline([ 35 | ('impute', SimpleImputer(strategy='most_frequent')), 36 | ('encode', OneHotEncoder(sparse=False, handle_unknown='ignore')) 37 | ]) 38 | featurisation = ColumnTransformer(transformers=[ 39 | ("impute_and_one_hot_encode", impute_and_one_hot_encode, ['smoker', 'county', 'race']), 40 | ('word2vec', MyW2VTransformer(min_count=2), ['last_name']), 41 | ('numeric', StandardScaler(), ['num_children', 'income']), 42 | ], remainder='drop') 43 | neural_net = KerasClassifier(model=create_model, epochs=10, batch_size=1, verbose=0, 44 | hidden_layer_sizes=(9, 9,), loss="binary_crossentropy") 45 | pipeline = Pipeline([ 46 | ('features', featurisation), 47 | ('learner', neural_net)]) 48 | 49 | train_data, test_data = train_test_split(data) 50 | model = pipeline.fit(train_data, train_data['label']) 51 | print(f"Mean accuracy: {model.score(test_data, test_data['label'])}") 52 | -------------------------------------------------------------------------------- /example_pipelines/healthcare/healthcare_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some useful utils for the project 3 | """ 4 | import numpy 5 | from keras import Input 6 | from sklearn.exceptions import NotFittedError 7 | from tensorflow.keras.models import Sequential 8 | from tensorflow.keras.layers import Dense 9 | 10 | from example_pipelines.healthcare._gensim_wrapper import W2VTransformer 11 | 12 | 13 | class MyW2VTransformer(W2VTransformer): 14 | """Some custom w2v transformer.""" 15 | # pylint: disable-all 16 | 17 | def partial_fit(self, X): 18 | super().partial_fit([X]) 19 | 20 | def fit(self, X, y=None): 21 | X = X.iloc[:, 0].tolist() 22 | return super().fit([X], y) 23 | 24 | def transform(self, words): 25 | words = words.iloc[:, 0].tolist() 26 | if self.gensim_model is None: 27 | raise NotFittedError( 28 | "This model has not been fitted yet. Call 'fit' with appropriate arguments before using this method." 29 | ) 30 | 31 | # The input as array of array 32 | vectors = [] 33 | for word in words: 34 | if word in self.gensim_model.wv: 35 | vectors.append(self.gensim_model.wv[word]) 36 | else: 37 | vectors.append(numpy.zeros(self.size)) 38 | return numpy.reshape(numpy.array(vectors), (len(words), self.size)) 39 | 40 | 41 | def create_model(meta, hidden_layer_sizes): 42 | n_features_in_ = meta["n_features_in_"] 43 | n_classes_ = meta["n_classes_"] 44 | model = Sequential() 45 | model.add(Input(shape=(n_features_in_,))) 46 | for hidden_layer_size in hidden_layer_sizes: 47 | model.add(Dense(hidden_layer_size, activation="relu")) 48 | model.add(Dense(1, activation="sigmoid")) 49 | return model 50 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/performance/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/experiments/performance/__init__.py -------------------------------------------------------------------------------- /experiments/performance/_empty_inspection.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple empty inspection 3 | """ 4 | from typing import Iterable 5 | 6 | from mlinspect.inspections._inspection import Inspection 7 | 8 | 9 | class EmptyInspection(Inspection): 10 | """ 11 | An empty inspection for performance experiments 12 | """ 13 | 14 | def __init__(self, inspection_id): 15 | self._id = inspection_id 16 | 17 | @property 18 | def inspection_id(self): 19 | return self._id 20 | 21 | def visit_operator(self, inspection_input) -> Iterable[any]: 22 | """ 23 | Visit an operator 24 | """ 25 | for _ in inspection_input.row_iterator: 26 | yield None 27 | 28 | def get_operator_annotation_after_visit(self) -> any: 29 | return None 30 | -------------------------------------------------------------------------------- /mlinspect/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Packages and classes we want to expose to users 3 | """ 4 | from ._pipeline_inspector import PipelineInspector 5 | from ._inspector_result import InspectorResult 6 | from .inspections._inspection_input import OperatorContext, FunctionInfo, OperatorType 7 | from .instrumentation._dag_node import DagNode, BasicCodeLocation, DagNodeDetails, OptionalCodeInfo, CodeReference 8 | 9 | __all__ = [ 10 | 'utils', 11 | 'inspections', 12 | 'checks', 13 | 'visualisation', 14 | 'PipelineInspector', 'InspectorResult', 15 | 'DagNode', 'OperatorType', 16 | 'BasicCodeLocation', 'OperatorContext', 'DagNodeDetails', 'OptionalCodeInfo', 'FunctionInfo', 'CodeReference' 17 | ] 18 | -------------------------------------------------------------------------------- /mlinspect/_inspector_result.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data class used as result of the PipelineExecutor 3 | """ 4 | import dataclasses 5 | from typing import Dict 6 | 7 | import networkx 8 | 9 | from mlinspect.checks._check import Check, CheckResult 10 | from mlinspect.inspections._inspection import Inspection 11 | 12 | 13 | @dataclasses.dataclass 14 | class InspectorResult: 15 | """ 16 | The class the PipelineExecutor returns 17 | """ 18 | dag: networkx.DiGraph 19 | dag_node_to_inspection_results: Dict[any, Dict[Inspection, any]] # First any is DagNode 20 | check_to_check_results: Dict[Check, CheckResult] 21 | -------------------------------------------------------------------------------- /mlinspect/_pipeline_inspector.py: -------------------------------------------------------------------------------- 1 | """ 2 | User-facing API for inspecting the pipeline 3 | """ 4 | from typing import Iterable, Dict, List 5 | 6 | from pandas import DataFrame 7 | 8 | from mlinspect.inspections._inspection import Inspection 9 | from .checks._check import Check, CheckResult 10 | from ._inspector_result import InspectorResult 11 | from .instrumentation._pipeline_executor import singleton 12 | 13 | 14 | class PipelineInspectorBuilder: 15 | """ 16 | The fluent API builder to build an inspection run 17 | """ 18 | 19 | def __init__(self, notebook_path: str or None = None, 20 | python_path: str or None = None, 21 | python_code: str or None = None 22 | ) -> None: 23 | self.track_code_references = True 24 | self.monkey_patching_modules = [] 25 | self.notebook_path = notebook_path 26 | self.python_path = python_path 27 | self.python_code = python_code 28 | self.inspections = [] 29 | self.checks = [] 30 | 31 | def add_required_inspection(self, inspection: Inspection): 32 | """ 33 | Add an analyzer 34 | """ 35 | self.inspections.append(inspection) 36 | return self 37 | 38 | def add_required_inspections(self, inspections: Iterable[Inspection]): 39 | """ 40 | Add a list of inspections 41 | """ 42 | self.inspections.extend(inspections) 43 | return self 44 | 45 | def add_check(self, check: Check): 46 | """ 47 | Add an analyzer 48 | """ 49 | self.checks.append(check) 50 | return self 51 | 52 | def add_checks(self, checks: Iterable[Check]): 53 | """ 54 | Add a list of inspections 55 | """ 56 | self.checks.extend(checks) 57 | return self 58 | 59 | def set_code_reference_tracking(self, track_code_references: bool): 60 | """ 61 | Set whether to track code references. The default is tracking them. 62 | """ 63 | self.track_code_references = track_code_references 64 | return self 65 | 66 | def add_custom_monkey_patching_modules(self, module_list: List): 67 | """ 68 | Add additional monkey patching modules. 69 | """ 70 | self.monkey_patching_modules.extend(module_list) 71 | return self 72 | 73 | def add_custom_monkey_patching_module(self, module: any): 74 | """ 75 | Add an additional monkey patching module. 76 | """ 77 | self.monkey_patching_modules.append(module) 78 | return self 79 | 80 | def execute(self) -> InspectorResult: 81 | """ 82 | Instrument and execute the pipeline 83 | """ 84 | return singleton.run(notebook_path=self.notebook_path, 85 | python_path=self.python_path, 86 | python_code=self.python_code, 87 | inspections=self.inspections, 88 | checks=self.checks, 89 | custom_monkey_patching=self.monkey_patching_modules) 90 | 91 | 92 | class PipelineInspector: 93 | """ 94 | The entry point to the fluent API to build an inspection run 95 | """ 96 | @staticmethod 97 | def on_pipeline_from_py_file(path: str) -> PipelineInspectorBuilder: 98 | """Inspect a pipeline from a .py file.""" 99 | return PipelineInspectorBuilder(python_path=path) 100 | 101 | @staticmethod 102 | def on_pipeline_from_ipynb_file(path: str) -> PipelineInspectorBuilder: 103 | """Inspect a pipeline from a .ipynb file.""" 104 | return PipelineInspectorBuilder(notebook_path=path) 105 | 106 | @staticmethod 107 | def on_pipeline_from_string(code: str) -> PipelineInspectorBuilder: 108 | """Inspect a pipeline from a string.""" 109 | return PipelineInspectorBuilder(python_code=code) 110 | 111 | @staticmethod 112 | def check_results_as_data_frame(check_to_check_results: Dict[Check, CheckResult]) -> DataFrame: 113 | """ 114 | Get a pandas DataFrame with an overview of the CheckResults 115 | """ 116 | check_names = [] 117 | status = [] 118 | descriptions = [] 119 | for check_result in check_to_check_results.values(): 120 | check_names.append(check_result.check) 121 | status.append(check_result.status) 122 | descriptions.append(check_result.description) 123 | return DataFrame(zip(check_names, status, descriptions), columns=["check_name", "status", "description"]) 124 | -------------------------------------------------------------------------------- /mlinspect/backends/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/mlinspect/backends/__init__.py -------------------------------------------------------------------------------- /mlinspect/backends/_all_backends.py: -------------------------------------------------------------------------------- 1 | """ 2 | Get all available backends 3 | """ 4 | from typing import List 5 | 6 | from ._backend import Backend 7 | from ._pandas_backend import PandasBackend 8 | from ._sklearn_backend import SklearnBackend 9 | 10 | 11 | def get_all_backends() -> List[Backend]: 12 | """Get the list of all currently available backends""" 13 | return [PandasBackend(), SklearnBackend()] 14 | -------------------------------------------------------------------------------- /mlinspect/backends/_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | The Interface for the different instrumentation backends 3 | """ 4 | import abc 5 | import dataclasses 6 | from types import MappingProxyType 7 | from typing import List, Dict 8 | 9 | from mlinspect.inspections import Inspection 10 | 11 | 12 | @dataclasses.dataclass(frozen=True) 13 | class AnnotatedDfObject: 14 | """ A dataframe-like object and its annotations """ 15 | result_data: any 16 | result_annotation: any 17 | 18 | 19 | @dataclasses.dataclass(frozen=True) 20 | class BackendResult: 21 | """ The annotated dataframe and the annotations for the current DAG node """ 22 | annotated_dfobject: AnnotatedDfObject 23 | dag_node_annotation: Dict[Inspection, any] 24 | optional_second_annotated_dfobject: AnnotatedDfObject = None 25 | optional_second_dag_node_annotation: Dict[Inspection, any] = None 26 | 27 | 28 | class Backend(metaclass=abc.ABCMeta): 29 | """ 30 | The Interface for the different instrumentation backends 31 | """ 32 | 33 | @staticmethod 34 | @abc.abstractmethod 35 | def before_call(operator_context, input_infos: List[AnnotatedDfObject]) \ 36 | -> List[AnnotatedDfObject]: 37 | """The value or module a function may be called on""" 38 | # pylint: disable=too-many-arguments, unused-argument 39 | raise NotImplementedError 40 | 41 | 42 | @staticmethod 43 | @abc.abstractmethod 44 | def after_call(operator_context, input_infos: List[AnnotatedDfObject], return_value, 45 | non_data_function_args: Dict[str, any] = MappingProxyType({})) -> BackendResult: 46 | """The return value of some function""" 47 | # pylint: disable=too-many-arguments, unused-argument 48 | raise NotImplementedError 49 | -------------------------------------------------------------------------------- /mlinspect/backends/_backend_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some utility functions the different instrumentation backends 3 | """ 4 | # pylint: disable=unnecessary-dunder-call 5 | import itertools 6 | 7 | import numpy 8 | from pandas import DataFrame, Series 9 | from scipy.sparse import csr_matrix 10 | 11 | from ._backend import AnnotatedDfObject 12 | from ..inspections._inspection_input import ColumnInfo 13 | from ..monkeypatching._mlinspect_ndarray import MlinspectNdarray 14 | 15 | 16 | def get_annotation_rows(input_annotations, inspection_index): 17 | """ 18 | In the pandas backend, we store annotations in a data frame, for the sklearn transformers lists are enough 19 | """ 20 | if isinstance(input_annotations, DataFrame): 21 | annotations_for_inspection = input_annotations.iloc[:, inspection_index] 22 | assert isinstance(annotations_for_inspection, Series) 23 | else: 24 | annotations_for_inspection = input_annotations[inspection_index] 25 | assert isinstance(annotations_for_inspection, list) 26 | annotation_rows = annotations_for_inspection.__iter__() 27 | return annotation_rows 28 | 29 | 30 | def build_annotation_df_from_iters(inspections, annotation_iterators): 31 | """ 32 | Build the annotations dataframe 33 | """ 34 | annotation_iterators = itertools.zip_longest(*annotation_iterators) 35 | inspection_names = [str(inspection) for inspection in inspections] 36 | annotations_df = DataFrame(annotation_iterators, columns=inspection_names) 37 | return annotations_df 38 | 39 | 40 | def build_annotation_list_from_iters(annotation_iterators): 41 | """ 42 | Build the annotations dataframe 43 | """ 44 | annotation_lists = [list(iterator) for iterator in annotation_iterators] 45 | return list(annotation_lists) 46 | 47 | 48 | def get_iterator_for_type(data, np_nditer_with_refs=False, columns=None): 49 | """ 50 | Create an efficient iterator for the data. 51 | Automatically detects the data type and fails if it cannot handle that data type. 52 | """ 53 | if isinstance(data, DataFrame): 54 | iterator = get_df_row_iterator(data) 55 | elif isinstance(data, numpy.ndarray): 56 | # TODO: Measure performance impact of np_nditer_with_refs. To support arbitrary pipelines, remove this 57 | # or check the type of the standard iterator. It seems the nditer variant is faster but does not always work 58 | iterator = get_numpy_array_row_iterator(data, np_nditer_with_refs, columns) 59 | elif isinstance(data, Series): 60 | iterator = get_series_row_iterator(data, columns) 61 | elif isinstance(data, csr_matrix): 62 | iterator = get_csr_row_iterator(data, columns) 63 | elif isinstance(data, list): 64 | iterator = get_list_row_iterator(data, columns) 65 | else: 66 | raise NotImplementedError(f"TODO: Support type {type(data)}!") 67 | return iterator 68 | 69 | 70 | def create_wrapper_with_annotations(annotations_df, return_value) -> AnnotatedDfObject: 71 | """ 72 | Create a wrapper based on the data type of the return value and store the annotations in it. 73 | """ 74 | if isinstance(return_value, numpy.ndarray): 75 | return_value = MlinspectNdarray(return_value) 76 | new_return_value = AnnotatedDfObject(return_value, annotations_df) 77 | elif isinstance(return_value, DataFrame): 78 | # Remove index columns that may have been created 79 | if "mlinspect_index" in return_value.columns: 80 | return_value = return_value.drop("mlinspect_index", axis=1) 81 | elif "mlinspect_index_x" in return_value.columns: 82 | return_value = return_value.drop(["mlinspect_index_x", "mlinspect_index_y"], axis=1) 83 | assert "mlinspect_index" not in return_value.columns 84 | assert "mlinspect_index_x" not in return_value.columns 85 | 86 | new_return_value = AnnotatedDfObject(return_value, annotations_df) 87 | elif isinstance(return_value, (Series, csr_matrix)): 88 | return_value.annotations = annotations_df 89 | new_return_value = AnnotatedDfObject(return_value, annotations_df) 90 | elif return_value is None: 91 | new_return_value = AnnotatedDfObject(None, annotations_df) 92 | else: 93 | raise NotImplementedError(f"A type that is still unsupported was found: {return_value}") 94 | return new_return_value 95 | 96 | 97 | def get_df_row_iterator(dataframe): 98 | """ 99 | Create an efficient iterator for the data frame rows. 100 | The implementation is inspired by the implementation of the pandas DataFrame.itertuple method 101 | """ 102 | # Performance tips: 103 | # https://stackoverflow.com/questions/16476924/how-to-iterate-over-rows-in-a-dataframe-in-pandas 104 | arrays = [] 105 | column_info = ColumnInfo(list(dataframe.columns.values)) 106 | arrays.extend(dataframe.iloc[:, k] for k in range(0, len(dataframe.columns))) 107 | 108 | return column_info, map(tuple, zip(*arrays)) 109 | 110 | 111 | def get_series_row_iterator(series, columns=None): 112 | """ 113 | Create an efficient iterator for the data frame rows. 114 | The implementation is inspired by the implementation of the pandas DataFrame.itertuple method 115 | """ 116 | if columns: 117 | column_info = ColumnInfo(columns) 118 | elif series.name: 119 | column_info = ColumnInfo([series.name]) 120 | else: 121 | column_info = ColumnInfo(["array"]) 122 | numpy_iterator = series.__iter__() 123 | 124 | return column_info, map(tuple, zip(numpy_iterator)) 125 | 126 | 127 | def get_numpy_array_row_iterator(nparray, nditer=False, columns=None): 128 | """ 129 | Create an efficient iterator for the data frame rows. 130 | The implementation is inspired by the implementation of the pandas DataFrame.itertuple method 131 | """ 132 | if columns: 133 | column_info = ColumnInfo(columns) 134 | else: 135 | column_info = ColumnInfo(["array"]) 136 | if nditer is True: 137 | numpy_iterator = numpy.nditer(nparray, ["refs_ok"]) 138 | else: 139 | numpy_iterator = nparray.__iter__() 140 | 141 | return column_info, map(tuple, zip(numpy_iterator)) 142 | 143 | 144 | def get_list_row_iterator(list_data, columns=None): 145 | """ 146 | Create an efficient iterator for the data frame rows. 147 | The implementation is inspired by the implementation of the pandas DataFrame.itertuple method 148 | """ 149 | if columns: 150 | column_info = ColumnInfo(columns) 151 | else: 152 | column_info = ColumnInfo(["array"]) 153 | numpy_iterator = list_data.__iter__() 154 | 155 | return column_info, map(tuple, zip(numpy_iterator)) 156 | 157 | 158 | def get_csr_row_iterator(csr, columns=None): 159 | """ 160 | Create an efficient iterator for csr rows. 161 | The implementation is inspired by the implementation of the pandas DataFrame.itertuple method 162 | """ 163 | # TODO: Maybe there is a way to use sparse rows that is faster 164 | # However, this is the fastest way I discovered so far 165 | np_array = csr.toarray() 166 | if columns: 167 | column_info = ColumnInfo(columns) 168 | else: 169 | column_info = ColumnInfo(["array"]) 170 | numpy_iterator = np_array.__iter__() 171 | 172 | return column_info, map(tuple, zip(numpy_iterator)) 173 | -------------------------------------------------------------------------------- /mlinspect/backends/_sklearn_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | The scikit-learn backend 3 | """ 4 | from types import MappingProxyType 5 | from typing import List, Dict 6 | 7 | import pandas 8 | 9 | from ._backend import Backend, AnnotatedDfObject, BackendResult 10 | from ._iter_creation import iter_input_annotation_output_sink_op, iter_input_annotation_output_nary_op 11 | from ._pandas_backend import execute_inspection_visits_unary_operator, store_inspection_outputs, \ 12 | execute_inspection_visits_data_source 13 | from .. import OperatorType 14 | from ..instrumentation._pipeline_executor import singleton 15 | 16 | 17 | class SklearnBackend(Backend): 18 | """ 19 | The scikit-learn backend 20 | """ 21 | 22 | @staticmethod 23 | def before_call(operator_context, input_infos: List[AnnotatedDfObject]): 24 | """The value or module a function may be called on""" 25 | # pylint: disable=too-many-arguments 26 | if operator_context.operator == OperatorType.TRAIN_TEST_SPLIT: 27 | pandas_df = input_infos[0].result_data 28 | assert isinstance(pandas_df, pandas.DataFrame) 29 | pandas_df['mlinspect_index'] = range(0, len(pandas_df)) 30 | return input_infos 31 | 32 | @staticmethod 33 | def after_call(operator_context, input_infos: List[AnnotatedDfObject], return_value, 34 | non_data_function_args: Dict[str, any] = MappingProxyType({})) \ 35 | -> BackendResult: 36 | """The return value of some function""" 37 | # pylint: disable=too-many-arguments 38 | if operator_context.operator == OperatorType.DATA_SOURCE: 39 | return_value = execute_inspection_visits_data_source(operator_context, return_value, non_data_function_args) 40 | elif operator_context.operator == OperatorType.TRAIN_TEST_SPLIT: 41 | train_data, test_data = return_value 42 | train_return_value = execute_inspection_visits_unary_operator(operator_context, 43 | input_infos[0].result_data, 44 | input_infos[0].result_annotation, 45 | train_data, 46 | True, non_data_function_args) 47 | test_return_value = execute_inspection_visits_unary_operator(operator_context, 48 | input_infos[0].result_data, 49 | input_infos[0].result_annotation, 50 | test_data, 51 | True, non_data_function_args) 52 | input_infos[0].result_data.drop("mlinspect_index", axis=1, inplace=True) 53 | train_data.drop("mlinspect_index", axis=1, inplace=True) 54 | test_data.drop("mlinspect_index", axis=1, inplace=True) 55 | return_value = BackendResult(train_return_value.annotated_dfobject, 56 | train_return_value.dag_node_annotation, 57 | test_return_value.annotated_dfobject, 58 | test_return_value.dag_node_annotation) 59 | elif operator_context.operator in {OperatorType.PROJECTION, OperatorType.PROJECTION_MODIFY, 60 | OperatorType.TRANSFORMER, OperatorType.TRAIN_DATA, OperatorType.TRAIN_LABELS, 61 | OperatorType.TEST_DATA, OperatorType.TEST_LABELS}: 62 | return_value = execute_inspection_visits_unary_operator(operator_context, input_infos[0].result_data, 63 | input_infos[0].result_annotation, return_value, 64 | False, non_data_function_args) 65 | elif operator_context.operator == OperatorType.ESTIMATOR: 66 | return_value = execute_inspection_visits_sink_op(operator_context, 67 | input_infos[0].result_data, 68 | input_infos[0].result_annotation, 69 | input_infos[1].result_data, 70 | input_infos[1].result_annotation, 71 | non_data_function_args) 72 | elif operator_context.operator == OperatorType.SCORE: 73 | return_value = execute_inspection_visits_nary_op(operator_context, 74 | input_infos, 75 | return_value, 76 | non_data_function_args) 77 | elif operator_context.operator == OperatorType.CONCATENATION: 78 | return_value = execute_inspection_visits_nary_op(operator_context, input_infos, return_value, 79 | non_data_function_args) 80 | else: 81 | raise NotImplementedError(f"SklearnBackend doesn't know any operations of type " 82 | f"'{operator_context.operator}' yet!") 83 | return return_value 84 | 85 | 86 | # ------------------------------------------------------- 87 | # Execute inspections functions 88 | # ------------------------------------------------------- 89 | 90 | def execute_inspection_visits_sink_op(operator_context, data, data_annotation, target, 91 | target_annotation, non_data_function_args) -> BackendResult: 92 | """ Execute inspections """ 93 | # pylint: disable=too-many-arguments 94 | inspection_count = len(singleton.inspections) 95 | iterators_for_inspections = iter_input_annotation_output_sink_op(inspection_count, 96 | data, 97 | data_annotation, 98 | target, 99 | target_annotation, 100 | operator_context, 101 | non_data_function_args) 102 | annotation_iterators = execute_visits(iterators_for_inspections) 103 | return_value = store_inspection_outputs(annotation_iterators, None) 104 | return return_value 105 | 106 | 107 | def execute_inspection_visits_nary_op(operator_context, annotated_dfs: List[AnnotatedDfObject], 108 | return_value_df, non_data_function_args) -> BackendResult: 109 | """Execute inspections""" 110 | # pylint: disable=too-many-arguments 111 | inspection_count = len(singleton.inspections) 112 | iterators_for_inspections = iter_input_annotation_output_nary_op(inspection_count, 113 | annotated_dfs, 114 | return_value_df, 115 | operator_context, 116 | non_data_function_args) 117 | annotation_iterators = execute_visits(iterators_for_inspections) 118 | return_value = store_inspection_outputs(annotation_iterators, return_value_df) 119 | return return_value 120 | 121 | 122 | def execute_visits(iterators_for_inspections): 123 | """ 124 | After creating the iterators we need depending on the operator type, we need to execute the 125 | generic inspection visits 126 | """ 127 | annotation_iterators = [] 128 | for inspection_index, inspection in enumerate(singleton.inspections): 129 | iterator_for_inspection = iterators_for_inspections[inspection_index] 130 | annotations_iterator = inspection.visit_operator(iterator_for_inspection) 131 | annotation_iterators.append(annotations_iterator) 132 | return annotation_iterators 133 | -------------------------------------------------------------------------------- /mlinspect/checks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Packages and classes we want to expose to users 3 | """ 4 | from ._no_bias_introduced_for import NoBiasIntroducedFor, NoBiasIntroducedForResult, BiasDistributionChange 5 | from ._no_illegal_features import NoIllegalFeatures, NoIllegalFeaturesResult 6 | from ._similar_removal_probabilities_for import SimilarRemovalProbabilitiesFor, SimilarRemovalProbabilitiesForResult, \ 7 | RemovalProbabilities 8 | from ._check import Check, CheckResult, CheckStatus 9 | 10 | __all__ = [ 11 | # General classes 12 | 'Check', 13 | 'CheckResult', 14 | 'CheckStatus', 15 | # Native checks 16 | 'NoBiasIntroducedFor', 'NoBiasIntroducedForResult', 'BiasDistributionChange', 17 | 'NoIllegalFeatures', 'NoIllegalFeaturesResult', 18 | 'SimilarRemovalProbabilitiesFor', 'SimilarRemovalProbabilitiesForResult', 'RemovalProbabilities' 19 | ] 20 | -------------------------------------------------------------------------------- /mlinspect/checks/_check.py: -------------------------------------------------------------------------------- 1 | """ 2 | The Interface for the Checks 3 | """ 4 | from __future__ import annotations 5 | 6 | import abc 7 | import dataclasses 8 | from enum import Enum 9 | from typing import Iterable 10 | 11 | from mlinspect.inspections._inspection import Inspection 12 | from mlinspect.inspections._inspection_result import InspectionResult 13 | 14 | 15 | class CheckStatus(Enum): 16 | """ 17 | The result of the check 18 | """ 19 | SUCCESS = "Success" 20 | FAILURE = "Failure" 21 | 22 | 23 | @dataclasses.dataclass 24 | class CheckResult: 25 | """ 26 | Does this check cause an error or a warning if it fails? 27 | """ 28 | check: Check 29 | status: CheckStatus 30 | description: str or None 31 | 32 | 33 | class Check(metaclass=abc.ABCMeta): 34 | """ 35 | Checks like no_bias_introduced 36 | """ 37 | # pylint: disable=unnecessary-pass, too-few-public-methods 38 | 39 | @property 40 | def check_id(self): 41 | """The id of the Check""" 42 | return None 43 | 44 | @property 45 | @abc.abstractmethod 46 | def required_inspections(self) -> Iterable[Inspection]: 47 | """Inspections required to evaluate this check""" 48 | raise NotImplementedError 49 | 50 | @abc.abstractmethod 51 | def evaluate(self, inspection_result: InspectionResult) -> CheckResult: 52 | """Evaluate the check""" 53 | raise NotImplementedError 54 | 55 | def __eq__(self, other): 56 | """Checks must implement equals""" 57 | return (isinstance(other, self.__class__) and 58 | self.check_id == other.check_id) 59 | 60 | def __hash__(self): 61 | """Checks must be hashable""" 62 | return hash((self.__class__.__name__, self.check_id)) 63 | 64 | def __repr__(self): 65 | """Checks must have a str representation""" 66 | return f"{self.__class__.__name__}({self.check_id or ''})" 67 | -------------------------------------------------------------------------------- /mlinspect/checks/_no_illegal_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | The Interface for the Constraints 3 | """ 4 | from __future__ import annotations 5 | 6 | import dataclasses 7 | from typing import List, Iterable 8 | 9 | from mlinspect.checks._check import Check, CheckStatus, CheckResult 10 | from mlinspect.inspections._inspection import Inspection 11 | from mlinspect.inspections._inspection_input import OperatorType 12 | from mlinspect.inspections._inspection_result import InspectionResult 13 | 14 | ILLEGAL_FEATURES = {"race", "gender", "age"} 15 | 16 | 17 | @dataclasses.dataclass 18 | class NoIllegalFeaturesResult(CheckResult): 19 | """ 20 | Does the pipeline use illegal features? 21 | """ 22 | illegal_features: List[str] 23 | 24 | 25 | class NoIllegalFeatures(Check): 26 | """ 27 | Does the model get sensitive attributes like race as feature? 28 | """ 29 | # pylint: disable=unnecessary-pass, too-few-public-methods 30 | 31 | def __init__(self, additional_illegal_feature_names=None): 32 | if additional_illegal_feature_names is None: 33 | additional_illegal_feature_names = [] 34 | self.additional_illegal_feature_names = additional_illegal_feature_names 35 | 36 | @property 37 | def required_inspections(self) -> Iterable[Inspection]: 38 | """The inspections required for the check""" 39 | return [] 40 | 41 | @property 42 | def check_id(self): 43 | """The id of the Constraints""" 44 | return tuple(self.additional_illegal_feature_names) 45 | 46 | def evaluate(self, inspection_result: InspectionResult) -> CheckResult: 47 | """Evaluate the check""" 48 | # TODO: Make this robust and add extensive testing 49 | dag = inspection_result.dag 50 | train_data_nodes = [node for node in dag.nodes if node.operator_info.operator == OperatorType.TRAIN_DATA] 51 | used_columns = [] 52 | for train_data_node in train_data_nodes: 53 | used_columns.extend(self.get_used_columns(dag, train_data_node)) 54 | forbidden_columns = {*ILLEGAL_FEATURES, *self.additional_illegal_feature_names} 55 | used_illegal_columns = list(set(used_columns).intersection(forbidden_columns)) 56 | if used_illegal_columns: 57 | description = f"Used illegal columns: {used_illegal_columns}" 58 | result = NoIllegalFeaturesResult(self, CheckStatus.FAILURE, description, used_illegal_columns) 59 | else: 60 | result = NoIllegalFeaturesResult(self, CheckStatus.SUCCESS, None, []) 61 | return result 62 | 63 | def get_used_columns(self, dag, current_node): 64 | """ 65 | Get the output column of the current dag node. If the current dag node is, e.g., a concatenation, 66 | check the parents of the current dag node. 67 | """ 68 | columns = current_node.details.columns 69 | if columns is not None and columns != ["array"]: 70 | result = columns 71 | else: 72 | parent_columns = [] 73 | for parent in dag.predecessors(current_node): 74 | parent_columns.extend(self.get_used_columns(dag, parent)) 75 | result = parent_columns 76 | return result 77 | -------------------------------------------------------------------------------- /mlinspect/inspections/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Packages and classes we want to expose to users 3 | """ 4 | from ._arg_capturing import ArgumentCapturing 5 | from ._completeness_of_columns import CompletenessOfColumns 6 | from ._count_distinct_of_columns import CountDistinctOfColumns 7 | from ._inspection import Inspection 8 | from ._inspection_result import InspectionResult 9 | from ._inspection_input import InspectionInputUnaryOperator, InspectionInputDataSource, InspectionInputSinkOperator, \ 10 | InspectionInputNAryOperator 11 | from ._histogram_for_columns import HistogramForColumns 12 | from ._intersectional_histogram_for_columns import IntersectionalHistogramForColumns 13 | from ._lineage import RowLineage 14 | from ._materialize_first_output_rows import MaterializeFirstOutputRows 15 | from ._column_propagation import ColumnPropagation 16 | 17 | __all__ = [ 18 | # For defining custom inspections 19 | 'Inspection', 'InspectionResult', 20 | 'InspectionInputUnaryOperator', 'InspectionInputDataSource', 'InspectionInputSinkOperator', 21 | 'InspectionInputNAryOperator', 22 | # Native inspections 23 | 'HistogramForColumns', 24 | 'ColumnPropagation', 25 | 'IntersectionalHistogramForColumns', 26 | 'RowLineage', 27 | 'MaterializeFirstOutputRows', 28 | 'CompletenessOfColumns', 29 | 'CountDistinctOfColumns', 30 | 'ArgumentCapturing' 31 | ] 32 | -------------------------------------------------------------------------------- /mlinspect/inspections/_arg_capturing.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple inspection to capture important function call arguments like estimator hyperparameters 3 | """ 4 | from typing import Iterable 5 | 6 | from ._inspection import Inspection 7 | 8 | 9 | class ArgumentCapturing(Inspection): 10 | """ 11 | A simple inspection to capture important function call arguments like estimator hyperparameters 12 | """ 13 | 14 | def __init__(self): 15 | self._captured_arguments = None 16 | 17 | @property 18 | def inspection_id(self): 19 | return None 20 | 21 | def visit_operator(self, inspection_input) -> Iterable[any]: 22 | """ 23 | Visit an operator 24 | """ 25 | self._captured_arguments = inspection_input.non_data_function_args 26 | 27 | for _ in inspection_input.row_iterator: 28 | yield None 29 | 30 | def get_operator_annotation_after_visit(self) -> any: 31 | captured_args = self._captured_arguments 32 | self._captured_arguments = None 33 | return captured_args 34 | -------------------------------------------------------------------------------- /mlinspect/inspections/_completeness_of_columns.py: -------------------------------------------------------------------------------- 1 | """ 2 | A inspection to compute the ratio of non-values in output columns 3 | """ 4 | from typing import Iterable 5 | 6 | import pandas 7 | 8 | from mlinspect.inspections._inspection import Inspection 9 | from mlinspect.inspections._inspection_input import OperatorType, InspectionInputSinkOperator 10 | 11 | 12 | class CompletenessOfColumns(Inspection): 13 | """ 14 | An inspection to compute the completeness of columns 15 | """ 16 | 17 | def __init__(self, columns): 18 | self._present_column_names = [] 19 | self._null_value_counts = [] 20 | self._total_counts = [] 21 | self._operator_type = None 22 | self.columns = columns 23 | 24 | @property 25 | def inspection_id(self): 26 | return tuple(self.columns) 27 | 28 | def visit_operator(self, inspection_input) -> Iterable[any]: 29 | """ 30 | Visit an operator 31 | """ 32 | # pylint: disable=too-many-branches, too-many-statements, too-many-locals 33 | self._present_column_names = [] 34 | self._null_value_counts = [] 35 | self._total_counts = [] 36 | self._operator_type = inspection_input.operator_context.operator 37 | 38 | if not isinstance(inspection_input, InspectionInputSinkOperator): 39 | present_columns_index = [] 40 | for column_name in self.columns: 41 | column_present = column_name in inspection_input.output_columns.fields 42 | if column_present: 43 | column_index = inspection_input.output_columns.get_index_of_column(column_name) 44 | present_columns_index.append(column_index) 45 | self._present_column_names.append(column_name) 46 | self._null_value_counts.append(0) 47 | self._total_counts.append(0) 48 | for row in inspection_input.row_iterator: 49 | for present_column_index, column_index in enumerate(present_columns_index): 50 | column_value = row.output[column_index] 51 | is_null = pandas.isna(column_value) 52 | self._null_value_counts[present_column_index] += int(is_null) 53 | self._total_counts[present_column_index] += 1 54 | yield None 55 | else: 56 | for _ in inspection_input.row_iterator: 57 | yield None 58 | 59 | def get_operator_annotation_after_visit(self) -> any: 60 | assert self._operator_type 61 | if self._operator_type is not OperatorType.ESTIMATOR: 62 | completeness_results = {} 63 | for column_index, column_name in enumerate(self._present_column_names): 64 | null_value_count = self._null_value_counts[column_index] 65 | total_count = self._total_counts[column_index] 66 | completeness = (total_count - null_value_count) / total_count 67 | completeness_results[column_name] = completeness 68 | return completeness_results 69 | self._operator_type = None 70 | return None 71 | -------------------------------------------------------------------------------- /mlinspect/inspections/_count_distinct_of_columns.py: -------------------------------------------------------------------------------- 1 | """ 2 | An inspection to compute the number of distinct values in output columns 3 | """ 4 | from typing import Iterable 5 | 6 | from mlinspect.inspections._inspection import Inspection 7 | from mlinspect.inspections._inspection_input import OperatorType, InspectionInputSinkOperator 8 | 9 | 10 | class CountDistinctOfColumns(Inspection): 11 | """ 12 | An inspection to compute the number of distinct values of columns 13 | """ 14 | 15 | def __init__(self, columns): 16 | self._present_column_names = [] 17 | self._distinct_value_sets = [] 18 | self._operator_type = None 19 | self.columns = columns 20 | 21 | @property 22 | def inspection_id(self): 23 | return tuple(self.columns) 24 | 25 | def visit_operator(self, inspection_input) -> Iterable[any]: 26 | """ 27 | Visit an operator 28 | """ 29 | # pylint: disable=too-many-branches, too-many-statements, too-many-locals 30 | self._present_column_names = [] 31 | self._distinct_value_sets = [] 32 | self._operator_type = inspection_input.operator_context.operator 33 | 34 | if not isinstance(inspection_input, InspectionInputSinkOperator): 35 | present_columns_index = [] 36 | for column_name in self.columns: 37 | column_present = column_name in inspection_input.output_columns.fields 38 | if column_present: 39 | column_index = inspection_input.output_columns.get_index_of_column(column_name) 40 | present_columns_index.append(column_index) 41 | self._present_column_names.append(column_name) 42 | self._distinct_value_sets.append(set()) 43 | for row in inspection_input.row_iterator: 44 | for present_column_index, column_index in enumerate(present_columns_index): 45 | column_value = row.output[column_index] 46 | self._distinct_value_sets[present_column_index].add(column_value) 47 | yield None 48 | else: 49 | for _ in inspection_input.row_iterator: 50 | yield None 51 | 52 | def get_operator_annotation_after_visit(self) -> any: 53 | assert self._operator_type 54 | if self._operator_type is not OperatorType.ESTIMATOR: 55 | completeness_results = {} 56 | for column_index, column_name in enumerate(self._present_column_names): 57 | distinct_value_count = len(self._distinct_value_sets[column_index]) 58 | completeness_results[column_name] = distinct_value_count 59 | del self._distinct_value_sets 60 | return completeness_results 61 | self._operator_type = None 62 | return None 63 | -------------------------------------------------------------------------------- /mlinspect/inspections/_histogram_for_columns.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple inspection to compute histograms of sensitive groups in the data 3 | """ 4 | from typing import Iterable 5 | 6 | from mlinspect.inspections._inspection import Inspection 7 | from mlinspect.inspections._inspection_input import InspectionInputDataSource, \ 8 | InspectionInputUnaryOperator, InspectionInputNAryOperator, OperatorType, FunctionInfo 9 | 10 | 11 | class HistogramForColumns(Inspection): 12 | """ 13 | An inspection to compute group membership histograms for multiple columns 14 | """ 15 | 16 | def __init__(self, sensitive_columns): 17 | self._histogram_op_output = None 18 | self._operator_type = None 19 | self.sensitive_columns = sensitive_columns 20 | 21 | @property 22 | def inspection_id(self): 23 | return tuple(self.sensitive_columns) 24 | 25 | def visit_operator(self, inspection_input) -> Iterable[any]: 26 | """ 27 | Visit an operator 28 | """ 29 | # pylint: disable=too-many-branches, too-many-statements, too-many-locals, too-many-nested-blocks 30 | current_count = - 1 31 | 32 | histogram_maps = [] 33 | for _ in self.sensitive_columns: 34 | histogram_maps.append({}) 35 | 36 | self._operator_type = inspection_input.operator_context.operator 37 | 38 | if isinstance(inspection_input, InspectionInputUnaryOperator): 39 | sensitive_columns_present = [] 40 | sensitive_columns_index = [] 41 | for column in self.sensitive_columns: 42 | column_present = column in inspection_input.input_columns.fields 43 | sensitive_columns_present.append(column_present) 44 | column_index = inspection_input.input_columns.get_index_of_column(column) 45 | sensitive_columns_index.append(column_index) 46 | if inspection_input.operator_context.function_info == FunctionInfo('sklearn.impute._base', 'SimpleImputer'): 47 | for row in inspection_input.row_iterator: 48 | current_count += 1 49 | column_values = [] 50 | for check_index, _ in enumerate(self.sensitive_columns): 51 | if sensitive_columns_present[check_index]: 52 | column_value = row.output[0][sensitive_columns_index[check_index]] 53 | else: 54 | column_value = row.annotation[check_index] 55 | column_values.append(column_value) 56 | group_count = histogram_maps[check_index].get(column_value, 0) 57 | group_count += 1 58 | histogram_maps[check_index][column_value] = group_count 59 | yield column_values 60 | else: 61 | for row in inspection_input.row_iterator: 62 | current_count += 1 63 | column_values = [] 64 | for check_index, _ in enumerate(self.sensitive_columns): 65 | if sensitive_columns_present[check_index]: 66 | column_value = row.input[sensitive_columns_index[check_index]] 67 | else: 68 | column_value = row.annotation[check_index] 69 | column_values.append(column_value) 70 | group_count = histogram_maps[check_index].get(column_value, 0) 71 | group_count += 1 72 | histogram_maps[check_index][column_value] = group_count 73 | yield column_values 74 | elif isinstance(inspection_input, InspectionInputDataSource): 75 | sensitive_columns_present = [] 76 | sensitive_columns_index = [] 77 | for column in self.sensitive_columns: 78 | column_present = column in inspection_input.output_columns.fields 79 | sensitive_columns_present.append(column_present) 80 | column_index = inspection_input.output_columns.get_index_of_column(column) 81 | sensitive_columns_index.append(column_index) 82 | for row in inspection_input.row_iterator: 83 | current_count += 1 84 | column_values = [] 85 | for check_index, _ in enumerate(self.sensitive_columns): 86 | if sensitive_columns_present[check_index]: 87 | column_value = row.output[sensitive_columns_index[check_index]] 88 | column_values.append(column_value) 89 | group_count = histogram_maps[check_index].get(column_value, 0) 90 | group_count += 1 91 | histogram_maps[check_index][column_value] = group_count 92 | else: 93 | column_values.append(None) 94 | yield column_values 95 | elif isinstance(inspection_input, InspectionInputNAryOperator): 96 | sensitive_columns_present = [] 97 | sensitive_columns_index = [] 98 | for column in self.sensitive_columns: 99 | column_present = column in inspection_input.output_columns.fields 100 | sensitive_columns_present.append(column_present) 101 | column_index = inspection_input.output_columns.get_index_of_column(column) 102 | sensitive_columns_index.append(column_index) 103 | for row in inspection_input.row_iterator: 104 | current_count += 1 105 | column_values = [] 106 | for check_index, _ in enumerate(self.sensitive_columns): 107 | if sensitive_columns_present[check_index]: 108 | column_value = row.output[sensitive_columns_index[check_index]] 109 | column_values.append(column_value) 110 | group_count = histogram_maps[check_index].get(column_value, 0) 111 | group_count += 1 112 | histogram_maps[check_index][column_value] = group_count 113 | else: 114 | if sensitive_columns_present[check_index]: 115 | column_value = row.output[sensitive_columns_index[check_index]] 116 | else: 117 | column_value_candidates = [annotation[check_index] for annotation in row.annotation 118 | if annotation[check_index] is not None] 119 | if len(column_value_candidates) >= 1: 120 | column_value = column_value_candidates[0] 121 | else: 122 | column_value = None 123 | column_values.append(column_value) 124 | group_count = histogram_maps[check_index].get(column_value, 0) 125 | group_count += 1 126 | histogram_maps[check_index][column_value] = group_count 127 | yield column_values 128 | else: 129 | for _ in inspection_input.row_iterator: 130 | yield None 131 | 132 | self._histogram_op_output = {} 133 | for check_index, column in enumerate(self.sensitive_columns): 134 | self._histogram_op_output[column] = histogram_maps[check_index] 135 | 136 | def get_operator_annotation_after_visit(self) -> any: 137 | assert self._operator_type 138 | if self._operator_type is not OperatorType.ESTIMATOR: 139 | result = self._histogram_op_output 140 | self._histogram_op_output = None 141 | self._operator_type = None 142 | return result 143 | self._operator_type = None 144 | return None 145 | -------------------------------------------------------------------------------- /mlinspect/inspections/_inspection.py: -------------------------------------------------------------------------------- 1 | """ 2 | The Interface for the Inspection 3 | """ 4 | import abc 5 | from typing import Union, Iterable 6 | 7 | from mlinspect.inspections._inspection_input import InspectionInputDataSource, \ 8 | InspectionInputUnaryOperator, InspectionInputNAryOperator, InspectionInputSinkOperator 9 | 10 | 11 | class Inspection(metaclass=abc.ABCMeta): 12 | """ 13 | The Interface for the Inspections 14 | """ 15 | 16 | @property 17 | def inspection_id(self): 18 | """The id of the inspection""" 19 | return None 20 | 21 | @abc.abstractmethod 22 | def visit_operator(self, inspection_input: Union[InspectionInputDataSource, InspectionInputUnaryOperator, 23 | InspectionInputNAryOperator, InspectionInputSinkOperator])\ 24 | -> Iterable[any]: 25 | """Visit an operator in the DAG""" 26 | raise NotImplementedError 27 | 28 | @abc.abstractmethod 29 | def get_operator_annotation_after_visit(self) -> any: 30 | """Get the output to be included in the DAG""" 31 | raise NotImplementedError 32 | 33 | def __eq__(self, other): 34 | """Inspections must implement equals""" 35 | return (isinstance(other, self.__class__) and 36 | self.inspection_id == other.inspection_id) 37 | 38 | def __hash__(self): 39 | """Inspections must be hashable""" 40 | return hash((self.__class__.__name__, self.inspection_id)) 41 | 42 | def __repr__(self): 43 | """Inspections must have a str representation""" 44 | return f"{self.__class__.__name__}({self.inspection_id})" 45 | -------------------------------------------------------------------------------- /mlinspect/inspections/_inspection_input.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data classes used as input for the inspections 3 | """ 4 | import dataclasses 5 | from enum import Enum 6 | from typing import Tuple, List, Iterable, Dict 7 | 8 | 9 | @dataclasses.dataclass(frozen=True) 10 | class ColumnInfo: 11 | """ 12 | A class we use to efficiently pass pandas/sklearn rows 13 | """ 14 | fields: List[str] 15 | 16 | def get_index_of_column(self, column_name): 17 | """ 18 | Get the values index for some column 19 | """ 20 | if column_name in self.fields: 21 | return self.fields.index(column_name) 22 | return None 23 | 24 | def __eq__(self, other): 25 | return (isinstance(other, ColumnInfo) and 26 | self.fields == other.fields) 27 | 28 | 29 | @dataclasses.dataclass(frozen=True) 30 | class FunctionInfo: 31 | """ 32 | Contains the function name and its path 33 | """ 34 | module: str 35 | function_name: str 36 | 37 | 38 | class OperatorType(Enum): 39 | """ 40 | The different operator types in our DAG 41 | """ 42 | DATA_SOURCE = "Data Source" 43 | MISSING_OP = "Encountered unsupported operation! Fallback: Data Source" 44 | SELECTION = "Selection" 45 | PROJECTION = "Projection" 46 | PROJECTION_MODIFY = "Projection (Modify)" 47 | TRANSFORMER = "Transformer" 48 | CONCATENATION = "Concatenation" 49 | ESTIMATOR = "Estimator" 50 | SCORE = "Score" 51 | TRAIN_DATA = "Train Data" 52 | TRAIN_LABELS = "Train Labels" 53 | TEST_DATA = "Test Data" 54 | TEST_LABELS = "Test Labels" 55 | JOIN = "Join" 56 | GROUP_BY_AGG = "Groupby and Aggregate" 57 | TRAIN_TEST_SPLIT = "Train Test Split" 58 | 59 | 60 | @dataclasses.dataclass(frozen=True) 61 | class OperatorContext: 62 | """ 63 | Additional context for the inspection. Contains, most importantly, the operator type. 64 | """ 65 | operator: OperatorType 66 | function_info: FunctionInfo or None 67 | 68 | 69 | @dataclasses.dataclass(frozen=True) 70 | class InspectionRowDataSource: 71 | """ 72 | Wrapper class for the only operator without a parent: a Data Source 73 | """ 74 | output: Tuple 75 | 76 | 77 | @dataclasses.dataclass(frozen=True) 78 | class InspectionInputDataSource: 79 | """ 80 | Additional context for the inspection. Contains, most importantly, the operator type. 81 | """ 82 | operator_context: OperatorContext 83 | output_columns: ColumnInfo 84 | row_iterator: Iterable[InspectionRowDataSource] 85 | non_data_function_args: Dict[str, any] 86 | 87 | 88 | @dataclasses.dataclass(frozen=True) 89 | class InspectionRowUnaryOperator: 90 | """ 91 | Wrapper class for the operators with one parent like Selections and Projections 92 | """ 93 | input: Tuple 94 | annotation: any 95 | output: Tuple 96 | 97 | 98 | @dataclasses.dataclass(frozen=True) 99 | class InspectionInputUnaryOperator: 100 | """ 101 | Additional context for the inspection. Contains, most importantly, the operator type. 102 | """ 103 | operator_context: OperatorContext 104 | input_columns: ColumnInfo 105 | output_columns: ColumnInfo 106 | row_iterator: Iterable[InspectionRowUnaryOperator] 107 | non_data_function_args: Dict[str, any] 108 | 109 | 110 | @dataclasses.dataclass(frozen=True) 111 | class InspectionRowNAryOperator: 112 | """ 113 | Wrapper class for the operators with multiple parents like Concatenations 114 | """ 115 | inputs: Tuple[Tuple] 116 | annotation: Tuple[any] 117 | output: Tuple 118 | 119 | 120 | @dataclasses.dataclass(frozen=True) 121 | class InspectionInputNAryOperator: 122 | """ 123 | Additional context for the inspection. Contains, most importantly, the operator type. 124 | """ 125 | operator_context: OperatorContext 126 | inputs_columns: List[ColumnInfo] 127 | output_columns: ColumnInfo 128 | row_iterator: Iterable[InspectionRowNAryOperator] 129 | non_data_function_args: Dict[str, any] 130 | 131 | 132 | @dataclasses.dataclass(frozen=True) 133 | class InspectionRowSinkOperator: 134 | """ 135 | Wrapper class for operators like Estimators that only get fitted 136 | """ 137 | input: Tuple[Tuple] 138 | annotation: any 139 | 140 | 141 | @dataclasses.dataclass(frozen=True) 142 | class InspectionInputSinkOperator: 143 | """ 144 | Additional context for the inspection. Contains, most importantly, the operator type. 145 | """ 146 | operator_context: OperatorContext 147 | inputs_columns: List[ColumnInfo] 148 | row_iterator: Iterable[InspectionRowSinkOperator] 149 | non_data_function_args: Dict[str, any] 150 | -------------------------------------------------------------------------------- /mlinspect/inspections/_inspection_result.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data class used as result of the PipelineExecutor 3 | """ 4 | import dataclasses 5 | from typing import Dict 6 | 7 | import networkx 8 | 9 | from mlinspect.instrumentation._dag_node import DagNode 10 | from mlinspect.inspections._inspection import Inspection 11 | 12 | 13 | @dataclasses.dataclass 14 | class InspectionResult: 15 | """ 16 | The class the PipelineExecutor returns 17 | """ 18 | dag: networkx.DiGraph 19 | dag_node_to_inspection_results: Dict[DagNode, Dict[Inspection, any]] 20 | -------------------------------------------------------------------------------- /mlinspect/inspections/_intersectional_histogram_for_columns.py: -------------------------------------------------------------------------------- 1 | """ 2 | An inspection to compute histograms of intersectional group memberships 3 | """ 4 | from typing import Iterable, List 5 | 6 | from mlinspect.inspections._inspection import Inspection 7 | from mlinspect.inspections._inspection_input import InspectionInputDataSource, \ 8 | InspectionInputUnaryOperator, InspectionInputNAryOperator, OperatorType, FunctionInfo 9 | 10 | 11 | class IntersectionalHistogramForColumns(Inspection): 12 | """ 13 | An inspection to compute intersectional group memberships 14 | """ 15 | 16 | def __init__(self, sensitive_columns: List[str]): 17 | self._histogram_op_output = None 18 | self._operator_type = None 19 | self.sensitive_columns = sensitive_columns 20 | 21 | @property 22 | def inspection_id(self): 23 | return tuple(self.sensitive_columns) 24 | 25 | def visit_operator(self, inspection_input) -> Iterable[any]: 26 | """ 27 | Visit an operator 28 | """ 29 | # pylint: disable=too-many-branches, too-many-statements, too-many-locals 30 | current_count = - 1 31 | 32 | histogram_map = {} 33 | 34 | self._operator_type = inspection_input.operator_context.operator 35 | 36 | if isinstance(inspection_input, InspectionInputUnaryOperator): 37 | sensitive_columns_present = [] 38 | sensitive_columns_index = [] 39 | for column in self.sensitive_columns: 40 | column_present = column in inspection_input.input_columns.fields 41 | sensitive_columns_present.append(column_present) 42 | column_index = inspection_input.input_columns.get_index_of_column(column) 43 | sensitive_columns_index.append(column_index) 44 | if inspection_input.operator_context.function_info == FunctionInfo('sklearn.impute._base', 'SimpleImputer'): 45 | for row in inspection_input.row_iterator: 46 | current_count += 1 47 | column_values = [] 48 | for check_index, _ in enumerate(self.sensitive_columns): 49 | if sensitive_columns_present[check_index]: 50 | column_value = row.output[0][sensitive_columns_index[check_index]] 51 | else: 52 | column_value = row.annotation[check_index] 53 | column_values.append(column_value) 54 | update_histograms(column_values, histogram_map) 55 | yield column_values 56 | else: 57 | for row in inspection_input.row_iterator: 58 | current_count += 1 59 | column_values = [] 60 | for check_index, _ in enumerate(self.sensitive_columns): 61 | if sensitive_columns_present[check_index]: 62 | column_value = row.input[sensitive_columns_index[check_index]] 63 | else: 64 | column_value = row.annotation[check_index] 65 | column_values.append(column_value) 66 | update_histograms(column_values, histogram_map) 67 | yield column_values 68 | elif isinstance(inspection_input, InspectionInputDataSource): 69 | sensitive_columns_present = [] 70 | sensitive_columns_index = [] 71 | for column in self.sensitive_columns: 72 | column_present = column in inspection_input.output_columns.fields 73 | sensitive_columns_present.append(column_present) 74 | column_index = inspection_input.output_columns.get_index_of_column(column) 75 | sensitive_columns_index.append(column_index) 76 | for row in inspection_input.row_iterator: 77 | current_count += 1 78 | column_values = [] 79 | for check_index, _ in enumerate(self.sensitive_columns): 80 | if sensitive_columns_present[check_index]: 81 | column_value = row.output[sensitive_columns_index[check_index]] 82 | column_values.append(column_value) 83 | else: 84 | column_values.append(None) 85 | update_histograms(column_values, histogram_map) 86 | yield column_values 87 | elif isinstance(inspection_input, InspectionInputNAryOperator): 88 | sensitive_columns_present = [] 89 | sensitive_columns_index = [] 90 | for column in self.sensitive_columns: 91 | column_present = column in inspection_input.output_columns.fields 92 | sensitive_columns_present.append(column_present) 93 | column_index = inspection_input.output_columns.get_index_of_column(column) 94 | sensitive_columns_index.append(column_index) 95 | for row in inspection_input.row_iterator: 96 | current_count += 1 97 | column_values = [] 98 | for check_index, _ in enumerate(self.sensitive_columns): 99 | if sensitive_columns_present[check_index]: 100 | column_value = row.output[sensitive_columns_index[check_index]] 101 | column_values.append(column_value) 102 | else: 103 | column_values.append(None) 104 | update_histograms(column_values, histogram_map) 105 | yield column_values 106 | else: 107 | for _ in inspection_input.row_iterator: 108 | yield None 109 | 110 | self._histogram_op_output = histogram_map 111 | 112 | def get_operator_annotation_after_visit(self) -> any: 113 | assert self._operator_type 114 | if self._operator_type is not OperatorType.ESTIMATOR: 115 | result = self._histogram_op_output 116 | self._histogram_op_output = None 117 | self._operator_type = None 118 | return result 119 | self._operator_type = None 120 | return None 121 | 122 | 123 | def update_histograms(column_values, histogram_map): 124 | """Update the histograms with the intersectional information""" 125 | value_tuple = tuple(column_values) 126 | group_count = histogram_map.get(value_tuple, 0) 127 | group_count += 1 128 | histogram_map[value_tuple] = group_count 129 | -------------------------------------------------------------------------------- /mlinspect/inspections/_lineage.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple inspection for lineage tracking 3 | """ 4 | import dataclasses 5 | from typing import Iterable, List 6 | 7 | from pandas import DataFrame, Series 8 | 9 | from mlinspect.inspections._inspection import Inspection 10 | from mlinspect.inspections._inspection_input import InspectionInputUnaryOperator, \ 11 | InspectionInputSinkOperator, InspectionInputDataSource, InspectionInputNAryOperator, OperatorType 12 | 13 | 14 | @dataclasses.dataclass(frozen=True) 15 | class LineageId: 16 | """ 17 | A lineage id class 18 | """ 19 | operator_id: int 20 | row_id: int 21 | 22 | 23 | class RowLineage(Inspection): 24 | """ 25 | A simple inspection for row-level lineage tracking 26 | """ 27 | # TODO: Add an option to pass a list of lineage ids to this inspection. Then it materializes all related tuples. 28 | # To do this efficiently, we do not want to do expensive membership tests. We can collect all base LineageIds 29 | # in a set and then it is enough to check for set memberships in InspectionInputDataSource inspection inputs. 30 | # This set membership can be used as a 'materialize' flag we use as annotation. Then we simply need to check this 31 | # flag to check whether to materialize rows. 32 | # pylint: disable=too-many-instance-attributes 33 | 34 | ALL_ROWS = -1 35 | 36 | def __init__(self, row_count: int, operator_type_restriction: List[OperatorType] = None): 37 | self.row_count = row_count 38 | if operator_type_restriction is not None: 39 | self.operator_type_restriction = set(operator_type_restriction) 40 | self._inspection_id = (self.row_count, *self.operator_type_restriction) 41 | else: 42 | self.operator_type_restriction = None 43 | self._inspection_id = self.row_count 44 | self._operator_count = -1 45 | self._op_output = None 46 | self._op_lineage = None 47 | self._output_columns = None 48 | self._is_sink = False 49 | self._materialize_for_this_operator = None 50 | 51 | def visit_operator(self, inspection_input) -> Iterable[any]: 52 | """Visit an operator, generate row index number annotations and check whether they get propagated correctly""" 53 | # pylint: disable=too-many-branches, too-many-statements 54 | self._operator_count += 1 55 | self._op_output = [] 56 | self._op_lineage = [] 57 | current_count = -1 58 | self._materialize_for_this_operator = (self.operator_type_restriction is None) or \ 59 | (inspection_input.operator_context.operator 60 | in self.operator_type_restriction) 61 | 62 | if not isinstance(inspection_input, InspectionInputSinkOperator): 63 | self._output_columns = inspection_input.output_columns.fields 64 | else: 65 | self._is_sink = True 66 | 67 | if isinstance(inspection_input, InspectionInputDataSource): 68 | if self._materialize_for_this_operator and self.row_count == RowLineage.ALL_ROWS: 69 | for row in inspection_input.row_iterator: 70 | current_count += 1 71 | annotation = {LineageId(self._operator_count, current_count)} 72 | self._op_output.append(row.output) 73 | self._op_lineage.append(annotation) 74 | yield annotation 75 | elif self._materialize_for_this_operator: 76 | for row in inspection_input.row_iterator: 77 | current_count += 1 78 | annotation = {LineageId(self._operator_count, current_count)} 79 | if current_count < self.row_count: 80 | self._op_output.append(row.output) 81 | self._op_lineage.append(annotation) 82 | yield annotation 83 | else: 84 | for _ in inspection_input.row_iterator: 85 | current_count += 1 86 | annotation = {LineageId(self._operator_count, current_count)} 87 | yield annotation 88 | elif isinstance(inspection_input, InspectionInputNAryOperator): 89 | if self._materialize_for_this_operator and self.row_count == RowLineage.ALL_ROWS: 90 | for row in inspection_input.row_iterator: 91 | current_count += 1 92 | 93 | annotation = set.union(*row.annotation) 94 | self._op_output.append(row.output) 95 | self._op_lineage.append(annotation) 96 | yield annotation 97 | elif self._materialize_for_this_operator: 98 | for row in inspection_input.row_iterator: 99 | current_count += 1 100 | 101 | annotation = set.union(*row.annotation) 102 | if current_count < self.row_count: 103 | self._op_output.append(row.output) 104 | self._op_lineage.append(annotation) 105 | yield annotation 106 | else: 107 | for row in inspection_input.row_iterator: 108 | current_count += 1 109 | annotation = set.union(*row.annotation) 110 | yield annotation 111 | elif isinstance(inspection_input, InspectionInputSinkOperator): 112 | if self._materialize_for_this_operator and self.row_count == RowLineage.ALL_ROWS: 113 | for row in inspection_input.row_iterator: 114 | current_count += 1 115 | 116 | annotation = set.union(*row.annotation) 117 | self._op_lineage.append(annotation) 118 | yield annotation 119 | elif self._materialize_for_this_operator: 120 | for row in inspection_input.row_iterator: 121 | current_count += 1 122 | 123 | annotation = set.union(*row.annotation) 124 | if current_count < self.row_count: 125 | self._op_lineage.append(annotation) 126 | yield annotation 127 | else: 128 | for row in inspection_input.row_iterator: 129 | current_count += 1 130 | annotation = set.union(*row.annotation) 131 | yield annotation 132 | elif isinstance(inspection_input, InspectionInputUnaryOperator): 133 | if self._materialize_for_this_operator and self.row_count == RowLineage.ALL_ROWS: 134 | for row in inspection_input.row_iterator: 135 | current_count += 1 136 | annotation = row.annotation 137 | self._op_output.append(row.output) 138 | self._op_lineage.append(annotation) 139 | yield annotation 140 | elif self._materialize_for_this_operator: 141 | for row in inspection_input.row_iterator: 142 | current_count += 1 143 | annotation = row.annotation 144 | 145 | if current_count < self.row_count: 146 | self._op_output.append(row.output) 147 | self._op_lineage.append(annotation) 148 | yield annotation 149 | else: 150 | for row in inspection_input.row_iterator: 151 | current_count += 1 152 | annotation = row.annotation 153 | yield annotation 154 | else: 155 | assert False 156 | 157 | def get_operator_annotation_after_visit(self) -> any: 158 | if not self._materialize_for_this_operator: 159 | result = None 160 | elif not self._is_sink: 161 | assert self._op_lineage 162 | result = DataFrame(self._op_output, columns=self._output_columns) 163 | result["mlinspect_lineage"] = self._op_lineage 164 | else: 165 | assert self._op_lineage 166 | lineage_series = Series(self._op_lineage) 167 | result = DataFrame(lineage_series, columns=["mlinspect_lineage"]) 168 | self._op_output = None 169 | self._op_lineage = None 170 | self._output_columns = None 171 | self._is_sink = False 172 | return result 173 | 174 | @property 175 | def inspection_id(self): 176 | return self._inspection_id 177 | -------------------------------------------------------------------------------- /mlinspect/inspections/_materialize_first_output_rows.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple inspection to materialise operator outputs 3 | """ 4 | from typing import Iterable 5 | 6 | from pandas import DataFrame 7 | 8 | from ._inspection import Inspection 9 | from ._inspection_input import InspectionInputSinkOperator, OperatorType 10 | 11 | 12 | class MaterializeFirstOutputRows(Inspection): 13 | """ 14 | A simple example analyzer 15 | """ 16 | 17 | def __init__(self, row_count: int): 18 | self.row_count = row_count 19 | self._analyzer_id = self.row_count 20 | self._first_rows_op_output = None 21 | self._operator_type = None 22 | self._output_columns = None 23 | 24 | @property 25 | def inspection_id(self): 26 | return self._analyzer_id 27 | 28 | def visit_operator(self, inspection_input) -> Iterable[any]: 29 | """ 30 | Visit an operator 31 | """ 32 | current_count = - 1 33 | operator_output = [] 34 | self._operator_type = inspection_input.operator_context.operator 35 | 36 | if not isinstance(inspection_input, InspectionInputSinkOperator): 37 | self._output_columns = inspection_input.output_columns.fields 38 | for row in inspection_input.row_iterator: 39 | current_count += 1 40 | if current_count < self.row_count: 41 | operator_output.append(row.output) 42 | yield None 43 | else: 44 | for _ in inspection_input.row_iterator: 45 | yield None 46 | 47 | self._first_rows_op_output = operator_output 48 | 49 | def get_operator_annotation_after_visit(self) -> any: 50 | assert self._operator_type 51 | if self._operator_type is not OperatorType.ESTIMATOR: 52 | assert self._first_rows_op_output and self._output_columns is not None # Visit must be finished 53 | result = DataFrame(self._first_rows_op_output, columns=self._output_columns) 54 | self._first_rows_op_output = None 55 | self._operator_type = None 56 | self._output_columns = None 57 | return result 58 | self._operator_type = None 59 | self._output_columns = None 60 | return None 61 | -------------------------------------------------------------------------------- /mlinspect/instrumentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/mlinspect/instrumentation/__init__.py -------------------------------------------------------------------------------- /mlinspect/instrumentation/_call_capture_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inserts function call capturing into the DAG 3 | """ 4 | import ast 5 | 6 | 7 | class CallCaptureTransformer(ast.NodeTransformer): 8 | """ 9 | ast.NodeTransformer to replace calls with captured calls 10 | """ 11 | 12 | def visit_Call(self, node): 13 | """ 14 | Instrument all function calls 15 | """ 16 | # pylint: disable=invalid-name 17 | ast.NodeTransformer.generic_visit(self, node) 18 | self.call_add_set_code_reference(node) 19 | return node 20 | 21 | def visit_Subscript(self, node): 22 | """ 23 | Instrument all subscript calls 24 | """ 25 | # pylint: disable=invalid-name 26 | ast.NodeTransformer.generic_visit(self, node) 27 | if isinstance(node.ctx, ast.Store): 28 | # Needed to get the parent assign node for subscript assigns. 29 | # Without this, "pandas_df['baz'] = baz + 1" would only be "pandas_df['baz']" 30 | code_reference_from_node = node.parents[0] 31 | else: 32 | code_reference_from_node = node 33 | self.subscript_add_set_code_reference(node, code_reference_from_node) 34 | return node 35 | 36 | @staticmethod 37 | def call_add_set_code_reference(node): 38 | """ 39 | When a method is called, capture the arguments of the method before executing it 40 | """ 41 | # We need to use a keyword argument call to capture stuff because the eval order and because 42 | # keyword arguments may contain function calls. 43 | # https://stackoverflow.com/questions/17948369/is-it-safe-to-rely-on-python-function-arguments-evaluation-order 44 | # Here we can consider instrumenting only functions we patch based on the name 45 | # But the detection based on the static function name is unreliable, so we will skip this for now 46 | kwargs = node.keywords 47 | call_node = CallCaptureTransformer.create_set_code_reference_node_call(node, kwargs) 48 | new_kwargs_node = ast.keyword(value=call_node, arg=None) 49 | node.keywords = [new_kwargs_node] 50 | 51 | @staticmethod 52 | def create_set_code_reference_node_call(node, kwargs): 53 | """ 54 | Create the set_code_reference function call ast node that then gets inserted into the AST 55 | """ 56 | call_node = ast.Call(func=ast.Name(id='set_code_reference_call', ctx=ast.Load()), 57 | args=[ast.Constant(n=node.lineno, kind=None), 58 | ast.Constant(n=node.col_offset, kind=None), 59 | ast.Constant(n=node.end_lineno, kind=None), 60 | ast.Constant(n=node.end_col_offset, kind=None)], 61 | keywords=kwargs) 62 | return call_node 63 | 64 | @staticmethod 65 | def create_set_code_reference_node_subscript(node, kwargs): 66 | """ 67 | Create the set_code_reference function call ast node that then gets inserted into the AST 68 | """ 69 | call_node = ast.Call(func=ast.Name(id='set_code_reference_subscript', ctx=ast.Load()), 70 | args=[ast.Constant(n=node.lineno, kind=None), 71 | ast.Constant(n=node.col_offset, kind=None), 72 | ast.Constant(n=node.end_lineno, kind=None), 73 | ast.Constant(n=node.end_col_offset, kind=None), 74 | kwargs], 75 | keywords=[]) 76 | return call_node 77 | 78 | @staticmethod 79 | def subscript_add_set_code_reference(node, code_reference_from_node): 80 | """ 81 | When the __getitem__ method of some object is called, capture the arguments of the method before executing it 82 | """ 83 | subscript_arg = node.slice 84 | call_node = CallCaptureTransformer.create_set_code_reference_node_subscript(code_reference_from_node, 85 | subscript_arg) 86 | node.slice = call_node 87 | -------------------------------------------------------------------------------- /mlinspect/instrumentation/_dag_node.py: -------------------------------------------------------------------------------- 1 | """ 2 | The Nodes used in the DAG as nodes for the networkx.DiGraph 3 | """ 4 | import dataclasses 5 | from typing import List 6 | 7 | from mlinspect.inspections._inspection_input import OperatorContext 8 | 9 | 10 | @dataclasses.dataclass(frozen=True) 11 | class CodeReference: 12 | """ 13 | Identifies a function call in the user pipeline code 14 | """ 15 | lineno: int 16 | col_offset: int 17 | end_lineno: int 18 | end_col_offset: int 19 | 20 | 21 | @dataclasses.dataclass 22 | class BasicCodeLocation: 23 | """ 24 | Basic information that can be collected even if `set_code_reference_tracking` is disabled 25 | """ 26 | caller_filename: str 27 | lineno: int 28 | 29 | 30 | @dataclasses.dataclass 31 | class OptionalCodeInfo: 32 | """ 33 | The additional information collected by mlinspect if `set_code_reference_tracking` is enabled 34 | """ 35 | code_reference: CodeReference 36 | source_code: str 37 | 38 | 39 | @dataclasses.dataclass 40 | class DagNodeDetails: 41 | """ 42 | Additional info about the DAG node 43 | """ 44 | description: str or None = None 45 | columns: List[str] = None 46 | 47 | 48 | @dataclasses.dataclass 49 | class DagNode: 50 | """ 51 | A DAG Node 52 | """ 53 | 54 | node_id: int 55 | code_location: BasicCodeLocation 56 | operator_info: OperatorContext 57 | details: DagNodeDetails 58 | optional_code_info: OptionalCodeInfo or None = None 59 | 60 | def __hash__(self): 61 | return hash(self.node_id) 62 | -------------------------------------------------------------------------------- /mlinspect/monkeypatching/README.md: -------------------------------------------------------------------------------- 1 | # Support for different libraries and API functions 2 | 3 | ## Handling of unknown functions 4 | * Extending mlinspect to support more and more API functions and libraries will be an ongoing effort. External contributions are very welcome! 5 | * However, mlinspect doesn't just crash when it encounters unknown functions. 6 | * mlinspect just ignores functions it doesn't recognize. If a function it does recognize encounters the input from a relevant unknown function, it will create a `MISSING_OP` node for a single or multiple unknown function calls. The inspections also get to see this unknown input, from their perspective it's just a new data source. 7 | * Example: 8 | ```python 9 | import networkx 10 | from inspect import cleandoc 11 | from testfixtures import compare 12 | from mlinspect import OperatorType, OperatorContext, FunctionInfo, PipelineInspector, CodeReference, DagNode, BasicCodeLocation, DagNodeDetails, \ 13 | OptionalCodeInfo 14 | 15 | 16 | test_code = cleandoc(""" 17 | from inspect import cleandoc 18 | import pandas 19 | from mlinspect.testing._testing_helper_utils import black_box_df_op 20 | 21 | df = black_box_df_op() 22 | df = df.dropna() 23 | """) 24 | 25 | extracted_dag = PipelineInspector.on_pipeline_from_string(test_code).execute().dag 26 | 27 | expected_dag = networkx.DiGraph() 28 | expected_missing_op = DagNode(-1, 29 | BasicCodeLocation("", 5), 30 | OperatorContext(OperatorType.MISSING_OP, None), 31 | DagNodeDetails('Warning! Operator :5 (df.dropna()) encountered a ' 32 | 'DataFrame resulting from an operation without mlinspect support!', 33 | ['A']), 34 | OptionalCodeInfo(CodeReference(5, 5, 5, 16), 'df.dropna()')) 35 | expected_select = DagNode(0, 36 | BasicCodeLocation("", 5), 37 | OperatorContext(OperatorType.SELECTION, FunctionInfo('pandas.core.frame', 'dropna')), 38 | DagNodeDetails('dropna', ['A']), 39 | OptionalCodeInfo(CodeReference(5, 5, 5, 16), 'df.dropna()')) 40 | expected_dag.add_edge(expected_missing_op, expected_select) 41 | compare(networkx.to_dict_of_dicts(extracted_dag), networkx.to_dict_of_dicts(expected_dag)) 42 | ``` 43 | 44 | ## Pandas 45 | * The implementation can be found mainly [here](./_patch_pandas.py) 46 | * The [tests](../../test/monkeypatching/test_patch_pandas.py) are probably more useful to look at 47 | * Currently supported functions: 48 | 49 | | Function Call | Operator 50 | | ------------- |:-------------:| 51 | | `('pandas.io.parsers', 'read_csv')` | Data Source | 52 | | `('pandas.core.frame', 'DataFrame')` | Data Source | 53 | | `('pandas.core.series', 'Series')` | Data Source | 54 | | `('pandas.core.frame', '__getitem__')`, arg type: strings | Projection| 55 | | `('pandas.core.frame', '__getitem__')`, arg type: series | Selection | 56 | | `('pandas.core.frame', 'dropna')` | Selection | 57 | | `('pandas.core.frame', 'replace')` | Projection (Mod) | 58 | | `('pandas.core.frame', '__setitem__')` | Projection (Mod) | 59 | | `('pandas.core.frame', 'merge')` | Join | 60 | | `('pandas.core.frame', 'groupby')` | Nothing (until a following agg call) | 61 | | `('pandas.core.groupbygeneric', 'agg')` | Groupby/Agg | 62 | 63 | ## Sklearn 64 | * The implementation can be found mainly [here](./_patch_sklearn.py) 65 | * The [tests](../../test/monkeypatching/test_patch_sklearn.py) are probably more useful to look at 66 | * Currently supported functions: 67 | 68 | | Function Call | Operator 69 | | ------------- |:-------------:| 70 | | `('sklearn.compose._column_transformer', 'ColumnTransformer')`, column selection | Projection | 71 | | `('sklearn.preprocessing._label', 'label_binarize')` | Projection (Mod) | 72 | | `('sklearn.compose._column_transformer', 'ColumnTransformer')`, concatenation | Concatenation | 73 | | `('sklearn.model_selection._split', 'train_test_split')` | Split (Train/Test) 74 | | `('sklearn.preprocessing._encoders', 'OneHotEncoder')`, arg type: strings | Transformer | 75 | | `('sklearn.preprocessing._data', 'StandardScaler')` | Transformer | 76 | | `('sklearn.impute._base’, 'SimpleImputer')` | Transformer | 77 | | `('sklearn.feature_extraction.text’, 'HashingVectorizer')` | Transformer | 78 | | `('sklearn.preprocessing._discretization', 'KBinsDiscretizer')` | Transformer | 79 | | `('sklearn.preprocessing_function_transformer','FunctionTransformer')` | Transformer | 80 | | `('sklearn.tree._classes', 'DecisionTreeClassifier')` | Estimator | 81 | | `('sklearn.linear_model._stochastic_gradient', 'SGDClassifier')` | Estimator | 82 | | `('tensorflow.python.keras.wrappers.scikit_learn', 'KerasClassifier')` | Estimator | 83 | | `('sklearn.linear_model._logistic', 'LogisticRegression')` | Estimator | 84 | 85 | 86 | ## Numpy 87 | * The implementation can be found mainly [here](./_patch_numpy.py) 88 | * The [tests](../../test/monkeypatching/test_patch_numpy.py) are probably more useful to look at 89 | * Currently supported functions: 90 | 91 | | Function Call | Operator 92 | | ------------- |:-------------:| 93 | | `('numpy.random', 'random')` | Data Source | 94 | 95 | ## Statsmodels 96 | * The implementation can be found mainly [here](./_patch_statsmodels.py) 97 | * The [tests](../../test/monkeypatching/test_patch_statsmodels.py) are probably more useful to look at 98 | * Currently supported functions: 99 | 100 | | Function Call | Operator 101 | | ------------- |:-------------:| 102 | | `('statsmodels.datasets', 'get_rdataset')` | Data Source | 103 | | `('statsmodels.api', 'add_constant')` | Projection (Mod) | 104 | | `('statsmodel.api', 'OLS')`, numpy syntax | Estimator | 105 | -------------------------------------------------------------------------------- /mlinspect/monkeypatching/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/mlinspect/monkeypatching/__init__.py -------------------------------------------------------------------------------- /mlinspect/monkeypatching/_mlinspect_ndarray.py: -------------------------------------------------------------------------------- 1 | """ 2 | Monkey patching for numpy 3 | """ 4 | import numpy 5 | 6 | 7 | class MlinspectNdarray(numpy.ndarray): 8 | """ 9 | A wrapper for numpy ndarrays to store our additional annotations. 10 | See https://docs.scipy.org/doc/numpy-1.13.0/user/basics.subclassing.html 11 | """ 12 | 13 | def __new__(cls, input_array, _mlinspect_dag_node=None, _mlinspect_annotation=None): 14 | # Input array is an already formed ndarray instance 15 | # We first cast to be our class type 16 | obj = numpy.asarray(input_array).view(cls) 17 | # add the new attribute to the created instance 18 | obj._mlinspect_dag_node = _mlinspect_dag_node 19 | obj._mlinspect_annotation = _mlinspect_annotation 20 | # Finally, we must return the newly created object: 21 | return obj 22 | 23 | def __array_finalize__(self, obj): 24 | # pylint: disable=attribute-defined-outside-init 25 | # see InfoArray.__array_finalize__ for comments 26 | if obj is None: 27 | return 28 | self._mlinspect_dag_node = getattr(obj, '_mlinspect_dag_node', None) 29 | self._mlinspect_annotation = getattr(obj, '_mlinspect_annotation', None) 30 | 31 | def ravel(self, order='C'): 32 | # pylint: disable=no-member 33 | result = super().ravel(order) 34 | assert isinstance(result, MlinspectNdarray) 35 | result._mlinspect_dag_node = self._mlinspect_dag_node # pylint: disable=protected-access 36 | result._mlinspect_annotation = self._mlinspect_annotation # pylint: disable=protected-access 37 | return result 38 | -------------------------------------------------------------------------------- /mlinspect/monkeypatching/_patch_numpy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Monkey patching for numpy 3 | """ 4 | import gorilla 5 | from numpy import random 6 | 7 | from mlinspect import DagNode, BasicCodeLocation, DagNodeDetails 8 | from mlinspect.backends._sklearn_backend import SklearnBackend 9 | from mlinspect.inspections._inspection_input import OperatorContext, FunctionInfo, OperatorType 10 | from mlinspect.monkeypatching._monkey_patching_utils import execute_patched_func, add_dag_node, \ 11 | get_optional_code_info_or_none 12 | 13 | 14 | @gorilla.patches(random) 15 | class NumpyRandomPatching: 16 | """ Patches for sklearn """ 17 | 18 | # pylint: disable=too-few-public-methods,no-self-argument 19 | 20 | @gorilla.name('random') 21 | @gorilla.settings(allow_hit=True) 22 | def patched_random(*args, **kwargs): 23 | """ Patch for ('numpy.random', 'random') """ 24 | # pylint: disable=no-method-argument 25 | original = gorilla.get_original_attribute(random, 'random') 26 | 27 | def execute_inspections(op_id, caller_filename, lineno, optional_code_reference, optional_source_code): 28 | """ Execute inspections, add DAG node """ 29 | function_info = FunctionInfo('numpy.random', 'random') 30 | operator_context = OperatorContext(OperatorType.DATA_SOURCE, function_info) 31 | input_infos = SklearnBackend.before_call(operator_context, []) 32 | result = original(*args, **kwargs) 33 | backend_result = SklearnBackend.after_call(operator_context, input_infos, result) 34 | 35 | dag_node = DagNode(op_id, 36 | BasicCodeLocation(caller_filename, lineno), 37 | operator_context, 38 | DagNodeDetails("random", ['array']), 39 | get_optional_code_info_or_none(optional_code_reference, optional_source_code)) 40 | add_dag_node(dag_node, [], backend_result) 41 | new_return_value = backend_result.annotated_dfobject.result_data 42 | return new_return_value 43 | 44 | return execute_patched_func(original, execute_inspections, *args, **kwargs) 45 | -------------------------------------------------------------------------------- /mlinspect/monkeypatching/_patch_statsmodels.py: -------------------------------------------------------------------------------- 1 | """ 2 | Monkey patching for numpy 3 | """ 4 | import gorilla 5 | from statsmodels import api 6 | from statsmodels.api import datasets 7 | 8 | from mlinspect import DagNode, BasicCodeLocation, DagNodeDetails 9 | from mlinspect.backends._pandas_backend import PandasBackend 10 | from mlinspect.backends._sklearn_backend import SklearnBackend 11 | from mlinspect.inspections._inspection_input import OperatorContext, FunctionInfo, OperatorType 12 | from mlinspect.instrumentation._pipeline_executor import singleton 13 | from mlinspect.monkeypatching._monkey_patching_utils import execute_patched_func, add_dag_node, \ 14 | get_optional_code_info_or_none, get_input_info, execute_patched_func_no_op_id, add_train_data_node, \ 15 | add_train_label_node 16 | 17 | 18 | @gorilla.patches(api) 19 | class StatsmodelApiPatching: 20 | """ Patches for statsmodel """ 21 | # pylint: disable=too-few-public-methods,no-self-argument 22 | 23 | @gorilla.name('add_constant') 24 | @gorilla.settings(allow_hit=True) 25 | def patched_random(*args, **kwargs): 26 | """ Patch for ('statsmodel.api', 'add_constant') """ 27 | # pylint: disable=no-method-argument 28 | original = gorilla.get_original_attribute(api, 'add_constant') 29 | 30 | def execute_inspections(op_id, caller_filename, lineno, optional_code_reference, optional_source_code): 31 | """ Execute inspections, add DAG node """ 32 | function_info = FunctionInfo('statsmodel.api', 'add_constant') 33 | input_info = get_input_info(args[0], caller_filename, lineno, function_info, optional_code_reference, 34 | optional_source_code) 35 | 36 | operator_context = OperatorContext(OperatorType.PROJECTION_MODIFY, function_info) 37 | input_infos = SklearnBackend.before_call(operator_context, [input_info.annotated_dfobject]) 38 | result = original(input_infos[0].result_data, *args[1:], **kwargs) 39 | backend_result = SklearnBackend.after_call(operator_context, 40 | input_infos, 41 | result) 42 | new_return_value = backend_result.annotated_dfobject.result_data 43 | 44 | dag_node = DagNode(op_id, 45 | BasicCodeLocation(caller_filename, lineno), 46 | operator_context, 47 | DagNodeDetails("Adds const column", ["array"]), 48 | get_optional_code_info_or_none(optional_code_reference, optional_source_code)) 49 | add_dag_node(dag_node, [input_info.dag_node], backend_result) 50 | 51 | return new_return_value 52 | return execute_patched_func(original, execute_inspections, *args, **kwargs) 53 | 54 | 55 | @gorilla.patches(datasets) 56 | class StatsmodelsDatasetPatching: 57 | """ Patches for pandas """ 58 | 59 | # pylint: disable=too-few-public-methods,no-self-argument 60 | 61 | @gorilla.name('get_rdataset') 62 | @gorilla.settings(allow_hit=True) 63 | def patched_read_csv(*args, **kwargs): 64 | """ Patch for ('statsmodels.datasets', 'get_rdataset') """ 65 | # pylint: disable=no-method-argument 66 | original = gorilla.get_original_attribute(datasets, 'get_rdataset') 67 | 68 | def execute_inspections(op_id, caller_filename, lineno, optional_code_reference, optional_source_code): 69 | """ Execute inspections, add DAG node """ 70 | function_info = FunctionInfo('statsmodels.datasets', 'get_rdataset') 71 | 72 | operator_context = OperatorContext(OperatorType.DATA_SOURCE, function_info) 73 | input_infos = PandasBackend.before_call(operator_context, []) 74 | result = original(*args, **kwargs) 75 | backend_result = PandasBackend.after_call(operator_context, 76 | input_infos, 77 | result.data) 78 | result.data = backend_result.annotated_dfobject.result_data 79 | dag_node = DagNode(op_id, 80 | BasicCodeLocation(caller_filename, lineno), 81 | operator_context, 82 | DagNodeDetails(result.title, list(result.data.columns)), 83 | get_optional_code_info_or_none(optional_code_reference, optional_source_code)) 84 | add_dag_node(dag_node, [], backend_result) 85 | return result 86 | 87 | return execute_patched_func(original, execute_inspections, *args, **kwargs) 88 | 89 | 90 | @gorilla.patches(api.OLS) 91 | class StatsmodelsOlsPatching: 92 | """ Patches for statsmodel OLS""" 93 | 94 | # pylint: disable=too-few-public-methods 95 | 96 | @gorilla.name('__init__') 97 | @gorilla.settings(allow_hit=True) 98 | def patched__init__(self, *args, **kwargs): 99 | """ Patch for ('statsmodel.api', 'OLS') """ 100 | # pylint: disable=no-method-argument, attribute-defined-outside-init, too-many-locals 101 | original = gorilla.get_original_attribute(api.OLS, '__init__') 102 | 103 | def execute_inspections(_, caller_filename, lineno, optional_code_reference, optional_source_code): 104 | """ Execute inspections, add DAG node """ 105 | original(self, *args, **kwargs) 106 | 107 | self.mlinspect_caller_filename = caller_filename 108 | self.mlinspect_lineno = lineno 109 | self.mlinspect_optional_code_reference = optional_code_reference 110 | self.mlinspect_optional_source_code = optional_source_code 111 | 112 | return execute_patched_func_no_op_id(original, execute_inspections, *args, **kwargs) 113 | 114 | @gorilla.name('fit') 115 | @gorilla.settings(allow_hit=True) 116 | def patched_fit(self, *args, **kwargs): 117 | """ Patch for ('statsmodel.api.OLS', 'fit') """ 118 | # pylint: disable=no-method-argument, too-many-locals 119 | original = gorilla.get_original_attribute(api.OLS, 'fit') 120 | function_info = FunctionInfo('statsmodel.api.OLS', 'fit') 121 | 122 | # Train data 123 | # pylint: disable=no-member 124 | data_backend_result, train_data_node, train_data_result = add_train_data_node(self, 125 | self.data.exog, 126 | function_info) 127 | self.data.exog = train_data_result 128 | # pylint: disable=no-member 129 | label_backend_result, train_labels_node, train_labels_result = add_train_label_node(self, self.data.endog, 130 | function_info) 131 | self.data.endog = train_labels_result 132 | 133 | # Estimator 134 | operator_context = OperatorContext(OperatorType.ESTIMATOR, function_info) 135 | input_dfs = [data_backend_result.annotated_dfobject, label_backend_result.annotated_dfobject] 136 | input_infos = SklearnBackend.before_call(operator_context, input_dfs) 137 | result = original(self, *args, **kwargs) 138 | estimator_backend_result = SklearnBackend.after_call(operator_context, 139 | input_infos, 140 | None) 141 | 142 | dag_node = DagNode(singleton.get_next_op_id(), 143 | BasicCodeLocation(self.mlinspect_caller_filename, self.mlinspect_lineno), 144 | operator_context, 145 | DagNodeDetails("Decision Tree", []), 146 | get_optional_code_info_or_none(self.mlinspect_optional_code_reference, 147 | self.mlinspect_optional_source_code)) 148 | add_dag_node(dag_node, [train_data_node, train_labels_node], estimator_backend_result) 149 | return result 150 | -------------------------------------------------------------------------------- /mlinspect/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/mlinspect/testing/__init__.py -------------------------------------------------------------------------------- /mlinspect/testing/_random_annotation_testing_inspection.py: -------------------------------------------------------------------------------- 1 | """ 2 | A simple analyzer for testing annotation propagation 3 | """ 4 | import random 5 | from typing import Iterable 6 | 7 | from mlinspect.inspections._inspection import Inspection 8 | from mlinspect.inspections._inspection_input import InspectionRowUnaryOperator, InspectionRowNAryOperator, \ 9 | InspectionRowSinkOperator 10 | 11 | 12 | class RandomAnnotationTestingInspection(Inspection): 13 | """ 14 | A simple analyzer for testing annotation propagation 15 | """ 16 | 17 | def __init__(self, row_count: int): 18 | self.operator_count = 0 19 | self.row_count = row_count 20 | self._analyzer_id = self.row_count 21 | self._operator_output = None 22 | self.rows_to_random_numbers_operator_0 = {} 23 | 24 | def visit_operator(self, inspection_input) -> Iterable[any]: 25 | """Visit an operator, generate random number annotations and check whether they get propagated correctly""" 26 | # pylint: disable=too-many-branches 27 | operator_output = [] 28 | current_count = - 1 29 | 30 | if self.operator_count == 0: 31 | for row in inspection_input.row_iterator: 32 | current_count += 1 33 | if current_count < self.row_count: 34 | random_number = random.randint(0, 10000) 35 | output_tuple = row.output 36 | self.rows_to_random_numbers_operator_0[output_tuple] = random_number 37 | operator_output.append(row.output) 38 | yield random_number 39 | else: 40 | yield None 41 | elif self.operator_count == 1: 42 | filtered_rows = 0 43 | for row in inspection_input.row_iterator: 44 | current_count += 1 45 | assert isinstance(row, InspectionRowUnaryOperator) # This analyzer is really only for testing 46 | annotation = row.annotation 47 | if current_count < self.row_count: 48 | output_tuple = row.output 49 | if output_tuple in self.rows_to_random_numbers_operator_0: 50 | random_number = self.rows_to_random_numbers_operator_0[output_tuple] 51 | assert annotation == random_number # Test whether the annotation got propagated correctly 52 | else: 53 | filtered_rows += 1 54 | assert filtered_rows != self.row_count # If all rows got filtered, this test is useless 55 | operator_output.append(row.output) 56 | yield annotation 57 | else: 58 | for row in inspection_input.row_iterator: 59 | assert isinstance(row, (InspectionRowUnaryOperator, InspectionRowNAryOperator, 60 | InspectionRowSinkOperator)) 61 | if isinstance(row, InspectionRowUnaryOperator): 62 | annotation = row.annotation 63 | else: 64 | annotation = row.annotation[0] 65 | yield annotation 66 | self.operator_count += 1 67 | self._operator_output = operator_output 68 | 69 | def get_operator_annotation_after_visit(self) -> any: 70 | assert self._operator_output or self.operator_count > 1 71 | result = self._operator_output 72 | self._operator_output = None 73 | return result 74 | 75 | @property 76 | def inspection_id(self): 77 | return self._analyzer_id 78 | -------------------------------------------------------------------------------- /mlinspect/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Packages and classes we want to expose to users 3 | """ 4 | from ._utils import get_project_root 5 | 6 | __all__ = [ 7 | 'get_project_root', 8 | ] 9 | -------------------------------------------------------------------------------- /mlinspect/utils/_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Some useful utils for the project 3 | """ 4 | from pathlib import Path 5 | 6 | 7 | def get_project_root() -> Path: 8 | """Returns the project root folder.""" 9 | return Path(__file__).parent.parent.parent 10 | -------------------------------------------------------------------------------- /mlinspect/visualisation/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Packages and classes we want to expose to users 3 | """ 4 | from ._visualisation import get_dag_as_pretty_string, save_fig_to_path 5 | 6 | __all__ = [ 7 | 'get_dag_as_pretty_string', 8 | 'save_fig_to_path', 9 | ] 10 | -------------------------------------------------------------------------------- /mlinspect/visualisation/_visualisation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions to visualise the extracted DAG 3 | """ 4 | from inspect import cleandoc 5 | 6 | import networkx 7 | from networkx.drawing.nx_agraph import to_agraph 8 | 9 | from mlinspect import DagNode 10 | 11 | 12 | def save_fig_to_path(extracted_dag, filename): 13 | """ 14 | Create a figure of the extracted DAG and save it with some filename 15 | """ 16 | 17 | def get_new_node_label(node: DagNode): 18 | label = cleandoc(f""" 19 | {node.node_id}: {node.operator_info.operator.value} (L{node.code_location.lineno}) 20 | {node.details.description or ""} 21 | """) 22 | return label 23 | 24 | # noinspection PyTypeChecker 25 | extracted_dag = networkx.relabel_nodes(extracted_dag, get_new_node_label) 26 | 27 | agraph = to_agraph(extracted_dag) 28 | agraph.layout('dot') 29 | agraph.draw(filename) 30 | 31 | 32 | def get_dag_as_pretty_string(extracted_dag): 33 | """ 34 | Create a figure of the extracted DAG and save it with some filename 35 | """ 36 | 37 | def get_new_node_label(node: DagNode): 38 | description = "" 39 | if node.details.description: 40 | description = f"({node.details.description})" 41 | 42 | label = f"{node.operator_info.operator.value}{description}" 43 | return label 44 | 45 | # noinspection PyTypeChecker 46 | extracted_dag = networkx.relabel_nodes(extracted_dag, get_new_node_label) 47 | 48 | agraph = to_agraph(extracted_dag) 49 | return agraph.to_string() 50 | -------------------------------------------------------------------------------- /requirements/requirements.dev.txt: -------------------------------------------------------------------------------- 1 | pylint==2.17.7 2 | pytest==8.0.0 3 | pytest-pylint==0.21.0 4 | pytest-runner==5.2 5 | pytest-cov==2.10.1 6 | pytest-pycharm==0.7.0 7 | pytest-mock==3.3.1 8 | gensim==4.3.0 9 | keras==2.15.0 10 | jupyter==1.0.0 11 | importnb==0.6.2 12 | seaborn==0.11.0 13 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn==1.3.2 2 | protobuf==3.20.3 3 | pandas==1.5.3 4 | numpy==1.23.5 5 | six==1.15.0 6 | nbformat==5.0.8 7 | nbconvert==6.4.5 8 | ipython==7.25.0 9 | astpretty==2.0.0 10 | astmonkey==0.3.6 11 | networkx==2.5 12 | more-itertools==8.6.0 13 | pygraphviz==1.7 14 | testfixtures==6.17.1 15 | matplotlib==3.4.2 16 | gorilla==0.4.0 17 | astunparse==1.6.3 18 | setuptools==57.0.0 19 | scipy==1.11.3 20 | statsmodels==0.14.1 21 | tensorflow==2.15.0 22 | scikeras==0.12 23 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [aliases] 2 | test=pytest 3 | 4 | [tool:pytest] 5 | addopts = --pylint --pylint-rcfile=./pylintrc --cov=mlinspect --cov-report=xml 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | For 'python setup.py develop' and 'python setup.py test' 3 | """ 4 | import os 5 | from setuptools import setup, find_packages 6 | 7 | ROOT = os.path.dirname(__file__) 8 | 9 | with open(os.path.join(ROOT, "requirements", "requirements.txt"), encoding="utf-8") as f: 10 | required = f.read().splitlines() 11 | 12 | with open(os.path.join(ROOT, "requirements", "requirements.dev.txt"), encoding="utf-8") as f: 13 | test_required = f.read().splitlines() 14 | 15 | with open("README.md", "r", encoding="utf-8") as fh: 16 | long_description = fh.read() 17 | 18 | setup( 19 | name="mlinspect", 20 | version="0.0.1.dev0", 21 | description="Inspect ML Pipelines in the form of a DAG", 22 | author='Stefan Grafberger', 23 | author_email='stefangrafberger@gmail.com', 24 | long_description=long_description, 25 | long_description_content_type="text/markdown", 26 | packages=find_packages(), 27 | include_package_data=True, 28 | install_requires=required, 29 | tests_require=test_required, 30 | extras_require={'dev': test_required}, 31 | license='Apache License 2.0', 32 | python_requires='==3.10.*', 33 | classifiers=[ 34 | 'License :: OSI Approved :: Apache Software License', 35 | 'Programming Language :: Python :: 3 :: Only', 36 | 'Programming Language :: Python :: 3.10' 37 | ] 38 | ) 39 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/__init__.py -------------------------------------------------------------------------------- /test/backends/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/backends/__init__.py -------------------------------------------------------------------------------- /test/backends/test_pandas_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the PipelineExecutor works 3 | """ 4 | from mlinspect.testing._testing_helper_utils import get_pandas_read_csv_and_dropna_code, \ 5 | run_random_annotation_testing_analyzer, run_row_index_annotation_testing_analyzer, run_multiple_test_analyzers 6 | 7 | 8 | def test_pandas_backend_random_annotation_propagation(): 9 | """ 10 | Tests whether the pandas backend works 11 | """ 12 | code = get_pandas_read_csv_and_dropna_code() 13 | random_annotation_analyzer_result = run_random_annotation_testing_analyzer(code) 14 | assert len(random_annotation_analyzer_result) == 3 15 | 16 | 17 | def test_pandas_backend_row_index_annotation_propagation(): 18 | """ 19 | Tests whether the pandas backend works 20 | """ 21 | code = get_pandas_read_csv_and_dropna_code() 22 | lineage_result = run_row_index_annotation_testing_analyzer(code) 23 | assert len(lineage_result) == 3 24 | 25 | 26 | def test_pandas_backend_annotation_propagation_multiple_analyzers(): 27 | """ 28 | Tests whether the pandas backend works 29 | """ 30 | code = get_pandas_read_csv_and_dropna_code() 31 | 32 | dag_node_to_inspection_results, analyzers = run_multiple_test_analyzers(code) 33 | 34 | for inspection_result in dag_node_to_inspection_results.values(): 35 | for analyzer in analyzers: 36 | assert analyzer in inspection_result 37 | -------------------------------------------------------------------------------- /test/backends/test_sklearn_backend.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the PipelineExecutor works 3 | """ 4 | 5 | from mlinspect.testing._testing_helper_utils import run_random_annotation_testing_analyzer, \ 6 | run_row_index_annotation_testing_analyzer, run_multiple_test_analyzers 7 | from example_pipelines import ADULT_SIMPLE_PY 8 | 9 | 10 | def test_sklearn_backend_random_annotation_propagation(): 11 | """ 12 | Tests whether the sklearn backend works 13 | """ 14 | with open(ADULT_SIMPLE_PY, encoding="utf-8") as file: 15 | code = file.read() 16 | 17 | random_annotation_analyzer_result = run_random_annotation_testing_analyzer(code) 18 | assert len(random_annotation_analyzer_result) == 12 19 | 20 | 21 | def test_sklearn_backend_row_index_annotation_propagation(): 22 | """ 23 | Tests whether the sklearn backend works 24 | """ 25 | with open(ADULT_SIMPLE_PY, encoding="utf-8") as file: 26 | code = file.read() 27 | lineage_result = run_row_index_annotation_testing_analyzer(code) 28 | assert len(lineage_result) == 12 29 | 30 | 31 | def test_sklearn_backend_annotation_propagation_multiple_analyzers(): 32 | """ 33 | Tests whether the sklearn backend works 34 | """ 35 | with open(ADULT_SIMPLE_PY, encoding="utf-8") as file: 36 | code = file.read() 37 | 38 | dag_node_to_inspection_results, analyzers = run_multiple_test_analyzers(code) 39 | 40 | for inspection_result in dag_node_to_inspection_results.values(): 41 | for analyzer in analyzers: 42 | assert analyzer in inspection_result 43 | -------------------------------------------------------------------------------- /test/checks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/checks/__init__.py -------------------------------------------------------------------------------- /test/checks/test_no_bias_introduced_for.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether NoMissingEmbeddings works 3 | """ 4 | import math 5 | from inspect import cleandoc 6 | 7 | from pandas import DataFrame 8 | from testfixtures import compare 9 | 10 | from mlinspect import DagNode, BasicCodeLocation, OperatorContext, OperatorType, FunctionInfo, DagNodeDetails, \ 11 | OptionalCodeInfo 12 | from mlinspect._pipeline_inspector import PipelineInspector 13 | from mlinspect.checks import CheckStatus, NoBiasIntroducedFor, \ 14 | NoBiasIntroducedForResult 15 | from mlinspect.checks._no_bias_introduced_for import BiasDistributionChange 16 | from mlinspect.instrumentation._dag_node import CodeReference 17 | 18 | 19 | def test_no_bias_introduced_for_merge(): 20 | """ 21 | Tests whether RowLineage works for joins 22 | """ 23 | test_code = cleandoc(""" 24 | import pandas as pd 25 | 26 | df_a = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'], 'B': [1, 2, 4, 5, 7]}) 27 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, 5], 'C': [1, 5, 4, 11, None]}) 28 | df_merged = df_a.merge(df_b, on='B') 29 | """) 30 | 31 | inspector_result = PipelineInspector \ 32 | .on_pipeline_from_string(test_code) \ 33 | .add_check(NoBiasIntroducedFor(['A'])) \ 34 | .execute() 35 | 36 | check_result = inspector_result.check_to_check_results[NoBiasIntroducedFor(['A'])] 37 | expected_result = get_expected_check_result_merge() 38 | compare(check_result, expected_result) 39 | 40 | 41 | def test_no_bias_introduced_simple_imputer(): 42 | """ 43 | Tests whether RowLineage works for joins 44 | """ 45 | test_code = cleandoc(""" 46 | import pandas as pd 47 | from sklearn.impute import SimpleImputer 48 | import numpy as np 49 | 50 | df = pd.DataFrame({'A': ['cat_a', np.nan, 'cat_a', 'cat_c']}) 51 | imputer = SimpleImputer(missing_values=np.nan, strategy='most_frequent') 52 | imputed_data = imputer.fit_transform(df) 53 | """) 54 | 55 | inspector_result = PipelineInspector \ 56 | .on_pipeline_from_string(test_code) \ 57 | .add_check(NoBiasIntroducedFor(['A'])) \ 58 | .execute() 59 | 60 | check_result = inspector_result.check_to_check_results[NoBiasIntroducedFor(['A'])] 61 | expected_result = get_expected_check_result_simple_imputer() 62 | compare(check_result, expected_result) 63 | 64 | 65 | def get_expected_check_result_merge(): 66 | """ Expected result for the code snippet in test_no_bias_introduced_for_merge""" 67 | failing_dag_node = DagNode(2, 68 | BasicCodeLocation('', 5), 69 | OperatorContext(OperatorType.JOIN, FunctionInfo('pandas.core.frame', 'merge')), 70 | DagNodeDetails("on 'B'", ['A', 'B', 'C']), 71 | OptionalCodeInfo(CodeReference(5, 12, 5, 36), "df_a.merge(df_b, on='B')")) 72 | 73 | change_df = DataFrame({'sensitive_column_value': ['cat_a', 'cat_b', 'cat_c'], 74 | 'count_before': [2, 2, 1], 75 | 'count_after': [2, 1, 1], 76 | 'ratio_before': [0.4, 0.4, 0.2], 77 | 'ratio_after': [0.5, 0.25, 0.25], 78 | 'relative_ratio_change': [(0.5 - 0.4) / 0.4, (.25 - 0.4) / 0.4, (0.25 - 0.2) / 0.2]}) 79 | expected_distribution_change = BiasDistributionChange(failing_dag_node, False, (.25 - 0.4) / 0.4, change_df) 80 | expected_dag_node_to_change = {failing_dag_node: {'A': expected_distribution_change}} 81 | failure_message = 'A Join causes a min_relative_ratio_change of \'A\' by -0.37500000000000006, a value below the ' \ 82 | 'configured minimum threshold -0.3!' 83 | expected_result = NoBiasIntroducedForResult(NoBiasIntroducedFor(['A']), CheckStatus.FAILURE, failure_message, 84 | expected_dag_node_to_change) 85 | return expected_result 86 | 87 | 88 | def get_expected_check_result_simple_imputer(): 89 | """ Expected result for the code snippet in test_no_bias_introduced_for_simple_imputer""" 90 | imputer_dag_node = DagNode(1, 91 | BasicCodeLocation('', 6), 92 | OperatorContext(OperatorType.TRANSFORMER, 93 | FunctionInfo('sklearn.impute._base', 'SimpleImputer')), 94 | DagNodeDetails('Simple Imputer: fit_transform', ['A']), 95 | OptionalCodeInfo(CodeReference(6, 10, 6, 72), 96 | "SimpleImputer(missing_values=np.nan, strategy='most_frequent')")) 97 | 98 | change_df = DataFrame({'sensitive_column_value': ['cat_a', 'cat_c', math.nan], 99 | 'count_before': [2, 1, 1], 100 | 'count_after': [3, 1, 0], 101 | 'ratio_before': [0.5, 0.25, 0.25], 102 | 'ratio_after': [0.75, 0.25, 0.], 103 | 'relative_ratio_change': [0.5, 0., -1.]}) 104 | expected_distribution_change = BiasDistributionChange(imputer_dag_node, True, 0., change_df) 105 | expected_dag_node_to_change = {imputer_dag_node: {'A': expected_distribution_change}} 106 | expected_result = NoBiasIntroducedForResult(NoBiasIntroducedFor(['A']), CheckStatus.SUCCESS, None, 107 | expected_dag_node_to_change) 108 | return expected_result 109 | -------------------------------------------------------------------------------- /test/checks/test_no_illegal_features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether NoMissingEmbeddings works 3 | """ 4 | from inspect import cleandoc 5 | 6 | from testfixtures import compare, SequenceComparison, StringComparison 7 | 8 | from mlinspect._pipeline_inspector import PipelineInspector 9 | from mlinspect.checks import NoIllegalFeatures, CheckStatus, NoIllegalFeaturesResult 10 | 11 | 12 | def test_no_illegal_features(): 13 | """ 14 | Tests whether NoIllegalFeatures works for joins 15 | """ 16 | test_code = cleandoc(""" 17 | import pandas as pd 18 | from sklearn.preprocessing import label_binarize, StandardScaler, OneHotEncoder 19 | from sklearn.compose import ColumnTransformer 20 | from sklearn.pipeline import Pipeline 21 | from sklearn.tree import DecisionTreeClassifier 22 | 23 | data = pd.DataFrame({'age': [1, 2, 10, 5], 'B': ['cat_a', 'cat_b', 'cat_a', 'cat_c'], 24 | 'C': ['cat_a', 'cat_b', 'cat_a', 'cat_c'], 'target': ['no', 'no', 'yes', 'yes']}) 25 | 26 | column_transformer = ColumnTransformer(transformers=[ 27 | ('numeric', StandardScaler(), ['age']), 28 | ('categorical', OneHotEncoder(sparse=False), ['B', 'C']) 29 | ]) 30 | 31 | income_pipeline = Pipeline([ 32 | ('features', column_transformer), 33 | ('classifier', DecisionTreeClassifier())]) 34 | 35 | labels = label_binarize(data['target'], classes=['no', 'yes']) 36 | income_pipeline.fit(data, labels) 37 | """) 38 | 39 | inspector_result = PipelineInspector \ 40 | .on_pipeline_from_string(test_code) \ 41 | .add_check(NoIllegalFeatures(['C'])) \ 42 | .execute() 43 | 44 | check_result = inspector_result.check_to_check_results[NoIllegalFeatures(['C'])] 45 | # pylint: disable=anomalous-backslash-in-string 46 | expected_result = NoIllegalFeaturesResult(NoIllegalFeatures(['C']), CheckStatus.FAILURE, 47 | StringComparison("Used illegal columns\: .*"), 48 | SequenceComparison('C', 'age', ordered=False)) 49 | compare(check_result, expected_result) 50 | -------------------------------------------------------------------------------- /test/demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/demo/__init__.py -------------------------------------------------------------------------------- /test/demo/feature_overview/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/demo/feature_overview/__init__.py -------------------------------------------------------------------------------- /test/demo/feature_overview/test_feature_overview.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the healthcare demo works 3 | """ 4 | import os 5 | 6 | from importnb import Notebook 7 | import matplotlib 8 | 9 | from mlinspect.utils import get_project_root 10 | 11 | 12 | DEMO_NB_FILE = os.path.join(str(get_project_root()), "demo", "feature_overview", "feature_overview.ipynb") 13 | 14 | 15 | def test_demo_nb(): 16 | """ 17 | Tests whether the demo notebook works 18 | """ 19 | matplotlib.use("template") # Disable plt.show when executing nb as part of this test 20 | Notebook.load(DEMO_NB_FILE) 21 | -------------------------------------------------------------------------------- /test/demo/feature_overview/test_missing_embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether MissingEmbeddings works 3 | """ 4 | from inspect import cleandoc 5 | 6 | from testfixtures import compare 7 | 8 | from demo.feature_overview.missing_embeddings import MissingEmbeddings, MissingEmbeddingsInfo 9 | from example_pipelines.healthcare import custom_monkeypatching 10 | from mlinspect._pipeline_inspector import PipelineInspector 11 | 12 | 13 | def test_missing_embeddings(): 14 | """ 15 | Tests whether MissingEmbeddings works for joins 16 | """ 17 | test_code = cleandoc(""" 18 | import pandas as pd 19 | from example_pipelines.healthcare.healthcare_utils import MyW2VTransformer 20 | 21 | df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c']}) 22 | word_to_vec = MyW2VTransformer(min_count=2, size=2, workers=1) 23 | encoded_data = word_to_vec.fit_transform(df) 24 | """) 25 | 26 | inspector_result = PipelineInspector \ 27 | .on_pipeline_from_string(test_code) \ 28 | .add_required_inspection(MissingEmbeddings(10)) \ 29 | .add_custom_monkey_patching_module(custom_monkeypatching) \ 30 | .execute() 31 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 32 | 33 | missing_embeddings_output = inspection_results[0][MissingEmbeddings(10)] 34 | expected_result = None 35 | compare(missing_embeddings_output, expected_result) 36 | 37 | missing_embeddings_output = inspection_results[1][MissingEmbeddings(10)] 38 | expected_result = MissingEmbeddingsInfo(2, ['cat_b', 'cat_c']) 39 | compare(missing_embeddings_output, expected_result) 40 | -------------------------------------------------------------------------------- /test/demo/feature_overview/test_no_missing_embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether NoMissingEmbeddings works 3 | """ 4 | from inspect import cleandoc 5 | 6 | from testfixtures import compare 7 | 8 | from demo.feature_overview.missing_embeddings import MissingEmbeddingsInfo 9 | from demo.feature_overview.no_missing_embeddings import NoMissingEmbeddings, NoMissingEmbeddingsResult 10 | from example_pipelines.healthcare import custom_monkeypatching 11 | from mlinspect import DagNode, BasicCodeLocation, OperatorContext, OperatorType, FunctionInfo, DagNodeDetails, \ 12 | OptionalCodeInfo 13 | from mlinspect._pipeline_inspector import PipelineInspector 14 | from mlinspect.checks import CheckStatus 15 | from mlinspect.instrumentation._dag_node import CodeReference 16 | 17 | 18 | def test_no_missing_embeddings(): 19 | """ 20 | Tests whether NoMissingEmbeddings works for joins 21 | """ 22 | test_code = cleandoc(""" 23 | import pandas as pd 24 | from example_pipelines.healthcare.healthcare_utils import MyW2VTransformer 25 | 26 | df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c']}) 27 | word_to_vec = MyW2VTransformer(min_count=2, size=2, workers=1) 28 | encoded_data = word_to_vec.fit_transform(df) 29 | """) 30 | 31 | inspector_result = PipelineInspector \ 32 | .on_pipeline_from_string(test_code) \ 33 | .add_check(NoMissingEmbeddings()) \ 34 | .add_custom_monkey_patching_module(custom_monkeypatching) \ 35 | .execute() 36 | 37 | check_result = inspector_result.check_to_check_results[NoMissingEmbeddings()] 38 | expected_failed_dag_node_with_result = { 39 | DagNode(1, 40 | BasicCodeLocation('', 5), 41 | OperatorContext(OperatorType.TRANSFORMER, 42 | FunctionInfo('example_pipelines.healthcare.healthcare_utils', 'MyW2VTransformer')), 43 | DagNodeDetails('Word2Vec: fit_transform', ['array']), 44 | OptionalCodeInfo(CodeReference(5, 14, 5, 62), 'MyW2VTransformer(min_count=2, size=2, workers=1)')) 45 | : MissingEmbeddingsInfo(2, ['cat_b', 'cat_c'])} 46 | expected_result = NoMissingEmbeddingsResult(NoMissingEmbeddings(10), CheckStatus.FAILURE, 47 | 'Missing embeddings were found!', expected_failed_dag_node_with_result) 48 | compare(check_result, expected_result) 49 | -------------------------------------------------------------------------------- /test/example_pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/example_pipelines/__init__.py -------------------------------------------------------------------------------- /test/example_pipelines/test_adult_complex.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the adult_easy test pipeline works 3 | """ 4 | import ast 5 | 6 | from mlinspect.testing._testing_helper_utils import run_and_assert_all_op_outputs_inspected 7 | from example_pipelines import ADULT_COMPLEX_PY, ADULT_COMPLEX_PNG 8 | 9 | 10 | def test_py_pipeline_runs(): 11 | """ 12 | Tests whether the .py version of the pipeline works 13 | """ 14 | with open(ADULT_COMPLEX_PY, encoding="utf-8") as file: 15 | text = file.read() 16 | parsed_ast = ast.parse(text) 17 | exec(compile(parsed_ast, filename="", mode="exec")) 18 | 19 | 20 | def test_instrumented_py_pipeline_runs(): 21 | """ 22 | Tests whether the pipeline works with instrumentation 23 | """ 24 | dag = run_and_assert_all_op_outputs_inspected(ADULT_COMPLEX_PY, ["race"], ADULT_COMPLEX_PNG) 25 | assert len(dag) == 24 26 | -------------------------------------------------------------------------------- /test/example_pipelines/test_adult_simple.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the adult_easy test pipeline works 3 | """ 4 | import ast 5 | 6 | import nbformat 7 | from nbconvert import PythonExporter 8 | 9 | from mlinspect.testing._testing_helper_utils import run_and_assert_all_op_outputs_inspected 10 | from example_pipelines import ADULT_SIMPLE_PY, ADULT_SIMPLE_IPYNB, ADULT_SIMPLE_PNG 11 | 12 | 13 | def test_py_pipeline_runs(): 14 | """ 15 | Tests whether the .py version of the pipeline works 16 | """ 17 | with open(ADULT_SIMPLE_PY, encoding="utf-8") as file: 18 | text = file.read() 19 | parsed_ast = ast.parse(text) 20 | exec(compile(parsed_ast, filename="", mode="exec")) 21 | 22 | 23 | def test_nb_pipeline_runs(): 24 | """ 25 | Tests whether the .ipynb version of the pipeline works 26 | """ 27 | with open(ADULT_SIMPLE_IPYNB, encoding="utf-8") as file: 28 | notebook = nbformat.reads(file.read(), nbformat.NO_CONVERT) 29 | exporter = PythonExporter() 30 | 31 | code, _ = exporter.from_notebook_node(notebook) 32 | parsed_ast = ast.parse(code) 33 | exec(compile(parsed_ast, filename="", mode="exec")) 34 | 35 | 36 | def test_instrumented_py_pipeline_runs(): 37 | """ 38 | Tests whether the pipeline works with instrumentation 39 | """ 40 | dag = run_and_assert_all_op_outputs_inspected(ADULT_SIMPLE_PY, ["race"], ADULT_SIMPLE_PNG) 41 | assert len(dag) == 12 42 | -------------------------------------------------------------------------------- /test/example_pipelines/test_compas.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the adult_easy test pipeline works 3 | """ 4 | import ast 5 | from mlinspect.testing._testing_helper_utils import run_and_assert_all_op_outputs_inspected 6 | from example_pipelines import COMPAS_PY, COMPAS_PNG 7 | 8 | 9 | def test_py_pipeline_runs(): 10 | """ 11 | Tests whether the .py version of the pipeline works 12 | """ 13 | with open(COMPAS_PY, encoding="utf-8") as file: 14 | text = file.read() 15 | parsed_ast = ast.parse(text) 16 | exec(compile(parsed_ast, filename="", mode="exec")) 17 | 18 | 19 | def test_instrumented_py_pipeline_runs(): 20 | """ 21 | Tests whether the pipeline works with instrumentation 22 | """ 23 | dag = run_and_assert_all_op_outputs_inspected(COMPAS_PY, ['sex', 'race'], COMPAS_PNG) 24 | assert len(dag) == 39 25 | -------------------------------------------------------------------------------- /test/example_pipelines/test_healthcare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the healthcare demo works 3 | """ 4 | import ast 5 | import pandas as pd 6 | from scikeras.wrappers import KerasClassifier 7 | from sklearn.preprocessing import StandardScaler, label_binarize 8 | 9 | from example_pipelines.healthcare import custom_monkeypatching 10 | from example_pipelines.healthcare.healthcare_utils import create_model, MyW2VTransformer 11 | from example_pipelines import HEALTHCARE_PY, HEALTHCARE_PNG 12 | from mlinspect.testing._testing_helper_utils import run_and_assert_all_op_outputs_inspected 13 | 14 | 15 | def test_my_word_to_vec_transformer(): 16 | """ 17 | Tests whether MyW2VTransformer works 18 | """ 19 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c']}) 20 | word_to_vec = MyW2VTransformer(min_count=2, size=2, workers=1) 21 | encoded_data = word_to_vec.fit_transform(pandas_df) 22 | assert encoded_data.shape == (4, 2) 23 | 24 | 25 | def test_my_keras_classifier(): 26 | """ 27 | Tests whether MyKerasClassifier works 28 | """ 29 | pandas_df = pd.DataFrame({'A': [0, 1, 2, 3], 'B': [0, 1, 2, 3], 'target': ['no', 'no', 'yes', 'yes']}) 30 | 31 | train = StandardScaler().fit_transform(pandas_df[['A', 'B']]) 32 | target = label_binarize(pandas_df[['target']], classes=['no', 'yes']) 33 | 34 | clf = KerasClassifier(build_fn=create_model, epochs=2, batch_size=1, verbose=0, 35 | hidden_layer_sizes=(9, 9,), loss="binary_crossentropy") 36 | clf.fit(train, target) 37 | 38 | test_predict = clf.predict([[0., 0.], [0.6, 0.6]]) 39 | assert test_predict.shape == (2, 1) 40 | 41 | 42 | def test_py_pipeline_runs(): 43 | """ 44 | Tests whether the pipeline works without instrumentation 45 | """ 46 | with open(HEALTHCARE_PY, encoding="utf-8") as file: 47 | healthcare_code = file.read() 48 | parsed_ast = ast.parse(healthcare_code) 49 | exec(compile(parsed_ast, filename="", mode="exec")) 50 | 51 | 52 | def test_instrumented_py_pipeline_runs(): 53 | """ 54 | Tests whether the pipeline works with instrumentation 55 | """ 56 | dag = run_and_assert_all_op_outputs_inspected(HEALTHCARE_PY, ["age_group", "race"], HEALTHCARE_PNG, 57 | [custom_monkeypatching]) 58 | assert len(dag) == 37 59 | -------------------------------------------------------------------------------- /test/experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/experiments/__init__.py -------------------------------------------------------------------------------- /test/experiments/performance/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/experiments/performance/__init__.py -------------------------------------------------------------------------------- /test/experiments/performance/test_benchmark_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the benchmark utils work 3 | """ 4 | 5 | from experiments.performance._benchmark_utils import do_op_instrumentation_benchmarks, OperatorBenchmarkType, \ 6 | do_op_inspections_benchmarks, do_full_pipeline_benchmarks, PipelineBenchmarkType 7 | 8 | 9 | def test_instrumentation_benchmarks(): 10 | """ 11 | Tests whether the pipeline works with instrumentation 12 | """ 13 | for op_type in OperatorBenchmarkType: 14 | benchmark_results = do_op_instrumentation_benchmarks(100, op_type) 15 | 16 | assert benchmark_results["no mlinspect"] 17 | assert benchmark_results["no inspection"] 18 | assert benchmark_results["one inspection"] 19 | assert benchmark_results["two inspections"] 20 | assert benchmark_results["three inspections"] 21 | 22 | 23 | def test_inspection_benchmarks(): 24 | """ 25 | Tests whether the pipeline works with instrumentation 26 | """ 27 | for op_type in OperatorBenchmarkType: 28 | benchmark_results = do_op_inspections_benchmarks(100, op_type) 29 | 30 | assert benchmark_results["empty inspection"] 31 | assert benchmark_results["MaterializeFirstOutputRows(10)"] 32 | assert benchmark_results["RowLineage(10)"] 33 | assert benchmark_results["HistogramForColumns(['group_col_1'])"] 34 | assert benchmark_results["HistogramForColumns(['group_col_1', 'group_col_2', 'group_col_3'])"] 35 | 36 | 37 | def test_full_pipeline_benchmarks(): 38 | """ 39 | Tests whether the pipeline works with instrumentation 40 | """ 41 | for pipeline in PipelineBenchmarkType: 42 | benchmark_results = do_full_pipeline_benchmarks(pipeline) 43 | 44 | assert benchmark_results["no mlinspect"] 45 | assert benchmark_results["no inspection"] 46 | assert benchmark_results["one inspection"] 47 | assert benchmark_results["two inspections"] 48 | assert benchmark_results["three inspections"] 49 | -------------------------------------------------------------------------------- /test/experiments/performance/test_performance_benchmarks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the performance benchmark notebook works 3 | """ 4 | import os 5 | 6 | import matplotlib 7 | from importnb import Notebook 8 | 9 | from mlinspect.utils import get_project_root 10 | 11 | EXPERIMENT_NB_FILE = os.path.join(str(get_project_root()), "experiments", "performance", "performance_benchmarks.ipynb") 12 | 13 | 14 | def test_experiment_nb(): 15 | """ 16 | Tests whether the experiment notebook works 17 | """ 18 | matplotlib.use("template") # Disable plt.show when executing nb as part of this test 19 | Notebook.load(EXPERIMENT_NB_FILE) 20 | -------------------------------------------------------------------------------- /test/inspections/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/inspections/__init__.py -------------------------------------------------------------------------------- /test/inspections/test_column_propagation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether ColumnPropagation works 3 | """ 4 | from inspect import cleandoc 5 | 6 | import pandas 7 | from pandas import DataFrame 8 | 9 | from mlinspect._pipeline_inspector import PipelineInspector 10 | from mlinspect.inspections import ColumnPropagation 11 | 12 | 13 | def test_propagation_merge(): 14 | """ 15 | Tests whether ColumnPropagation works for joins 16 | """ 17 | test_code = cleandoc(""" 18 | import pandas as pd 19 | 20 | df_a = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'], 'B': [1, 2, 4, 5, 7]}) 21 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, 5], 'C': [1, 5, 4, 11, None]}) 22 | df_merged = df_a.merge(df_b, on='B') 23 | """) 24 | 25 | inspector_result = PipelineInspector \ 26 | .on_pipeline_from_string(test_code) \ 27 | .add_required_inspection(ColumnPropagation(["A"], 2)) \ 28 | .execute() 29 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 30 | 31 | propagation_output = inspection_results[0][ColumnPropagation(["A"], 2)] 32 | expected_df = DataFrame([['cat_a', 1, 'cat_a'], ['cat_b', 2, 'cat_b']], columns=['A', 'B', 'mlinspect_A']) 33 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True)) 34 | 35 | propagation_output = inspection_results[1][ColumnPropagation(["A"], 2)] 36 | expected_df = DataFrame([[1, 1., None], [2, 5., None]], columns=['B', 'C', 'mlinspect_A']) 37 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True)) 38 | 39 | propagation_output = inspection_results[2][ColumnPropagation(["A"], 2)] 40 | expected_df = DataFrame([['cat_a', 1, 1., 'cat_a'], ['cat_b', 2, 5., 'cat_b']], 41 | columns=['A', 'B', 'C', 'mlinspect_A']) 42 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True)) 43 | 44 | 45 | def test_propagation_projection(): 46 | """ 47 | Tests whether ColumnPropagation works for projections 48 | """ 49 | test_code = cleandoc(""" 50 | import pandas as pd 51 | 52 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'], 53 | 'B': [1, 2, 4, 5, 7], 'C': [2, 2, 10, 5, 7]}) 54 | pandas_df = pandas_df[['B', 'C']] 55 | pandas_df = pandas_df[['C']] 56 | """) 57 | 58 | inspector_result = PipelineInspector \ 59 | .on_pipeline_from_string(test_code) \ 60 | .add_required_inspection(ColumnPropagation(["A"], 2)) \ 61 | .execute() 62 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 63 | 64 | propagation_output = inspection_results[0][ColumnPropagation(["A"], 2)] 65 | expected_df = DataFrame([['cat_a', 1, 2, 'cat_a'], ['cat_b', 2, 2, 'cat_b']], columns=['A', 'B', 'C', 'mlinspect_A']) 66 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True)) 67 | 68 | propagation_output = inspection_results[1][ColumnPropagation(["A"], 2)] 69 | expected_df = DataFrame([[1, 2, 'cat_a'], [2, 2, 'cat_b']], columns=['B', 'C', 'mlinspect_A']) 70 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True)) 71 | 72 | propagation_output = inspection_results[2][ColumnPropagation(["A"], 2)] 73 | expected_df = DataFrame([[2, 'cat_a'], [2, 'cat_b']], columns=['C', 'mlinspect_A']) 74 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True)) 75 | 76 | 77 | def test_propagation_score(): 78 | """ 79 | Tests whether ColumnPropagation works for projections 80 | """ 81 | test_code = cleandoc(""" 82 | import pandas as pd 83 | from sklearn.preprocessing import label_binarize, StandardScaler 84 | from sklearn.tree import DecisionTreeClassifier 85 | import numpy as np 86 | 87 | df = pd.DataFrame({'A': [0, 1, 2, 3], 'B': [0, 1, 2, 3], 'cat_col': ['cat_a', 'cat_b', 'cat_a', 'cat_a'], 88 | 'target': ['no', 'no', 'yes', 'yes']}) 89 | 90 | train = StandardScaler().fit_transform(df[['A', 'B']]) 91 | target = label_binarize(df['target'], classes=['no', 'yes']) 92 | 93 | clf = DecisionTreeClassifier() 94 | clf = clf.fit(train, target) 95 | 96 | test_df = pd.DataFrame({'A': [0., 0.6], 'B': [0., 0.6], 'cat_col': ['cat_a', 'cat_b'], 97 | 'target': ['no', 'yes']}) 98 | test_labels = label_binarize(test_df['target'], classes=['no', 'yes']) 99 | test_score = clf.score(test_df[['A', 'B']], test_labels) 100 | assert test_score == 1.0 101 | """) 102 | 103 | inspector_result = PipelineInspector \ 104 | .on_pipeline_from_string(test_code) \ 105 | .add_required_inspection(ColumnPropagation(["cat_col"], 2)) \ 106 | .execute() 107 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 108 | 109 | propagation_output = inspection_results[14][ColumnPropagation(["cat_col"], 2)] 110 | expected_df = DataFrame([[0, 'cat_a'], [1, 'cat_b']], columns=['array', 'mlinspect_cat_col']) 111 | pandas.testing.assert_frame_equal(propagation_output.reset_index(drop=True), expected_df.reset_index(drop=True)) 112 | -------------------------------------------------------------------------------- /test/inspections/test_completeness_of_columns.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether CompletenessOfColumns works 3 | """ 4 | from inspect import cleandoc 5 | 6 | from testfixtures import compare 7 | 8 | from mlinspect._pipeline_inspector import PipelineInspector 9 | from mlinspect.inspections import CompletenessOfColumns 10 | 11 | 12 | def test_completeness_merge(): 13 | """ 14 | Tests whether CompletenessOfColumns works for joins 15 | """ 16 | test_code = cleandoc(""" 17 | import numpy as np 18 | import pandas as pd 19 | 20 | df_a = pd.DataFrame({'A': ['cat_a', None, 'cat_a', 'cat_c', None], 'B': [1, 2, 4, 5, 7]}) 21 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, np.nan], 'C': [1, 5, 4, 11, None]}) 22 | df_merged = df_a.merge(df_b, on='B') 23 | """) 24 | 25 | inspector_result = PipelineInspector \ 26 | .on_pipeline_from_string(test_code) \ 27 | .add_required_inspection(CompletenessOfColumns(['A', 'B'])) \ 28 | .execute() 29 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 30 | 31 | completeness_output = inspection_results[0][CompletenessOfColumns(['A', 'B'])] 32 | expected_completeness = {'A': 0.6, 'B': 1.0} 33 | compare(completeness_output, expected_completeness) 34 | 35 | completeness_output = inspection_results[1][CompletenessOfColumns(['A', 'B'])] 36 | expected_completeness = {'B': 0.8} 37 | compare(completeness_output, expected_completeness) 38 | 39 | completeness_output = inspection_results[2][CompletenessOfColumns(['A', 'B'])] 40 | expected_completeness = {'A': 2/3, 'B': 1.0} 41 | compare(completeness_output, expected_completeness) 42 | 43 | 44 | def test_completeness_projection(): 45 | """ 46 | Tests whether CompletenessOfColumns works for projections 47 | """ 48 | test_code = cleandoc(""" 49 | import pandas as pd 50 | import numpy as np 51 | 52 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', None, 'cat_c', 'cat_b'], 53 | 'B': [1, None, np.nan, None, 7], 'C': [2, 2, 10, 5, 7]}) 54 | pandas_df = pandas_df[['B', 'C']] 55 | pandas_df = pandas_df[['C']] 56 | """) 57 | 58 | inspector_result = PipelineInspector \ 59 | .on_pipeline_from_string(test_code) \ 60 | .add_required_inspection(CompletenessOfColumns(['A', 'B'])) \ 61 | .execute() 62 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 63 | 64 | completeness_output = inspection_results[0][CompletenessOfColumns(['A', 'B'])] 65 | expected_completeness = {'A': 0.8, 'B': 0.4} 66 | compare(completeness_output, expected_completeness) 67 | 68 | completeness_output = inspection_results[1][CompletenessOfColumns(['A', 'B'])] 69 | expected_completeness = {'B': 0.4} 70 | compare(completeness_output, expected_completeness) 71 | 72 | completeness_output = inspection_results[2][CompletenessOfColumns(['A', 'B'])] 73 | expected_completeness = {} 74 | compare(completeness_output, expected_completeness) 75 | -------------------------------------------------------------------------------- /test/inspections/test_count_distinct_of_columns.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether CountDistinctOfColumns works 3 | """ 4 | from inspect import cleandoc 5 | 6 | from testfixtures import compare 7 | 8 | from mlinspect._pipeline_inspector import PipelineInspector 9 | from mlinspect.inspections import CountDistinctOfColumns 10 | 11 | 12 | def test_count_distinct_merge(): 13 | """ 14 | Tests whether CountDistinctOfColumns works for joins 15 | """ 16 | test_code = cleandoc(""" 17 | import numpy as np 18 | import pandas as pd 19 | 20 | df_a = pd.DataFrame({'A': ['cat_a', None, 'cat_a', 'cat_c', None], 'B': [1, 2, 4, 5, 7]}) 21 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, np.nan], 'C': [1, 5, 4, 11, None]}) 22 | df_merged = df_a.merge(df_b, on='B') 23 | """) 24 | 25 | inspector_result = PipelineInspector \ 26 | .on_pipeline_from_string(test_code) \ 27 | .add_required_inspection(CountDistinctOfColumns(['A', 'B'])) \ 28 | .execute() 29 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 30 | 31 | count_distinct_output = inspection_results[0][CountDistinctOfColumns(['A', 'B'])] 32 | expected_count_distinct = {'A': 3, 'B': 5} 33 | compare(count_distinct_output, expected_count_distinct) 34 | 35 | count_distinct_output = inspection_results[1][CountDistinctOfColumns(['A', 'B'])] 36 | expected_count_distinct = {'B': 5} 37 | compare(count_distinct_output, expected_count_distinct) 38 | 39 | count_distinct_output = inspection_results[2][CountDistinctOfColumns(['A', 'B'])] 40 | expected_count_distinct = {'A': 2, 'B': 3} 41 | compare(count_distinct_output, expected_count_distinct) 42 | 43 | 44 | def test_count_distinct_projection(): 45 | """ 46 | Tests whether CountDistinctOfColumns works for projections 47 | """ 48 | test_code = cleandoc(""" 49 | import pandas as pd 50 | import numpy as np 51 | 52 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', None, 'cat_c', 'cat_b'], 53 | 'B': [1, None, np.nan, None, 7], 'C': [2, 2, 10, 5, 7]}) 54 | pandas_df = pandas_df[['B', 'C']] 55 | pandas_df = pandas_df[['C']] 56 | """) 57 | 58 | inspector_result = PipelineInspector \ 59 | .on_pipeline_from_string(test_code) \ 60 | .add_required_inspection(CountDistinctOfColumns(['A', 'B'])) \ 61 | .execute() 62 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 63 | 64 | count_distinct_output = inspection_results[0][CountDistinctOfColumns(['A', 'B'])] 65 | expected_count_distinct = {'A': 4, 'B': 5} 66 | compare(count_distinct_output, expected_count_distinct) 67 | 68 | count_distinct_output = inspection_results[1][CountDistinctOfColumns(['A', 'B'])] 69 | expected_count_distinct = {'B': 5} 70 | compare(count_distinct_output, expected_count_distinct) 71 | 72 | count_distinct_output = inspection_results[2][CountDistinctOfColumns(['A', 'B'])] 73 | expected_count_distinct = {} 74 | compare(count_distinct_output, expected_count_distinct) 75 | -------------------------------------------------------------------------------- /test/inspections/test_histogram_for_columns.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether HistogramForColumns works 3 | """ 4 | from inspect import cleandoc 5 | 6 | from testfixtures import compare 7 | 8 | from mlinspect._pipeline_inspector import PipelineInspector 9 | from mlinspect.inspections import HistogramForColumns 10 | 11 | 12 | def test_histogram_merge(): 13 | """ 14 | Tests whether HistogramForColumns works for joins 15 | """ 16 | test_code = cleandoc(""" 17 | import pandas as pd 18 | 19 | df_a = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'], 'B': [1, 2, 4, 5, 7]}) 20 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, 5], 'C': [1, 5, 4, 11, None]}) 21 | df_merged = df_a.merge(df_b, on='B') 22 | """) 23 | 24 | inspector_result = PipelineInspector \ 25 | .on_pipeline_from_string(test_code) \ 26 | .add_required_inspection(HistogramForColumns(["A"])) \ 27 | .execute() 28 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 29 | 30 | histogram_output = inspection_results[0][HistogramForColumns(["A"])] 31 | expected_histogram = {'A': {'cat_a': 2, 'cat_b': 2, 'cat_c': 1}} 32 | compare(histogram_output, expected_histogram) 33 | 34 | histogram_output = inspection_results[1][HistogramForColumns(["A"])] 35 | expected_histogram = {'A': {}} 36 | compare(histogram_output, expected_histogram) 37 | 38 | histogram_output = inspection_results[2][HistogramForColumns(["A"])] 39 | expected_histogram = {'A': {'cat_a': 2, 'cat_b': 1, 'cat_c': 1}} 40 | compare(histogram_output, expected_histogram) 41 | 42 | 43 | def test_histogram_projection(): 44 | """ 45 | Tests whether HistogramForColumns works for projections 46 | """ 47 | test_code = cleandoc(""" 48 | import pandas as pd 49 | 50 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'], 51 | 'B': [1, 2, 4, 5, 7], 'C': [2, 2, 10, 5, 7]}) 52 | pandas_df = pandas_df[['B', 'C']] 53 | pandas_df = pandas_df[['C']] 54 | """) 55 | 56 | inspector_result = PipelineInspector \ 57 | .on_pipeline_from_string(test_code) \ 58 | .add_required_inspection(HistogramForColumns(["A"])) \ 59 | .execute() 60 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 61 | 62 | histogram_output = inspection_results[0][HistogramForColumns(["A"])] 63 | expected_histogram = {'A': {'cat_a': 2, 'cat_b': 2, 'cat_c': 1}} 64 | compare(histogram_output, expected_histogram) 65 | 66 | histogram_output = inspection_results[1][HistogramForColumns(["A"])] 67 | expected_histogram = {'A': {'cat_a': 2, 'cat_b': 2, 'cat_c': 1}} 68 | compare(histogram_output, expected_histogram) 69 | 70 | histogram_output = inspection_results[2][HistogramForColumns(["A"])] 71 | expected_histogram = {'A': {'cat_a': 2, 'cat_b': 2, 'cat_c': 1}} 72 | compare(histogram_output, expected_histogram) 73 | 74 | 75 | def test_histogram_score(): 76 | """ 77 | Tests whether HistogramForColumns works for projections 78 | """ 79 | test_code = cleandoc(""" 80 | import pandas as pd 81 | from sklearn.preprocessing import label_binarize, StandardScaler 82 | from sklearn.tree import DecisionTreeClassifier 83 | import numpy as np 84 | 85 | df = pd.DataFrame({'A': [0, 1, 2, 3], 'B': [0, 1, 2, 3], 'target': ['no', 'no', 'yes', 'yes']}) 86 | 87 | train = StandardScaler().fit_transform(df[['A', 'B']]) 88 | target = label_binarize(df['target'], classes=['no', 'yes']) 89 | 90 | clf = DecisionTreeClassifier() 91 | clf = clf.fit(train, target) 92 | 93 | test_df = pd.DataFrame({'A': [0., 0.6], 'B': [0., 0.6], 'target': ['no', 'yes']}) 94 | test_labels = label_binarize(test_df['target'], classes=['no', 'yes']) 95 | test_score = clf.score(test_df[['A', 'B']], test_labels) 96 | assert test_score == 1.0 97 | """) 98 | 99 | inspector_result = PipelineInspector \ 100 | .on_pipeline_from_string(test_code) \ 101 | .add_required_inspection(HistogramForColumns(["target"])) \ 102 | .execute() 103 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 104 | 105 | histogram_output = inspection_results[14][HistogramForColumns(["target"])] 106 | expected_histogram = {'target': {'no': 1, 'yes': 1}} 107 | compare(histogram_output, expected_histogram) 108 | -------------------------------------------------------------------------------- /test/inspections/test_intersectional_histogram_for_columns.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether IntersectionalHistogramForColumns works 3 | """ 4 | from inspect import cleandoc 5 | 6 | from testfixtures import compare 7 | 8 | from mlinspect._pipeline_inspector import PipelineInspector 9 | from mlinspect.inspections import IntersectionalHistogramForColumns 10 | 11 | 12 | def test_histogram_merge(): 13 | """ 14 | Tests whether IntersectionalHistogramForColumns works for joins 15 | """ 16 | test_code = cleandoc(""" 17 | import pandas as pd 18 | 19 | df_a = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_b', 'cat_b'], 20 | 'B': [1, 2, 4, 5, 7], 21 | 'C': [True, False, True, True, True]}) 22 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, 5], 'D': [1, 5, 4, 11, None]}) 23 | df_merged = df_a.merge(df_b, on='B') 24 | """) 25 | 26 | inspector_result = PipelineInspector \ 27 | .on_pipeline_from_string(test_code) \ 28 | .add_required_inspection(IntersectionalHistogramForColumns(["A", "C"])) \ 29 | .execute() 30 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 31 | 32 | histogram_output = inspection_results[0][IntersectionalHistogramForColumns(["A", "C"])] 33 | expected_histogram = {('cat_a', True): 2, ('cat_b', False): 1, ('cat_b', True): 2} 34 | compare(histogram_output, expected_histogram) 35 | 36 | histogram_output = inspection_results[1][IntersectionalHistogramForColumns(["A", "C"])] 37 | expected_histogram = {(None, None): 5} 38 | compare(histogram_output, expected_histogram) 39 | 40 | histogram_output = inspection_results[2][IntersectionalHistogramForColumns(["A", "C"])] 41 | expected_histogram = {('cat_a', True): 2, ('cat_b', False): 1, ('cat_b', True): 1} 42 | compare(histogram_output, expected_histogram) 43 | 44 | 45 | def test_histogram_projection(): 46 | """ 47 | Tests whether IntersectionalHistogramForColumns works for projections 48 | """ 49 | test_code = cleandoc(""" 50 | import pandas as pd 51 | 52 | pandas_df = pd.DataFrame({'A': ['cat_a', 'cat_b', 'cat_a', 'cat_c', 'cat_b'], 53 | 'B': [1, 2, 4, 5, 7], 'C': [True, False, True, True, True]}) 54 | pandas_df = pandas_df[['B', 'C']] 55 | pandas_df = pandas_df[['C']] 56 | """) 57 | 58 | inspector_result = PipelineInspector \ 59 | .on_pipeline_from_string(test_code) \ 60 | .add_required_inspection(IntersectionalHistogramForColumns(["A", "C"])) \ 61 | .execute() 62 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 63 | 64 | histogram_output = inspection_results[0][IntersectionalHistogramForColumns(["A", "C"])] 65 | expected_histogram = {('cat_a', True): 2, ('cat_b', False): 1, ('cat_c', True): 1, ('cat_b', True): 1} 66 | compare(histogram_output, expected_histogram) 67 | 68 | histogram_output = inspection_results[1][IntersectionalHistogramForColumns(["A", "C"])] 69 | expected_histogram = {('cat_a', True): 2, ('cat_b', False): 1, ('cat_c', True): 1, ('cat_b', True): 1} 70 | compare(histogram_output, expected_histogram) 71 | 72 | histogram_output = inspection_results[2][IntersectionalHistogramForColumns(["A", "C"])] 73 | expected_histogram = {('cat_a', True): 2, ('cat_b', False): 1, ('cat_c', True): 1, ('cat_b', True): 1} 74 | compare(histogram_output, expected_histogram) 75 | -------------------------------------------------------------------------------- /test/inspections/test_lineage.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether RowLineage works 3 | """ 4 | from inspect import cleandoc 5 | 6 | import pandas 7 | import numpy as np 8 | from pandas import DataFrame 9 | 10 | from mlinspect import OperatorType 11 | from mlinspect._pipeline_inspector import PipelineInspector 12 | from mlinspect.inspections import RowLineage 13 | from mlinspect.inspections._lineage import LineageId 14 | 15 | 16 | def test_row_lineage_merge(): 17 | """ 18 | Tests whether RowLineage works for joins 19 | """ 20 | test_code = cleandoc(""" 21 | import pandas as pd 22 | 23 | df_a = pd.DataFrame({'A': [0, 2, 4, 8, 5], 'B': [1, 2, 4, 5, 7]}) 24 | df_b = pd.DataFrame({'B': [1, 2, 3, 4, 5], 'C': [1, 5, 4, 11, None]}) 25 | df_merged = df_a.merge(df_b, on='B') 26 | """) 27 | 28 | inspector_result = PipelineInspector \ 29 | .on_pipeline_from_string(test_code) \ 30 | .add_required_inspection(RowLineage(2)) \ 31 | .execute() 32 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 33 | 34 | lineage_output = inspection_results[0][RowLineage(2)] 35 | expected_lineage_df = DataFrame([[0, 1, {LineageId(0, 0)}], 36 | [2, 2, {LineageId(0, 1)}]], 37 | columns=['A', 'B', 'mlinspect_lineage']) 38 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True)) 39 | 40 | lineage_output = inspection_results[1][RowLineage(2)] 41 | expected_lineage_df = DataFrame([[1, 1., {LineageId(1, 0)}], 42 | [2, 5., {LineageId(1, 1)}]], 43 | columns=['B', 'C', 'mlinspect_lineage']) 44 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True)) 45 | 46 | lineage_output = inspection_results[2][RowLineage(2)] 47 | expected_lineage_df = DataFrame([[0, 1, 1., {LineageId(0, 0), LineageId(1, 0)}], 48 | [2, 2, 5., {LineageId(0, 1), LineageId(1, 1)}]], 49 | columns=['A', 'B', 'C', 'mlinspect_lineage']) 50 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True)) 51 | 52 | 53 | def test_row_lineage_concat(): 54 | """ 55 | Tests whether RowLineage works for concats 56 | """ 57 | test_code = cleandoc(""" 58 | import pandas as pd 59 | from sklearn.preprocessing import StandardScaler, OneHotEncoder 60 | from sklearn.compose import ColumnTransformer 61 | 62 | df = pd.DataFrame({'A': [1, 2, 10, 5], 'B': ['cat_a', 'cat_b', 'cat_a', 'cat_c']}) 63 | column_transformer = ColumnTransformer(transformers=[ 64 | ('numeric', StandardScaler(), ['A']), 65 | ('categorical', OneHotEncoder(sparse=False), ['B']) 66 | ]) 67 | encoded_data = column_transformer.fit_transform(df) 68 | """) 69 | 70 | inspector_result = PipelineInspector \ 71 | .on_pipeline_from_string(test_code) \ 72 | .add_required_inspection(RowLineage(2)) \ 73 | .execute() 74 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 75 | 76 | lineage_output = inspection_results[0][RowLineage(2)] 77 | expected_lineage_df = DataFrame([[1, 'cat_a', {LineageId(0, 0)}], 78 | [2, 'cat_b', {LineageId(0, 1)}]], 79 | columns=['A', 'B', 'mlinspect_lineage']) 80 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True)) 81 | 82 | lineage_output = inspection_results[1][RowLineage(2)] 83 | expected_lineage_df = DataFrame([[1, {LineageId(0, 0)}], 84 | [2, {LineageId(0, 1)}]], 85 | columns=['A', 'mlinspect_lineage']) 86 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True)) 87 | 88 | lineage_output = inspection_results[2][RowLineage(2)] 89 | expected_lineage_df = DataFrame([[np.array([-1.0]), {LineageId(0, 0)}], 90 | [np.array([-0.7142857142857143]), {LineageId(0, 1)}]], 91 | columns=['array', 'mlinspect_lineage']) 92 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True)) 93 | 94 | lineage_output = inspection_results[3][RowLineage(2)] 95 | expected_lineage_df = DataFrame([['cat_a', {LineageId(0, 0)}], 96 | ['cat_b', {LineageId(0, 1)}]], 97 | columns=['B', 'mlinspect_lineage']) 98 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True)) 99 | 100 | lineage_output = inspection_results[4][RowLineage(2)] 101 | expected_lineage_df = DataFrame([[np.array([1., 0., 0.]), {LineageId(0, 0)}], 102 | [np.array([0., 1., 0.]), {LineageId(0, 1)}]], 103 | columns=['array', 'mlinspect_lineage']) 104 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True)) 105 | 106 | lineage_output = inspection_results[5][RowLineage(2)] 107 | expected_lineage_df = DataFrame([[np.array([-1.0, 1., 0., 0.]), {LineageId(0, 0)}], 108 | [np.array([-0.7142857142857143, 0., 1., 0.]), {LineageId(0, 1)}]], 109 | columns=['array', 'mlinspect_lineage']) 110 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True)) 111 | 112 | 113 | def test_all_rows_for_op_type(): 114 | """ 115 | Tests whether RowLineage works for materialising all data from specific operators 116 | """ 117 | test_code = cleandoc(""" 118 | import pandas as pd 119 | 120 | df_a = pd.DataFrame({'A': [0, 2], 'B': [1, 2]}) 121 | df_b = pd.DataFrame({'B': [1, 2], 'C': [1, 5]}) 122 | df_merged = df_a.merge(df_b, on='B') 123 | """) 124 | row_lineage = RowLineage(RowLineage.ALL_ROWS, [OperatorType.DATA_SOURCE]) 125 | inspector_result = PipelineInspector \ 126 | .on_pipeline_from_string(test_code) \ 127 | .add_required_inspection(row_lineage) \ 128 | .execute() 129 | inspection_results = list(inspector_result.dag_node_to_inspection_results.values()) 130 | 131 | lineage_output = inspection_results[0][row_lineage] 132 | expected_lineage_df = DataFrame([[0, 1, {LineageId(0, 0)}], 133 | [2, 2, {LineageId(0, 1)}]], 134 | columns=['A', 'B', 'mlinspect_lineage']) 135 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True)) 136 | 137 | lineage_output = inspection_results[1][row_lineage] 138 | expected_lineage_df = DataFrame([[1, 1, {LineageId(1, 0)}], 139 | [2, 5, {LineageId(1, 1)}]], 140 | columns=['B', 'C', 'mlinspect_lineage']) 141 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True)) 142 | 143 | lineage_output = inspection_results[2][row_lineage] 144 | assert lineage_output is None 145 | -------------------------------------------------------------------------------- /test/instrumentation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/instrumentation/__init__.py -------------------------------------------------------------------------------- /test/monkeypatching/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/monkeypatching/__init__.py -------------------------------------------------------------------------------- /test/monkeypatching/test_patch_numpy.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the monkey patching works for all patched numpy methods 3 | """ 4 | from inspect import cleandoc 5 | 6 | import pandas 7 | from pandas import DataFrame 8 | from testfixtures import compare 9 | 10 | from mlinspect import OperatorContext, FunctionInfo, OperatorType 11 | from mlinspect.inspections._lineage import RowLineage, LineageId 12 | from mlinspect.instrumentation import _pipeline_executor 13 | from mlinspect.instrumentation._dag_node import DagNode, CodeReference, BasicCodeLocation, DagNodeDetails, \ 14 | OptionalCodeInfo 15 | 16 | 17 | def test_numpy_random(): 18 | """ 19 | Tests whether the monkey patching of ('numpy.random', 'random') works 20 | """ 21 | test_code = cleandoc(""" 22 | import numpy as np 23 | np.random.seed(42) 24 | test = np.random.random(100) 25 | assert len(test) == 100 26 | """) 27 | 28 | inspector_result = _pipeline_executor.singleton.run(python_code=test_code, track_code_references=True, 29 | inspections=[RowLineage(2)]) 30 | extracted_node: DagNode = list(inspector_result.dag.nodes)[0] 31 | 32 | expected_node = DagNode(0, 33 | BasicCodeLocation("", 3), 34 | OperatorContext(OperatorType.DATA_SOURCE, FunctionInfo('numpy.random', 'random')), 35 | DagNodeDetails('random', ['array']), 36 | OptionalCodeInfo(CodeReference(3, 7, 3, 28), "np.random.random(100)")) 37 | compare(extracted_node, expected_node) 38 | 39 | inspection_results_data_source = inspector_result.dag_node_to_inspection_results[extracted_node] 40 | lineage_output = inspection_results_data_source[RowLineage(2)] 41 | expected_lineage_df = DataFrame([[0.5, {LineageId(0, 0)}], 42 | [0.5, {LineageId(0, 1)}]], 43 | columns=['array', 'mlinspect_lineage']) 44 | pandas.testing.assert_frame_equal(lineage_output.reset_index(drop=True), expected_lineage_df.reset_index(drop=True), 45 | atol=1) 46 | -------------------------------------------------------------------------------- /test/test_pipeline_inspector.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the fluent API works 3 | """ 4 | 5 | import networkx 6 | from testfixtures import compare 7 | 8 | from example_pipelines.healthcare import custom_monkeypatching 9 | from example_pipelines import ADULT_SIMPLE_PY, ADULT_SIMPLE_IPYNB, HEALTHCARE_PY 10 | from mlinspect import PipelineInspector, OperatorType 11 | from mlinspect.checks import CheckStatus, NoBiasIntroducedFor, NoIllegalFeatures 12 | from mlinspect.inspections import HistogramForColumns, MaterializeFirstOutputRows 13 | from mlinspect.testing._testing_helper_utils import get_expected_dag_adult_easy 14 | 15 | 16 | def test_inspector_adult_easy_py_pipeline(): 17 | """ 18 | Tests whether the .py version of the inspector works 19 | """ 20 | inspector_result = PipelineInspector\ 21 | .on_pipeline_from_py_file(ADULT_SIMPLE_PY)\ 22 | .add_required_inspection(MaterializeFirstOutputRows(5))\ 23 | .add_check(NoBiasIntroducedFor(['race']))\ 24 | .add_check(NoIllegalFeatures())\ 25 | .execute() 26 | extracted_dag = inspector_result.dag 27 | expected_dag = get_expected_dag_adult_easy(ADULT_SIMPLE_PY) 28 | compare(networkx.to_dict_of_dicts(extracted_dag), networkx.to_dict_of_dicts(expected_dag)) 29 | 30 | assert HistogramForColumns(['race']) in list(inspector_result.dag_node_to_inspection_results.values())[0] 31 | check_to_check_results = inspector_result.check_to_check_results 32 | assert check_to_check_results[NoBiasIntroducedFor(['race'])].status == CheckStatus.SUCCESS 33 | assert check_to_check_results[NoIllegalFeatures()].status == CheckStatus.FAILURE 34 | 35 | 36 | def test_inspector_adult_easy_py_pipeline_without_inspections(): 37 | """ 38 | Tests whether the .py version of the inspector works 39 | """ 40 | inspector_result = PipelineInspector\ 41 | .on_pipeline_from_py_file(ADULT_SIMPLE_PY)\ 42 | .execute() 43 | extracted_dag = inspector_result.dag 44 | expected_dag = get_expected_dag_adult_easy(ADULT_SIMPLE_PY) 45 | compare(networkx.to_dict_of_dicts(extracted_dag), networkx.to_dict_of_dicts(expected_dag)) 46 | 47 | 48 | def test_inspector_adult_easy_ipynb_pipeline(): 49 | """ 50 | Tests whether the .ipynb version of the inspector works 51 | """ 52 | inspector_result = PipelineInspector\ 53 | .on_pipeline_from_ipynb_file(ADULT_SIMPLE_IPYNB)\ 54 | .add_required_inspection(MaterializeFirstOutputRows(5)) \ 55 | .add_check(NoBiasIntroducedFor(['race'])) \ 56 | .add_check(NoIllegalFeatures()) \ 57 | .execute() 58 | extracted_dag = inspector_result.dag 59 | expected_dag = get_expected_dag_adult_easy(ADULT_SIMPLE_IPYNB, 6) 60 | compare(networkx.to_dict_of_dicts(extracted_dag), networkx.to_dict_of_dicts(expected_dag)) 61 | 62 | assert HistogramForColumns(['race']) in list(inspector_result.dag_node_to_inspection_results.values())[0] 63 | check_to_check_results = inspector_result.check_to_check_results 64 | assert check_to_check_results[NoBiasIntroducedFor(['race'])].status == CheckStatus.SUCCESS 65 | assert check_to_check_results[NoIllegalFeatures()].status == CheckStatus.FAILURE 66 | 67 | 68 | def test_inspector_adult_easy_str_pipeline(): 69 | """ 70 | Tests whether the str version of the inspector works 71 | """ 72 | with open(ADULT_SIMPLE_PY, encoding="utf-8") as file: 73 | code = file.read() 74 | 75 | inspector_result = PipelineInspector\ 76 | .on_pipeline_from_string(code)\ 77 | .add_required_inspection(MaterializeFirstOutputRows(5)) \ 78 | .add_check(NoBiasIntroducedFor(['race'])) \ 79 | .add_check(NoIllegalFeatures()) \ 80 | .execute() 81 | extracted_dag = inspector_result.dag 82 | expected_dag = get_expected_dag_adult_easy("") 83 | compare(networkx.to_dict_of_dicts(extracted_dag), networkx.to_dict_of_dicts(expected_dag)) 84 | 85 | assert HistogramForColumns(['race']) in list(inspector_result.dag_node_to_inspection_results.values())[0] 86 | check_to_check_results = inspector_result.check_to_check_results 87 | assert check_to_check_results[NoBiasIntroducedFor(['race'])].status == CheckStatus.SUCCESS 88 | assert check_to_check_results[NoIllegalFeatures()].status == CheckStatus.FAILURE 89 | 90 | 91 | def test_inspector_additional_module(): 92 | """ 93 | Tests whether the str version of the inspector works 94 | """ 95 | inspector_result = PipelineInspector \ 96 | .on_pipeline_from_py_file(HEALTHCARE_PY) \ 97 | .add_required_inspection(MaterializeFirstOutputRows(5)) \ 98 | .add_custom_monkey_patching_module(custom_monkeypatching) \ 99 | .execute() 100 | 101 | assert_healthcare_pipeline_output_complete(inspector_result) 102 | 103 | 104 | def test_inspector_additional_modules(): 105 | """ 106 | Tests whether the str version of the inspector works 107 | """ 108 | inspector_result = PipelineInspector \ 109 | .on_pipeline_from_py_file(HEALTHCARE_PY) \ 110 | .add_required_inspection(MaterializeFirstOutputRows(5)) \ 111 | .add_custom_monkey_patching_modules([custom_monkeypatching]) \ 112 | .execute() 113 | 114 | assert_healthcare_pipeline_output_complete(inspector_result) 115 | 116 | 117 | def assert_healthcare_pipeline_output_complete(inspector_result): 118 | """ Assert that the healthcare DAG was extracted completely """ 119 | for dag_node, inspection_result in inspector_result.dag_node_to_inspection_results.items(): 120 | assert dag_node.operator_info.operator != OperatorType.MISSING_OP 121 | assert MaterializeFirstOutputRows(5) in inspection_result 122 | if dag_node.operator_info.operator is not OperatorType.ESTIMATOR: 123 | assert inspection_result[MaterializeFirstOutputRows(5)] is not None 124 | else: 125 | assert inspection_result[MaterializeFirstOutputRows(5)] is None 126 | assert len(inspector_result.dag) == 37 127 | -------------------------------------------------------------------------------- /test/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/utils/__init__.py -------------------------------------------------------------------------------- /test/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the utils work 3 | """ 4 | from pathlib import Path 5 | 6 | from mlinspect.utils._utils import get_project_root 7 | 8 | 9 | def test_get_project_root(): 10 | """ 11 | Tests whether get_project_root works 12 | """ 13 | assert get_project_root() == Path(__file__).parent.parent.parent 14 | -------------------------------------------------------------------------------- /test/visualisation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stefan-grafberger/mlinspect/c2207ef058e5fb28cc74c72c7c9f3deed04fc639/test/visualisation/__init__.py -------------------------------------------------------------------------------- /test/visualisation/test_visualisation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests whether the visualisation of the resulting DAG works 3 | """ 4 | import os 5 | 6 | from mlinspect.utils import get_project_root 7 | from mlinspect.visualisation import save_fig_to_path, get_dag_as_pretty_string 8 | from mlinspect.testing._testing_helper_utils import get_expected_dag_adult_easy 9 | 10 | 11 | def test_save_fig_to_path(): 12 | """ 13 | Tests whether the .py version of the inspector works 14 | """ 15 | extracted_dag = get_expected_dag_adult_easy("") 16 | 17 | filename = os.path.join(str(get_project_root()), "example_pipelines", "adult_simple", "adult_simple.png") 18 | save_fig_to_path(extracted_dag, filename) 19 | 20 | assert os.path.isfile(filename) 21 | 22 | 23 | def test_get_dag_as_pretty_string(): 24 | """ 25 | Tests whether the .py version of the inspector works 26 | """ 27 | extracted_dag = get_expected_dag_adult_easy("") 28 | 29 | pretty_string = get_dag_as_pretty_string(extracted_dag) 30 | 31 | print(pretty_string) 32 | --------------------------------------------------------------------------------