├── tests ├── __init__.py ├── test_drift.py ├── test_summary.py ├── test_multiple_dfs.py ├── conftest.py └── test_embedding.py ├── docs ├── images │ └── gate.png ├── index.md ├── api.md ├── how-it-works.md ├── embedding.md └── example.md ├── .gitignore ├── .github └── workflows │ ├── ruff.yml │ └── ci.yml ├── gate ├── __init__.py ├── summarize.py ├── statistics.py ├── summary.py └── drift.py ├── Makefile ├── LICENSE ├── pyproject.toml ├── .pre-commit-config.yaml ├── README.md └── mkdocs.yml /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/images/gate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dm4ml/gate/HEAD/docs/images/gate.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | examples/data* 2 | *__pycache__* 3 | *.pytest_cache* 4 | canonical-partitioned-dataset* 5 | *.ipynb_checkpoints* 6 | dist/* 7 | site/* -------------------------------------------------------------------------------- /.github/workflows/ruff.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: [ push, pull_request ] 4 | 5 | 6 | jobs: 7 | test: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v3 11 | - uses: chartboost/ruff-action@v1 12 | -------------------------------------------------------------------------------- /gate/__init__.py: -------------------------------------------------------------------------------- 1 | from gate.summarize import summarize, compute_embeddings 2 | from gate.drift import detect_drift 3 | from gate.statistics import type_to_statistics 4 | 5 | __all__ = [ 6 | "summarize", 7 | "detect_drift", 8 | "type_to_statistics", 9 | "compute_embeddings", 10 | ] 11 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test lint docs release install 2 | 3 | test: 4 | poetry run pytest 5 | 6 | lint: 7 | ruff . --fix 8 | 9 | docs: 10 | mkdocs gh-deploy --force 11 | 12 | release: 13 | poetry version patch 14 | poetry publish --build 15 | 16 | install: 17 | pip install poetry 18 | poetry install 19 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: gate 2 | 3 | on: [ push, pull_request ] 4 | 5 | 6 | jobs: 7 | test: 8 | runs-on: ubuntu-latest 9 | timeout-minutes: 30 10 | steps: 11 | - name: Checkout code 12 | uses: actions/checkout@v2 13 | 14 | - name: Set up Python 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: 3.9 18 | 19 | - name: Install Poetry 20 | uses: snok/install-poetry@v1 21 | 22 | - name: Install dependencies 23 | run: poetry install 24 | 25 | - name: Run pytest 26 | run: poetry run pytest 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 dm4ml 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "gate-drift" 3 | version = "0.1.5" 4 | description = "Data drift detection tool for machine learning pipelines." 5 | authors = ["Shreya Shankar "] 6 | license = "MIT" 7 | readme = "README.md" 8 | packages = [{include = "gate"}] 9 | 10 | [tool.poetry.dependencies] 11 | python = "^3.8" 12 | numpy = "^1.24.2" 13 | pandas = "^2.0.0" 14 | scikit-learn = "^1.2.2" 15 | sentence-transformers = "^2.2.2" 16 | polars = "^0.17.5" 17 | pyarrow = "^11.0.0" 18 | 19 | [tool.poetry.group.dev.dependencies] 20 | pytest = "^7.3.1" 21 | ruff = "^0.0.261" 22 | pre-commit = "^3.2.2" 23 | matplotlib = "^3.7.1" 24 | seaborn = "^0.12.2" 25 | mkdocs = "^1.4.2" 26 | mkdocs-material = "^9.1.6" 27 | mkdocstrings = "^0.21.2" 28 | pytkdocs = "^0.16.1" 29 | linkchecker = "^10.2.1" 30 | mkdocstrings-python = "^0.9.0" 31 | pytest-rerunfailures = "^11.1.2" 32 | 33 | [tool.pytest.ini_options] 34 | testpaths = ["tests"] 35 | 36 | [tool.black] 37 | max-line-length = 88 38 | preview = true 39 | include = '\.pyi?$' 40 | exclude = ''' 41 | /( 42 | \.git 43 | | \.hg 44 | | \.mypy_cache 45 | | \.tox 46 | | \.venv 47 | | _build 48 | | buck-out 49 | | build 50 | )/ 51 | ''' 52 | 53 | [build-system] 54 | requires = ["poetry-core"] 55 | build-backend = "poetry.core.masonry.api" 56 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_prs: false 3 | 4 | files: '^gate/' 5 | exclude: '\__init__.py$' 6 | 7 | repos: 8 | - repo: https://github.com/pre-commit/pre-commit-hooks 9 | rev: v4.4.0 10 | hooks: 11 | - id: trailing-whitespace 12 | - id: end-of-file-fixer 13 | exclude: ^.*\.egg-info/ 14 | - id: check-merge-conflict 15 | - id: check-case-conflict 16 | - id: pretty-format-json 17 | args: [--autofix, --no-ensure-ascii, --no-sort-keys] 18 | - id: check-ast 19 | - id: debug-statements 20 | - id: check-docstring-first 21 | 22 | - repo: https://github.com/hadialqattan/pycln 23 | rev: v2.1.2 24 | hooks: 25 | - id: pycln 26 | args: [--all] 27 | 28 | - repo: https://github.com/psf/black 29 | rev: 22.12.0 30 | hooks: 31 | - id: black 32 | 33 | - repo: https://github.com/pycqa/isort 34 | rev: 5.12.0 35 | hooks: 36 | - id: isort 37 | name: "isort (python)" 38 | types: [python] 39 | args: [--profile, black] 40 | 41 | - repo: https://github.com/charliermarsh/ruff-pre-commit 42 | # Ruff version. 43 | rev: 'v0.0.261' 44 | hooks: 45 | - id: ruff 46 | 47 | - repo: https://github.com/pre-commit/pre-commit 48 | rev: v2.21.0 49 | hooks: 50 | - id: validate_manifest -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | GATE is a Python module that detects drift in partitions of data. GATE computes partition _summaries_, which are then fed into an anomaly detection algorithm to detect whether a new partition is anomalous. This minimizes false positive alerts when detecting drift in machine learning (ML) pipelines, where there may be many features and prediction columns. 2 | 3 | !!! tip "Support for Embeddings" 4 | 5 | We now support drift detection on embeddings, in addition to structured data. GATE considers _both_ the structured data and the embeddings when computing partition summaries and detecting drift. Check out the [embeddings page](./embedding) for a walkthrough of how to use GATE with embeddings. 6 | 7 | ## Installation 8 | 9 | GATE is available on PyPI and can be installed with pip: 10 | 11 | ```bash 12 | pip install gate-drift 13 | ``` 14 | 15 | Note that GATE requires Python 3.8 or higher. 16 | 17 | ## Usage 18 | 19 | GATE is designed to be used with [Pandas](https://pandas.pydata.org/) dataframes. Check out the [example](./example) for a walkthrough of how to use GATE. 20 | 21 | ## Research Contributions 22 | 23 | GATE was developed and is maintained by researchers at the UC Berkeley [EPIC Lab](https://epic.berkeley.edu/). 24 | 25 | An initial version of GATE was developed as part of a collaboration with Meta, and the research paper, "Moving Fast With Broken Data" by Shankar et al., is available on [arXiv](https://arxiv.org/abs/2303.06094). This module slightly differs from the original implementation, but the core ideas around partition summaries and anomaly detection are the same. 26 | -------------------------------------------------------------------------------- /docs/api.md: -------------------------------------------------------------------------------- 1 | ::: gate.summarize 2 | handler: python 3 | options: 4 | members: 5 | - summarize 6 | - compute_embeddings 7 | show_root_full_path: false 8 | show_root_toc_entry: false 9 | show_root_heading: false 10 | show_source: false 11 | 12 | ::: gate.summary.Summary 13 | handler: python 14 | options: 15 | members: 16 | - summary 17 | - embeddings_summary 18 | - partition_key 19 | - partition 20 | - columns 21 | - non_embedding_columns 22 | - embedding_examples 23 | - embedding_centroids 24 | - statistics 25 | - value 26 | - __str__ 27 | show_root_full_path: false 28 | show_root_toc_entry: false 29 | show_root_heading: true 30 | show_source: false 31 | 32 | ::: gate.drift 33 | handler: python 34 | options: 35 | members: 36 | - detect_drift 37 | show_root_full_path: false 38 | show_root_toc_entry: false 39 | show_root_heading: false 40 | show_source: false 41 | 42 | ::: gate.drift.DriftResult 43 | handler: python 44 | options: 45 | members: 46 | - summary 47 | - neighbor_summaries 48 | - drifted_examples 49 | - score 50 | - score_percentile 51 | - is_drifted 52 | - all_scores 53 | - clustering 54 | - drill_down 55 | - drifted_columns 56 | - __str__ 57 | show_root_full_path: false 58 | show_root_toc_entry: false 59 | show_root_heading: true 60 | show_source: false 61 | 62 | ::: gate.statistics 63 | handler: python 64 | options: 65 | members: 66 | - type_to_statistics 67 | show_root_full_path: false 68 | show_root_toc_entry: false 69 | show_root_heading: false 70 | show_source: false -------------------------------------------------------------------------------- /tests/test_drift.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from gate import summarize, detect_drift 3 | 4 | 5 | def test_no_drift(medium_df): 6 | summary = summarize( 7 | medium_df, 8 | partition_key="date", 9 | columns=["string_col", "int_col", "float_col"], 10 | ) 11 | assert len(summary) == 30 12 | 13 | drift_results = detect_drift(summary[-1], summary[:-1]) 14 | 15 | assert drift_results.score < 1e-7 16 | 17 | 18 | def test_attributes(tiny_df): 19 | summary = summarize( 20 | tiny_df, 21 | partition_key="grp", 22 | columns=["string_col", "int_col", "float_col"], 23 | ) 24 | assert len(summary) == 1 25 | 26 | with pytest.raises(ValueError): 27 | detect_drift(summary[-1], summary[:-1]) 28 | 29 | 30 | def test_drift(df_with_drift): 31 | summary = summarize( 32 | df_with_drift, 33 | partition_key="date", 34 | columns=["string_col", "int_col", "float_col"], 35 | ) 36 | assert len(summary) == 10 37 | 38 | drift_results = detect_drift(summary[-1], summary[:-1]) 39 | 40 | assert drift_results.score_percentile > 0.85 41 | 42 | assert drift_results.drifted_columns().index.values[0] in [ 43 | "int_col", 44 | "float_col", 45 | ] 46 | assert drift_results.drifted_columns()["z-score"].abs().values[0] > 2.0 47 | 48 | 49 | def test_drift_small_clustering(df_with_drift): 50 | columns = df_with_drift.columns.tolist() 51 | columns.remove("date") 52 | summary = summarize( 53 | df_with_drift, 54 | partition_key="date", 55 | columns=columns, 56 | ) 57 | 58 | drift_results = detect_drift(summary[-1], summary[:-1], cluster=True) 59 | 60 | assert len(drift_results.clustering) > 0 61 | assert drift_results.score_percentile > 0.85 62 | assert drift_results.drifted_columns().index.values[0] in [ 63 | "int_col", 64 | "float_col", 65 | ] 66 | 67 | assert len(drift_results.drifted_columns()) > 3 68 | -------------------------------------------------------------------------------- /tests/test_summary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from gate import summarize 5 | 6 | 7 | def test_summarize(tiny_df): 8 | summary = summarize( 9 | tiny_df, 10 | partition_key="grp", 11 | columns=["string_col", "int_col", "float_col"], 12 | ) 13 | assert len(summary) == 1 14 | summary = summary[0].summary 15 | 16 | assert len(summary) == 3 17 | 18 | # Check all the statistics 19 | expected_result = pd.DataFrame( 20 | [ 21 | { 22 | "grp": "A", 23 | "column": "float_col", 24 | "coverage": 1.0, 25 | "mean": 0.10000000149011612, 26 | "num_unique_values": np.nan, 27 | "occurrence_ratio": np.nan, 28 | "p50": 0.10000000149011612, 29 | "p95": 0.20000000298023224, 30 | }, 31 | { 32 | "grp": "A", 33 | "column": "int_col", 34 | "coverage": 0.6666666865348816, 35 | "mean": 0.5, 36 | "num_unique_values": np.nan, 37 | "occurrence_ratio": np.nan, 38 | "p50": 1.0, 39 | "p95": 1.0, 40 | }, 41 | { 42 | "grp": "A", 43 | "column": "string_col", 44 | "coverage": 1.0, 45 | "mean": np.nan, 46 | "num_unique_values": 2.0, 47 | "occurrence_ratio": 0.6666666865348816, 48 | "p50": np.nan, 49 | "p95": np.nan, 50 | }, 51 | ] 52 | ) 53 | 54 | assert expected_result.columns.tolist() == summary.columns.tolist() 55 | 56 | 57 | def test_bad_df(tiny_df, tiny_df_2): 58 | summary = summarize( 59 | tiny_df, 60 | partition_key="grp", 61 | columns=["string_col", "int_col", "float_col"], 62 | ) 63 | 64 | with pytest.raises(ValueError): 65 | summarize(tiny_df_2, previous_summaries=summary) 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GATE: Data Drift Detection for Machine Learning Pipelines 2 | 3 | [![GATE](https://github.com/dm4ml/gate/workflows/gate/badge.svg)](https://github.com/dm4ml/gate/actions?query=workflow:"gate") 4 | [![lint (via ruff)](https://github.com/dm4ml/gate/workflows/lint/badge.svg)](https://github.com/dm4ml/gate/actions?query=workflow:"lint") 5 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 6 | 7 | GATE is a Python module that detects drift in partitions of data. GATE computes partition summaries, which are then fed into an anomaly detection algorithm to detect whether a new partition is anomalous. This minimizes false positive alerts when detecting drift in machine learning (ML) pipelines, where there may be many features and prediction columns. 8 | 9 | ### Support for Embeddings 10 | 11 | We now support drift detection on embeddings, in addition to structured data. GATE considers _both_ the structured data and the embeddings when computing partition summaries and detecting drift. Check out the [embeddings page](./embedding) for a walkthrough of how to use GATE with embeddings. 12 | 13 | ## Installation 14 | 15 | GATE is available on PyPI and can be installed with pip: 16 | 17 | ```bash 18 | pip install gate-drift 19 | ``` 20 | 21 | Note that GATE requires Python 3.8 or higher. 22 | 23 | ## Usage 24 | 25 | GATE is designed to be used with [Pandas](https://pandas.pydata.org/) dataframes. Check out the [documentation](https://dm4ml.github.io/gate/) for a walkthrough of how to use GATE. 26 | 27 | ## Research Contributions 28 | 29 | GATE was developed and is maintained by researchers at the UC Berkeley [EPIC Lab](https://epic.berkeley.edu/). 30 | 31 | An initial version of GATE was developed as part of a collaboration with Meta, and the research paper, "Moving Fast With Broken Data" by Shankar et al., is available on [arXiv](https://arxiv.org/abs/2303.06094). This module slightly differs from the original implementation, but the core ideas around partition summaries and anomaly detection are the same. 32 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: GATE Docs 2 | site_url: https://dm4ml.github.io/gate/ 3 | repo_url: https://github.com/dm4ml/gate 4 | repo_name: dm4ml/gate 5 | remote_branch: gh-pages 6 | nav: 7 | - Home: index.md 8 | - Example: example.md 9 | - Embeddings: embedding.md 10 | - API Reference: api.md 11 | - How it Works: how-it-works.md 12 | 13 | theme: 14 | name: material 15 | icon: 16 | logo: material/gate 17 | repo: fontawesome/brands/git-alt 18 | favicon: images/logo.png 19 | extra_files: 20 | - images/ 21 | palette: 22 | # Palette toggle for automatic mode 23 | - media: "(prefers-color-scheme)" 24 | primary: blue 25 | accent: orange 26 | toggle: 27 | icon: material/brightness-auto 28 | name: Switch to light mode 29 | 30 | # Palette toggle for light mode 31 | - media: "(prefers-color-scheme: light)" 32 | primary: blue 33 | accent: orange 34 | scheme: default 35 | toggle: 36 | icon: material/brightness-7 37 | name: Switch to dark mode 38 | 39 | # Palette toggle for dark mode 40 | - media: "(prefers-color-scheme: dark)" 41 | primary: blue 42 | accent: orange 43 | scheme: slate 44 | toggle: 45 | icon: material/brightness-4 46 | name: Switch to system preference 47 | font: 48 | text: Fira Sans 49 | code: Fira Code 50 | 51 | features: 52 | - navigation.instant 53 | - navigation.tracking 54 | - navigation.expand 55 | - navigation.path 56 | - navigation.prune 57 | - navigation.indexes 58 | - navigation.top 59 | - navigation.tabs 60 | - navigation.tabs.sticky 61 | - navigation.sections 62 | - toc.follow 63 | - toc.integrate 64 | - content.code.copy 65 | - content.code.annotate 66 | 67 | plugins: 68 | - search 69 | - mkdocstrings 70 | - autorefs 71 | 72 | markdown_extensions: 73 | - abbr 74 | - admonition 75 | - def_list 76 | - footnotes 77 | - md_in_html 78 | - tables 79 | - pymdownx.snippets 80 | - pymdownx.inlinehilite 81 | - pymdownx.tabbed: 82 | alternate_style: true 83 | - pymdownx.superfences: 84 | custom_fences: 85 | - name: mermaid 86 | class: mermaid 87 | format: !!python/name:pymdownx.superfences.fence_code_format 88 | - pymdownx.details 89 | - attr_list 90 | - pymdownx.emoji: 91 | emoji_index: !!python/name:materialx.emoji.twemoji 92 | emoji_generator: !!python/name:materialx.emoji.to_svg -------------------------------------------------------------------------------- /tests/test_multiple_dfs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from gate import summarize, detect_drift 4 | 5 | 6 | def get_df_no_drift(num_cols, num_rows_per_partition=10000): 7 | # create example date range 8 | date_range = pd.date_range(start="2022-01-01", periods=30, freq="D") 9 | 10 | # combine data into a DataFrame 11 | df_elems = [] 12 | for date in date_range: 13 | date_data = {"date": date} 14 | date_data = pd.DataFrame( 15 | { 16 | "date": [date] * num_rows_per_partition, 17 | **{ 18 | f"int_col_{i}": np.random.randint( 19 | low=0, high=10, size=num_rows_per_partition 20 | ) 21 | for i in range(num_cols) 22 | }, 23 | **{ 24 | f"float_col_{i}": np.random.normal( 25 | loc=0, scale=1, size=num_rows_per_partition 26 | ) 27 | for i in range(num_cols) 28 | }, 29 | **{ 30 | f"string_col_{i}": np.random.choice( 31 | ["A", "B", "C"], size=num_rows_per_partition 32 | ) 33 | for i in range(num_cols) 34 | }, 35 | } 36 | ) 37 | df_elems.append(date_data) 38 | 39 | df = pd.concat(df_elems).reset_index(drop=True) 40 | 41 | return df 42 | 43 | 44 | def test_no_drift_scale(): 45 | import time 46 | 47 | start = time.time() 48 | df = get_df_no_drift(100) 49 | time.time() - start 50 | 51 | columns = df.columns.to_list() 52 | columns.remove("date") 53 | 54 | start = time.time() 55 | summaries = summarize( 56 | df, 57 | partition_key="date", 58 | columns=columns, 59 | ) 60 | time.time() - start 61 | 62 | start = time.time() 63 | drift_results = detect_drift(summaries[-1], summaries[:-1]) 64 | time.time() - start 65 | 66 | # print(f"df_creation_time: {df_creation_time}") 67 | # print(f"summary_time: {summary_time}") 68 | # print(f"drift_time: {drift_time}") 69 | # assert False 70 | 71 | assert abs(0.5 - drift_results.score_percentile) <= 0.5 72 | 73 | 74 | def test_drift_scale(): 75 | df = get_df_no_drift(100) 76 | 77 | # Add drift 78 | max_date = df["date"].max() 79 | for i in range(50): 80 | df.loc[df["date"] == max_date, f"int_col_{i}"] = 1000 81 | 82 | columns = df.columns.to_list() 83 | columns.remove("date") 84 | 85 | summaries = summarize( 86 | df, 87 | partition_key="date", 88 | columns=columns, 89 | ) 90 | 91 | drift_results = detect_drift(summaries[-1], summaries[:-1]) 92 | 93 | assert drift_results.score_percentile > 0.85 94 | -------------------------------------------------------------------------------- /docs/how-it-works.md: -------------------------------------------------------------------------------- 1 | ![GATE Architecture Diagram](../images/gate.png) 2 | 3 | GATE is designed specifically for machine learning (ML) pipelines, where there may be many features and prediction columns. While other methods to detect drift may result in large numbers of false positives, GATE is designed to be more robust to this problem through the use of partition summaries. 4 | 5 | ## Partition Summarization 6 | 7 | GATE ingests raw data and computes a _partition summary_ for each partition. A partition summary is a vector of statistical measures that captures the distribution of the data in the partition. Partitions are typically time-based; for example, one per day. The following statistics are computed for each column: 8 | 9 | - coverage: The fraction of the column that has non-null values. 10 | - mean: The mean of the column. 11 | - p50: The median of the column. 12 | - num_unique_values: The number of unique values in the column. 13 | - occurrence_ratio: The count of the most frequent value divided by the total count. 14 | - p95: The 95th percentile of the column. 15 | 16 | Partition summaries are small, and can be computed quickly. They are also robust to outliers, which is important for ML pipelines where there may be many features and prediction columns. 17 | 18 | ## Drift Detection 19 | 20 | The partition summaries are then fed into an anomaly detection algorithm to detect whether a new partition is anomalous. 21 | 22 | ### Clustering 23 | 24 | Since many columns might be correlated, GATE first clusters the columns into groups. GATE considers both the semantic meaning of the column (e.g., "age" and "income") and the partition summaries. 25 | 26 | Clustering is automatically performed by the GATE algorithm. The user does not need to specify the number of clusters. Partition summaries are normalized via z-score before clustering, so that the clustering algorithm is not biased towards columns with larger values. 27 | 28 | !!! note 29 | 30 | Clustering is engaged if there are more than 10 columns. If there are fewer than 10 columns, no clustering is performed. 31 | 32 | ### Nearest Neighbor Algorithm 33 | 34 | Normalized partition summaries are then fed into a nearest neighbor algorithm to detect whether a new partition is anomalous. The nearest neighbor algorithm is a variant of the [k-nearest neighbors](https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm) algorithm. The algorithm computes the distance between the new partition and the nearest neighbors in previous partition summaries. If the distance is large (i.e., in the 90th percentile of distances), the new partition is considered drifted. 35 | 36 | !!! note 37 | If clustering is engaged, column summaries are averaged within each cluster before computing distances. This essentially reduces the dimensionality of the partition summary. 38 | 39 | ## Drill Down 40 | 41 | If a partition is detected as drifted, GATE can be used to drill down into the partition to identify the specific columns that are drifted. The columns with the largest z-score values are returned. 42 | 43 | ## Differences from the original research paper 44 | 45 | Differences from the original implementation include: 46 | 47 | - Removal of the need to specify a window size to normalize statistics over. 48 | - Removal of the Wasserstein-1 distance and num_frequent_values metrics, which are time-consuming to compute and not as useful as other metrics. 49 | - Addition of the p95 metric. 50 | - Embeddings of column names and types in the clustering algorithm (in addition to partition summaries). 51 | - Support for drift detection on embedding columns and computing drifted clusters. -------------------------------------------------------------------------------- /gate/summarize.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pandas as pd 4 | import requests 5 | from PIL import Image 6 | from sentence_transformers import SentenceTransformer 7 | 8 | from gate.summary import Summary 9 | 10 | 11 | def compute_embeddings(column: pd.Series, column_type: str) -> pd.Series: 12 | """Computes embeddings for a Series with the 13 | huggingface/transformers library. We use the 14 | clip-ViT-B-32 model. This is an optional function; 15 | we recommend you compute embeddings yourself. 16 | 17 | Args: 18 | column (pd.Series): Series to compute embeddings for. 19 | Must be of string type. Can contain either 20 | paths to files or text. 21 | column_type (str): Type of the column. Must be "text" or "image". 22 | 23 | Returns: 24 | pd.Series: Series of embeddings to add to your DataFrame. 25 | """ 26 | assert column_type in [ 27 | "text", 28 | "image", 29 | ], "column_type must be text or image" 30 | 31 | model = SentenceTransformer("clip-ViT-B-32") 32 | 33 | def compute_embedding_helper(text: str) -> typing.List[float]: 34 | if column_type == "image": 35 | try: 36 | img = Image.open(text) 37 | return model.encode(img) 38 | except FileNotFoundError: 39 | img = Image.open(requests.get(text, stream=True).raw) 40 | return model.encode(img) 41 | except Exception as e: 42 | raise e 43 | 44 | return model.encode(text) 45 | 46 | return column.apply(compute_embedding_helper) 47 | 48 | 49 | def summarize( 50 | df: pd.DataFrame, 51 | columns: typing.List[str] = [], 52 | embedding_column_map: typing.Dict[str, str] = {}, 53 | partition_key: str = "", 54 | previous_summaries: typing.List[Summary] = [], 55 | ) -> typing.List[Summary]: 56 | """This function computes partition-wide summary statistics for the given 57 | columns. df can have multiple partitions. 58 | 59 | Args: 60 | df (pd.DataFrame): 61 | Dataframe to summarize. 62 | columns (typing.List[str], optional): 63 | List of columns to generate summary statistics for. Must be a 64 | subset of df.columns. If empty, previous_summaries must not be 65 | empty. 66 | embedding_column_map (typing.Dict[str, str], optional): 67 | Dictionary of embedding key to embedding value column. Keys and 68 | values must be in df.columns. If empty, previous_summaries must not 69 | be empty. 70 | partition_key (str, optional): 71 | Name of column to partition the dataframe by. Must be in df. 72 | columns. Can be empty if no partitioning is desired, or if the 73 | dataframe represents a single partition. If empty, 74 | previous_summaries must not be empty. 75 | previous_summaries (typing.List[Summary], optional): 76 | List of Summary objects representing previous partition summaries. 77 | 78 | Returns: 79 | typing.List[Summary]: 80 | List of Summary objects, one per distinct partition found in df. 81 | 82 | Raises: 83 | ValueError: 84 | If `partition_key` is "group". 85 | ValueError: 86 | If `columns is empty` and `previous_summaries` is empty. 87 | ValueError: 88 | If `partition_key `is empty and `previous_summaries` is empty. 89 | ValueError: 90 | If `partition_key` is not in `df.columns`. 91 | ValueError: 92 | If any column in `columns` is not in `df.columns`. 93 | """ 94 | if partition_key == "group": 95 | raise ValueError("Please rename the partition_key; it cannot be `group`.") 96 | 97 | if len(previous_summaries) == 0: 98 | if len(columns) == 0 and len(embedding_column_map) == 0: 99 | raise ValueError( 100 | "You must pass in some columns if you do not have any previous" 101 | " summaries." 102 | ) 103 | if not partition_key: 104 | raise ValueError( 105 | "You must pass in a partition column if you do not have any" 106 | " previous summaries." 107 | ) 108 | 109 | summary = Summary.fromRaw( 110 | df, 111 | columns=columns, 112 | embedding_column_map=embedding_column_map, 113 | partition_key=partition_key, 114 | previous_summaries=previous_summaries, 115 | ) 116 | 117 | return summary 118 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | import random 4 | import numpy as np 5 | 6 | 7 | @pytest.fixture 8 | def tiny_df(): 9 | df = pd.DataFrame( 10 | { 11 | "grp": ["A"] * 3, 12 | "string_col": ["cat", "dog", "dog"], 13 | "int_col": [0, 1, None], 14 | "float_col": [0.0, 0.1, 0.2], 15 | } 16 | ) 17 | return df 18 | 19 | 20 | @pytest.fixture 21 | def tiny_df_2(): 22 | df = pd.DataFrame( 23 | { 24 | "other_grp": ["A"] * 3, 25 | "string_col": ["cat", "dog", "dog"], 26 | "int_col": [0, 1, None], 27 | "float_col": [0.0, 0.1, 0.2], 28 | } 29 | ) 30 | return df 31 | 32 | 33 | @pytest.fixture 34 | def small_df(): 35 | # create example data 36 | groups = ["A", "B", "C", "D", "E"] * 2 37 | string_values = ["foo", "bar", "baz"] * 3 38 | string_values.append(None) 39 | int_values = [random.randint(0, 100) for _ in range(10)] 40 | float_values = [random.uniform(0, 1) for _ in range(10)] 41 | 42 | # create DataFrame 43 | df = pd.DataFrame( 44 | { 45 | "grp": groups, 46 | "string_col": string_values, 47 | "int_col": int_values, 48 | "float_col": float_values, 49 | } 50 | ) 51 | return df 52 | 53 | 54 | @pytest.fixture 55 | def medium_df(): 56 | # create example date range 57 | date_range = pd.date_range(start="2022-01-01", periods=30, freq="D") 58 | 59 | # create example data for each column 60 | int_col = np.random.randint(low=0, high=10, size=10000) 61 | float_col = np.random.normal(loc=0, scale=1, size=10000) 62 | string_col = np.random.choice(["A", "B", "C"], size=10000) 63 | 64 | # combine data into a DataFrame 65 | df_elems = [] 66 | for date in date_range: 67 | date_data = {"date": date} 68 | date_data = pd.DataFrame( 69 | { 70 | "date": [date] * len(int_col), 71 | "int_col": int_col, 72 | "float_col": float_col, 73 | "string_col": string_col, 74 | } 75 | ) 76 | df_elems.append(date_data) 77 | 78 | df = pd.concat(df_elems).reset_index(drop=True) 79 | return df 80 | 81 | 82 | @pytest.fixture 83 | def df_with_drift(): 84 | # create example date range 85 | date_range = pd.date_range(start="2022-01-01", periods=10, freq="D") 86 | 87 | # combine data into a DataFrame 88 | df_elems = [] 89 | for date in date_range: 90 | if date != date_range[-1]: 91 | int_col = np.random.randint(low=0, high=10, size=10000) 92 | float_col = np.random.normal(loc=0, scale=1, size=10000) 93 | string_col = np.random.choice(["A", "B", "C"], size=10000) 94 | 95 | date_data = {"date": date} 96 | date_data = pd.DataFrame( 97 | { 98 | "date": [date] * len(int_col), 99 | "int_col": int_col, 100 | "float_col": float_col, 101 | "string_col": string_col, 102 | "int_col_2": np.random.randint(low=10, high=20, size=10000), 103 | "float_col_2": np.random.normal(loc=1, scale=2, size=10000), 104 | "float_col_3": np.random.normal(loc=1, scale=2, size=10000), 105 | "float_col_4": np.random.normal(loc=1, scale=2, size=10000), 106 | "string_col_2": np.random.choice(["D", "B", "C"], size=10000), 107 | "string_col_3": np.random.choice(["E", "B", "C"], size=10000), 108 | "string_col_4": np.random.choice(["F", "B", "C"], size=10000), 109 | } 110 | ) 111 | df_elems.append(date_data) 112 | else: 113 | int_col = np.random.randint(low=10, high=20, size=10000) 114 | float_col = np.random.normal(loc=1, scale=2, size=10000) 115 | string_col = np.random.choice(["D", "B", "C"], size=10000) 116 | 117 | date_data = {"date": date} 118 | date_data = pd.DataFrame( 119 | { 120 | "date": [date] * len(int_col), 121 | "int_col": int_col, 122 | "int_col_2": np.random.randint(low=10, high=20, size=10000), 123 | "float_col": float_col, 124 | "float_col_2": np.random.normal(loc=1, scale=2, size=10000), 125 | "float_col_3": np.random.normal(loc=1, scale=2, size=10000), 126 | "float_col_4": np.random.normal(loc=1, scale=2, size=10000), 127 | "string_col": string_col, 128 | "string_col_2": np.random.choice(["D", "B", "C"], size=10000), 129 | "string_col_3": np.random.choice(["E", "B", "C"], size=10000), 130 | "string_col_4": np.random.choice(["F", "B", "C"], size=10000), 131 | } 132 | ) 133 | df_elems.append(date_data) 134 | 135 | df = pd.concat(df_elems).reset_index(drop=True) 136 | return df 137 | -------------------------------------------------------------------------------- /docs/embedding.md: -------------------------------------------------------------------------------- 1 | # Drift Detection on Embeddings 2 | 3 | GATE supports drift detection and debugging of embeddings, in addition to structured data. At a high level, embeddings are represented in their own column, and you can call [`summarize`](/gate/api/#gate.summarize.summarize) and [`detect_drift`](/gate/api/#gate.drift.detect_drift) on dataframes with embedding columns. 4 | 5 | ## Embedding key and value columns 6 | 7 | In your original dataframe, you should have a column that contains the embedding key, and a column that contains the embedding value. The key column should be a string (e.g., text, filename), and the value column should be a list of floats. For example: 8 | 9 | ```python 10 | df = pd.DataFrame( 11 | { 12 | "date": ["2020-01-01", "2020-01-01", "2020-01-01"], # This is the partition key 13 | "text": ["hello world!", "goodbye", "a third greeting"], 14 | "embedding": [ 15 | [0.1, 0.2, 0.3], # Imagine this is the embedding for "hello world!" 16 | [0.4, 0.5, 0.6], # Imagine this is the embedding for "goodbye" 17 | [0.7, 0.8, 0.9], # Imagine this is the embedding for "a third greeting" 18 | ], 19 | } 20 | ) 21 | ``` 22 | 23 | Then, when calling [`summarize`](/gate/api/#gate.summarize.summarize) on your dataframe, you can specify the embedding key-value pairs as follows: 24 | 25 | ```python 26 | from gate import summarize 27 | 28 | summarize( 29 | df, 30 | partition_key="date", 31 | embedding_column_map={"text": "embedding"}, 32 | ) 33 | ``` 34 | 35 | Both keys and values in `embedding_column_map` should be strings, representing column names in your dataframe. 36 | 37 | ## Summarizing embeddings 38 | 39 | When you call [`summarize`](/gate/api/#gate.summarize.summarize) on a dataframe with embedding columns, GATE will automatically compute summary statistics for each dimension in the embedding values. You can access these summaries by calling [`embeddings_summary`](/gate/api/#gate.summary.Summary.embeddings_summary) on the returned [`Summary`](/gate/api/#gate.summary.Summary) object. 40 | 41 | GATE will also cluster the embeddings, compute centroids for each cluster, and store examples for each cluster. Embeddings are clustered for each embedding column separately. You can access the examples by calling [`embedding_examples`](/gate/api/#gate.summary.Summary.embedding_examples) on the returned [`Summary`](/gate/api/#gate.summary.Summary) object. You can access the centroids by calling [`embedding_centroids`](/gate/api/#gate.summary.Summary.embedding_centroids) on the returned [`Summary`](/gate/api/#gate.summary.Summary) object. 42 | 43 | 44 | ```python 45 | from gate import summarize 46 | 47 | summaries = summarize( 48 | df, 49 | partition_key="date", 50 | columns=[], # No structured columns 51 | embedding_column_map={"text": "embedding"}, 52 | ) # (1)! 53 | 54 | # Get the summary statistics for the embedding values 55 | summaries[0].embeddings_summary 56 | 57 | # Get the examples for each cluster 58 | summaries[0].embedding_examples("text") # Must passing embedding key 59 | 60 | # Get the centroids for each cluster 61 | summaries[0].embedding_centroids("text") # Must passing embedding key 62 | ``` 63 | 64 | 1. Note that [`summarize`](/gate/api/#gate.summarize.summarize) returns a list of [`Summary`](/gate/api/#gate.summary.Summary) objects, one for each partition key. In this example, we only have one partition key, so we access the first element of the list. 65 | 66 | In practice, you probably won't need to call [`embedding_examples`](/gate/api/#gate.summary.Summary.embedding_examples) or [`embedding_centroids`](/gate/api/#gate.summary.Summary.embedding_centroids) directly. These methods are used in `detect_drift`, as described below. 67 | 68 | ## Detecting drift on embeddings 69 | 70 | You can call [`detect_drift`](/gate/api/#gate.drift.detect_drift) on summaries of dataframes with embedding columns. Drift detection takes both structured column data and embeddings into consideration, if you have both. 71 | 72 | [`detect_drift`](/gate/api/#gate.drift.detect_drift) will return a [`DriftResult`](/gate/api/#gate.drift.DriftResult) object, which contains the following information relevant to embeddings: 73 | 74 | - [`drifted_columns`](/gate/api/#gate.drift.DriftResult.drifted_columns): Returns a dataframe of column names that have drifted, their most anomalous statistic (e.g., coverage), and the z-score. This includes both structured columns and embedding columns. 75 | - [`drifted_examples`](/gate/api/#gate.drift.DriftResult.drifted_examples): Returns examples that have drifted most from their historical clusters. This is specific to embeddings. The object returned is a dictionary with `drifted_examples` and `corresponding_examples` keys. The value of each key is a dataframe with columns `partition_key`, `embedding_key_column`, and `embedding_value_column`. 76 | 77 | An example of calling [`detect_drift`](/gate/api/#gate.drift.detect_drift) on a dataframe with embedding columns is shown below: 78 | 79 | ```python 80 | from gate import detect_drift 81 | 82 | drift_result = detect_drift( 83 | summary, 84 | previous_summaries 85 | ) 86 | 87 | # Get the drifted columns 88 | drift_result.drifted_columns() 89 | 90 | # Get the drifted examples 91 | drifted_example_result = drift_result.drifted_examples("text") # Must passing embedding key 92 | drifted_example_result["drifted_examples"] 93 | ``` 94 | 95 | ## Real Dataset Example 96 | 97 | For an example of using GATE with embeddings, see this [example notebook](https://www.github.com/dm4ml/gate/blob/main/examples/civilcomments.ipynb) in the Github repository. -------------------------------------------------------------------------------- /docs/example.md: -------------------------------------------------------------------------------- 1 | There are two functions exposed by the GATE module: [`summarize`](/gate/api/#gate.summarize.summarize) and [`detect_drift`](/gate/api/#gate.drift.detect_drift). [`summarize`](/gate/api/#gate.summarize.summarize) computes partition summaries for a dataframe, and [`detect_drift`](/gate/api/#gate.drift.detect_drift) detects whether a new partition is drifted. 2 | 3 | In this example, we'll demonstrate how to use GATE to detect drift in small synthetic dataset. 4 | 5 | ## Dataset Creation 6 | 7 | Our synthetic dataset will be created in Pandas. The partition key will be `date`. There will be 10 partitions, and each partition will have 10,000 rows. There will be 3 columns. The last partition will have a different column distribution than the first 9 partitions. 8 | 9 | ```python 10 | import numpy as np 11 | import pandas as pd 12 | 13 | # create example date range 14 | date_range = pd.date_range(start="2022-01-01", periods=10, freq="D") 15 | 16 | # create example data for each column 17 | int_col = np.random.randint(low=0, high=10, size=10000) 18 | float_col = np.random.normal(loc=0, scale=1, size=10000) 19 | string_col = np.random.choice(["A", "B", "C"], size=10000) 20 | 21 | # combine data into a DataFrame 22 | df_elems = [] 23 | for date in date_range: 24 | date_data = {"date": date} 25 | if date != date_range[-1]: 26 | date_data = pd.DataFrame( 27 | { 28 | "date": [date] * len(int_col), 29 | "int_col": int_col, 30 | "float_col": float_col, 31 | "string_col": string_col, 32 | } 33 | ) 34 | else: 35 | # Change the distribution of the int column 36 | date_data = pd.DataFrame( 37 | { 38 | "date": [date] * len(int_col), 39 | "int_col": np.random.randint(low=10, high=20, size=10000), 40 | "float_col": float_col, 41 | "string_col": string_col 42 | } 43 | ) 44 | df_elems.append(date_data) 45 | 46 | df = pd.concat(df_elems).reset_index(drop=True) 47 | ``` 48 | 49 | ## [`summarize`](/gate/api/#gate.summarize.summarize) 50 | 51 | The [`summarize`](/gate/api/#gate.summarize.summarize) function computes partition summaries for a dataframe. In addition to a Pandas dataframe of raw data, it accepts the partition key and a list of columns in the dataframe to compute statistics for. Or, one can specify a list of previous partition summaries instead of the partition key and column list, and GATE will infer the partition key and columns from the previous partition summaries. 52 | 53 | The [`summarize`](/gate/api/#gate.summarize.summarize) function returns a list of [`Summary`](/gate/api/#gate.summary.Summary) objects. Each [`Summary`](/gate/api/#gate.summary.Summary) object contains the partition summary and other metadata, and has a `__str__` method that prints the summary in a human-readable format. 54 | 55 | 56 | ```python 57 | from gate import summarize 58 | 59 | summaries = summarize( 60 | df, partition_key="date", columns=["int_col", "float_col", "string_col"] 61 | ) 62 | # len(summaries) == 10 because there are 10 distinct partitions 63 | 64 | print(summaries[-1]) 65 | 66 | """ 67 | date column coverage mean num_unique_values occurrence_ratio p50 p95 68 | 0 2022-01-10 float_col 1.0 0.015739 NaN NaN 0.019152 1.665352 69 | 1 2022-01-10 int_col 1.0 14.520700 10.0 0.1032 15.000000 19.000000 70 | 2 2022-01-10 string_col 1.0 NaN 3.0 0.3411 NaN NaN 71 | """ 72 | ``` 73 | 74 | !!! note 75 | 76 | You can access the summary data as a Pandas dataframe with the `value` attribute of the [`Summary`](/gate/api/#gate.summary.Summary) object (i.e., `summaries[-1].summary`). 77 | 78 | ## [`detect_drift`](/gate/api/#gate.drift.detect_drift) 79 | 80 | The [`detect_drift`](/gate/api/#gate.drift.detect_drift) function detects whether a new partition is drifted. It accepts a new partition summary and list of previous partition summaries and returns a [`DriftResult`](/gate/api/#gate.drift.DriftResult) object. The [`DriftResult`](/gate/api/#gate.drift.DriftResult) object has a `__str__` method that prints the drift result in a human-readable format. 81 | 82 | ```python 83 | from gate import detect_drift 84 | 85 | drift_result = detect_drift(summaries[-1], summaries[:-1]) 86 | print(drift_result) 87 | 88 | """ 89 | Drift score: 6.3246 (100.00% percentile) 90 | Top drifted columns: 91 | statistic z-score 92 | column 93 | int_col p95 2.846050 94 | float_col p95 0.000002 95 | string_col coverage 0.000000 96 | """ 97 | ``` 98 | 99 | The z-score represents the number of standard deviations away from the mean that the new partition is. In this case, the int col correctly has a high z-score. We recommend focusing on z-scores > 2.5 or < -2.5 when looking for drift. 100 | 101 | If you want to cluster correlated columns, you can pass `cluster = True` into [`detect_drift`](/gate/api/#gate.drift.detect_drift). The [`DriftResult`](/gate/api/#gate.drift.DriftResult) object has a `clustering` attribute that contains the clusters. 102 | 103 | !!! note 104 | 105 | The list of previous partition summaries must have at least one element. Best results are achieved when there are at least 5 previous partition summaries. 106 | 107 | ## Real Dataset Example 108 | 109 | For an end-to-end example on a real weather dataset, see the [example notebook](https://www.github.com/dm4ml/gate/blob/main/examples/weather.ipynb) in the Github repository. -------------------------------------------------------------------------------- /gate/statistics.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import polars as pl 6 | from sklearn.cluster import KMeans 7 | 8 | 9 | def type_to_statistics(t: str) -> typing.List[str]: 10 | """Returns the statistics that are computed for a given type. 11 | 12 | Args: 13 | t (str): Type (one of "int", "float", "string", "embedding"). 14 | 15 | Returns: 16 | typing.List[str]: 17 | List of statistics that are computed for the type. 18 | Partition summaries will have NaNs for statistics that are not computed. 19 | 20 | Raises: 21 | ValueError: If the type is unknown. 22 | """ 23 | 24 | if t == "int": 25 | return [ 26 | "coverage", 27 | "mean", 28 | "p50", 29 | # "stdev", 30 | "num_unique_values", 31 | "occurrence_ratio", 32 | "p95", 33 | ] 34 | 35 | if t == "float": 36 | return [ 37 | "coverage", 38 | "mean", 39 | "p50", 40 | # "stdev", 41 | "p95", 42 | ] 43 | 44 | if t == "string": 45 | return ["coverage", "num_unique_values", "occurrence_ratio"] 46 | 47 | if t == "embedding": 48 | return ["coverage", "mean", "p50", "p95"] 49 | 50 | raise ValueError(f"Unknown type {t}") 51 | 52 | 53 | def cluster( 54 | group: pd.DataFrame, 55 | embedding_value_column: str, 56 | num_clusters: int, 57 | num_examples: int, 58 | limit: int = 2000, 59 | ): 60 | shuffled = group.sample(limit, random_state=42) if len(group) > limit else group 61 | matrix = np.vstack(shuffled[embedding_value_column].apply(np.array).values) 62 | 63 | kmeans = KMeans( 64 | n_clusters=num_clusters, 65 | init="k-means++", 66 | n_init="auto", 67 | random_state=42, 68 | ) 69 | kmeans.fit(matrix) 70 | labels = kmeans.labels_ 71 | shuffled["cluster"] = labels 72 | centroids = kmeans.cluster_centers_ 73 | 74 | # Select examples from each cluster 75 | examples = ( 76 | shuffled.groupby("cluster") 77 | .apply(lambda x: x.sample(num_examples) if len(x) > num_examples else x) 78 | .reset_index(drop=True) 79 | ) 80 | 81 | return examples, centroids 82 | 83 | 84 | def compute_embeddings_examples( 85 | polars_df: pl.DataFrame, 86 | embedding_column_map: typing.Dict[str, str], 87 | partition_key: str, 88 | num_clusters: int, 89 | num_examples: int, 90 | ) -> typing.Tuple[ 91 | typing.Dict[str, typing.Dict[str, pd.DataFrame]], 92 | typing.Dict[str, typing.Dict[str, np.ndarray]], 93 | ]: 94 | """Computes examples and centroids to store in each partition 95 | summary for each embedding column. 96 | 97 | Args: 98 | polars_df (pl.DataFrame): DataFrame with the embeddings. 99 | embedding_column_map (typing.Dict[str, str]): 100 | Map from embedding key column to embedding value column. 101 | partition_key (str): Column to partition by. 102 | num_clusters (int): Number of clusters to use in KMeans to 103 | cluster the embeddings. 104 | num_examples (int): Number of examples from each cluster to store 105 | in the partition summary. 106 | 107 | Returns: 108 | typing.Tuple[typing.Dict[str, typing.Dict[str, pd.DataFrame]], typing. 109 | Dict[str, typing.Dict[str, np.ndarray]]]: 110 | Examples and centroids to store in each partition summary. 111 | """ 112 | 113 | all_examples = {} 114 | all_centroids = {} 115 | for ( 116 | embedding_key_column, 117 | embedding_value_column, 118 | ) in embedding_column_map.items(): 119 | # Select examples 120 | for partition_value, group in ( 121 | polars_df[[partition_key, embedding_key_column, embedding_value_column]] 122 | .to_pandas() 123 | .groupby(partition_key) 124 | ): 125 | examples, centroids = cluster( 126 | group, embedding_value_column, num_clusters, num_examples 127 | ) 128 | 129 | if partition_value not in all_examples: 130 | all_examples[partition_value] = {} 131 | 132 | if partition_value not in all_centroids: 133 | all_centroids[partition_value] = {} 134 | 135 | all_examples[partition_value][embedding_key_column] = examples 136 | all_centroids[partition_value][embedding_key_column] = centroids 137 | 138 | return all_examples, all_centroids 139 | 140 | 141 | def compute_embeddings_summary( 142 | polars_df: pl.DataFrame, 143 | embedding_column_map: typing.Dict[str, str], 144 | partition_key: str, 145 | ): 146 | embedding_dfs = [] 147 | for ( 148 | _, 149 | embedding_value_column, 150 | ) in embedding_column_map.items(): 151 | lengths = polars_df.select( 152 | pl.col(embedding_value_column).arr.lengths().alias("lengths") 153 | ) 154 | num_lengths = lengths["lengths"].n_unique() 155 | if num_lengths > 1: 156 | raise ValueError( 157 | f"Embedding value column {embedding_value_column} has" 158 | " different lengths. All embeddings must have the same" 159 | " length." 160 | ) 161 | length = lengths["lengths"].head(1)[0] 162 | 163 | embedding_df = polars_df.select( 164 | [pl.col(partition_key)] 165 | + [ 166 | pl.col(embedding_value_column) 167 | .arr.get(i) 168 | .alias(embedding_value_column + f"_{i}") 169 | for i in range(length) 170 | ] 171 | ) 172 | embedding_dfs.append(embedding_df) 173 | 174 | full_embedding_df = embedding_dfs[0] 175 | for embedding_df in embedding_dfs[1:]: 176 | full_embedding_df = full_embedding_df.join(embedding_df, on=partition_key) 177 | 178 | return full_embedding_df 179 | -------------------------------------------------------------------------------- /tests/test_embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from gate import summarize, detect_drift, compute_embeddings 4 | 5 | import os 6 | import pytest 7 | 8 | IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" 9 | 10 | 11 | @pytest.mark.flaky(reruns=2) 12 | def test_summarize_embedding_tiny(): 13 | embedding_list = [np.random.rand(3) for _ in range(100)] 14 | 15 | # Create the DataFrame 16 | df = pd.DataFrame( 17 | { 18 | "grp": np.random.choice(["A", "B", "C", "D", "E", "F"], size=100), 19 | "embedding_key": np.random.randint(20, 60, 100), 20 | "embedding_value": embedding_list, 21 | } 22 | ) 23 | 24 | summary = summarize( 25 | df, 26 | partition_key="grp", 27 | embedding_column_map={"embedding_key": "embedding_value"}, 28 | ) 29 | drift_results = detect_drift(summary[-1], summary[:-1]) 30 | 31 | assert not drift_results.is_drifted 32 | 33 | 34 | @pytest.mark.flaky(reruns=2) 35 | def test_summarize_embedding_medium(): 36 | embedding_list = [np.random.rand(100) for _ in range(100)] 37 | 38 | # Create the DataFrame 39 | df = pd.DataFrame( 40 | { 41 | "grp": np.random.choice(["A", "B", "C", "D", "E", "F"], size=100), 42 | "embedding_key": np.random.randint(20, 60, 100), 43 | "embedding_value": embedding_list, 44 | } 45 | ) 46 | 47 | summary = summarize( 48 | df, 49 | partition_key="grp", 50 | embedding_column_map={"embedding_key": "embedding_value"}, 51 | ) 52 | drift_results = detect_drift(summary[-1], summary[:-1]) 53 | 54 | assert not drift_results.is_drifted 55 | 56 | 57 | @pytest.mark.flaky(reruns=2) 58 | def test_summarize_embedding_big_with_drift(): 59 | embedding_list = [np.random.rand(2048) for _ in range(1000)] 60 | date_range = pd.date_range(start="2022-01-01", periods=10, freq="D") 61 | 62 | # Create the DataFrame 63 | df = pd.DataFrame( 64 | { 65 | "date": np.random.choice(date_range[:-1], size=1000), 66 | "embedding_key": np.random.choice( 67 | np.random.randint(20, 60, 100), size=1000 68 | ), 69 | "embedding_value": embedding_list, 70 | } 71 | ) 72 | 73 | prev_summaries = summarize( 74 | df, 75 | partition_key="date", 76 | embedding_column_map={"embedding_key": "embedding_value"}, 77 | ) 78 | 79 | drifted_df = pd.DataFrame( 80 | { 81 | "date": [date_range[-1] for _ in range(1000)], 82 | "embedding_key": np.random.choice( 83 | np.random.randint(20, 60, 100), size=1000 84 | ), 85 | "embedding_value": [ 86 | np.random.rand(2048) * 10 for _ in range(1000) 87 | ], 88 | } 89 | ) 90 | 91 | summary = summarize(drifted_df, previous_summaries=prev_summaries) 92 | 93 | assert len(summary[0].embedding_examples("embedding_key")) > 0 94 | assert len(summary[0].embedding_examples("embedding_key").columns) == 4 95 | 96 | with pytest.raises(ValueError): 97 | summary[0].embedding_examples("nonexsistent_key") 98 | 99 | drift_results = detect_drift(summary[0], prev_summaries) 100 | 101 | assert drift_results.is_drifted 102 | 103 | examples = drift_results.drifted_examples("embedding_key") 104 | assert "drifted_examples" in examples.keys() 105 | assert "corresponding_examples" in examples.keys() 106 | 107 | 108 | def test_compute_embedding(): 109 | # Create dataframe with image urls from the internet 110 | df = pd.DataFrame( 111 | { 112 | "url": [ 113 | "https://picsum.photos/200/300/?random&{}".format(i) 114 | for i in range(1, 11) 115 | ], 116 | "id": [i for i in range(1, 11)], 117 | } 118 | ) 119 | embeddings = compute_embeddings(df["url"], column_type="image") 120 | assert len(embeddings) == 10 121 | 122 | 123 | @pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="takes too long") 124 | def test_image_embedding_drift(): 125 | # Create dataframe with image urls from the internet 126 | date_range = pd.date_range(start="2022-01-01", periods=6, freq="D") 127 | df = pd.DataFrame( 128 | { 129 | "url": [ 130 | "https://picsum.photos/200/300/?random&{}".format(i) 131 | for i in range(1, 51) 132 | ], 133 | "date": np.random.choice(date_range[:-1], size=50), 134 | } 135 | ) 136 | df["embedding"] = compute_embeddings(df["url"], column_type="image") 137 | 138 | prev_summaries = summarize( 139 | df, 140 | partition_key="date", 141 | embedding_column_map={"url": "embedding"}, 142 | ) 143 | 144 | drifted_df = pd.DataFrame( 145 | { 146 | "date": [date_range[-1] for _ in range(10)], 147 | "url": ["fake_url_{}".format(i) for i in range(1, 11)], 148 | "embedding": [np.random.rand(512) for _ in range(10)], 149 | } 150 | ) 151 | 152 | summary = summarize(drifted_df, previous_summaries=prev_summaries)[0] 153 | 154 | assert len(summary.embedding_examples("url")) > 0 155 | 156 | drift_results = detect_drift(summary, prev_summaries) 157 | 158 | assert drift_results.is_drifted 159 | 160 | drifted_examples = drift_results.drifted_examples("url") 161 | 162 | assert len(drifted_examples["drifted_examples"]) > 0 163 | assert len(drifted_examples["corresponding_examples"]) > 0 164 | 165 | 166 | @pytest.mark.skipif( 167 | IN_GITHUB_ACTIONS, reason="Test doesn't work in Github Actions." 168 | ) 169 | def test_clustering_with_drift(): 170 | num_rows = 100 171 | num_cols = 50 172 | data = np.random.rand(num_rows, num_cols) 173 | cols = [f"col_{i+1}" for i in range(num_cols)] 174 | date_range = pd.date_range(start="2022-01-01", periods=10, freq="D") 175 | 176 | df = pd.DataFrame(data, columns=cols) 177 | df["embedding_key"] = [ 178 | "embedding_key_{}".format(i) for i in range(num_rows) 179 | ] 180 | df["embedding"] = [np.random.rand(1024) for _ in range(num_rows)] 181 | df["embedding_key2"] = [ 182 | "embedding_key2_{}".format(i) for i in range(num_rows) 183 | ] 184 | df["embedding2"] = [np.random.rand(1024) + 1 for _ in range(num_rows)] 185 | df["date"] = np.random.choice(date_range[:-1], size=num_rows) 186 | 187 | prev_summaries = summarize( 188 | df, 189 | partition_key="date", 190 | columns=cols, 191 | embedding_column_map={ 192 | "embedding_key": "embedding", 193 | "embedding_key2": "embedding2", 194 | }, 195 | ) 196 | 197 | drifted_df = pd.DataFrame(data[:50, :], columns=cols) 198 | drifted_df["date"] = [date_range[-1] for _ in range(50)] 199 | drifted_df["embedding"] = [np.random.rand(1024) * 10 for _ in range(50)] 200 | drifted_df["embedding2"] = [ 201 | np.random.rand(1024) * 10 + 1 for _ in range(50) 202 | ] 203 | drifted_df["embedding_key"] = [ 204 | "embedding_key_{}".format(i) for i in range(50) 205 | ] 206 | drifted_df["embedding_key2"] = [ 207 | "embedding_key2_{}".format(i) for i in range(50) 208 | ] 209 | 210 | summary = summarize(drifted_df, previous_summaries=prev_summaries) 211 | 212 | drift_results = detect_drift(summary[0], prev_summaries) 213 | 214 | assert drift_results.is_drifted 215 | assert ( 216 | "embedding2" in drift_results.drifted_columns().head(2).index.tolist() 217 | ) 218 | assert ( 219 | "embedding" in drift_results.drifted_columns().head(2).index.tolist() 220 | ) 221 | -------------------------------------------------------------------------------- /gate/summary.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import polars as pl 6 | 7 | from gate.statistics import compute_embeddings_examples, compute_embeddings_summary 8 | 9 | 10 | class Summary: 11 | def __init__( 12 | self, 13 | summary: pd.DataFrame, 14 | embeddings_summary: pd.DataFrame, 15 | string_columns: typing.List[str], 16 | float_columns: typing.List[str], 17 | int_columns: typing.List[str], 18 | embedding_column_map: typing.Dict[str, str], 19 | embedding_examples: typing.Dict[str, pd.DataFrame], 20 | embedding_centroids: typing.Dict[str, np.ndarray], 21 | partition_key: str, 22 | partition_value: typing.Any, 23 | ): 24 | self._summary = summary 25 | self._embeddings_summary = embeddings_summary 26 | self._string_columns = string_columns 27 | self._float_columns = float_columns 28 | self._int_columns = int_columns 29 | self._embedding_columns = list(embedding_column_map.values()) 30 | self._embedding_column_map = embedding_column_map 31 | self._embedding_examples = embedding_examples 32 | self._embedding_centroids = embedding_centroids 33 | self._partition_key = partition_key 34 | self._partition = partition_value 35 | 36 | if len(embedding_column_map) > 0: 37 | assert embedding_column_map.keys() == embedding_examples.keys() 38 | 39 | @property 40 | def summary(self) -> pd.DataFrame: 41 | """Dataframe containing the summary statistics.""" 42 | return self._summary 43 | 44 | @property 45 | def embeddings_summary(self) -> pd.DataFrame: 46 | """Dataframe containing the embeddings summary statistics if 47 | there are embeddings, otherwise None.""" 48 | return self._embeddings_summary 49 | 50 | @property 51 | def partition(self) -> str: 52 | """Partition value.""" 53 | return self._partition 54 | 55 | @property 56 | def columns(self) -> typing.List[str]: 57 | """Columns for which summary statistics were computed.""" 58 | return ( 59 | self._string_columns 60 | + self._float_columns 61 | + self._int_columns 62 | + self._embedding_columns 63 | ) 64 | 65 | @property 66 | def non_embedding_columns(self) -> typing.List[str]: 67 | """Columns for which summary statistics were computed. 68 | Ignores embedding columns.""" 69 | return self._string_columns + self._float_columns + self._int_columns 70 | 71 | @property 72 | def partition_key(self) -> str: 73 | """Partition key column.""" 74 | return self._partition_key 75 | 76 | def embedding_examples(self, embedding_key_column: str) -> pd.DataFrame: 77 | """Returns examples in each embedding cluster for the given 78 | embedding key column. 79 | 80 | Args: 81 | embedding_key_column (str): 82 | Column name representing the embedding key. 83 | 84 | Raises: 85 | ValueError: If there are no embedding examples. 86 | ValueError: If the embedding key column does not exist. 87 | 88 | Returns: 89 | pd.DataFrame: Examples in each embedding cluster. Contains 90 | the columns partition_key, embedding_key_column, 91 | embedding_value_column, and cluster. 92 | """ 93 | if self._embedding_examples is None: 94 | raise ValueError("There are no embedding examples.") 95 | 96 | if embedding_key_column not in self._embedding_examples: 97 | raise ValueError( 98 | f"Embedding key column {embedding_key_column} does not exist." 99 | f" Valid columns are {self._embedding_examples.keys()}." 100 | ) 101 | 102 | return self._embedding_examples[embedding_key_column] 103 | 104 | def embedding_centroids(self, embedding_key_column: str) -> np.ndarray: 105 | """Returns embedding centroids for the given embedding key column. 106 | 107 | Args: 108 | embedding_key_column (str): 109 | Column name representing the embedding key. 110 | 111 | Raises: 112 | ValueError: If there are no embedding examples. 113 | ValueError: If the embedding key column does not exist. 114 | 115 | Returns: 116 | np.ndarray: 117 | Matrix of embedding centroids, size (num_clusters, embedding_dim). 118 | """ 119 | if self._embedding_centroids is None: 120 | raise ValueError("There are no embedding centroids.") 121 | 122 | if embedding_key_column not in self._embedding_centroids: 123 | raise ValueError( 124 | f"Embedding key column {embedding_key_column} does not exist." 125 | f" Valid columns are {self._embedding_centroids.keys()}." 126 | ) 127 | 128 | return self._embedding_centroids[embedding_key_column] 129 | 130 | def statistics(self) -> typing.List[str]: 131 | """ 132 | Returns list of statistics computed for each column: 133 | 134 | * coverage: Fraction of rows that are not null. 135 | * mean: Mean of the column. 136 | * p50: Median of the column. 137 | * num_unique_values: Number of unique values in the column. 138 | * occurrence_ratio: Ratio of the most common value to all other 139 | values. 140 | * p95: 95th percentile of the column. 141 | """ 142 | value = self.value() 143 | statistics = value.columns.tolist() 144 | statistics.remove(self.partition_key) 145 | statistics.remove("column") 146 | return statistics 147 | 148 | @classmethod 149 | def fromRaw( 150 | cls, 151 | raw_data: pd.DataFrame, 152 | columns: typing.List[str] = [], 153 | embedding_column_map: typing.Dict[str, str] = {}, 154 | partition_key: str = "", 155 | previous_summaries: typing.List["Summary"] = [], 156 | ) -> typing.List["Summary"]: 157 | polars_df = pl.DataFrame(raw_data) 158 | 159 | if len(previous_summaries) > 0: 160 | partition_key = previous_summaries[0].partition_key 161 | columns = previous_summaries[0].columns 162 | string_columns = previous_summaries[0]._string_columns 163 | float_columns = previous_summaries[0]._float_columns 164 | int_columns = previous_summaries[0]._int_columns 165 | embedding_column_map = previous_summaries[0]._embedding_column_map 166 | else: 167 | # Set up columns if it's the first partition 168 | assert ( 169 | len(columns) > 0 or len(embedding_column_map) > 0 170 | ), "Must specify columns or embedding_column_map." 171 | 172 | if not set(columns).issubset(set(raw_data.columns)): 173 | raise ValueError( 174 | "Columns to compute summaries on are not all in the dataframe." 175 | ) 176 | types = polars_df.schema 177 | 178 | column_types = {c: types[c] for c in columns} 179 | string_columns = [c for c, t in column_types.items() if t == pl.Utf8] 180 | float_columns = [ 181 | c for c, t in column_types.items() if t == pl.Float32 or t == pl.Float64 182 | ] 183 | int_columns = [ 184 | c 185 | for c, t in column_types.items() 186 | if t == pl.Int64 or t == pl.Int32 or t == pl.Int16 or t == pl.Int8 187 | ] 188 | bool_columns = [c for c, t in column_types.items() if t == pl.Boolean] 189 | for c in bool_columns: 190 | polars_df = polars_df.with_columns([pl.col(c).cast(pl.Int8).alias(c)]) 191 | int_columns += bool_columns 192 | 193 | assert len(string_columns) + len(float_columns) + len(int_columns) == len( 194 | columns 195 | ), "Columns have unknown type. Must be one of int, float, string," 196 | 197 | if partition_key not in polars_df.columns: 198 | raise ValueError( 199 | f"Partition column {partition_key} is not in dataframe columns." 200 | ) 201 | if not set(columns).issubset(set(polars_df.columns)): 202 | raise ValueError( 203 | "Columns to compute summaries on are not all in the dataframe." 204 | ) 205 | 206 | # Compute the summary statistics 207 | statistics = [ 208 | polars_df.groupby(partition_key) 209 | .agg( 210 | [ 211 | pl.col(c).is_not_null().mean().alias(c).cast(pl.Float32) 212 | for c in string_columns + float_columns + int_columns 213 | ] 214 | ) 215 | .with_columns([pl.lit("coverage").alias("statistic")]), 216 | polars_df.groupby(partition_key) 217 | .agg( 218 | [ 219 | pl.col(c).cast(pl.Float32).mean().alias(c) 220 | for c in float_columns + int_columns 221 | ] 222 | ) 223 | .with_columns([pl.lit("mean").alias("statistic")]), 224 | polars_df.groupby(partition_key) 225 | .agg( 226 | [ 227 | pl.col(c).quantile(0.5).cast(pl.Float32).alias(c) 228 | for c in float_columns + int_columns 229 | ] 230 | ) 231 | .with_columns([pl.lit("p50").alias("statistic")]), 232 | polars_df.groupby(partition_key) 233 | .agg( 234 | [ 235 | pl.col(c).approx_unique().cast(pl.Float32).alias(c) 236 | for c in string_columns + int_columns 237 | ] 238 | ) 239 | .with_columns([pl.lit("num_unique_values").alias("statistic")]), 240 | polars_df.groupby(partition_key) 241 | .agg( 242 | [ 243 | ((pl.col(c).unique_counts().max()) / (pl.col(c).count())) 244 | .alias(c) 245 | .cast(pl.Float32) 246 | for c in string_columns + int_columns 247 | ] 248 | ) 249 | .with_columns([pl.lit("occurrence_ratio").alias("statistic")]), 250 | polars_df.groupby(partition_key) 251 | .agg( 252 | [ 253 | pl.col(c).quantile(0.95).cast(pl.Float32).alias(c) 254 | for c in float_columns + int_columns 255 | ] 256 | ) 257 | .with_columns([pl.lit("p95").alias("statistic")]), 258 | ] 259 | statistics = pl.concat(statistics, how="diagonal").to_pandas() 260 | 261 | # Pivot such that columns are the statistics and rows are the row name, 262 | # and it's grouped by partition col 263 | 264 | pivoted_statistics = ( 265 | statistics.melt( 266 | id_vars=[partition_key, "statistic"], 267 | value_vars=string_columns + float_columns + int_columns, 268 | var_name="column", 269 | ) 270 | .pivot( 271 | index=[partition_key, "column"], 272 | columns="statistic", 273 | values="value", 274 | ) 275 | .reset_index() 276 | ) 277 | pivoted_statistics.columns = pivoted_statistics.columns.tolist() 278 | 279 | if len(embedding_column_map) > 0: 280 | # Embedding statistics 281 | full_embedding_df = compute_embeddings_summary( 282 | polars_df, embedding_column_map, partition_key 283 | ) 284 | ( 285 | embedding_example_map, 286 | embedding_centroids_map, 287 | ) = compute_embeddings_examples( 288 | polars_df, 289 | embedding_column_map, 290 | partition_key, 291 | num_clusters=5, 292 | num_examples=10, 293 | ) 294 | 295 | embedding_statistics = [ 296 | full_embedding_df.groupby(partition_key) 297 | .agg( 298 | [ 299 | pl.col(c).is_not_null().mean().alias(c).cast(pl.Float32) 300 | for c in full_embedding_df.columns[1:] 301 | ] 302 | ) 303 | .with_columns([pl.lit("coverage").alias("statistic")]), 304 | full_embedding_df.groupby(partition_key) 305 | .agg( 306 | [ 307 | pl.col(c).cast(pl.Float32).mean().alias(c) 308 | for c in full_embedding_df.columns[1:] 309 | ] 310 | ) 311 | .with_columns([pl.lit("mean").alias("statistic")]), 312 | full_embedding_df.groupby(partition_key) 313 | .agg( 314 | [ 315 | pl.col(c).quantile(0.5).cast(pl.Float32).alias(c) 316 | for c in full_embedding_df.columns[1:] 317 | ] 318 | ) 319 | .with_columns([pl.lit("p50").alias("statistic")]), 320 | full_embedding_df.groupby(partition_key) 321 | .agg( 322 | [ 323 | pl.col(c).quantile(0.95).cast(pl.Float32).alias(c) 324 | for c in full_embedding_df.columns[1:] 325 | ] 326 | ) 327 | .with_columns([pl.lit("p95").alias("statistic")]), 328 | ] 329 | 330 | embedding_statistics = pl.concat( 331 | embedding_statistics, how="diagonal" 332 | ).to_pandas() 333 | embedding_col_index = embedding_statistics.columns.tolist() 334 | embedding_col_index.remove("statistic") 335 | embedding_col_index.remove(partition_key) 336 | 337 | pivoted_embeddings = ( 338 | embedding_statistics.melt( 339 | id_vars=[partition_key, "statistic"], 340 | value_vars=embedding_col_index, 341 | var_name="column", 342 | ) 343 | .pivot( 344 | index=[partition_key, "column"], 345 | columns="statistic", 346 | values="value", 347 | ) 348 | .reset_index() 349 | ) 350 | pivoted_embeddings.columns = pivoted_embeddings.columns.tolist() 351 | 352 | groups = [] 353 | for partition_value, group in pivoted_statistics.groupby(partition_key): 354 | relevant_embeddings = ( 355 | pivoted_embeddings[ 356 | pivoted_embeddings[partition_key] == partition_value 357 | ].reset_index(drop=True) 358 | if len(embedding_column_map) > 0 359 | else None 360 | ) 361 | 362 | groups.append( 363 | cls( 364 | group.reset_index(drop=True), 365 | relevant_embeddings, 366 | string_columns, 367 | float_columns, 368 | int_columns, 369 | embedding_column_map, 370 | ( 371 | embedding_example_map[partition_value] 372 | if len(embedding_column_map) > 0 373 | else None 374 | ), 375 | ( 376 | embedding_centroids_map[partition_value] 377 | if len(embedding_column_map) > 0 378 | else None 379 | ), 380 | partition_key, 381 | partition_value, 382 | ) 383 | ) 384 | 385 | # Handle if there are only embeddings columns 386 | if len(string_columns) + len(float_columns) + len(int_columns) == 0: 387 | for partition_value, group in pivoted_embeddings.groupby(partition_key): 388 | groups.append( 389 | cls( 390 | None, 391 | group.reset_index(drop=True), 392 | string_columns, 393 | float_columns, 394 | int_columns, 395 | embedding_column_map, 396 | embedding_example_map[partition_value], 397 | embedding_centroids_map[partition_value], 398 | partition_key, 399 | partition_value, 400 | ) 401 | ) 402 | 403 | return groups 404 | 405 | def value(self) -> pd.DataFrame: 406 | """Combines the summary and embeddings summary into a single dataframe. 407 | 408 | Returns: 409 | pd.DataFrame: Summary including embeddings, if exists. 410 | """ 411 | if self.embeddings_summary is None: 412 | return self.summary 413 | 414 | if self.summary is None: 415 | return self.embeddings_summary 416 | 417 | return pd.concat([self.summary, self.embeddings_summary], ignore_index=True) 418 | 419 | def __str__(self) -> str: 420 | """ 421 | String representation of the object's value (i.e., summary). 422 | 423 | Usage: `print(summary)` 424 | """ 425 | 426 | if self.embeddings_summary is None: 427 | return f"Summary:\n{self.summary.to_string()}" 428 | 429 | if self.summary is None: 430 | return f"Embedding summary:\n{self.embeddings_summary.to_string()}" 431 | 432 | return ( 433 | f"Regular summary:\n{self.summary.to_string()}\nEmbedding" 434 | f" summary:\n{self.embeddings_summary.to_string()}" 435 | ) 436 | -------------------------------------------------------------------------------- /gate/drift.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import numpy as np 4 | import pandas as pd 5 | from scipy.spatial import cKDTree 6 | from scipy.stats import percentileofscore 7 | from sentence_transformers import SentenceTransformer 8 | from sklearn.cluster import AgglomerativeClustering 9 | from sklearn.decomposition import PCA 10 | from sklearn.metrics.pairwise import cosine_similarity 11 | 12 | from gate.summary import Summary 13 | 14 | 15 | class DriftResult: 16 | def __init__( 17 | self, 18 | all_scores: pd.Series, 19 | nn_features: pd.DataFrame, 20 | summary: Summary, 21 | neighbor_summaries: typing.List[Summary], 22 | clustered_features: pd.DataFrame, 23 | embedding_columns: typing.List[str], 24 | ) -> None: 25 | self._all_scores = all_scores 26 | self._nn_features = nn_features 27 | self._summary = summary 28 | self._neighbor_summaries = neighbor_summaries 29 | self._clustered_features = clustered_features 30 | self._embedding_columns = embedding_columns 31 | 32 | @property 33 | def summary(self) -> Summary: 34 | """Summary of the partition.""" 35 | return self._summary 36 | 37 | @property 38 | def neighbor_summaries(self) -> typing.List[Summary]: 39 | """Summaries of the nearest neighbors of the partition.""" 40 | return self._neighbor_summaries 41 | 42 | def drifted_examples( 43 | self, embedding_key_column: str 44 | ) -> typing.Dict[str, pd.DataFrame]: 45 | """Returns some examples from the partition that are 46 | most drifted from nearest neighbors in the embedding space 47 | in previous partitions. 48 | 49 | Throws an error if the embedding_key_column isn't a valid 50 | embedding key column, or if there are no embedding columns. 51 | 52 | Args: 53 | embedding_key_column (str): 54 | Column that represents the embedding key (e.g., text, image). 55 | 56 | Returns: 57 | typing.Dict[str, pd.DataFrame]: 58 | Dictionary with two keys: "drifted_examples" and 59 | "corresponding_examples". The value of each key is a 60 | dataframe with columns "partition_key", "embedding_key_column", 61 | and "embedding_value_column". 62 | """ 63 | all_centroids = np.vstack( 64 | [ 65 | s.embedding_centroids(embedding_key_column) 66 | for s in self.neighbor_summaries 67 | ] 68 | ) 69 | all_centroid_idxs = [ 70 | (i, j) 71 | for i in range(len(self.neighbor_summaries)) 72 | for j in range( 73 | len( 74 | self.neighbor_summaries[i].embedding_centroids(embedding_key_column) 75 | ) 76 | ) 77 | ] 78 | curr_centroids = self.summary.embedding_centroids(embedding_key_column) 79 | 80 | # Compute similarity 81 | similarity_matrix = cosine_similarity(curr_centroids, all_centroids) 82 | most_dissimilar_row_idx = np.argmax(np.min(similarity_matrix, axis=1)) 83 | dissimilar_examples = self.summary.embedding_examples(embedding_key_column) 84 | dissimilar_examples = dissimilar_examples[ 85 | dissimilar_examples["cluster"] == most_dissimilar_row_idx 86 | ].reset_index(drop=True) 87 | corresponding_row_idx = np.argmin(similarity_matrix[most_dissimilar_row_idx]) 88 | corresponding_examples = self.neighbor_summaries[ 89 | all_centroid_idxs[corresponding_row_idx][0] 90 | ].embedding_examples(embedding_key_column) 91 | corresponding_examples = corresponding_examples[ 92 | corresponding_examples["cluster"] 93 | == all_centroid_idxs[corresponding_row_idx][1] 94 | ].reset_index(drop=True) 95 | 96 | return { 97 | "drifted_examples": dissimilar_examples.drop("cluster", axis=1), 98 | "corresponding_examples": corresponding_examples.drop("cluster", axis=1), 99 | } 100 | 101 | @property 102 | def score(self) -> float: 103 | """Distance from the partition to its k nearest neighbors.""" 104 | return self._all_scores[self.summary.partition] 105 | 106 | @property 107 | def is_drifted(self) -> bool: 108 | """ 109 | Indicates whether the partition is drifted or not, compared 110 | to previous partitions. This is determined by the percentile 111 | of the partition's score in the distribution of all scores. 112 | The threshold is 95%. 113 | """ 114 | return self.score_percentile >= 0.95 115 | 116 | @property 117 | def score_percentile(self) -> float: 118 | """Percentile of the partition's score in the distribution 119 | of all scores.""" 120 | return percentileofscore(self.all_scores, self.score) * 1.0 / 100 121 | 122 | @property 123 | def all_scores(self) -> pd.Series: 124 | """Scores of all previous partitions.""" 125 | mask = self._all_scores.index != self.summary.partition 126 | return self._all_scores[mask] 127 | 128 | @property 129 | def clustering(self) -> typing.Dict[int, typing.List[str]]: 130 | """ 131 | Clustering of the columns based on their partition summaries 132 | and meaning of column names (determined via embeddings). Returns 133 | a dictionary with cluster numbers as keys and lists of columns 134 | as values. 135 | """ 136 | if self._clustered_features is None: 137 | raise ValueError("No clustering was performed.") 138 | 139 | clustering_map = self._clustered_features.groupby("cluster")["column"].agg(set) 140 | clustering_map = clustering_map.apply(list) 141 | 142 | return clustering_map.to_dict() 143 | 144 | def drill_down( 145 | self, 146 | sort_by_cluster_score: bool = False, 147 | average_embedding_columns: bool = True, 148 | ) -> pd.DataFrame: 149 | """Compute the columns with highest magnitude anomaly scores. 150 | Anomaly scores are computed as the z-score of the column with 151 | respect to previous partition summary statistics. 152 | 153 | The resulting dataframe has the following schema (column, statistic are 154 | indexes): 155 | 156 | - column: Name of the column 157 | - statistic: Name of the statistic 158 | - z-score: z-score of the column 159 | - cluster: Cluster number that the column belongs to (if clustering was 160 | performed) 161 | - abs(z-score-cluster): absolute value of the average z-score of the 162 | column in the cluster (if clustering was performed) 163 | 164 | Use the `drifted_columns` method first, since `drifted_columns` 165 | deduplicates columns. 166 | 167 | Args: 168 | sort_by_cluster_score (bool, optional): 169 | Whether to sort by cluster z-score. Defaults to False. 170 | average_embedding_columns (bool, optional): 171 | Whether to average statistics across embedding dimensions. 172 | Defaults to True. 173 | 174 | Returns: 175 | pd.DataFrame: 176 | Dataframe with columns with highest magnitude anomaly 177 | scores. Sorted by the magnitude of the z-score for a column. 178 | If clustering was performed, the dataframe will be sorted 179 | by the magnitude of the z-score in the cluster before 180 | the column score. 181 | """ 182 | 183 | # Return a dataframe with features with highest magnitude anomaly 184 | # scores 185 | last_day = self._nn_features.loc[self.summary.partition] 186 | sorted_cols = last_day.abs().sort_values(ascending=False).index 187 | sorted_df = last_day[sorted_cols].to_frame() 188 | sorted_df.rename(columns={sorted_df.columns[0]: "z-score"}, inplace=True) 189 | sorted_df = sorted_df.rename_axis(["column", "statistic"]) 190 | 191 | if self._clustered_features is not None: 192 | # Join the clustered features with the sorted_df 193 | # sorted_df.rename(index={"column": "cluster"}, inplace=True) 194 | sorted_df = sorted_df.rename_axis(["cluster", "statistic"]).reset_index() 195 | sorted_df.rename(columns={"z-score": "abs(z-score-cluster)"}, inplace=True) 196 | 197 | sorted_df = sorted_df.merge( 198 | self._clustered_features, 199 | on=["cluster", "statistic"], 200 | how="left", 201 | ) 202 | 203 | # Sort again 204 | if sort_by_cluster_score: 205 | sorted_df = sorted_df.reindex( 206 | sorted_df[["abs(z-score-cluster)", "z-score"]] 207 | .abs() 208 | .sort_values( 209 | by=["abs(z-score-cluster)", "z-score"], ascending=False 210 | ) 211 | .index 212 | ) 213 | sorted_df.set_index(["column", "statistic"], inplace=True) 214 | 215 | if len(self._embedding_columns) > 0 and average_embedding_columns: 216 | # Average the z-scores 217 | sorted_df.reset_index(inplace=True) 218 | sorted_df["column"] = sorted_df["column"].apply( 219 | lambda x: name_to_ec(x, self._embedding_columns) 220 | ) 221 | sorted_df["z-score"] = sorted_df.apply( 222 | lambda x: ( 223 | abs(x["z-score"]) 224 | if x["column"] in self._embedding_columns 225 | else x["z-score"] 226 | ), 227 | axis=1, 228 | ) 229 | sorted_df = sorted_df.groupby(["column", "statistic"]).mean() 230 | sorted_df = sorted_df.reindex( 231 | sorted_df["z-score"].abs().sort_values(ascending=False).index 232 | ) 233 | 234 | # sorted_df.sort_values(by="z-score", ascending=False, inplace=True) 235 | 236 | return sorted_df 237 | 238 | def __str__(self) -> str: 239 | """Prints the drift score, percentile, and the top drifted columns.""" 240 | results = ( 241 | "Drift score:" 242 | f" {self.score:.4f} ({self.score_percentile:.2%} percentile)\nTop" 243 | f" drifted columns:\n{self.drifted_columns()}" 244 | ) 245 | return results 246 | 247 | def drifted_columns( 248 | self, 249 | limit: int = 10, 250 | average_embedding_columns: bool = True, 251 | ) -> pd.DataFrame: 252 | """Returns the top limit columns that have drifted. The 253 | resulting dataframe has the following schema (column is an 254 | index): 255 | 256 | - column: Name of the column 257 | - statistic: Name of the statistic 258 | - z-score: z-score of the column 259 | - cluster: Cluster number of the column (if clustering was performed) 260 | - abs(z-score-cluster): z-score of the column in the cluster (if 261 | clustering was performed) 262 | 263 | Args: 264 | limit (int, optional): 265 | Limit for number of drifted columns to return. Defaults to 10. 266 | average_embedding_columns (bool, optional): 267 | Whether to average statistics across embedding dimensions. 268 | Defaults to True. 269 | 270 | Returns: 271 | pd.DataFrame: 272 | Dataframe with columns with highest magnitude z-scores. 273 | If clustering was performed, the dataframe will also contain 274 | the z-score in the cluster and the cluster number. 275 | Each column is deduplicated, so only the statistic with the 276 | highest magnitude z-score is returned. 277 | """ 278 | # Return a dataframe of the top limit columns that have drifted 279 | # Drop duplicate column names 280 | dd_results = self.drill_down(average_embedding_columns) 281 | 282 | if self._clustered_features is not None: 283 | # Sort by z-score first, then abs(z-score-cluster) 284 | dd_results = dd_results.reindex( 285 | dd_results[["z-score", "abs(z-score-cluster)"]] 286 | .abs() 287 | .sort_values(by=["z-score", "abs(z-score-cluster)"], ascending=False) 288 | .index 289 | ) 290 | 291 | dd_results.reset_index(inplace=True) 292 | 293 | dd_results.drop_duplicates(subset=["column"], keep="first", inplace=True) 294 | dd_results.set_index("column", inplace=True) 295 | 296 | if self._clustered_features is not None: 297 | # Reorder columns 298 | dd_results = dd_results[ 299 | ["statistic", "z-score", "cluster", "abs(z-score-cluster)"] 300 | ] 301 | dd_results = dd_results[dd_results["abs(z-score-cluster)"].abs() > 0.0] 302 | 303 | return dd_results.head(limit) 304 | 305 | 306 | def name_to_ec(name: str, embedding_columns: typing.List[str]) -> str: 307 | """Converts a column name to an embedding column name. 308 | 309 | Args: 310 | name (str): 311 | Column name. 312 | embedding_columns (typing.List[str]): 313 | List of embedding columns. 314 | 315 | Returns: 316 | str: 317 | Embedding column name. 318 | """ 319 | if type(name) != str: 320 | print(name) 321 | split_name = name.rsplit("_", 1)[0] 322 | if split_name in embedding_columns: 323 | return split_name 324 | else: 325 | return name 326 | 327 | 328 | def detect_drift( 329 | current_summary: Summary, 330 | previous_summaries: typing.List[Summary], 331 | validity: typing.List[int] = [], 332 | cluster: bool = True, 333 | k: int = 3, 334 | ) -> DriftResult: 335 | """Computes whether the current partition summary has drifted from previous 336 | summaries. 337 | 338 | Args: 339 | current_summary (Summary): 340 | Partition summary for current partition. 341 | previous_summaries (typing.List[Summary]): 342 | Previous partition summaries. 343 | validity (typing.List[int], optional): 344 | Indicator list identifying which partition summaries are valid. 1 345 | if valid, 0 if invalid. If empty, we assume all partition summaries 346 | are valid. Must be empty or equal to length of previous_summaries. 347 | cluster (bool, optional): 348 | Whether or not to cluster columns in summaries. Increases runtime 349 | but also increases precision in drift detection. Only engaged if 350 | summaries have more than 10 columns. Defaults to True. 351 | k (int, optional): 352 | Number of nearest neighbor partitions to inspect. 353 | Defaults to 3. 354 | 355 | Returns (DriftResult): DriftResult object with score and score percentile. 356 | """ 357 | if len(previous_summaries) < 5: 358 | raise ValueError( 359 | "You must have at least 5 previous partition summary to detect" 360 | " drift. You can randomly split your data from previous partitions" 361 | " into 5+ partitions if you need to." 362 | ) 363 | 364 | partition_key = current_summary.partition_key 365 | columns = current_summary.columns 366 | statistics = current_summary.statistics() 367 | 368 | # Create validity vector 369 | if not validity: 370 | validity = [1] * len(previous_summaries) 371 | if len(validity) != len(previous_summaries): 372 | raise ValueError( 373 | f"Validity vector has length {len(validity)} but should have" 374 | f" length {len(previous_summaries)} to match previous_summaries." 375 | ) 376 | validity.append(1) 377 | 378 | # Normalize current and previous partition summaries 379 | prev_summaries = [ 380 | s.value() for i, s in enumerate(previous_summaries) if validity[i] == 1 381 | ] 382 | normalized_summaries = normalize( 383 | pd.concat(prev_summaries + [current_summary.value()]).reset_index(drop=True), 384 | partition_key, 385 | statistics, 386 | ) 387 | 388 | # Run clustering algorithm if there are more than 10 columns 389 | if cluster and len(columns) >= 10: 390 | clustering = compute_clusters( 391 | normalized_summaries, 392 | partition_key, 393 | current_summary._string_columns, 394 | current_summary._float_columns, 395 | current_summary._int_columns, 396 | current_summary._embedding_columns, 397 | ) 398 | 399 | normalized_summaries["value_abs"] = normalized_summaries["value"].abs() 400 | 401 | cluster_normalized = ( 402 | normalized_summaries.merge(clustering, on=["column"], how="left") 403 | # .set_index(partition_key) 404 | .groupby([partition_key, "cluster", "statistic"])["value_abs"] 405 | .mean() 406 | .reset_index() 407 | ) 408 | cluster_normalized.rename( 409 | {"cluster": "column", "value_abs": "value"}, axis=1, inplace=True 410 | ) 411 | normalized_summaries.drop("value_abs", axis=1, inplace=True) 412 | 413 | # Run nearest neighbor algorithm to get distances 414 | nn_features_unpivoted = ( 415 | cluster_normalized if (cluster and len(columns) >= 10) else normalized_summaries 416 | ) 417 | # nn_features_unpivoted["value"] = nn_features_unpivoted["value"].apply( 418 | # lambda x: 0.0 if np.abs(x) < z_score_cutoff else x 419 | # ) 420 | 421 | nn_features = ( 422 | nn_features_unpivoted.fillna(0.0) 423 | .pivot_table( 424 | index=partition_key, 425 | columns=["column", "statistic"], 426 | values="value", 427 | ) 428 | .fillna(0.0) 429 | ) 430 | 431 | dists, indices = cKDTree(nn_features.values).query(nn_features.values, k=k + 1) 432 | 433 | neighbor_partitions = nn_features.index[indices[-1][1:]].to_list() 434 | neighbor_summaries = [ 435 | s for s in previous_summaries if s.partition in neighbor_partitions 436 | ] 437 | 438 | # Replace inf with nan 439 | dists[np.isinf(dists)] = np.nan 440 | 441 | scores = pd.Series( 442 | data=np.nanmean(dists[:, 1:], axis=1), 443 | index=nn_features.index, 444 | ) 445 | 446 | if cluster and len(columns) >= 10: 447 | partition_value = scores.index[-1] 448 | clustered_features = normalized_summaries[ 449 | normalized_summaries[partition_key] == partition_value 450 | ].merge(clustering, on=["column"], how="left") 451 | 452 | clustered_features.rename({"value": "z-score"}, axis=1, inplace=True) 453 | clustered_features.drop(partition_key, axis=1, inplace=True) 454 | 455 | return DriftResult( 456 | scores, 457 | nn_features, 458 | current_summary, 459 | neighbor_summaries=neighbor_summaries, 460 | clustered_features=clustered_features, 461 | embedding_columns=current_summary._embedding_columns, 462 | ) 463 | 464 | else: 465 | return DriftResult( 466 | scores, 467 | nn_features, 468 | current_summary, 469 | neighbor_summaries=neighbor_summaries, 470 | clustered_features=None, 471 | embedding_columns=current_summary._embedding_columns, 472 | ) 473 | 474 | 475 | def normalize( 476 | all_summaries: pd.DataFrame, 477 | partition_key: str, 478 | statistics: typing.List[str], 479 | ) -> pd.DataFrame: 480 | """Melt and normalize partition summaries. 481 | 482 | Args: 483 | all_summaries (pd.DataFrame): concatenated summaries to normalize 484 | partition_key (str): partition key 485 | statistics (typing.List[str]): statistics to normalize 486 | 487 | Returns: 488 | pd.DataFrame: normalized summary 489 | """ 490 | normalized = all_summaries.melt( 491 | id_vars=[partition_key, "column"], 492 | value_vars=statistics, 493 | var_name="statistic", 494 | value_name="value", 495 | ).dropna() 496 | 497 | grouped = normalized.groupby(["column", "statistic"]) 498 | mean = grouped["value"].transform("mean") 499 | std = grouped["value"].transform("std") 500 | std += 1e-10 501 | normalized["value"] = (normalized["value"] - mean) / std 502 | return normalized 503 | 504 | 505 | def compute_clusters( 506 | normalized: pd.DataFrame, 507 | partition_key: str, 508 | string_columns: typing.List[str], 509 | float_columns: typing.List[str], 510 | int_columns: typing.List[str], 511 | embedding_columns: typing.List[str], 512 | ) -> pd.DataFrame: 513 | """Computes clusters of columns in a partition summary. 514 | 515 | Args: 516 | normalized (pd.DataFrame): Normalized partition summary. 517 | partition_key (str): Name of partition column. 518 | string_columns (typing.List[str]): List of string columns. 519 | float_columns (typing.List[str]): List of float columns. 520 | int_columns (typing.List[str]): List of int columns. 521 | embedding_columns (typing.List[str]): List of embedding columns. 522 | 523 | Returns (pd.DataFrame): Mapping of column names to cluster numbers. 524 | """ 525 | 526 | column_stats = normalized.pivot_table( 527 | index="column", columns=[partition_key, "statistic"], values="value" 528 | ).fillna(0.0) 529 | 530 | column_names = column_stats.index.tolist() 531 | column_names_to_types = {c: "string" for c in string_columns} 532 | column_names_to_types.update({c: "float" for c in float_columns}) 533 | column_names_to_types.update({c: "int" for c in int_columns}) 534 | embedding_columns_with_indexes = [ 535 | c for c in column_names if name_to_ec(c, embedding_columns) in embedding_columns 536 | ] 537 | # column_names_to_types.update( 538 | # {c: "embedding" for c in embedding_columns_with_indexes} 539 | # ) 540 | for embedding_col_name in embedding_columns_with_indexes: 541 | column_names.remove(embedding_col_name) 542 | 543 | model = SentenceTransformer("sentence-transformers/clip-ViT-B-32") 544 | embeddings = model.encode( 545 | [f"{c} is of type {column_names_to_types[c]}" for c in column_names] 546 | ) 547 | 548 | embedding_similarity_matrix = cosine_similarity(embeddings) 549 | value_similarity_matrix = cosine_similarity( 550 | column_stats[column_stats.index.isin(column_names)].values 551 | ) 552 | similarity_matrix = ( 553 | 0.25 * embedding_similarity_matrix + 0.75 * value_similarity_matrix 554 | ) 555 | 556 | # Run PCA on similarity matrix to get number of clusters 557 | pca = PCA(random_state=42) 558 | pca.fit(similarity_matrix) 559 | cumev = np.cumsum(pca.explained_variance_ratio_) 560 | # Find cluster cutoff 561 | cutoff = -1 562 | PCA_THRESHOLD = 0.95 563 | for idx, elem in enumerate(cumev): 564 | if elem > PCA_THRESHOLD: 565 | cutoff = idx 566 | break 567 | 568 | clustering = AgglomerativeClustering( 569 | metric="precomputed", 570 | linkage="average", 571 | n_clusters=cutoff + 1, 572 | ) 573 | clustering.fit(similarity_matrix) 574 | 575 | # Aggregate columns based on clustering labels 576 | 577 | cluster_labels = clustering.labels_ 578 | clusters = {} 579 | for i, label in enumerate(cluster_labels): 580 | clusters[column_names[i]] = label 581 | max_label = max(cluster_labels) 582 | 583 | # Add embedding columns to clusters 584 | for i, embedding_col_name in enumerate(embedding_columns): 585 | for name in column_stats.index.tolist(): 586 | if name_to_ec(name, embedding_columns) == embedding_col_name: 587 | clusters[name] = max_label + i + 1 588 | 589 | cluster_df = ( 590 | pd.DataFrame.from_dict(clusters, orient="index", columns=["cluster"]) 591 | .reset_index() 592 | .rename(columns={"index": "column"}) 593 | ) 594 | # cluster_df["cluster"] = cluster_df["cluster"].astype(str) 595 | 596 | return cluster_df 597 | --------------------------------------------------------------------------------