├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── Makefile
├── README.md
├── conftest.py
├── docs
├── Makefile
├── docs-requirements.txt
├── make.bat
└── source
│ ├── _static
│ └── spark_matcher_logo.png
│ ├── api
│ ├── activelearner.rst
│ ├── blocker.rst
│ ├── data.rst
│ ├── deduplicator.rst
│ ├── matcher.rst
│ ├── matching_base.rst
│ ├── modules.rst
│ ├── sampler.rst
│ ├── scorer.rst
│ └── similarity_metrics.rst
│ ├── conf.py
│ ├── example.ipynb
│ ├── index.rst
│ └── installation_guide.rst
├── examples
├── example_deduplicator.ipynb
├── example_matcher.ipynb
├── example_matcher_advanced.ipynb
└── example_stopword_removal.ipynb
├── external_dependencies
└── .gitkeep
├── licenses_bundled
├── pyproject.toml
├── setup.py
├── spark_matcher
├── __init__.py
├── activelearner
│ ├── __init__.py
│ └── active_learner.py
├── blocker
│ ├── __init__.py
│ ├── block_learner.py
│ └── blocking_rules.py
├── config.py
├── data
│ ├── .gitkeep
│ ├── __init__.py
│ ├── acm.csv
│ ├── datasets.py
│ ├── dblp.csv
│ ├── stoxx50.csv
│ ├── voters_1.csv
│ └── voters_2.csv
├── deduplicator
│ ├── __init__.py
│ ├── connected_components_calculator.py
│ ├── deduplicator.py
│ └── hierarchical_clustering.py
├── matcher
│ ├── __init__.py
│ └── matcher.py
├── matching_base
│ ├── __init__.py
│ └── matching_base.py
├── sampler
│ ├── __init__.py
│ └── training_sampler.py
├── scorer
│ ├── __init__.py
│ └── scorer.py
├── similarity_metrics
│ ├── __init__.py
│ └── similarity_metrics.py
├── table_checkpointer.py
└── utils.py
├── spark_requirements.txt
└── test
├── test_active_learner
└── test_active_learner.py
├── test_blocker
├── test_block_learner.py
└── test_blocking_rules.py
├── test_deduplicator
├── test_deduplicator.py
└── test_hierarchical_clustering.py
├── test_matcher
└── test_matcher.py
├── test_matching_base
└── test_matching_base.py
├── test_sampler
└── test_training_sampler.py
├── test_scorer
└── test_scorer.py
├── test_similarity_metrics
└── test_similarity_metrics.py
├── test_table_checkpointer
└── test_table_checkpointer.py
└── test_utils
└── test_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Data-folder
2 | /data/**
3 | !/data/**/
4 |
5 | # VSCode
6 | .vscode/
7 |
8 | # Generated documentation
9 | docs/build/*
10 | .pdf
11 | *.pdf
12 |
13 | # To make sure that the folder structure is tracked
14 | !/**/.gitkeep
15 |
16 | # General
17 | *~
18 | *.log
19 | .env
20 | *.swp
21 | .cache/*
22 |
23 | # Python pickle files
24 | *.pickle
25 | *.pkl
26 |
27 | # Byte-compiled / optimized / DLL files
28 | __pycache__/
29 | *.py[cod]
30 | *$py.class
31 |
32 | # Jupyter Notebook
33 | .ipynb_checkpoints
34 |
35 | # Pycharm
36 | .idea/
37 |
38 | # Pytest
39 | *.pytest_cache/
40 |
41 | *.DS_Store
42 |
43 |
44 |
45 | # Created by https://www.gitignore.io/api/python
46 | # Edit at https://www.gitignore.io/?templates=python
47 |
48 | ### Python ###
49 | # Byte-compiled / optimized / DLL files
50 | __pycache__/
51 | *.py[cod]
52 | *$py.class
53 |
54 | # C extensions
55 | *.so
56 |
57 | # Distribution / packaging
58 | .Python
59 | build/
60 | develop-eggs/
61 | dist/
62 | downloads/
63 | eggs/
64 | .eggs/
65 | lib/
66 | lib64/
67 | parts/
68 | sdist/
69 | var/
70 | wheels/
71 | pip-wheel-metadata/
72 | share/python-wheels/
73 | *.egg-info/
74 | .installed.cfg
75 | *.egg
76 | MANIFEST
77 |
78 | # PyInstaller
79 | # Usually these files are written by a python script from a template
80 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
81 | *.manifest
82 | *.spec
83 |
84 | # Installer logs
85 | pip-log.txt
86 | pip-delete-this-directory.txt
87 |
88 | # Unit test / coverage reports
89 | htmlcov/
90 | .tox/
91 | .nox/
92 | .coverage
93 | .coverage.*
94 | .cache
95 | nosetests.xml
96 | coverage.xml
97 | *.cover
98 | .hypothesis/
99 | .pytest_cache/
100 |
101 | # Translations
102 | *.mo
103 | *.pot
104 |
105 | # Django stuff:
106 | *.log
107 | local_settings.py
108 | db.sqlite3
109 |
110 | # Flask stuff:
111 | instance/
112 | .webassets-cache
113 |
114 | # Scrapy stuff:
115 | .scrapy
116 |
117 | # Sphinx documentation
118 | docs/_build/
119 |
120 | # PyBuilder
121 | target/
122 |
123 | # Jupyter Notebook
124 | .ipynb_checkpoints
125 |
126 | # IPython
127 | profile_default/
128 | ipython_config.py
129 |
130 | # pyenv
131 | .python-version
132 |
133 | # pipenv
134 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
135 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
136 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not
137 | # install all needed dependencies.
138 | #Pipfile.lock
139 |
140 | # celery beat schedule file
141 | celerybeat-schedule
142 |
143 | # SageMath parsed files
144 | *.sage.py
145 |
146 | # Environments
147 | .env
148 | .venv
149 | env/
150 | venv/
151 | ENV/
152 | env.bak/
153 | venv.bak/
154 |
155 | # Spyder project settings
156 | .spyderproject
157 | .spyproject
158 |
159 | # Rope project settings
160 | .ropeproject
161 |
162 | # mkdocs documentation
163 | /site
164 |
165 | # mypy
166 | .mypy_cache/
167 | .dmypy.json
168 | dmypy.json
169 |
170 | # Pyre type checker
171 | .pyre/
172 |
173 | # End of https://www.gitignore.io/api/python
174 | /external_dependencies/graphframes*.jar
175 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yaml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | version: 2
6 |
7 | sphinx:
8 | configuration: docs/source/conf.py
9 |
10 | python:
11 | version: 3.7
12 | install:
13 | - requirements: docs/docs-requirements.txt
14 | - method: pip
15 | path: .
16 | system_packages: true
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | create_documentation:
2 | make clean
3 | make html
4 |
5 | deploy:
6 | rm -rf dist
7 | rm -rf build
8 | python3 -m build
9 | python3 -m twine upload dist/*
10 |
11 | clean:
12 | rm -f *.o prog3
13 |
14 | html:
15 | sphinx-build -b html docs/source docs/build
16 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | [](https://pypi.org/project/spark-matcher/)
3 | [](https://pepy.tech/project/spark-matcher)
4 | 
5 | [][#docs-package]
6 |
7 | [#docs-package]: https://spark-matcher.readthedocs.io/en/latest/
8 |
9 |
10 | 
11 |
12 | # Spark-Matcher
13 |
14 | Spark-Matcher is a scalable entity matching algorithm implemented in PySpark. With Spark-Matcher the user can easily
15 | train an algorithm to solve a custom matching problem. Spark Matcher uses active learning (modAL) to train a
16 | classifier (Scikit-learn) to match entities. In order to deal with the N^2 complexity of matching large tables, blocking is
17 | implemented to reduce the number of pairs. Since the implementation is done in PySpark, Spark-Matcher can deal with
18 | extremely large tables.
19 |
20 | Documentation with examples can be found [here](https://spark-matcher.readthedocs.io/en/latest/).
21 |
22 | Developed by data scientists at ING Analytics, www.ing.com.
23 |
24 | ## Installation
25 |
26 | ### Normal installation
27 |
28 | As Spark-Matcher is intended to be used with large datasets on a Spark cluster, it is assumed that Spark is already
29 | installed. If that is not the case, first install Spark and PyArrow (`pip install pyspark pyarrow`).
30 |
31 | Install Spark-Matcher using PyPi:
32 |
33 | ```
34 | pip install spark-matcher
35 | ```
36 |
37 | ### Install with possibility to create documentation
38 |
39 | Pandoc, the general markup converter needs to be available. You may follow the official [Pandoc installations instructions](https://pandoc.org/installing.html) or use conda:
40 |
41 | ```
42 | conda install -c conda-forge pandoc
43 | ```
44 |
45 | Then clone the Spark-Matcher repository and add `[doc]` like this:
46 |
47 | ```
48 | pip install ".[doc]"
49 | ```
50 |
51 | ### Install to contribute
52 |
53 | Clone this repo and install in editable mode. This also installs PySpark and Jupyterlab:
54 |
55 | ```
56 | python -m pip install -e ".[dev]"
57 | python setup.py develop
58 | ```
59 |
60 | ## Documentation
61 |
62 | Documentation can be created using the following command:
63 |
64 | ```
65 | make create_documentation
66 | ```
67 |
68 | ## Dependencies
69 |
70 | The usage examples in the `examples` directory contain notebooks that run in local mode.
71 | Using the SparkMatcher in cluster mode, requires sending the SparkMatcher package and several other python packages (see spark_requirements.txt) to the executors.
72 | How to send these dependencies, depends on the cluster.
73 | Please read the instructions and examples of Apache Spark on how to do this: https://spark.apache.org/docs/latest/api/python/user_guide/python_packaging.html.
74 |
75 | SparkMatcher uses `graphframes` under to hood.
76 | Therefore, depending on the spark version, the correct version of `graphframes` needs to be added to the `external_dependencies` directory and to the configuration of the spark session.
77 | As a default, `graphframes` for spark 3.0 is used in the spark sessions in the notebooks in the `examples` directory.
78 | For a different version, see: https://spark-packages.org/package/graphframes/graphframes.
79 |
80 | ## Usage
81 |
82 | Example notebooks are provided in the `examples` directory.
83 | Using the SparkMatcher to find matches between Spark
84 | dataframes `a` and `b` goes as follows:
85 |
86 | ```python
87 | from spark_matcher.matcher import Matching
88 |
89 | myMatcher = Matcher(spark_session, col_names=['name', 'suburb', 'postcode'])
90 | ```
91 |
92 | Now we are ready for fitting the Matcher object using 'active learning'; this means that the user has to enter whether a
93 | pair is a match or not. You enter 'y' if a pair is a match or 'n' when a pair is not a match. You will be notified when
94 | the model has converged and you can stop training by pressing 'f'.
95 |
96 | ```python
97 | myMatcher.fit(a, b)
98 | ```
99 |
100 | The Matcher is now trained and can be used to predict on all data. This can be the data used for training or new data
101 | that was not seen by the model yet.
102 |
103 | ```python
104 | result = myMatcher.predict(a, b)
105 | ```
106 |
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pyspark.sql import SparkSession
3 |
4 | from spark_matcher.table_checkpointer import ParquetCheckPointer
5 |
6 |
7 | @pytest.fixture(scope="session")
8 | def spark_session():
9 | spark = (SparkSession
10 | .builder
11 | .appName(str(__file__))
12 | .getOrCreate()
13 | )
14 | yield spark
15 | spark.stop()
16 |
17 |
18 | @pytest.fixture(scope="session")
19 | def table_checkpointer(spark_session):
20 | return ParquetCheckPointer(spark_session, 'temp_database')
21 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = source
9 | BUILDDIR = build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/docs-requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx==3.5.4
2 | nbsphinx
3 | sphinx_rtd_theme
4 | Jinja2<3.1
5 | pyspark
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=source
11 | set BUILDDIR=build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/source/_static/spark_matcher_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/spark-matcher/526fd7983c5b841158bf62d70adaab383d69f0af/docs/source/_static/spark_matcher_logo.png
--------------------------------------------------------------------------------
/docs/source/api/activelearner.rst:
--------------------------------------------------------------------------------
1 | activelearner module
2 | ====================================
3 |
4 | .. automodule:: spark_matcher.activelearner
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
9 | .. toctree::
10 | :maxdepth: 4
11 | :caption: Contents:
12 |
--------------------------------------------------------------------------------
/docs/source/api/blocker.rst:
--------------------------------------------------------------------------------
1 | blocker module
2 | ==============================
3 |
4 | .. automodule:: spark_matcher.blocker
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
9 | .. toctree::
10 | :maxdepth: 4
11 | :caption: Contents:
12 |
--------------------------------------------------------------------------------
/docs/source/api/data.rst:
--------------------------------------------------------------------------------
1 | data module
2 | ===========================
3 |
4 | .. automodule:: spark_matcher.data.datasets
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
9 | .. toctree::
10 | :maxdepth: 4
11 | :caption: Contents:
12 |
--------------------------------------------------------------------------------
/docs/source/api/deduplicator.rst:
--------------------------------------------------------------------------------
1 | deduplicator module
2 | ===================================
3 |
4 | .. automodule:: spark_matcher.deduplicator
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
9 | .. toctree::
10 | :maxdepth: 4
11 | :caption: Contents:
12 |
--------------------------------------------------------------------------------
/docs/source/api/matcher.rst:
--------------------------------------------------------------------------------
1 | matcher module
2 | ==============================
3 |
4 | .. automodule:: spark_matcher.matcher
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
9 | .. toctree::
10 | :maxdepth: 4
11 | :caption: Contents:
12 |
--------------------------------------------------------------------------------
/docs/source/api/matching_base.rst:
--------------------------------------------------------------------------------
1 | matching_base module
2 | =====================================
3 |
4 | .. automodule:: spark_matcher.matching_base
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
9 | .. toctree::
10 | :maxdepth: 4
11 | :caption: Contents:
--------------------------------------------------------------------------------
/docs/source/api/modules.rst:
--------------------------------------------------------------------------------
1 | spark_matcher
2 | =============
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | activelearner
8 | blocker
9 | data
10 | deduplicator
11 | matcher
12 | matching_base
13 | sampler
14 | scorer
15 | similarity_metrics
16 |
--------------------------------------------------------------------------------
/docs/source/api/sampler.rst:
--------------------------------------------------------------------------------
1 | sampler module
2 | ==============================
3 |
4 | .. automodule:: spark_matcher.sampler
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
9 | .. toctree::
10 | :maxdepth: 4
11 | :caption: Contents:
--------------------------------------------------------------------------------
/docs/source/api/scorer.rst:
--------------------------------------------------------------------------------
1 | scorer module
2 | =============================
3 |
4 | .. automodule:: spark_matcher.scorer
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
9 | .. toctree::
10 | :maxdepth: 4
11 | :caption: Contents:
--------------------------------------------------------------------------------
/docs/source/api/similarity_metrics.rst:
--------------------------------------------------------------------------------
1 | similarity_metrics module
2 | ==========================================
3 |
4 | .. automodule:: spark_matcher.similarity_metrics
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
9 | .. toctree::
10 | :maxdepth: 4
11 | :caption: Contents:
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 | sys.path.insert(0, os.path.abspath('./../..'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = 'Spark Matcher'
21 | copyright = '2021, Ahmet Bayraktar, Frits Hermans, Stan Leisink'
22 | author = 'Ahmet Bayraktar, Frits Hermans, Stan Leisink'
23 |
24 | # The full version, including alpha/beta/rc tags
25 | release = '0.3.2'
26 |
27 |
28 | # -- General configuration ---------------------------------------------------
29 |
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = ['nbsphinx', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon']
34 |
35 | # Add any paths that contain templates here, relative to this directory.
36 | templates_path = ['_templates']
37 |
38 | # List of patterns, relative to source directory, that match files and
39 | # directories to ignore when looking for source files.
40 | # This pattern also affects html_static_path and html_extra_path.
41 | exclude_patterns = []
42 |
43 |
44 | # -- Options for HTML output -------------------------------------------------
45 |
46 | # The theme to use for HTML and HTML Help pages. See the documentation for
47 | # a list of builtin themes.
48 | #
49 | html_theme = 'sphinx_rtd_theme'
50 |
51 | # Add any paths that contain custom static files (such as style sheets) here,
52 | # relative to this directory. They are copied after the builtin static files,
53 | # so a file named "default.css" will overwrite the builtin "default.css".
54 | html_static_path = ['_static']
55 |
56 | # FKH:
57 | nbsphinx_allow_errors = True
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | .. Spark Matcher documentation master file, created by
2 | sphinx-quickstart on Tue Nov 23 10:39:10 2021.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | .. image:: _static/spark_matcher_logo.png
7 |
8 | Welcome to Spark Matcher's documentation!
9 | =========================================
10 |
11 | Spark Matcher is a scalable entity matching algorithm implemented in PySpark.
12 | With Spark Matcher the user can easily train an algorithm to solve a custom matching problem.
13 | Spark Matcher uses active learning (modAL) to train a classifier (Sklearn) to match entities.
14 | In order to deal with the N^2 complexity of matching large tables, blocking is implemented to reduce the number of pairs.
15 | Since the implementation is done in PySpark, Spark Matcher can deal with extremely large tables.
16 |
17 |
18 | .. toctree::
19 | :maxdepth: 2
20 | :caption: Contents:
21 |
22 | installation_guide
23 | example.ipynb
24 | api/modules
25 |
26 |
27 |
28 | Indices and tables
29 | ==================
30 |
31 | * :ref:`genindex`
32 | * :ref:`modindex`
33 | * :ref:`search`
34 |
--------------------------------------------------------------------------------
/docs/source/installation_guide.rst:
--------------------------------------------------------------------------------
1 | How to Install
2 | **************
3 |
4 | As Spark-Matcher is intended to be used with large datasets on a Spark cluster, it is assumed that Spark is already
5 | installed. If that is not the case, first install Spark and PyArrow (:code:`pip install pyspark pyarrow`).
6 |
7 | Install Spark-Matcher from PyPi:
8 |
9 | .. code-block:: bash
10 |
11 | pip install spark-matcher
12 |
--------------------------------------------------------------------------------
/examples/example_matcher_advanced.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "7601f10c",
6 | "metadata": {},
7 | "source": [
8 | "# Spark-Matcher advanced Matcher example "
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "71300ce4",
14 | "metadata": {},
15 | "source": [
16 | "This notebook shows how to use the `spark_matcher` for matching entities with more customized settings. First we create a Spark session:"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "id": "b29a36e5",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "%config Completer.use_jedi = False # for proper autocompletion\n",
27 | "from pyspark.sql import SparkSession"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "id": "b5d8664e",
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "spark = (SparkSession\n",
38 | " .builder\n",
39 | " .master(\"local\")\n",
40 | " .enableHiveSupport()\n",
41 | " .getOrCreate())"
42 | ]
43 | },
44 | {
45 | "cell_type": "markdown",
46 | "id": "0f72cb2c",
47 | "metadata": {},
48 | "source": [
49 | "Load the example data:"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": null,
55 | "id": "fc90dcba",
56 | "metadata": {},
57 | "outputs": [],
58 | "source": [
59 | "from spark_matcher.data import load_data"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": null,
65 | "id": "26d826fd",
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "a, b = load_data(spark)"
70 | ]
71 | },
72 | {
73 | "cell_type": "markdown",
74 | "id": "eec05ac1",
75 | "metadata": {},
76 | "source": [
77 | "We now create a `Matcher` object with our own string similarity metric and blocking rules:"
78 | ]
79 | },
80 | {
81 | "cell_type": "code",
82 | "execution_count": null,
83 | "id": "e4c9b675",
84 | "metadata": {},
85 | "outputs": [],
86 | "source": [
87 | "from spark_matcher.matcher import Matcher"
88 | ]
89 | },
90 | {
91 | "cell_type": "markdown",
92 | "id": "9935a8a6",
93 | "metadata": {},
94 | "source": [
95 | "First create a string similarity metric that checks if the first word is a perfect match:"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "id": "8b1c0150",
102 | "metadata": {},
103 | "outputs": [],
104 | "source": [
105 | "def first_word(string_1, string_2):\n",
106 | " return float(string_1.split()[0]==string_2.split()[0])"
107 | ]
108 | },
109 | {
110 | "cell_type": "markdown",
111 | "id": "98b9d059",
112 | "metadata": {},
113 | "source": [
114 | "We also want to use the `token_sort_ratio` from the `thefuzz` package. Note that this package should be available on the Spark worker nodes."
115 | ]
116 | },
117 | {
118 | "cell_type": "code",
119 | "execution_count": null,
120 | "id": "04f54a91",
121 | "metadata": {},
122 | "outputs": [],
123 | "source": [
124 | "from thefuzz.fuzz import token_sort_ratio"
125 | ]
126 | },
127 | {
128 | "cell_type": "code",
129 | "execution_count": null,
130 | "id": "6f1b85f8",
131 | "metadata": {},
132 | "outputs": [],
133 | "source": [
134 | "field_info={'name':[first_word, token_sort_ratio], 'suburb':[token_sort_ratio], 'postcode':[token_sort_ratio]}"
135 | ]
136 | },
137 | {
138 | "cell_type": "markdown",
139 | "id": "77b1623b",
140 | "metadata": {},
141 | "source": [
142 | "Moreover, we want to limit blocking to the 'title' field only by looking at the first 3 character and the first 3 words:"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": null,
148 | "id": "c3f6c974",
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "from spark_matcher.blocker.blocking_rules import FirstNChars, FirstNWords"
153 | ]
154 | },
155 | {
156 | "cell_type": "code",
157 | "execution_count": null,
158 | "id": "736c3643",
159 | "metadata": {},
160 | "outputs": [],
161 | "source": [
162 | "blocking_rules=[FirstNChars('name', 3), FirstNWords('name', 3)]"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": null,
168 | "id": "3b98dbb6",
169 | "metadata": {},
170 | "outputs": [],
171 | "source": [
172 | "myMatcher = Matcher(spark, field_info=field_info, blocking_rules=blocking_rules, checkpoint_dir='path_to_checkpoints')"
173 | ]
174 | },
175 | {
176 | "cell_type": "markdown",
177 | "id": "22440b47",
178 | "metadata": {},
179 | "source": [
180 | "Now we are ready for fitting the `Matcher` object:"
181 | ]
182 | },
183 | {
184 | "cell_type": "code",
185 | "execution_count": null,
186 | "id": "9a5999c6",
187 | "metadata": {
188 | "tags": []
189 | },
190 | "outputs": [],
191 | "source": [
192 | "myMatcher.fit(a, b)"
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "id": "a0cfadaa",
198 | "metadata": {},
199 | "source": [
200 | "This fitted model can now be use to predict:"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": null,
206 | "id": "594fd2cd",
207 | "metadata": {},
208 | "outputs": [],
209 | "source": [
210 | "result = myMatcher.predict(a, b)"
211 | ]
212 | }
213 | ],
214 | "metadata": {
215 | "kernelspec": {
216 | "display_name": "Python 3 (ipykernel)",
217 | "language": "python",
218 | "name": "python3"
219 | },
220 | "language_info": {
221 | "codemirror_mode": {
222 | "name": "ipython",
223 | "version": 3
224 | },
225 | "file_extension": ".py",
226 | "mimetype": "text/x-python",
227 | "name": "python",
228 | "nbconvert_exporter": "python",
229 | "pygments_lexer": "ipython3",
230 | "version": "3.8.11"
231 | }
232 | },
233 | "nbformat": 4,
234 | "nbformat_minor": 5
235 | }
236 |
--------------------------------------------------------------------------------
/examples/example_stopword_removal.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "5287ca32",
6 | "metadata": {
7 | "tags": [
8 | "keep_output"
9 | ]
10 | },
11 | "source": [
12 | "# Spark-Matcher advanced example "
13 | ]
14 | },
15 | {
16 | "cell_type": "markdown",
17 | "id": "15c0ec9a",
18 | "metadata": {
19 | "tags": [
20 | "keep_output"
21 | ]
22 | },
23 | "source": [
24 | "This notebook shows how to use the `spark_matcher` with more customized settings. First we create a Spark session:"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "53d5be5d",
31 | "metadata": {
32 | "tags": [
33 | "keep_output"
34 | ]
35 | },
36 | "outputs": [],
37 | "source": [
38 | "%config Completer.use_jedi = False # for proper autocompletion\n",
39 | "from pyspark.sql import SparkSession"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "id": "7a6d7d03",
46 | "metadata": {
47 | "tags": [
48 | "keep_output"
49 | ]
50 | },
51 | "outputs": [],
52 | "source": [
53 | "spark = (SparkSession\n",
54 | " .builder\n",
55 | " .master(\"local\")\n",
56 | " .enableHiveSupport()\n",
57 | " .getOrCreate())"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "id": "fd1f147c",
63 | "metadata": {},
64 | "source": [
65 | "Load the example data:"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "id": "43d1d5a4",
72 | "metadata": {
73 | "tags": [
74 | "keep_output"
75 | ]
76 | },
77 | "outputs": [],
78 | "source": [
79 | "from spark_matcher.data import load_data"
80 | ]
81 | },
82 | {
83 | "cell_type": "markdown",
84 | "id": "a6c60c7c",
85 | "metadata": {},
86 | "source": [
87 | "We use the 'library' data and remove the (numeric) 'year' column:"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": null,
93 | "id": "ba4878ba",
94 | "metadata": {
95 | "tags": [
96 | "keep_output"
97 | ]
98 | },
99 | "outputs": [],
100 | "source": [
101 | "a, b = load_data(spark, kind='library')\n",
102 | "a, b = a.drop('year'), b.drop('year')"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": null,
108 | "id": "f21332ee",
109 | "metadata": {
110 | "tags": [
111 | "keep_output"
112 | ]
113 | },
114 | "outputs": [
115 | {
116 | "data": {
117 | "text/html": [
118 | "
\n",
119 | "\n",
132 | "
\n",
133 | " \n",
134 | " \n",
135 | " | \n",
136 | " title | \n",
137 | " authors | \n",
138 | " venue | \n",
139 | "
\n",
140 | " \n",
141 | " \n",
142 | " \n",
143 | " 0 | \n",
144 | " The WASA2 object-oriented workflow management ... | \n",
145 | " Gottfried Vossen, Mathias Weske | \n",
146 | " International Conference on Management of Data | \n",
147 | "
\n",
148 | " \n",
149 | " 1 | \n",
150 | " A user-centered interface for querying distrib... | \n",
151 | " Isabel F. Cruz, Kimberly M. James | \n",
152 | " International Conference on Management of Data | \n",
153 | "
\n",
154 | " \n",
155 | " 2 | \n",
156 | " World Wide Database-integrating the Web, CORBA... | \n",
157 | " Athman Bouguettaya, Boualem Benatallah, Lily H... | \n",
158 | " International Conference on Management of Data | \n",
159 | "
\n",
160 | " \n",
161 | "
\n",
162 | "
"
163 | ],
164 | "text/plain": [
165 | " title \\\n",
166 | "0 The WASA2 object-oriented workflow management ... \n",
167 | "1 A user-centered interface for querying distrib... \n",
168 | "2 World Wide Database-integrating the Web, CORBA... \n",
169 | "\n",
170 | " authors \\\n",
171 | "0 Gottfried Vossen, Mathias Weske \n",
172 | "1 Isabel F. Cruz, Kimberly M. James \n",
173 | "2 Athman Bouguettaya, Boualem Benatallah, Lily H... \n",
174 | "\n",
175 | " venue \n",
176 | "0 International Conference on Management of Data \n",
177 | "1 International Conference on Management of Data \n",
178 | "2 International Conference on Management of Data "
179 | ]
180 | },
181 | "execution_count": null,
182 | "metadata": {},
183 | "output_type": "execute_result"
184 | }
185 | ],
186 | "source": [
187 | "a.limit(3).toPandas()"
188 | ]
189 | },
190 | {
191 | "cell_type": "markdown",
192 | "id": "962ef4b5",
193 | "metadata": {},
194 | "source": [
195 | "`spark_matcher` is shipped with a utility function to get the most frequenty occurring words in a Spark dataframe column. We apply this to the `venue` column:"
196 | ]
197 | },
198 | {
199 | "cell_type": "code",
200 | "execution_count": null,
201 | "id": "0bdd312e",
202 | "metadata": {
203 | "tags": [
204 | "keep_output"
205 | ]
206 | },
207 | "outputs": [],
208 | "source": [
209 | "from spark_matcher.utils import get_most_frequent_words"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": null,
215 | "id": "81bc2861",
216 | "metadata": {
217 | "tags": [
218 | "keep_output"
219 | ]
220 | },
221 | "outputs": [
222 | {
223 | "data": {
224 | "text/html": [
225 | "\n",
226 | "\n",
239 | "
\n",
240 | " \n",
241 | " \n",
242 | " | \n",
243 | " words | \n",
244 | " count | \n",
245 | " df | \n",
246 | "
\n",
247 | " \n",
248 | " \n",
249 | " \n",
250 | " 0 | \n",
251 | " SIGMOD | \n",
252 | " 1917 | \n",
253 | " 0.390428 | \n",
254 | "
\n",
255 | " \n",
256 | " 1 | \n",
257 | " Data | \n",
258 | " 1640 | \n",
259 | " 0.334012 | \n",
260 | "
\n",
261 | " \n",
262 | " 2 | \n",
263 | " Conference | \n",
264 | " 1603 | \n",
265 | " 0.326477 | \n",
266 | "
\n",
267 | " \n",
268 | " 3 | \n",
269 | " VLDB | \n",
270 | " 1289 | \n",
271 | " 0.262525 | \n",
272 | "
\n",
273 | " \n",
274 | " 4 | \n",
275 | " on | \n",
276 | " 1135 | \n",
277 | " 0.231161 | \n",
278 | "
\n",
279 | " \n",
280 | " 5 | \n",
281 | " Record | \n",
282 | " 1111 | \n",
283 | " 0.226273 | \n",
284 | "
\n",
285 | " \n",
286 | " 6 | \n",
287 | " International | \n",
288 | " 1001 | \n",
289 | " 0.203870 | \n",
290 | "
\n",
291 | " \n",
292 | " 7 | \n",
293 | " | \n",
294 | " 858 | \n",
295 | " 0.174745 | \n",
296 | "
\n",
297 | " \n",
298 | " 8 | \n",
299 | " Large | \n",
300 | " 843 | \n",
301 | " 0.171690 | \n",
302 | "
\n",
303 | " \n",
304 | " 9 | \n",
305 | " Very | \n",
306 | " 843 | \n",
307 | " 0.171690 | \n",
308 | "
\n",
309 | " \n",
310 | "
\n",
311 | "
"
312 | ],
313 | "text/plain": [
314 | " words count df\n",
315 | "0 SIGMOD 1917 0.390428\n",
316 | "1 Data 1640 0.334012\n",
317 | "2 Conference 1603 0.326477\n",
318 | "3 VLDB 1289 0.262525\n",
319 | "4 on 1135 0.231161\n",
320 | "5 Record 1111 0.226273\n",
321 | "6 International 1001 0.203870\n",
322 | "7 858 0.174745\n",
323 | "8 Large 843 0.171690\n",
324 | "9 Very 843 0.171690"
325 | ]
326 | },
327 | "execution_count": null,
328 | "metadata": {},
329 | "output_type": "execute_result"
330 | }
331 | ],
332 | "source": [
333 | "frequent_words = get_most_frequent_words(a.unionByName(b), col_name='venue')\n",
334 | "frequent_words.head(10)"
335 | ]
336 | },
337 | {
338 | "cell_type": "markdown",
339 | "id": "7e0a1379",
340 | "metadata": {},
341 | "source": [
342 | "Based on this list, we decide that we want to consider the words 'conference' and 'international' as stopwords. The utility function `remove_stopwords` does this job:"
343 | ]
344 | },
345 | {
346 | "cell_type": "code",
347 | "execution_count": null,
348 | "id": "00cb0b30",
349 | "metadata": {
350 | "tags": [
351 | "keep_output"
352 | ]
353 | },
354 | "outputs": [],
355 | "source": [
356 | "from spark_matcher.utils import remove_stopwords"
357 | ]
358 | },
359 | {
360 | "cell_type": "code",
361 | "execution_count": null,
362 | "id": "a386ead6",
363 | "metadata": {
364 | "tags": [
365 | "keep_output"
366 | ]
367 | },
368 | "outputs": [],
369 | "source": [
370 | "stopwords = ['conference', 'international']"
371 | ]
372 | },
373 | {
374 | "cell_type": "code",
375 | "execution_count": null,
376 | "id": "18f3322e",
377 | "metadata": {
378 | "tags": [
379 | "keep_output"
380 | ]
381 | },
382 | "outputs": [],
383 | "source": [
384 | "a = remove_stopwords(a, col_name='venue', stopwords=stopwords).drop('venue')\n",
385 | "b = remove_stopwords(b, col_name='venue', stopwords=stopwords).drop('venue')"
386 | ]
387 | },
388 | {
389 | "cell_type": "markdown",
390 | "id": "816e6201",
391 | "metadata": {},
392 | "source": [
393 | "A new column `venue_wo_stopwords` is created in which the stopwords are removed:"
394 | ]
395 | },
396 | {
397 | "cell_type": "code",
398 | "execution_count": null,
399 | "id": "d2c8d0c9",
400 | "metadata": {
401 | "tags": [
402 | "keep_output"
403 | ]
404 | },
405 | "outputs": [
406 | {
407 | "data": {
408 | "text/html": [
409 | "\n",
410 | "\n",
423 | "
\n",
424 | " \n",
425 | " \n",
426 | " | \n",
427 | " title | \n",
428 | " authors | \n",
429 | " venue_wo_stopwords | \n",
430 | "
\n",
431 | " \n",
432 | " \n",
433 | " \n",
434 | " 0 | \n",
435 | " The WASA2 object-oriented workflow management ... | \n",
436 | " Gottfried Vossen, Mathias Weske | \n",
437 | " on Management of Data | \n",
438 | "
\n",
439 | " \n",
440 | " 1 | \n",
441 | " A user-centered interface for querying distrib... | \n",
442 | " Isabel F. Cruz, Kimberly M. James | \n",
443 | " on Management of Data | \n",
444 | "
\n",
445 | " \n",
446 | " 2 | \n",
447 | " World Wide Database-integrating the Web, CORBA... | \n",
448 | " Athman Bouguettaya, Boualem Benatallah, Lily H... | \n",
449 | " on Management of Data | \n",
450 | "
\n",
451 | " \n",
452 | "
\n",
453 | "
"
454 | ],
455 | "text/plain": [
456 | " title \\\n",
457 | "0 The WASA2 object-oriented workflow management ... \n",
458 | "1 A user-centered interface for querying distrib... \n",
459 | "2 World Wide Database-integrating the Web, CORBA... \n",
460 | "\n",
461 | " authors venue_wo_stopwords \n",
462 | "0 Gottfried Vossen, Mathias Weske on Management of Data \n",
463 | "1 Isabel F. Cruz, Kimberly M. James on Management of Data \n",
464 | "2 Athman Bouguettaya, Boualem Benatallah, Lily H... on Management of Data "
465 | ]
466 | },
467 | "execution_count": null,
468 | "metadata": {},
469 | "output_type": "execute_result"
470 | }
471 | ],
472 | "source": [
473 | "a.limit(3).toPandas()"
474 | ]
475 | },
476 | {
477 | "cell_type": "markdown",
478 | "id": "f1c71962",
479 | "metadata": {},
480 | "source": [
481 | "We use the `spark_matcher` to link the records in dataframe `a` with the records in dataframe `b`. Instead of the `venue` column, we now use the newly created `venue_wo_stopwords` column."
482 | ]
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": null,
487 | "id": "34d71374",
488 | "metadata": {
489 | "tags": [
490 | "keep_output"
491 | ]
492 | },
493 | "outputs": [],
494 | "source": [
495 | "from spark_matcher.matcher import Matcher"
496 | ]
497 | },
498 | {
499 | "cell_type": "code",
500 | "execution_count": null,
501 | "id": "833caee8",
502 | "metadata": {
503 | "tags": [
504 | "keep_output"
505 | ]
506 | },
507 | "outputs": [],
508 | "source": [
509 | "myMatcher = Matcher(spark, col_names=['title', 'authors', 'venue_wo_stopwords'], checkpoint_dir='path_to_checkpoints')"
510 | ]
511 | },
512 | {
513 | "cell_type": "markdown",
514 | "id": "4cf920a9",
515 | "metadata": {
516 | "tags": [
517 | "keep_output"
518 | ]
519 | },
520 | "source": [
521 | "Now we are ready for fitting the `Matcher` object."
522 | ]
523 | },
524 | {
525 | "cell_type": "code",
526 | "execution_count": null,
527 | "id": "ef8e304a",
528 | "metadata": {
529 | "tags": [
530 | "keep_output"
531 | ]
532 | },
533 | "outputs": [],
534 | "source": [
535 | "myMatcher.fit(a, b)"
536 | ]
537 | },
538 | {
539 | "cell_type": "markdown",
540 | "id": "5b0e4425",
541 | "metadata": {},
542 | "source": [
543 | "The `Matcher` is now trained and can be used to predict on all data as usual:"
544 | ]
545 | },
546 | {
547 | "cell_type": "code",
548 | "execution_count": null,
549 | "id": "f342c880",
550 | "metadata": {
551 | "tags": [
552 | "keep_output"
553 | ]
554 | },
555 | "outputs": [],
556 | "source": [
557 | "result = myMatcher.predict(a, b, threshold=0.5, top_n=3)"
558 | ]
559 | },
560 | {
561 | "cell_type": "markdown",
562 | "id": "2a68b63e",
563 | "metadata": {},
564 | "source": [
565 | "Now let's have a look at the results:"
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "execution_count": null,
571 | "id": "b3057fd0",
572 | "metadata": {
573 | "tags": [
574 | "keep_output"
575 | ]
576 | },
577 | "outputs": [],
578 | "source": [
579 | "result_pdf = result.toPandas()"
580 | ]
581 | },
582 | {
583 | "cell_type": "code",
584 | "execution_count": null,
585 | "id": "2aae3ea3",
586 | "metadata": {
587 | "tags": [
588 | "keep_output"
589 | ]
590 | },
591 | "outputs": [],
592 | "source": [
593 | "result_pdf.sort_values('score')"
594 | ]
595 | }
596 | ],
597 | "metadata": {
598 | "kernelspec": {
599 | "display_name": "Python 3 (ipykernel)",
600 | "language": "python",
601 | "name": "python3"
602 | },
603 | "language_info": {
604 | "codemirror_mode": {
605 | "name": "ipython",
606 | "version": 3
607 | },
608 | "file_extension": ".py",
609 | "mimetype": "text/x-python",
610 | "name": "python",
611 | "nbconvert_exporter": "python",
612 | "pygments_lexer": "ipython3",
613 | "version": "3.8.13"
614 | }
615 | },
616 | "nbformat": 4,
617 | "nbformat_minor": 5
618 | }
619 |
--------------------------------------------------------------------------------
/external_dependencies/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/spark-matcher/526fd7983c5b841158bf62d70adaab383d69f0af/external_dependencies/.gitkeep
--------------------------------------------------------------------------------
/licenses_bundled:
--------------------------------------------------------------------------------
1 | The Spark-Matcher repository and source distributions bundle several libraries that are
2 | compatibly licensed. We list these here.
3 |
4 | Name: dill
5 | License: BSD License (3-clause BSD)
6 |
7 | Name: graphframes
8 | License: MIT License
9 |
10 | Name: jupyterlab
11 | License: BSD License
12 |
13 | Name: modAL
14 | License: MIT License
15 |
16 | Name: multipledispatch
17 | License: BSD License
18 |
19 | Name: nbsphinx
20 | License: MIT License
21 |
22 | Name: numpy
23 | License: BSD License (3-clause BSD)
24 |
25 | Name: pandas
26 | License: BSD License (3-clause BSD)
27 |
28 | Name: pyarrow
29 | License: Apache Software License version 2
30 |
31 | Name: pyspark
32 | License: Apache Software License version 2
33 |
34 | Name: pytest
35 | License: MIT License
36 |
37 | Name: python-Levenshtein
38 | License: GNU General Public License v2
39 |
40 | Name: scikit-learn
41 | License: OSI Approved (new BSD)
42 |
43 | Name: scipy
44 | License: BSD License (3-clause BSD)
45 |
46 | Name: sphinx
47 | License: BSD License (2-clause BSD)
48 |
49 | Name: sphinx_rtd_theme
50 | License: MIT License
51 |
52 | Name: thefuzz
53 | License: GNU General Public License v2
54 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=42"]
3 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | base_packages = [
4 | 'pandas',
5 | 'numpy',
6 | 'scikit-learn',
7 | 'python-Levenshtein',
8 | 'thefuzz',
9 | 'modAL-python',
10 | 'pytest',
11 | 'multipledispatch',
12 | 'dill',
13 | 'graphframes',
14 | 'scipy'
15 | ]
16 |
17 | doc_packages = [
18 | 'sphinx',
19 | 'nbsphinx',
20 | 'sphinx_rtd_theme'
21 | ]
22 |
23 | util_packages = [
24 | 'pyspark',
25 | 'pyarrow',
26 | 'jupyterlab'
27 | ]
28 |
29 | base_doc_packages = base_packages + doc_packages
30 | dev_packages = base_packages + doc_packages + util_packages
31 |
32 | with open("README.md", "r", encoding="utf-8") as fh:
33 | long_description = fh.read()
34 |
35 | setup(name='Spark-Matcher',
36 | version='0.3.2',
37 | author="Ahmet Bayraktar, Stan Leisink, Frits Hermans",
38 | description="Record matching and entity resolution at scale in Spark",
39 | long_description=long_description,
40 | long_description_content_type="text/markdown",
41 | classifiers=[
42 | "Programming Language :: Python :: 3",
43 | "License :: OSI Approved :: MIT License",
44 | "Operating System :: OS Independent",
45 | ],
46 | packages=find_packages(exclude=['examples']),
47 | package_data={"spark_matcher": ["data/*.csv"]},
48 | install_requires=base_packages,
49 | extras_require={
50 | "base": base_packages,
51 | "doc": base_doc_packages,
52 | "dev": dev_packages,
53 | },
54 | python_requires=">=3.7",
55 | )
56 |
--------------------------------------------------------------------------------
/spark_matcher/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.3.2"
2 |
--------------------------------------------------------------------------------
/spark_matcher/activelearner/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['ScoringLearner']
2 |
3 | from .active_learner import ScoringLearner
--------------------------------------------------------------------------------
/spark_matcher/activelearner/active_learner.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | from typing import List, Optional, Union
6 |
7 | import numpy as np
8 | import pandas as pd
9 | from modAL.models import ActiveLearner
10 | from modAL.uncertainty import uncertainty_sampling
11 | from pyspark.sql import DataFrame
12 | from sklearn.base import BaseEstimator
13 |
14 |
15 | class ScoringLearner:
16 | """
17 | Class to train a string matching model using active learning.
18 | Attributes:
19 | col_names: column names used for matching
20 | scorer: the scorer to be used in the active learning loop
21 | min_nr_samples: minimum number of responses required before classifier convergence is tested
22 | uncertainty_threshold: threshold on the uncertainty of the classifier during active learning,
23 | used for determining if the model has converged
24 | uncertainty_improvement_threshold: threshold on the uncertainty improvement of classifier during active
25 | learning, used for determining if the model has converged
26 | n_uncertainty_improvement: span of iterations to check for largest difference between uncertainties
27 | n_queries: maximum number of iterations to be done for the active learning session
28 | sampling_method: sampling method to be used for the active learning session
29 | verbose: sets verbosity
30 | """
31 | def __init__(self, col_names: List[str], scorer: BaseEstimator, min_nr_samples: int = 10,
32 | uncertainty_threshold: float = 0.1, uncertainty_improvement_threshold: float = 0.01,
33 | n_uncertainty_improvement: int = 5, n_queries: int = 9999, sampling_method=uncertainty_sampling,
34 | verbose: int = 0):
35 | self.col_names = col_names
36 | self.learner = ActiveLearner(
37 | estimator=scorer,
38 | query_strategy=sampling_method
39 | )
40 | self.counter_total = 0
41 | self.counter_positive = 0
42 | self.counter_negative = 0
43 | self.min_nr_samples = min_nr_samples
44 | self.uncertainty_threshold = uncertainty_threshold
45 | self.uncertainty_improvement_threshold = uncertainty_improvement_threshold
46 | self.n_uncertainty_improvement = n_uncertainty_improvement
47 | self.uncertainties = []
48 | self.n_queries = n_queries
49 | self.verbose = verbose
50 |
51 | def _input_assert(self, message: str, choices: List[str]) -> str:
52 | """
53 | Adds functionality to the python function `input` to limit the choices that can be returned
54 | Args:
55 | message: message to user
56 | choices: list containing possible choices that can be returned
57 | Returns:
58 | input returned by user
59 | """
60 | output = input(message).lower()
61 | if output not in choices:
62 | print(f"Wrong input! Your input should be one of the following: {', '.join(choices)}")
63 | return self._input_assert(message, choices)
64 | else:
65 | return output
66 |
67 | def _get_uncertainty_improvement(self) -> Optional[float]:
68 | """
69 | Calculates the uncertainty differences during active learning. The largest difference over the `last_n`
70 | iterations is returned. The aim of this function is to suggest early stopping of active learning.
71 |
72 | Returns: largest uncertainty update in `last_n` iterations
73 |
74 | """
75 | uncertainties = np.asarray(self.uncertainties)
76 | abs_differences = abs(uncertainties[1:] - uncertainties[:-1])
77 | return max(abs_differences[-self.n_uncertainty_improvement:])
78 |
79 | def _is_converged(self) -> bool:
80 | """
81 | Checks whether the model is converged by comparing the last uncertainty value with the `uncertainty_threshold`
82 | and comparing the `last_n` uncertainty improvements with the `uncertainty_improvement_threshold`. These checks
83 | are only performed if at least `min_nr_samples` are labelled.
84 |
85 | Returns:
86 | boolean indicating whether the model is converged
87 |
88 | """
89 | if (self.counter_total >= self.min_nr_samples) and (
90 | len(self.uncertainties) >= self.n_uncertainty_improvement + 1):
91 | uncertainty_improvement = self._get_uncertainty_improvement()
92 | if (self.uncertainties[-1] <= self.uncertainty_threshold) or (
93 | uncertainty_improvement <= self.uncertainty_improvement_threshold):
94 | return True
95 | else:
96 | return False
97 |
98 | def _get_active_learning_input(self, query_inst: pd.DataFrame) -> np.ndarray:
99 | """
100 | Obtain user input for a query during active learning.
101 | Args:
102 | query_inst: query as provided by the ActiveLearner instance
103 | Returns: label of user input '1' or '0' as yes or no
104 | 'p' to go to previous
105 | 'f' to finish
106 | 's' to skip the query
107 | """
108 | print(f'\nNr. {self.counter_total + 1} ({self.counter_positive}+/{self.counter_negative}-)')
109 | print("Is this a match? (y)es, (n)o, (p)revious, (s)kip, (f)inish")
110 | print('')
111 | for element in [1, 2]:
112 | for col_name in self.col_names:
113 | print(f'{col_name}_{element}' + ': ' + query_inst[f'{col_name}_{element}'].iloc[0])
114 | print('')
115 | user_input = self._input_assert("", ['y', 'n', 'p', 'f', 's'])
116 | # replace 'y' and 'n' with '1' and '0' to make them valid y labels
117 | user_input = user_input.replace('y', '1').replace('n', '0')
118 |
119 | y_new = np.array([user_input])
120 | return y_new
121 |
122 | def _calculate_uncertainty(self, x) -> None:
123 | # take the maximum probability of the predicted classes as proxy of the confidence of the classifier
124 | confidence = self.predict_proba(x).max(axis=1)[0]
125 | if self.verbose:
126 | print('uncertainty:', 1 - confidence)
127 | self.uncertainties.append(1 - confidence)
128 |
129 | def _show_min_max_scores(self, X: pd.DataFrame) -> None:
130 | """
131 | Prints the lowest and the highest logistic regression scores on train data during active learning.
132 |
133 | Args:
134 | X: Pandas dataframe containing train data that is available for labelling duringg active learning
135 | """
136 | X_all = pd.concat((X, self.train_samples))
137 | pred_max = self.learner.predict_proba(np.array(X_all['similarity_metrics'].tolist())).max(axis=0)
138 | print(f'lowest score: {1 - pred_max[0]:.3f}')
139 | print(f'highest score: {pred_max[1]:.3f}')
140 |
141 | def _label_perfect_train_matches(self, identical_records: pd.DataFrame) -> None:
142 | """
143 | To prevent asking labels for the perfect matches that were created by setting `n_perfect_train_matches`, these
144 | are provided to the active learner upfront.
145 |
146 | Args:
147 | identical_records: Pandas dataframe containing perfect matches
148 |
149 | """
150 | identical_records['y'] = '1'
151 | self.learner.teach(np.array(identical_records['similarity_metrics'].values.tolist()),
152 | identical_records['y'].values)
153 | self.train_samples = pd.concat([self.train_samples, identical_records])
154 |
155 | def fit(self, X: pd.DataFrame) -> 'ScoringLearner':
156 | """
157 | Fit ScoringLearner instance on pairs of strings
158 | Args:
159 | X: Pandas dataframe containing pairs of strings and distance metrics of paired strings
160 | """
161 | self.train_samples = pd.DataFrame([])
162 | query_inst_prev = None
163 |
164 | # automatically label all perfect train matches:
165 | identical_records = X[X['perfect_train_match']].copy()
166 | self._label_perfect_train_matches(identical_records)
167 | X = X.drop(identical_records.index).reset_index(drop=True) # remove identical records to avoid double labelling
168 |
169 | for i in range(self.n_queries):
170 | query_idx, query_inst = self.learner.query(np.array(X['similarity_metrics'].tolist()))
171 |
172 | if self.learner.estimator.fitted_:
173 | # the uncertainty calculations need a fitted estimator
174 | # however it can occur that the estimator can only be fit after a couple rounds of querying
175 | self._calculate_uncertainty(query_inst)
176 | if self.verbose >= 2:
177 | self._show_min_max_scores(X)
178 |
179 | y_new = self._get_active_learning_input(X.iloc[query_idx])
180 | if y_new == 'p': # use previous (input is 'p')
181 | y_new = self._get_active_learning_input(query_inst_prev)
182 | elif y_new == 'f': # finish labelling (input is 'f')
183 | break
184 | query_inst_prev = X.iloc[query_idx]
185 | if y_new != 's': # skip case (input is 's')
186 | self.learner.teach(np.asarray([X.iloc[query_idx]['similarity_metrics'].iloc[0]]), np.asarray(y_new))
187 | train_sample_to_add = X.iloc[query_idx].copy()
188 | train_sample_to_add['y'] = y_new
189 | self.train_samples = pd.concat([self.train_samples, train_sample_to_add])
190 |
191 | X = X.drop(query_idx).reset_index(drop=True)
192 |
193 | if self._is_converged():
194 | print("Classifier converged, enter 'f' to stop training")
195 |
196 | if y_new == '1':
197 | self.counter_positive += 1
198 | elif y_new == '0':
199 | self.counter_negative += 1
200 | self.counter_total += 1
201 | return self
202 |
203 | def predict_proba(self, X: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]:
204 | """
205 | Predict probabilities on new data whether the pairs are a match or not
206 | Args:
207 | X: Pandas or Spark dataframe to predict on
208 | Returns: match probabilities
209 | """
210 | return self.learner.estimator.predict_proba(X)
211 |
--------------------------------------------------------------------------------
/spark_matcher/blocker/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['BlockLearner', 'BlockingRule']
2 |
3 | from .block_learner import BlockLearner
4 | from .blocking_rules import BlockingRule
--------------------------------------------------------------------------------
/spark_matcher/blocker/block_learner.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | from typing import List, Tuple, Optional, Union
6 |
7 | from pyspark.sql import DataFrame, functions as F
8 |
9 | from spark_matcher.blocker.blocking_rules import BlockingRule
10 | from spark_matcher.table_checkpointer import TableCheckpointer
11 |
12 |
13 | class BlockLearner:
14 | """
15 | Class to learn blocking rules from training data.
16 |
17 | Attributes:
18 | blocking_rules: list of `BlockingRule` objects that are taken into account during block learning
19 | recall: the minimum required percentage of training pairs that are covered by the learned blocking rules
20 | verbose: set verbosity
21 | """
22 | def __init__(self, blocking_rules: List[BlockingRule], recall: float,
23 | table_checkpointer: Optional[TableCheckpointer] = None, verbose=0):
24 | self.table_checkpointer = table_checkpointer
25 | self.blocking_rules = blocking_rules
26 | self.recall = recall
27 | self.cover_blocking_rules = None
28 | self.fitted = False
29 | self.full_set = None
30 | self.full_set_size = None
31 | self.verbose = verbose
32 |
33 | def _greedy_set_coverage(self):
34 | """
35 | This method solves the `set cover problem` with a greedy algorithm. It identifies a subset of blocking-rules
36 | that cover `recall` * |full_set| (|full_set| stands for cardinality of the full_set)percent of all the elements
37 | in the original (full) set.
38 | """
39 | # sort the blocking rules to start the algorithm with the ones that have most coverage:
40 | _sorted_blocking_rules = sorted(self.blocking_rules, key=lambda bl: bl.training_coverage_size, reverse=True)
41 |
42 | self.cover_blocking_rules = [_sorted_blocking_rules.pop(0)]
43 | self.cover_set = self.cover_blocking_rules[0].training_coverage
44 |
45 | for blocking_rule in _sorted_blocking_rules:
46 | # check if the required recall is already reached:
47 | if len(self.cover_set) >= int(self.recall * self.full_set_size):
48 | break
49 | # check if subset is dominated by the cover_set:
50 | if blocking_rule.training_coverage.issubset(self.cover_set):
51 | continue
52 | self.cover_set = self.cover_set.union(blocking_rule.training_coverage)
53 | self.cover_blocking_rules.append(blocking_rule)
54 |
55 | def fit(self, sdf: DataFrame) -> 'BlockLearner':
56 | """
57 | This method fits, i.e. learns, the blocking rules that are needed to cover `recall` percent of the training
58 | set pairs. The fitting is done by solving the set-cover problem. It is solved by using a greedy algorithm.
59 |
60 | Args:
61 | sdf: a labelled training set containing pairs.
62 |
63 | Returns:
64 | the object itself
65 | """
66 | # verify whether `row_id` and `label` are columns of `sdf_1`
67 | if 'row_id' not in sdf.columns:
68 | raise AssertionError('`row_id` is not present as a column of sdf_1')
69 |
70 | if 'label' not in sdf.columns:
71 | raise AssertionError('`label` is not present as a column of sdf_1')
72 |
73 |
74 | # determine the full set of pairs in the training data that have positive labels from active learning:
75 | sdf = (
76 | sdf
77 | .filter(F.col('label') == 1)
78 | .persist() # break the lineage to avoid recomputing since sdf_1 is used many times during the fitting
79 | )
80 | self.full_set = set(sdf.select('row_id').toPandas()['row_id'])
81 | # determine the cardinality of the full_set, i.e. |full_set|:
82 | self.full_set_size = len(self.full_set)
83 |
84 | # calculate the training coverage for each blocking rule:
85 | self.blocking_rules = (
86 | list(
87 | map(
88 | lambda x: x.calculate_training_set_coverage(sdf),
89 | self.blocking_rules
90 | )
91 | )
92 | )
93 |
94 | # use a greedy set cover algorithm to select a subset of the blocking rules that cover `recall` * |full_set|
95 | self._greedy_set_coverage()
96 | if self.verbose:
97 | print('Blocking rules:', ", ".join([x.__repr__() for x in self.cover_blocking_rules]))
98 | self.fitted = True
99 | return self
100 |
101 | def _create_blocks(self, sdf: DataFrame) -> List[DataFrame]:
102 | """
103 | This method creates a list of blocked data. Blocked data is created by applying the learned blocking rules on
104 | the input dataframe.
105 |
106 | Args:
107 | sdf: dataframe containing records
108 |
109 | Returns:
110 | A list of blocked dataframes, i.e. list of dataframes containing block-keys
111 | """
112 | sdf_blocks = []
113 | for blocking_rule in self.cover_blocking_rules:
114 | sdf_blocks.append(blocking_rule.create_block_key(sdf))
115 | return sdf_blocks
116 |
117 | @staticmethod
118 | def _create_block_table(blocks: List[DataFrame]) -> DataFrame:
119 | """
120 | This method unifies the blocked data into a single dataframe.
121 | Args:
122 | blocks: containing blocked dataframes, i.e. dataframes containing block-keys
123 |
124 | Returns:
125 | a unified dataframe with all the block-keys
126 | """
127 | block_table = blocks[0]
128 | for block in blocks[1:]:
129 | block_table = block_table.unionByName(block)
130 | return block_table
131 |
132 | def transform(self, sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None) -> Union[DataFrame,
133 | Tuple[DataFrame, DataFrame]]:
134 | """
135 | This method adds the block-keys to the input dataframes. It applies all the learned blocking rules on the
136 | input data and unifies the results. The result of this method is/are the input dataframe(s) containing the
137 | block-keys from the learned blocking rules.
138 |
139 | Args:
140 | sdf_1: dataframe containing records
141 | sdf_2: dataframe containing records
142 |
143 | Returns:
144 | dataframe(s) containing block-keys from the learned blocking-rules
145 | """
146 | if not self.fitted:
147 | raise ValueError('BlockLearner is not yet fitted')
148 |
149 | sdf_1_blocks = self._create_blocks(sdf_1)
150 | sdf_1_blocks = self._create_block_table(sdf_1_blocks)
151 |
152 | if sdf_2:
153 | sdf_1_blocks = self.table_checkpointer(sdf_1_blocks, checkpoint_name='sdf_1_blocks')
154 | sdf_2_blocks = self._create_blocks(sdf_2)
155 | sdf_2_blocks = self._create_block_table(sdf_2_blocks)
156 |
157 | return sdf_1_blocks, sdf_2_blocks
158 |
159 | return sdf_1_blocks
160 |
--------------------------------------------------------------------------------
/spark_matcher/blocker/blocking_rules.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | import abc
6 |
7 | from pyspark.sql import Column, DataFrame, functions as F, types as T
8 |
9 |
10 | class BlockingRule(abc.ABC):
11 | """
12 | Abstract class for blocking rules. This class contains all the base functionality for blocking rules.
13 |
14 | Attributes:
15 | blocking_column: the column on which the `BlockingRule` is applied
16 | """
17 |
18 | def __init__(self, blocking_column: str):
19 | self.blocking_column = blocking_column
20 | self.training_coverage = None
21 | self.training_coverage_size = None
22 |
23 | @abc.abstractmethod
24 | def _blocking_rule(self, c: Column) -> Column:
25 | """
26 | Abstract method for a blocking-rule
27 |
28 | Args:
29 | c: a Column
30 |
31 | Returns:
32 | A Column with a blocking-rule result
33 | """
34 | pass
35 |
36 | @abc.abstractmethod
37 | def __repr__(self) -> str:
38 | """
39 | Abstract method for a class representation
40 |
41 | Returns:
42 | return a string that represents the blocking-rule
43 |
44 | """
45 | pass
46 |
47 | def _apply_blocking_rule(self, c: Column) -> Column:
48 | """
49 | This method applies the blocking rule on an input column and adds a unique representation id to make sure the
50 | block-key is unique. Uniqueness is important to avoid collisions between block-keys, e.g. a blocking-rule that
51 | captures the first string character can return the same as a blocking-rule that captures the last character. To
52 | avoid this, the blocking-rule results is concatenated with the __repr__ method.
53 |
54 | Args:
55 | c: a Column
56 |
57 | Returns:
58 | a Column with a block_key
59 |
60 | """
61 | return F.concat(F.lit(f'{self.__repr__()}:'), self._blocking_rule(c))
62 |
63 | def create_block_key(self, sdf: DataFrame) -> DataFrame:
64 | """
65 | This method calculates and adds the block-key column to the input dataframe
66 |
67 | Args:
68 | sdf: a dataframe with records that need to be matched
69 |
70 | Returns:
71 | the dataframe with the block-key column
72 | """
73 | return sdf.withColumn('block_key', self._apply_blocking_rule(sdf[f'{self.blocking_column}']))
74 |
75 | def _create_training_block_keys(self, sdf: DataFrame) -> DataFrame:
76 | """
77 | This method is used to create block-keys on a training dataframe
78 |
79 | Args:
80 | sdf: a dataframe containing record pairs for training
81 |
82 | Returns:
83 | the dataframe containing block-keys for the record pairs
84 | """
85 | return (
86 | sdf
87 | .withColumn('block_key_1', self._apply_blocking_rule(sdf[f'{self.blocking_column}_1']))
88 | .withColumn('block_key_2', self._apply_blocking_rule(sdf[f'{self.blocking_column}_2']))
89 | )
90 |
91 | @staticmethod
92 | def _compare_and_filter_keys(sdf: DataFrame) -> DataFrame:
93 | """
94 | This method is used to compare block-keys of record pairs and subsequently filter record pairs in the training
95 | dataframe that have identical block-keys
96 |
97 | Args:
98 | sdf: a dataframe containing record pairs for training with their block-keys
99 |
100 | Returns:
101 | the filtered dataframe containing only record pairs with identical block-keys
102 | """
103 | return (
104 | sdf
105 | # split on `:` since this separates the `__repr__` from the blocking_rule result
106 | .withColumn('_blocking_result_1', F.split(F.col('block_key_1'), ":").getItem(1))
107 | .withColumn('_blocking_result_2', F.split(F.col('block_key_2'), ":").getItem(1))
108 | .filter(
109 | (F.col('block_key_1') == F.col('block_key_2')) &
110 | ((F.col('_blocking_result_1') != '') & (F.col('_blocking_result_2') != ''))
111 | )
112 | .drop('_blocking_result_1', '_blocking_result_2')
113 | )
114 |
115 | @staticmethod
116 | def _length_check(c: Column, n: int, word_count: bool = False) -> Column:
117 | """
118 | This method checks the length of the created block key.
119 |
120 | Args:
121 | c: block key to check
122 | n: given length of the string
123 | word_count: whether to check the string length or word count
124 | Returns:
125 | the block key if it is not shorter than the given length, otherwise returns None
126 | """
127 | if word_count:
128 | return F.when(F.size(c) >= n, c).otherwise(None)
129 |
130 | return F.when(F.length(c) == n, c).otherwise(None)
131 |
132 | def calculate_training_set_coverage(self, sdf: DataFrame) -> 'BlockingRule':
133 | """
134 | This method calculate the set coverage of the blocking rule on the training pairs. The set coverage of the rule
135 | is determined by looking at how many record pairs in the training set end up in the same block. This coverage
136 | is used in the BlockLearner to sort blocking rules in the greedy set_covering algorithm.
137 | Args:
138 | sdf: a dataframe containing record pairs for training
139 |
140 | Returns:
141 | The object itself
142 | """
143 | sdf = self._create_training_block_keys(sdf)
144 |
145 | sdf = self._compare_and_filter_keys(sdf)
146 |
147 | self.training_coverage = set(
148 | sdf
149 | .agg(F.collect_set('row_id').alias('training_coverage'))
150 | .collect()[0]['training_coverage']
151 | )
152 |
153 | self.training_coverage_size = len(self.training_coverage)
154 | return self
155 |
156 |
157 | # define the concrete blocking rule examples:
158 |
159 | class FirstNChars(BlockingRule):
160 | def __init__(self, blocking_column: str, n: int = 3):
161 | super().__init__(blocking_column)
162 | self.n = n
163 |
164 | def __repr__(self):
165 | return f"first_{self.n}_characters_{self.blocking_column}"
166 |
167 | def _blocking_rule(self, c: Column) -> Column:
168 | key = F.substring(c, 0, self.n)
169 | return self._length_check(key, self.n)
170 |
171 |
172 | class FirstNCharsLastWord(BlockingRule):
173 | def __init__(self, blocking_column: str, n: int = 3, remove_non_alphanumerical=False):
174 | super().__init__(blocking_column)
175 | self.n = n
176 | self.remove_non_alphanumerical = remove_non_alphanumerical
177 |
178 | def __repr__(self):
179 | return f"first_{self.n}_characters_last_word_{self.blocking_column}"
180 |
181 | def _blocking_rule(self, c: Column) -> Column:
182 | if self.remove_non_alphanumerical:
183 | c = F.regexp_replace(c, r'\W+', ' ')
184 | tokens = F.split(c, r'\s+')
185 | last_word = F.element_at(tokens, -1)
186 | key = F.substring(last_word, 1, self.n)
187 | return self._length_check(key, self.n, word_count=False)
188 |
189 |
190 | class FirstNCharactersFirstTokenSorted(BlockingRule):
191 | def __init__(self, blocking_column: str, n: int = 3, remove_non_alphanumerical=False):
192 | super().__init__(blocking_column)
193 | self.n = n
194 | self.remove_non_alphanumerical = remove_non_alphanumerical
195 |
196 | def __repr__(self):
197 | return f"first_{self.n}_characters_first_token_sorted_{self.blocking_column}"
198 |
199 | def _blocking_rule(self, c):
200 | if self.remove_non_alphanumerical:
201 | c = F.regexp_replace(c, r'\W+', ' ')
202 | tokens = F.split(c, r'\s+')
203 | sorted_tokens = F.sort_array(tokens)
204 | filtered_tokens = F.filter(sorted_tokens, lambda x: F.length(x) >= self.n)
205 | first_token = filtered_tokens.getItem(0)
206 | return F.substring(first_token, 1, self.n)
207 |
208 |
209 | class LastNChars(BlockingRule):
210 | def __init__(self, blocking_column: str, n: int = 3):
211 | super().__init__(blocking_column)
212 | self.n = n
213 |
214 | def __repr__(self):
215 | return f"last_{self.n}_characters_{self.blocking_column}"
216 |
217 | def _blocking_rule(self, c: Column) -> Column:
218 | key = F.substring(c, -self.n, self.n)
219 | return self._length_check(key, self.n)
220 |
221 |
222 | class WholeField(BlockingRule):
223 | def __init__(self, blocking_column: str):
224 | super().__init__(blocking_column)
225 |
226 | def __repr__(self):
227 | return f"whole_field_{self.blocking_column}"
228 |
229 | def _blocking_rule(self, c: Column) -> Column:
230 | return c
231 |
232 |
233 | class FirstNWords(BlockingRule):
234 | def __init__(self, blocking_column: str, n: int = 1, remove_non_alphanumerical=False):
235 | super().__init__(blocking_column)
236 | self.n = n
237 | self.remove_non_alphanumerical = remove_non_alphanumerical
238 |
239 | def __repr__(self):
240 | return f"first_{self.n}_words_{self.blocking_column}"
241 |
242 | def _blocking_rule(self, c: Column) -> Column:
243 | if self.remove_non_alphanumerical:
244 | c = F.regexp_replace(c, r'\W+', ' ')
245 | tokens = F.split(c, r'\s+')
246 | key = self._length_check(tokens, self.n, word_count=True)
247 | return F.array_join(F.slice(key, 1, self.n), ' ')
248 |
249 |
250 | class FirstNLettersNoSpace(BlockingRule):
251 | def __init__(self, blocking_column: str, n: int = 3):
252 | super().__init__(blocking_column)
253 | self.n = n
254 |
255 | def __repr__(self):
256 | return f"first_{self.n}_letters_{self.blocking_column}_no_space"
257 |
258 | def _blocking_rule(self, c: Column) -> Column:
259 | key = F.substring(F.regexp_replace(c, r'[^a-zA-Z]+', ''), 1, self.n)
260 | return self._length_check(key, self.n)
261 |
262 |
263 | class SortedIntegers(BlockingRule):
264 | def __init__(self, blocking_column: str):
265 | super().__init__(blocking_column)
266 |
267 | def __repr__(self):
268 | return f"sorted_integers_{self.blocking_column}"
269 |
270 | def _blocking_rule(self, c: Column) -> Column:
271 | number_string = F.trim(F.regexp_replace(c, r'[^0-9\s]+', ''))
272 | number_string_array = F.when(number_string != '', F.split(number_string, r'\s+'))
273 | number_int_array = F.transform(number_string_array, lambda x: x.cast(T.IntegerType()))
274 | number_sorted = F.array_sort(number_int_array)
275 | return F.array_join(number_sorted, " ")
276 |
277 |
278 | class FirstInteger(BlockingRule):
279 | def __init__(self, blocking_column: str):
280 | super().__init__(blocking_column)
281 |
282 | def __repr__(self):
283 | return f"first_integer_{self.blocking_column}"
284 |
285 | def _blocking_rule(self, c: Column) -> Column:
286 | number_string_array = F.split(F.trim(F.regexp_replace(c, r'[^0-9\s]+', '')), r'\s+')
287 | number_int_array = F.transform(number_string_array, lambda x: x.cast(T.IntegerType()))
288 | first_number = number_int_array.getItem(0)
289 | return first_number.cast(T.StringType())
290 |
291 |
292 | class LastInteger(BlockingRule):
293 | def __init__(self, blocking_column: str):
294 | super().__init__(blocking_column)
295 |
296 | def __repr__(self):
297 | return f"last_integer_{self.blocking_column}"
298 |
299 | def _blocking_rule(self, c: Column) -> Column:
300 | number_string_array = F.split(F.trim(F.regexp_replace(c, r'[^0-9\s]+', '')), r'\s+')
301 | number_int_array = F.transform(number_string_array, lambda x: x.cast(T.IntegerType()))
302 | last_number = F.slice(number_int_array, -1, 1).getItem(0)
303 | return last_number.cast(T.StringType())
304 |
305 |
306 | class LargestInteger(BlockingRule):
307 | def __init__(self, blocking_column: str):
308 | super().__init__(blocking_column)
309 |
310 | def __repr__(self):
311 | return f"largest_integer_{self.blocking_column}"
312 |
313 | def _blocking_rule(self, c: Column) -> Column:
314 | number_string_array = F.split(F.trim(F.regexp_replace(c, r'[^0-9\s]+', '')), r'\s+')
315 | number_int_array = F.transform(number_string_array, lambda x: x.cast(T.IntegerType()))
316 | largest_number = F.array_max(number_int_array)
317 | return largest_number.cast(T.StringType())
318 |
319 |
320 | class NLetterAbbreviation(BlockingRule):
321 | def __init__(self, blocking_column: str, n: int = 3):
322 | super().__init__(blocking_column)
323 | self.n = n
324 |
325 | def __repr__(self):
326 | return f"{self.n}_letter_abbreviation_{self.blocking_column}"
327 |
328 | def _blocking_rule(self, c: Column) -> Column:
329 | words = F.split(F.trim(F.regexp_replace(c, r'[0-9]+', '')), r'\s+')
330 | first_letters = F.when(F.size(words) >= self.n, F.transform(words, lambda x: F.substring(x, 1, 1)))
331 | return F.array_join(first_letters, '')
332 |
333 |
334 | # this is an example of a blocking rule that contains a udf with plain python code:
335 |
336 | class UdfFirstNChar(BlockingRule):
337 | def __init__(self, blocking_column: str, n: int):
338 | super().__init__(blocking_column)
339 | self.n = n
340 |
341 | def __repr__(self):
342 | return f"udf_first_integer_{self.blocking_column}"
343 |
344 | def _blocking_rule(self, c: Column) -> Column:
345 | @F.udf
346 | def _rule(s: str) -> str:
347 | return s[:self.n]
348 |
349 | return _rule(c)
350 |
351 |
352 | default_blocking_rules = [FirstNChars, FirstNCharsLastWord, LastNChars, WholeField, FirstNWords, FirstNLettersNoSpace,
353 | SortedIntegers, FirstInteger, LastInteger, LargestInteger, NLetterAbbreviation]
354 |
--------------------------------------------------------------------------------
/spark_matcher/config.py:
--------------------------------------------------------------------------------
1 | MINHASHING_MAXDF = 0.01
2 | MINHASHING_VOCABSIZE = 1_000_000
3 |
--------------------------------------------------------------------------------
/spark_matcher/data/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/spark-matcher/526fd7983c5b841158bf62d70adaab383d69f0af/spark_matcher/data/.gitkeep
--------------------------------------------------------------------------------
/spark_matcher/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import load_data
2 |
--------------------------------------------------------------------------------
/spark_matcher/data/datasets.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | from typing import Tuple, Optional, Union
6 | from pkg_resources import resource_filename
7 |
8 | import pandas as pd
9 |
10 | from pyspark.sql import SparkSession, DataFrame
11 |
12 |
13 | def load_data(spark: SparkSession, kind: Optional[str] = 'voters') -> Union[Tuple[DataFrame, DataFrame], DataFrame]:
14 | """
15 | Load examples datasets to be used to experiment with `spark-matcher`. For matching problems, set `kind` to `voters`
16 | for North Carolina voter registry data or `library` for bibliography data. For deduplication problems, set `kind`
17 | to `stoxx50` for EuroStoxx 50 company names and addresses.
18 |
19 | Voter data:
20 | - provided by Prof. Erhard Rahm
21 | https://dbs.uni-leipzig.de/research/projects/object_matching/benchmark_datasets_for_entity_resolution
22 |
23 | Library data:
24 | - DBLP bibliography, http://www.informatik.uni-trier.de/~ley/db/index.html
25 | - ACM Digital Library, http://portal.acm.org/portal.cfm
26 |
27 | Args:
28 | spark: Spark session
29 | kind: kind of data: `voters`, `library` or `stoxx50`
30 |
31 | Returns:
32 | two Spark dataframes for `voters` or `library`, a single dataframe for `stoxx50`
33 |
34 | """
35 | if kind == 'library':
36 | return _load_data_library(spark)
37 | if kind == 'voters':
38 | return _load_data_voters(spark)
39 | if kind == 'stoxx50':
40 | return _load_data_stoxx50(spark)
41 | else:
42 | raise ValueError('`kind` must be `library`, `voters` or `stoxx50`')
43 |
44 |
45 | def _load_data_library(spark: SparkSession) -> Tuple[DataFrame, DataFrame]:
46 | """
47 | Load examples datasets to be used to experiment with `spark-matcher`. Two Spark dataframe are returned with the
48 | same columns:
49 |
50 | - DBLP bibliography, http://www.informatik.uni-trier.de/~ley/db/index.html
51 | - ACM Digital Library, http://portal.acm.org/portal.cfm
52 |
53 | Args:
54 | spark: Spark session
55 |
56 | Returns:
57 | Spark dataframe for DBLP data and a Spark dataframe for ACM data
58 |
59 | """
60 | file_path_acm = resource_filename('spark_matcher.data', 'acm.csv')
61 | file_path_dblp = resource_filename('spark_matcher.data', 'dblp.csv')
62 | acm_pdf = pd.read_csv(file_path_acm)
63 | dblp_pdf = pd.read_csv(file_path_dblp, encoding="ISO-8859-1")
64 |
65 | for col in acm_pdf.select_dtypes('object').columns:
66 | acm_pdf[col] = acm_pdf[col].fillna("")
67 | for col in dblp_pdf.select_dtypes('object').columns:
68 | dblp_pdf[col] = dblp_pdf[col].fillna("")
69 |
70 | acm_sdf = spark.createDataFrame(acm_pdf)
71 | dblp_sdf = spark.createDataFrame(dblp_pdf)
72 | return acm_sdf, dblp_sdf
73 |
74 |
75 | def _load_data_voters(spark: SparkSession) -> Tuple[DataFrame, DataFrame]:
76 | """
77 | Voters data is based on the North Carolina voter registry and this dataset is provided by Prof. Erhard Rahm
78 | ('Comparative Evaluation of Distributed Clustering Schemes for Multi-source Entity Resolution'). Two Spark
79 | dataframe are returned with the same columns.
80 |
81 | Args:
82 | spark: Spark session
83 |
84 | Returns:
85 | two Spark dataframes containing voter data
86 |
87 | """
88 | file_path_voters_1 = resource_filename('spark_matcher.data', 'voters_1.csv')
89 | file_path_voters_2 = resource_filename('spark_matcher.data', 'voters_2.csv')
90 | voters_1_pdf = pd.read_csv(file_path_voters_1)
91 | voters_2_pdf = pd.read_csv(file_path_voters_2)
92 |
93 | voters_1_sdf = spark.createDataFrame(voters_1_pdf)
94 | voters_2_sdf = spark.createDataFrame(voters_2_pdf)
95 | return voters_1_sdf, voters_2_sdf
96 |
97 |
98 | def _load_data_stoxx50(spark: SparkSession) -> DataFrame:
99 | """
100 | The Stoxx50 dataset contains a single column containing the concatenation of Eurostoxx 50 company names and
101 | addresses. This dataset is created by the developers of spark_matcher.
102 |
103 | Args:
104 | spark: Spark session
105 |
106 | Returns:
107 | Spark dataframe containing Eurostoxx 50 names and addresses
108 |
109 | """
110 | file_path_stoxx50 = resource_filename('spark_matcher.data', 'stoxx50.csv')
111 | stoxx50_pdf = pd.read_csv(file_path_stoxx50)
112 |
113 | stoxx50_sdf = spark.createDataFrame(stoxx50_pdf)
114 | return stoxx50_sdf
115 |
--------------------------------------------------------------------------------
/spark_matcher/data/dblp.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ing-bank/spark-matcher/526fd7983c5b841158bf62d70adaab383d69f0af/spark_matcher/data/dblp.csv
--------------------------------------------------------------------------------
/spark_matcher/data/stoxx50.csv:
--------------------------------------------------------------------------------
1 | name
2 | adidas ag adi dassler strasse 1 91074 germany
3 | adidas ag adi dassler strasse 1 91074 herzogenaurach
4 | adidas ag adi dassler strasse 1 91074 herzogenaurach germany
5 | airbus se 2333 cs leiden netherlands
6 | airbus se 2333 cs netherlands
7 | airbus se leiden netherlands
8 | allianz se 80802 munich germany
9 | allianz se koniginstrasse 28 munich germany
10 | allianz se munich germany
11 | amadeus it group s a salvador de madariaga 1 28027
12 | amadeus it group s a salvador de madariaga 1 28027 madrid
13 | amadeus it group s a salvador de madariaga 1 28027 madrid spain
14 | anheuser busch inbev sa nv 3000 leuven belgium
15 | anheuser busch inbev sa nv brouwerijplein 1 3000 leuven belgium
16 | anheuser busch inbev sa nv brouwerijplein 1 leuven belgium
17 | asml holding n v de run 6501 5504 dr veldhoven netherlands
18 | asml holding n v de run 6501 veldhoven netherlands
19 | axa sa 25 avenue matignon 75008 paris france
20 | axa sa 75008 paris
21 | axa sa 75008 paris france
22 | banco bilbao vizcaya argentaria s a 48005 bilbao spain
23 | banco bilbao vizcaya argentaria s a bilbao
24 | banco bilbao vizcaya argentaria s a plaza san nicolas 4 48005 spain
25 | banco santander s a 28660
26 | banco santander s a 28660 madrid
27 | banco santander s a 28660 madrid spain
28 | basf se 67056 ludwigshafen am rhein germany
29 | basf se carl bosch strasse 38 67056 germany
30 | bayer aktiengesellschaft
31 | bayer aktiengesellschaft 51368 leverkusen germany
32 | bayer aktiengesellschaft kaiser wilhelm allee 1 51368 leverkusen germany
33 | bayerische motoren werke aktiengesellschaft munich germany
34 | bayerische motoren werke aktiengesellschaft petuelring 130 80788 munich
35 | bayerische motoren werke aktiengesellschaft petuelring 130 munich germany
36 | bnp paribas sa 16 boulevard des italiens 75009 france
37 | bnp paribas sa 16 boulevard des italiens paris france
38 | bnp paribas sa paris
39 | crh plc stonemason s way 16 dublin ireland
40 | crh plc stonemason s way 16 ireland
41 | daimler ag 70372 stuttgart germany
42 | daimler ag germany
43 | daimler ag mercedesstrasse 120 70372 stuttgart germany
44 | danone s a 15 rue du helder 75439 paris france
45 | danone s a 15 rue du helder paris france
46 | danone s a 75439 paris
47 | deutsche boerse
48 | deutsche boerse 60485 frankfurt
49 | deutsche boerse frankfurt
50 | deutsche post ag platz der deutschen post
51 | deutsche post ag platz der deutschen post 53113 germany
52 | deutsche post ag platz der deutschen post bonn germany
53 | deutsche telekom ag 53113 bonn germany
54 | deutsche telekom ag 53113 germany
55 | enel spa rome italy
56 | enel spa viale regina margherita 137 rm 00198 italy
57 | engie sa 1 place samuel de champlain 92400 courbevoie
58 | engie sa 1 place samuel de champlain 92400 france
59 | engie sa 92400 courbevoie france
60 | eni s p a piazzale enrico mattei 1 rome italy
61 | eni s p a rm 00144 rome italy
62 | essilorluxottica 1 6 rue paul cezanne 75008 paris france
63 | essilorluxottica 1 6 rue paul cezanne paris france
64 | fresenius se co kgaa else kroner strasse 1 61352 bad homburg vor der hohe germany
65 | fresenius se co kgaa else kroner strasse 1 61352 germany
66 | fresenius se co kgaa else kroner strasse 1 bad homburg vor der hohe germany
67 | iberdrola s a plaza euskadi 5 48009 bilbao spain
68 | iberdrola s a plaza euskadi 5 bilbao spain
69 | industria de diseno textil s a avenida de la diputacion s n arteixo 15143 a coruna spain
70 | industria de diseno textil s a avenida de la diputacion s n arteixo a coruna spain
71 | ing groep n v bijlmerplein 888 1102 mg amsterdam
72 | ing groep n v bijlmerplein 888 1102 mg netherlands
73 | intesa sanpaolo s p a piazza san carlo 156 to 10121 italy
74 | intesa sanpaolo s p a piazza san carlo 156 to 10121 turin
75 | intesa sanpaolo s p a turin
76 | kering sa 40 rue de sevres 75007 paris
77 | kering sa 40 rue de sevres 75007 paris france
78 | koninklijke ahold delhaize n v 1506 ma zaandam netherlands
79 | koninklijke ahold delhaize n v provincialeweg 11
80 | koninklijke ahold delhaize n v provincialeweg 11 netherlands
81 | koninklijke philips n v amstelplein 2 1096 bc
82 | koninklijke philips n v amstelplein 2 1096 bc amsterdam
83 | koninklijke philips n v amstelplein 2 1096 bc netherlands
84 | l air liquide s a 75 quai d orsay 75007 france
85 | l air liquide s a 75007 paris france
86 | l air liquide s a paris france
87 | l oreal s a 41 rue martre 92117 clichy france
88 | l oreal s a 92117 france
89 | linde plc 10 priestley road surrey research park gu2 7xy guildford
90 | linde plc 10 priestley road surrey research park gu2 7xy united kingdom
91 | lvmh moet hennessy louis vuitton societe europeenne 22 avenue montaigne 75008 france
92 | lvmh moet hennessy louis vuitton societe europeenne 22 avenue montaigne paris
93 | lvmh moet hennessy louis vuitton societe europeenne 75008 france
94 | munchener ruckversicherungs gesellschaft aktiengesellschaft koniginstrasse 107
95 | munchener ruckversicherungs gesellschaft aktiengesellschaft koniginstrasse 107 80802 munich
96 | munchener ruckversicherungs gesellschaft aktiengesellschaft koniginstrasse 107 munich germany
97 | nokia corporation 2610 espoo finland
98 | nokia corporation finland
99 | nokia corporation karakaari 7
100 | orange s a 75015 paris france
101 | orange s a 78 rue olivier de serres 75015 paris france
102 | orange s a paris
103 | safran sa 2 boulevard du general martial valin paris france
104 | safran sa 75724 france
105 | safran sa paris france
106 | sanofi 54 rue la boetie 75008 france
107 | sanofi 54 rue la boetie 75008 paris france
108 | sanofi 75008 france
109 | sap se 69190
110 | sap se 69190 germany
111 | sap se dietmar hopp allee 16
112 | schneider electric s e 35 rue joseph monier
113 | schneider electric s e 35 rue joseph monier 92500 rueil malmaison
114 | schneider electric s e 92500 france
115 | siemens aktiengesellschaft munich germany
116 | siemens aktiengesellschaft werner von siemens strasse 1 80333 germany
117 | siemens aktiengesellschaft werner von siemens strasse 1 80333 munich germany
118 | societe generale societe 29 boulevard haussmann 75009 paris france
119 | societe generale societe 75009 paris france
120 | telefonica s a 28050 madrid
121 | telefonica s a ronda de la comunicacion
122 | telefonica s a ronda de la comunicacion 28050 madrid spain
123 | the unilever group weena 455 3000 dk netherlands
124 | the unilever group weena 455 3000 dk rotterdam netherlands
125 | total s a 2 place jean millier paris france
126 | total s a 92078 france
127 | total s a paris
128 | vinci sa 1 cours ferdinand de lesseps 92851 france
129 | vinci sa 1 cours ferdinand de lesseps france
130 | vinci sa 1 cours ferdinand de lesseps rueil malmaison france
131 | vivendi sa 42 avenue de friedland 75380 paris france
132 | vivendi sa 75380 paris
133 | volkswagen ag 38440 germany
134 | volkswagen ag berliner ring 2 38440
135 | volkswagen ag berliner ring 2 38440 wolfsburg
--------------------------------------------------------------------------------
/spark_matcher/deduplicator/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['Deduplicator']
2 |
3 | from .deduplicator import Deduplicator
--------------------------------------------------------------------------------
/spark_matcher/deduplicator/connected_components_calculator.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | from typing import List, Tuple
6 |
7 | from graphframes import GraphFrame
8 | from pyspark.sql import functions as F, types as T, DataFrame, Window
9 | from pyspark.sql.utils import AnalysisException
10 |
11 | from spark_matcher.table_checkpointer import TableCheckpointer
12 |
13 |
14 | class ConnectedComponentsCalculator:
15 |
16 | def __init__(self, scored_pairs_table: DataFrame, max_edges_clustering: int,
17 | edge_filter_thresholds: List[float], table_checkpointer: TableCheckpointer):
18 | self.scored_pairs_table = scored_pairs_table
19 | self.max_edges_clustering = max_edges_clustering
20 | self.edge_filter_thresholds = edge_filter_thresholds
21 | self.table_checkpointer = table_checkpointer
22 |
23 | @staticmethod
24 | def _create_graph(scores_table: DataFrame) -> GraphFrame:
25 | """
26 | This function creates a graph where each row-number is a vertex and each similarity score between a pair of
27 | row-numbers represents an edge. This Graph is used as input for connected components calculations to determine
28 | the subgraphs of linked pairs.
29 |
30 | Args:
31 | scores_table: a pairs table with similarity scores for each pair
32 | Returns:
33 | a GraphFrames graph object representing the row-numbers and the similarity scores between them as a graph
34 | """
35 |
36 | vertices_1 = scores_table.select('row_number_1').withColumnRenamed('row_number_1', 'id')
37 | vertices_2 = scores_table.select('row_number_2').withColumnRenamed('row_number_2', 'id')
38 | vertices = vertices_1.unionByName(vertices_2).drop_duplicates()
39 |
40 | edges = (
41 | scores_table
42 | .select('row_number_1', 'row_number_2', 'score')
43 | .withColumnRenamed('row_number_1', 'src')
44 | .withColumnRenamed('row_number_2', 'dst')
45 | )
46 | return GraphFrame(vertices, edges)
47 |
48 | def _calculate_connected_components(self, scores_table: DataFrame, checkpoint_name: str) -> DataFrame:
49 | """
50 | This function calculates the connected components (i.e. the subgraphs) of a graph of scored pairs.
51 | The result of the connected-components algorithm is cached and saved.
52 |
53 | Args:
54 | scores_table: a pairs table with similarity scores for each pair
55 | Returns:
56 | a spark dataframe containing the connected components of a graph of scored pairs.
57 | """
58 | graph = self._create_graph(scores_table)
59 | # Since graphframes 0.3.0, checkpointing is used for the connected components algorithm. The Spark cluster might
60 | # not allow writing checkpoints. In such case we fall back to the graphx algorithm that doesn't require
61 | # checkpointing.
62 | try:
63 | connected_components = graph.connectedComponents()
64 | except AnalysisException:
65 | connected_components = graph.connectedComponents(algorithm='graphx')
66 | return self.table_checkpointer(connected_components, checkpoint_name)
67 |
68 | @staticmethod
69 | def _add_component_id_to_scores_table(scores_table: DataFrame, connected_components: DataFrame) -> DataFrame:
70 | """
71 | This function joins the initial connected-component identifiers to the scored pairs table. For each scored pair,
72 | a component identifier indicating to which subgraph a pair belongs is added.
73 |
74 | Args:
75 | scores_table: a pairs table with similarity scores for each pair
76 | connected_components: a spark dataframe containing the result of the connected components algorithm
77 | Returns:
78 | the scores_table with an identifier that indicates to which component a pair belongs
79 | """
80 | scores_table_with_component_ids = (
81 | scores_table
82 | .join(connected_components, on=scores_table['row_number_1']==connected_components['id'], how='left')
83 | .withColumnRenamed('component', 'component_1')
84 | .drop('id')
85 | .join(connected_components, on=scores_table['row_number_2']==connected_components['id'], how='left')
86 | .withColumnRenamed('component', 'component_2')
87 | .drop('id')
88 | )
89 |
90 | scores_table_with_component_ids = (
91 | scores_table_with_component_ids
92 | .withColumnRenamed('component_1', 'component')
93 | .drop('component_2')
94 | )
95 | return scores_table_with_component_ids
96 |
97 | @staticmethod
98 | def _add_big_components_info(scored_pairs_w_components: DataFrame, component_cols: List[str], max_size: int) -> \
99 | Tuple[DataFrame, bool]:
100 | """
101 | This function adds information if there are components in the scored_pairs table that contain more pairs than
102 | the `max_size'. If there is a component that is too big, the function will identify this and will indicate that
103 | filter iterations are required via the `continue_iteration` value.
104 |
105 | Args:
106 | scored_pairs_w_components: a pairs table with similarity scores for each pair and component identifiers
107 | component_cols: a list of columns that represent the composite key that indicate connected-components
108 | after one or multiple iterations
109 | max_size: a number indicating the maximum size of the number of pairs that are in a connected component
110 | in the scores table
111 | Returns:
112 | the scores table with a column indicating whether a pair belongs to a big component
113 | and a boolean indicating whether there are too big components and thus whether more iterations are required.
114 | """
115 | window = Window.partitionBy(*component_cols)
116 | big_components_info = (
117 | scored_pairs_w_components
118 | .withColumn('constant', F.lit(1))
119 | .withColumn('component_size', F.count('constant').over(window))
120 | .drop('constant')
121 | .withColumn('is_in_big_component', F.col('component_size') > max_size)
122 | .drop('component_size')
123 | )
124 | continue_iteration = False
125 | if big_components_info.filter(F.col('is_in_big_component')).count() > 0:
126 | continue_iteration = True
127 | return big_components_info, continue_iteration
128 |
129 | @staticmethod
130 | def _add_low_score_info(scored_pairs: DataFrame, threshold) -> DataFrame:
131 | """
132 | This function adds a column that indicates whether a similarity score between pairs is lower than a `threshold`
133 | value.
134 |
135 | Args:
136 | scored_pairs: a pairs table with similarity scores for each pair
137 | threshold: a value indicating the treshold value for a similarity score
138 | Returns:
139 | the scores table with an addiction column indicating whether a pair has a score lower than the `threshold`
140 | """
141 | scored_pairs = (
142 | scored_pairs
143 | .withColumn('is_lower_than_threshold', F.col('score') DataFrame:
150 | """
151 | This function joins the the connected component results for each iteration to the scores table and prunes the
152 | scores table by removing edges that belong to a too big component and have a similarity score that is lower than
153 | the current threshold.
154 |
155 | Args:
156 | scored_pairs: a pairs table with similarity scores for each pair for an iteration
157 | connected_components: a spark dataframe containing the result of the connected components algorithm for
158 | an iteration
159 | iteration_number: a counter indicating which iteration it is
160 | Returns:
161 | a scores table with the component identifiers for an iteration and without edges that belonged to too big
162 | components
163 | that had scores that were below the current threshold.
164 | """
165 | # join the connected components results to the scored-pairs on `row_number_1`
166 | result = (
167 | scored_pairs
168 | .join(connected_components, on=scored_pairs['row_number_1'] == connected_components['id'], how='left')
169 | .drop('id')
170 | .withColumnRenamed('component', 'component_1')
171 | .persist())
172 |
173 | # join the connected components results to the scored-pairs on `row_number_2`
174 | result = (
175 | result
176 | .join(connected_components, on=scored_pairs['row_number_2'] == connected_components['id'], how='left')
177 | .drop('id')
178 | .withColumnRenamed('component', 'component_2')
179 | .filter((F.col('component_1') == F.col('component_2')) | (F.col('component_1').isNull() & F.col(
180 | 'component_2').isNull()))
181 | .withColumnRenamed('component_1', f'component_iteration_{iteration_number}')
182 | .drop('component_2')
183 | # add a default value for rows that were not part of the iteration, i.e. -1 since component ids' are
184 | # always positive integers
185 | .fillna(-1, subset=[f'component_iteration_{iteration_number}'])
186 | .filter((F.col('is_in_big_component') & ~F.col('is_lower_than_threshold')) |
187 | (~F.col('is_in_big_component')))
188 | .drop('is_in_big_component', 'is_lower_than_threshold')
189 | )
190 | return result
191 |
192 | @staticmethod
193 | def _add_component_id(sdf: DataFrame, iteration_count: int) -> DataFrame:
194 | """
195 | This function adds a final `component_id` after all iterations that can be used as a groupby-key to
196 | distribute the deduplication. The `component_id` is a hashed value of the concatenation with '_' seperation
197 | of all component identifiers of the iterations. The sha2 hash function is used order to make the likelihood of
198 | hash collisions negligibly small.
199 |
200 | Args:
201 | sdf: a dataframe containing the scored pairs
202 | iteration_count: a number indicating the final iteration number that is used.
203 | Returns:
204 | the final scores-table with a component-id that can be used as groupby-key to distribute the deduplication.
205 | """
206 | if iteration_count > 0:
207 | component_id_cols = ['initial_component'] + [f'component_iteration_{i}' for i in
208 | range(1, iteration_count + 1)]
209 | sdf = sdf.withColumn('component_id', F.sha2(F.concat_ws('_', *component_id_cols), numBits=256))
210 | else:
211 | # use the initial component as component_id and cast the value to string to match with the return type of
212 | # the sha2 algorithm
213 | sdf = (
214 | sdf
215 | .withColumn('component_id', F.col('initial_component').cast(T.StringType()))
216 | .drop('initial_component')
217 | )
218 | return sdf
219 |
220 | def _create_component_ids(self) -> DataFrame:
221 | """
222 | This function wraps the other methods in this class and performs the connected-component calculations and the
223 | edge filtering if required.
224 |
225 | Returns:
226 | a scored-pairs table with a component identifier that can be used as groupby-key to distribute the
227 | deduplication.
228 | """
229 | # calculate the initial connected components
230 | connected_components = self._calculate_connected_components(self.scored_pairs_table,
231 | "cached_connected_components_table_initial")
232 | # add the component ids to the scored-pairs table
233 | self.scored_pairs_table = self._add_component_id_to_scores_table(self.scored_pairs_table, connected_components)
234 | self.scored_pairs_table = self.scored_pairs_table.withColumnRenamed('component', 'initial_component')
235 |
236 | #checkpoint the results
237 | self.scored_pairs_table = self.table_checkpointer(self.scored_pairs_table,
238 | "cached_scores_with_components_table_initial")
239 |
240 | component_columns = ['initial_component']
241 | iteration_count = 0
242 | for i in range(1, len(self.edge_filter_thresholds) + 1):
243 | # set and update the threshold
244 | threshold = self.edge_filter_thresholds[i - 1]
245 |
246 | self.scored_pairs_table, continue_iter = self._add_big_components_info(self.scored_pairs_table,
247 | component_columns,
248 | self.max_edges_clustering)
249 | if not continue_iter:
250 | # no further iterations are required
251 | break
252 |
253 | self.scored_pairs_table = self._add_low_score_info(self.scored_pairs_table, threshold)
254 | # recalculate the connected_components
255 | iteration_input_sdf = (
256 | self.scored_pairs_table
257 | .filter((F.col('is_in_big_component') & ~F.col('is_lower_than_threshold')))
258 | )
259 |
260 | print(f"calculate connected components iteration {i} with edge filter threshold {threshold}")
261 | connected_components_iteration = self._calculate_connected_components(iteration_input_sdf,
262 | f'cached_connected_components_table_iteration_{i}')
263 |
264 | self.scored_pairs_table = (
265 | self
266 | ._join_connected_components_iteration_results(self.scored_pairs_table,
267 | connected_components_iteration, i)
268 | )
269 |
270 | # checkpoint the result to break the lineage after each iteration and to inspect intermediate results
271 | self.scored_pairs_table = self.table_checkpointer(self.scored_pairs_table,
272 | f'cached_scores_with_components_table_iteration_{i}')
273 |
274 | component_columns.append(f'component_iteration_{i}')
275 | iteration_count += 1
276 |
277 | # final check to see if there were sufficient iterations to reduce the size of the big components if all
278 | # specified iterations are used to reduce the component size
279 | if iteration_count == len(self.edge_filter_thresholds):
280 | self.scored_pairs_table, continue_iter = self._add_big_components_info(self.scored_pairs_table,
281 | component_columns,
282 | self.max_edges_clustering)
283 | if continue_iter:
284 | print("THE EDGE-FILTER-THRESHOLDS ARE NOT ENOUGH TO SUFFICIENTLY REDUCE THE SIZE OF THE BIG COMPONENTS")
285 |
286 | # add the final component-id that can be used as groupby key for distributed deduplication
287 | self.scored_pairs_table = self._add_component_id(self.scored_pairs_table, iteration_count)
288 |
289 | return self.table_checkpointer(self.scored_pairs_table, 'cached_scores_with_components_table')
290 |
--------------------------------------------------------------------------------
/spark_matcher/deduplicator/deduplicator.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | from typing import Optional, List, Dict
6 |
7 | from pyspark.sql import DataFrame, SparkSession, functions as F, types as T
8 | from sklearn.exceptions import NotFittedError
9 | from scipy.cluster import hierarchy
10 |
11 | from spark_matcher.blocker.blocking_rules import BlockingRule
12 | from spark_matcher.deduplicator.connected_components_calculator import ConnectedComponentsCalculator
13 | from spark_matcher.deduplicator.hierarchical_clustering import apply_deduplication
14 | from spark_matcher.matching_base.matching_base import MatchingBase
15 | from spark_matcher.scorer.scorer import Scorer
16 | from spark_matcher.table_checkpointer import TableCheckpointer
17 |
18 |
19 | class Deduplicator(MatchingBase):
20 | """
21 | Deduplicator class to apply deduplication. Provide either the column names `col_names` using the default string
22 | similarity metrics or explicitly define the string similarity metrics in a dict `field_info` as in the example
23 | below. If `blocking_rules` is left empty, default blocking rules are used. Otherwise, provide blocking rules as
24 | a list containing `BlockingRule` instances (see example below). The number of perfect matches used during
25 | training is set by `n_perfect_train_matches`.
26 |
27 | E.g.:
28 |
29 | from spark_matcher.blocker.blocking_rules import FirstNChars
30 |
31 | myDeduplicator = Deduplicator(spark_session, field_info={'name':[metric_function_1, metric_function_2],
32 | 'address:[metric_function_1, metric_function_3]},
33 | blocking_rules=[FirstNChars('name', 3)])
34 |
35 | Args:
36 | spark_session: Spark session
37 | col_names: list of column names to use for matching
38 | field_info: dict of column names as keys and lists of string similarity metrics as values
39 | blocking_rules: list of `BlockingRule` instances
40 | table_checkpointer: pointer object to store cached tables
41 | checkpoint_dir: checkpoint directory if provided
42 | n_train_samples: nr of pair samples to be created for training
43 | ratio_hashed_samples: ratio of hashed samples to be created for training, rest is sampled randomly
44 | n_perfect_train_matches: nr of perfect matches used for training
45 | scorer: a Scorer object used for scoring pairs
46 | verbose: sets verbosity
47 | max_edges_clustering: max number of edges per component that enters clustering
48 | edge_filter_thresholds: list of score thresholds to use for filtering when components are too large
49 | cluster_score_threshold: threshold value between [0.0, 1.0], only pairs are put together in clusters if
50 | cluster similarity scores are >= cluster_score_threshold
51 | cluster_linkage_method: linkage method to be used within hierarchical clustering, can take values such as
52 | 'centroid', 'single', 'complete', 'average', 'weighted', 'median', 'ward' etc.
53 | """
54 | def __init__(self, spark_session: SparkSession, col_names: Optional[List[str]] = None,
55 | field_info: Optional[Dict] = None, blocking_rules: Optional[List[BlockingRule]] = None,
56 | blocking_recall: float = 1.0, table_checkpointer: Optional[TableCheckpointer] = None,
57 | checkpoint_dir: Optional[str] = None, n_perfect_train_matches=1, n_train_samples: int = 100_000,
58 | ratio_hashed_samples: float = 0.5, scorer: Optional[Scorer] = None, verbose: int = 0,
59 | max_edges_clustering: int = 500_000,
60 | edge_filter_thresholds: List[float] = [0.45, 0.55, 0.65, 0.75, 0.85, 0.95],
61 | cluster_score_threshold: float = 0.5, cluster_linkage_method: str = "centroid"):
62 |
63 | super().__init__(spark_session, table_checkpointer, checkpoint_dir, col_names, field_info, blocking_rules,
64 | blocking_recall, n_perfect_train_matches, n_train_samples, ratio_hashed_samples, scorer,
65 | verbose)
66 |
67 | self.fitted_ = False
68 | self.max_edges_clustering = max_edges_clustering
69 | self.edge_filter_thresholds = edge_filter_thresholds
70 | self.cluster_score_threshold = cluster_score_threshold
71 | if cluster_linkage_method not in list(hierarchy._LINKAGE_METHODS.keys()):
72 | raise ValueError(f"Invalid cluster_linkage_method: {cluster_linkage_method}")
73 | self.cluster_linkage_method = cluster_linkage_method
74 | # set the checkpoints directory for graphframes
75 | self.spark_session.sparkContext.setCheckpointDir('/tmp/checkpoints')
76 |
77 | def _create_predict_pairs_table(self, sdf_blocked: DataFrame) -> DataFrame:
78 | """
79 | This method performs an alias self-join on `sdf` to create the pairs table for prediction. `sdf` is joined to
80 | itself based on the `block_key` column. The result of this join is a pairs table.
81 | """
82 | sdf_blocked_1 = self._add_suffix_to_col_names(sdf_blocked, 1)
83 | sdf_blocked_1 = sdf_blocked_1.withColumnRenamed('row_number', 'row_number_1')
84 | sdf_blocked_2 = self._add_suffix_to_col_names(sdf_blocked, 2)
85 | sdf_blocked_2 = sdf_blocked_2.withColumnRenamed('row_number', 'row_number_2')
86 |
87 | pairs = (
88 | sdf_blocked_1
89 | .join(sdf_blocked_2, on='block_key', how='inner')
90 | .filter(F.col("row_number_1") < F.col("row_number_2"))
91 | .drop_duplicates(subset=[col + "_1" for col in self.col_names] + [col + "_2" for col in self.col_names])
92 | )
93 | return pairs
94 |
95 | def _map_distributed_identifiers_to_long(self, clustered_results: DataFrame) -> DataFrame:
96 | """
97 | Method to add a unique `entity_identifier` to the results from clustering
98 |
99 | Args:
100 | clustered_results: results from clustering
101 | Returns:
102 | Spark dataframe containing `row_number` and `entity_identifier`
103 | """
104 | long_entity_ids = (
105 | clustered_results
106 | .select('entity_identifier')
107 | .drop_duplicates()
108 | .withColumn('long_entity_identifier', F.monotonically_increasing_id())
109 | )
110 | long_entity_ids = self.table_checkpointer(long_entity_ids, 'cached_long_ids_table')
111 |
112 | clustered_results = (
113 | clustered_results
114 | .join(long_entity_ids, on='entity_identifier', how='left')
115 | .drop('entity_identifier')
116 | .withColumnRenamed('long_entity_identifier', 'entity_identifier')
117 | )
118 | return clustered_results
119 |
120 | def _distributed_deduplication(self, scored_pairs_with_components: DataFrame) -> DataFrame:
121 | schema = T.StructType([T.StructField('row_number', T.LongType(), True),
122 | T.StructField('entity_identifier', T.StringType(), True)])
123 |
124 | clustered_results = (
125 | scored_pairs_with_components
126 | .select('row_number_1', 'row_number_2', 'score', 'component_id')
127 | .groupby('component_id')
128 | .applyInPandas(apply_deduplication(self.cluster_score_threshold, self.cluster_linkage_method) , schema=schema)
129 | )
130 | return self.table_checkpointer(clustered_results, "cached_clustered_results_table")
131 |
132 | @staticmethod
133 | def _add_singletons_entity_identifiers(result_sdf: DataFrame) -> DataFrame:
134 | """
135 | Function to add entity_identifier to entities (singletons) that are not combined with other entities into a
136 | deduplicated entity. If there are no singletons, the input table will be returned as it is.
137 |
138 | Args:
139 | result_sdf: Spark dataframe containing the result of deduplication where entities that are not deduplicated
140 | have a missing value in the `entity_identifier` column.
141 |
142 | Returns:
143 | Spark dataframe with `entity_identifier` values for all entities
144 |
145 | """
146 | if result_sdf.filter(F.col('entity_identifier').isNull()).count() == 0:
147 | return result_sdf
148 | start_cluster_id = (
149 | result_sdf
150 | .filter(F.col('entity_identifier').isNotNull())
151 | .select(F.max('entity_identifier'))
152 | .first()[0])
153 |
154 | singletons_entity_identifiers = (
155 | result_sdf
156 | .filter(F.col('entity_identifier').isNull())
157 | .select('row_number')
158 | .rdd
159 | .zipWithIndex()
160 | .toDF()
161 | .withColumn('row_number', F.col('_1').getItem("row_number"))
162 | .drop("_1")
163 | .withColumn('entity_identifier_singletons', F.col('_2') + start_cluster_id + 1)
164 | .drop("_2"))
165 |
166 | result_sdf = (
167 | result_sdf
168 | .join(singletons_entity_identifiers, on='row_number', how='left')
169 | .withColumn('entity_identifier',
170 | F.when(F.col('entity_identifier').isNull(),
171 | F.col('entity_identifier_singletons'))
172 | .otherwise(F.col('entity_identifier')))
173 | .drop('entity_identifier_singletons'))
174 |
175 | return result_sdf
176 |
177 | def _create_deduplication_results(self, clustered_results: DataFrame, entities: DataFrame) -> DataFrame:
178 | """
179 | Joins deduplication results back to entity_table and adds identifiers to rows that are not deduplicated with
180 | other rows
181 | """
182 | deduplication_results = (
183 | entities
184 | .join(clustered_results, on='row_number', how='left')
185 | )
186 |
187 | deduplication_results = (
188 | self._add_singletons_entity_identifiers(deduplication_results)
189 | .drop('row_number', 'block_key')
190 | )
191 | return deduplication_results
192 |
193 | def _get_large_clusters(self, pairs_with_components: DataFrame) -> DataFrame:
194 | """
195 | Components that are too large after iteratively removing edges with a similarity score lower than the
196 | thresholds in `edge_filter_thresholds`, are considered to be one entity. For these the `component_id` is
197 | temporarily used as an `entity_identifier`.
198 | """
199 | large_component = (pairs_with_components.filter(F.col('is_in_big_component'))
200 | .withColumn('entity_identifier', F.col('component_id')))
201 |
202 | large_cluster = (large_component.select('row_number_1', 'entity_identifier')
203 | .withColumnRenamed('row_number_1', 'row_number')
204 | .unionByName(large_component.select('row_number_2', 'entity_identifier')
205 | .withColumnRenamed('row_number_2', 'row_number'))
206 | .drop_duplicates()
207 | .persist())
208 | return large_cluster
209 |
210 | def predict(self, sdf: DataFrame, threshold: float = 0.5):
211 | """
212 | Method to predict on data used for training or new data.
213 |
214 | Args:
215 | sdf: table to be applied entity deduplication
216 | threshold: probability threshold for similarity score
217 |
218 | Returns:
219 | Spark dataframe with the deduplication result
220 |
221 | """
222 | if not self.fitted_:
223 | raise NotFittedError('The Deduplicator instance is not fitted yet. Call `fit` and train the instance.')
224 |
225 | sdf = self._simplify_dataframe_for_matching(sdf)
226 | sdf = sdf.withColumn('row_number', F.monotonically_increasing_id())
227 | # make sure the `row_number` is a fixed value. See the docs of `monotonically_increasing_id`, the function
228 | # depends on the partitioning of the table. Not fixing/storing the result will cause problems when the table is
229 | # recalculated under the hood during other operations. Saving it to disk and reading it in, breaks lineage and
230 | # forces the result to be deterministic in subsequent operations.
231 | sdf = self.table_checkpointer(sdf, "cached_entities_numbered_table")
232 |
233 | sdf_blocked = self.table_checkpointer(self.blocker.transform(sdf), "cached_block_table")
234 | pairs_table = self.table_checkpointer(self._create_predict_pairs_table(sdf_blocked), "cached_pairs_table")
235 | metrics_table = self.table_checkpointer(self._calculate_metrics(pairs_table), "cached_metrics_table")
236 |
237 | scores_table = (
238 | metrics_table
239 | .withColumn('score', self.scoring_learner.predict_proba(metrics_table['similarity_metrics']))
240 | )
241 | scores_table = self.table_checkpointer(scores_table, "cached_scores_table")
242 |
243 | scores_table_filtered = (
244 | scores_table
245 | .filter(F.col('score') >= threshold)
246 | .drop('block_key', 'similarity_metrics')
247 | )
248 |
249 | # create the subgraphs by using the connected components algorithm
250 | connected_components_calculator = ConnectedComponentsCalculator(scores_table_filtered,
251 | self.max_edges_clustering,
252 | self.edge_filter_thresholds,
253 | self.table_checkpointer)
254 |
255 | # calculate the component id for each scored-pairs
256 | scored_pairs_with_component_ids = connected_components_calculator._create_component_ids()
257 |
258 | large_cluster = self._get_large_clusters(scored_pairs_with_component_ids)
259 |
260 | # apply the distributed deduplication to components that are sufficiently small
261 | clustered_results = self._distributed_deduplication(
262 | scored_pairs_with_component_ids.filter(~F.col('is_in_big_component')))
263 |
264 | all_results = large_cluster.unionByName(clustered_results)
265 |
266 | # assign a unique entity_identifier to all components
267 | all_results_long = self._map_distributed_identifiers_to_long(all_results)
268 | all_results_long = self.table_checkpointer(all_results_long, "cached_clustered_results_with_long_ids_table")
269 |
270 | deduplication_results = self._create_deduplication_results(all_results_long, sdf)
271 |
272 | return deduplication_results
273 |
--------------------------------------------------------------------------------
/spark_matcher/deduplicator/hierarchical_clustering.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | from collections import defaultdict
6 | from typing import List, Dict, Callable
7 |
8 | import numpy as np
9 | import pandas as pd
10 | import scipy.spatial.distance as ssd
11 | from scipy.cluster import hierarchy
12 |
13 |
14 | def _perform_clustering(component: pd.DataFrame, threshold: float, linkage_method: str):
15 | """
16 | Apply hierarchical clustering to scored_pairs_table with component_ids
17 | Args:
18 | component: pandas dataframe containing all pairs and the similarity scores for a connected component
19 | threshold: threshold to apply in hierarchical clustering
20 | linkage_method: linkage method to apply in hierarchical clustering
21 | Returns:
22 | Generator that contains tuples of ids and scores
23 | """
24 |
25 | distance_threshold = 1 - threshold
26 | if len(component) > 1:
27 | i_to_id, condensed_distances = _get_condensed_distances(component)
28 |
29 | linkage = hierarchy.linkage(condensed_distances, method=linkage_method)
30 | partition = hierarchy.fcluster(linkage, t=distance_threshold, criterion='distance')
31 |
32 | clusters: Dict[int, List[int]] = defaultdict(list)
33 |
34 | for i, cluster_id in enumerate(partition):
35 | clusters[cluster_id].append(i)
36 |
37 | for cluster in clusters.values():
38 | if len(cluster) > 1:
39 | yield tuple(i_to_id[i] for i in cluster), None
40 | else:
41 | ids = np.array([int(component['row_number_1']), int(component['row_number_2'])])
42 | score = float(component['score'])
43 | if score > threshold:
44 | yield tuple(ids), (score,) * 2
45 |
46 |
47 | def _convert_data_to_adjacency_matrix(component: pd.DataFrame):
48 | """
49 | This function converts a pd.DataFrame to a numpy adjacency matrix
50 | Args:
51 | component: pd.DataFrame
52 | Returns:
53 | index of elements of the components and a numpy adjacency matrix
54 | """
55 | def _get_adjacency_matrix(df, col1, col2, score_col:str = 'score'):
56 | df = pd.crosstab(df[col1], df[col2], values=df[score_col], aggfunc='max')
57 | idx = df.columns.union(df.index)
58 | df = df.reindex(index = idx, columns=idx, fill_value=0).fillna(0)
59 | return df
60 |
61 | a_to_b = _get_adjacency_matrix(component, "row_number_1", "row_number_2")
62 | b_to_a = _get_adjacency_matrix(component, "row_number_2", "row_number_1")
63 |
64 | symmetric_adjacency_matrix = a_to_b + b_to_a
65 |
66 | return symmetric_adjacency_matrix.index, np.array(symmetric_adjacency_matrix)
67 |
68 |
69 | def _get_condensed_distances(component: pd.DataFrame):
70 | """
71 | Converts the pairwise list of distances to "condensed distance matrix" required by the hierarchical clustering
72 | algorithms. Also return a dictionary that maps the distance matrix to the ids.
73 |
74 | Args:
75 | component: pandas dataframe containing all pairs and the similarity scores for a connected component
76 |
77 | Returns:
78 | condensed distances and a dict with ids
79 | """
80 |
81 | i_to_id, adj_matrix = _convert_data_to_adjacency_matrix(component)
82 | distances = (np.ones_like(adj_matrix) - np.eye(len(adj_matrix))) - adj_matrix
83 | return dict(enumerate(i_to_id)), ssd.squareform(distances)
84 |
85 |
86 | def _convert_dedupe_result_to_pandas_dataframe(dedupe_result: List, component_id: int) -> pd.DataFrame:
87 | """
88 | Function to convert the dedupe result into a pandas dataframe.
89 | E.g.
90 | dedupe_result = [((1, 2), array([0.96, 0.96])), ((3, 4, 5), array([0.95, 0.95, 0.95]))]
91 |
92 | returns
93 |
94 | | row_number | entity_identifier |
95 | | ---------- | ------------------- |
96 | | 1 | 1 |
97 | | 2 | 1 |
98 | | 3 | 2 |
99 | | 4 | 2 |
100 | | 5 | 2 |
101 |
102 | Args:
103 | dedupe_result: the result with the deduplication results from the clustering
104 | Returns:
105 | pandas dataframe with row_number and entity_identifier
106 | """
107 | if len(dedupe_result) == 0:
108 | return pd.DataFrame(data={'row_number': [], 'entity_identifier': []})
109 |
110 | entity_identifier = 0
111 | df_list = []
112 |
113 | for ids, _ in dedupe_result:
114 | df_ = pd.DataFrame(data={'row_number': list(ids), 'entity_identifier': f"{component_id}_{entity_identifier}"})
115 | df_list.append(df_)
116 | entity_identifier += 1
117 | return pd.concat(df_list)
118 |
119 |
120 | def apply_deduplication(cluster_score_threshold: float, cluster_linkage_method: str) -> Callable:
121 | """
122 | This function is a wrapper function to parameterize the _apply_deduplucation function with extra parameters for
123 | the cluster_score_threshold and the linkage method.
124 | Args:
125 | cluster_score_threshold: a float in [0,1]
126 | cluster_linkage_method: a string indicating the linkage method to be used for hierarchical clustering
127 | Returns:
128 | a function, i.e. _apply_deduplication, that can be called as a Pandas udf
129 | """
130 | def _apply_deduplication(component: pd.DataFrame) -> pd.DataFrame:
131 | """
132 | This function applies deduplication on a component, i.e. a subgraph calculated by the connected components
133 | algorithm. This function is applied to a spark dataframe in a pandas udf to distribute the deduplication
134 | over a Spark cluster, component by component.
135 | Args:
136 | component: a pandas Dataframe
137 | Returns:
138 | a pandas Dataframe with the results from the hierarchical clustering of the deduplication
139 | """
140 | component_id = component['component_id'][0]
141 |
142 | # perform the clustering:
143 | component = list(_perform_clustering(component, cluster_score_threshold, cluster_linkage_method))
144 |
145 | # convert the results to a dataframe:
146 | return _convert_dedupe_result_to_pandas_dataframe(component, component_id)
147 |
148 | return _apply_deduplication
149 |
--------------------------------------------------------------------------------
/spark_matcher/matcher/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['Matcher']
2 |
3 | from .matcher import Matcher
--------------------------------------------------------------------------------
/spark_matcher/matcher/matcher.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | from typing import Optional, List, Dict
6 |
7 | from pyspark.sql import DataFrame, functions as F
8 | from pyspark.sql import SparkSession
9 | from pyspark.sql import Window
10 | from sklearn.exceptions import NotFittedError
11 |
12 | from spark_matcher.blocker.blocking_rules import BlockingRule
13 | from spark_matcher.matching_base.matching_base import MatchingBase
14 | from spark_matcher.scorer.scorer import Scorer
15 | from spark_matcher.table_checkpointer import TableCheckpointer
16 |
17 |
18 | class Matcher(MatchingBase):
19 | """
20 | Matcher class to apply record linkage. Provide either the column names `col_names` using the default string
21 | similarity metrics or explicitly define the string similarity metrics in a dict `field_info` as in the example
22 | below. If `blocking_rules` is left empty, default blocking rules are used. Otherwise provide blocking rules as
23 | a list containing `BlockingRule` instances (see example below). The number of perfect matches used during
24 | training is set by `n_perfect_train_matches`.
25 |
26 | E.g.:
27 |
28 | from spark_matcher.blocker.blocking_rules import FirstNChars
29 |
30 | myMatcher = Matcher(spark_session, field_info={'name':[metric_function_1, metric_function_2],
31 | 'address:[metric_function_1, metric_function_3]},
32 | blocking_rules=[FirstNChars('name', 3)])
33 |
34 | Args:
35 | spark_session: Spark session
36 | col_names: list of column names to use for matching
37 | field_info: dict of column names as keys and lists of string similarity metrics as values
38 | blocking_rules: list of `BlockingRule` instances
39 | n_train_samples: nr of pair samples to be created for training
40 | ratio_hashed_samples: ratio of hashed samples to be created for training, rest is sampled randomly
41 | n_perfect_train_matches: nr of perfect matches used for training
42 | scorer: a Scorer object used for scoring pairs
43 | verbose: sets verbosity
44 | """
45 | def __init__(self, spark_session: SparkSession, table_checkpointer: Optional[TableCheckpointer]=None,
46 | checkpoint_dir: Optional[str]=None, col_names: Optional[List[str]] = None,
47 | field_info: Optional[Dict] = None, blocking_rules: Optional[List[BlockingRule]] = None,
48 | blocking_recall: float = 1.0, n_perfect_train_matches=1, n_train_samples: int = 100_000,
49 | ratio_hashed_samples: float = 0.5, scorer: Optional[Scorer] = None, verbose: int = 0):
50 | super().__init__(spark_session, table_checkpointer, checkpoint_dir, col_names, field_info, blocking_rules,
51 | blocking_recall, n_perfect_train_matches, n_train_samples, ratio_hashed_samples, scorer,
52 | verbose)
53 | self.fitted_ = False
54 |
55 | def _create_predict_pairs_table(self, sdf_1_blocked: DataFrame, sdf_2_blocked: DataFrame) -> DataFrame:
56 | """
57 | Method to create pairs within blocks as provided by both input Spark dataframes in the column `block_key`.
58 | Note that an inner join is performed to create pairs: entries in `sdf_1_blocked` and `sdf_2_blocked`
59 | that don't share a `block_key` value in the other table will not appear in the resulting pairs table.
60 |
61 | Args:
62 | sdf_1_blocked: Spark dataframe containing all columns used for matching and the `block_key`
63 | sdf_2_blocked: Spark dataframe containing all columns used for matching and the `block_key`
64 |
65 | Returns:
66 | Spark dataframe containing all pairs to score
67 |
68 | """
69 | sdf_1_blocked = self._add_suffix_to_col_names(sdf_1_blocked, 1)
70 | sdf_2_blocked = self._add_suffix_to_col_names(sdf_2_blocked, 2)
71 |
72 | predict_pairs_table = (sdf_1_blocked.join(sdf_2_blocked, on='block_key', how='inner')
73 | .drop_duplicates(subset=[col + "_1" for col in self.col_names] +
74 | [col + "_2" for col in self.col_names])
75 | )
76 | return predict_pairs_table
77 |
78 | def predict(self, sdf_1, sdf_2, threshold=0.5, top_n=None):
79 | """
80 | Method to predict on data used for training or new data.
81 |
82 | Args:
83 | sdf_1: first table
84 | sdf_2: second table
85 | threshold: probability threshold
86 | top_n: only return best `top_n` matches above threshold
87 |
88 | Returns:
89 | Spark dataframe with the matching result
90 |
91 | """
92 | if not self.fitted_:
93 | raise NotFittedError('The Matcher instance is not fitted yet. Call `fit` and train the instance.')
94 |
95 | sdf_1, sdf_2 = self._simplify_dataframe_for_matching(sdf_1), self._simplify_dataframe_for_matching(sdf_2)
96 | sdf_1_blocks, sdf_2_blocks = self.blocker.transform(sdf_1, sdf_2)
97 | pairs_table = self._create_predict_pairs_table(sdf_1_blocks, sdf_2_blocks)
98 | metrics_table = self._calculate_metrics(pairs_table)
99 | scores_table = (
100 | metrics_table
101 | .withColumn('score', self.scoring_learner.predict_proba(metrics_table['similarity_metrics']))
102 | )
103 | scores_table_filtered = (scores_table.filter(F.col('score') >= threshold)
104 | .drop('block_key', 'similarity_metrics'))
105 |
106 | if top_n:
107 | # we add additional columns to order by to remove ties and return exactly n items
108 | window = (Window.partitionBy(*[x + "_1" for x in self.col_names])
109 | .orderBy(F.desc('score'), *[F.asc(col) for col in [x + "_2" for x in self.col_names]]))
110 | result = (scores_table_filtered.withColumn('rank', F.rank().over(window))
111 | .filter(F.col('rank') <= top_n)
112 | .drop('rank'))
113 | else:
114 | result = scores_table_filtered
115 |
116 | return result
117 |
--------------------------------------------------------------------------------
/spark_matcher/matching_base/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['MatchingBase']
2 |
3 | from .matching_base import MatchingBase
--------------------------------------------------------------------------------
/spark_matcher/matching_base/matching_base.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | import warnings
6 | from typing import Optional, List, Dict
7 |
8 | import dill
9 | from pyspark.sql import DataFrame, functions as F, SparkSession
10 | import numpy as np
11 | import pandas as pd
12 | from thefuzz.fuzz import token_set_ratio, token_sort_ratio
13 |
14 | from spark_matcher.activelearner.active_learner import ScoringLearner
15 | from spark_matcher.blocker.block_learner import BlockLearner
16 | from spark_matcher.sampler.training_sampler import HashSampler, RandomSampler
17 | from spark_matcher.scorer.scorer import Scorer
18 | from spark_matcher.similarity_metrics import SimilarityMetrics
19 | from spark_matcher.blocker.blocking_rules import BlockingRule, default_blocking_rules
20 | from spark_matcher.table_checkpointer import TableCheckpointer, ParquetCheckPointer
21 |
22 |
23 | class MatchingBase:
24 |
25 | def __init__(self, spark_session: SparkSession, table_checkpointer: Optional[TableCheckpointer] = None,
26 | checkpoint_dir: Optional[str] = None, col_names: Optional[List[str]] = None,
27 | field_info: Optional[Dict] = None, blocking_rules: Optional[List[BlockingRule]] = None,
28 | blocking_recall: float = 1.0, n_perfect_train_matches=1, n_train_samples: int = 100_000,
29 | ratio_hashed_samples: float = 0.5, scorer: Optional[Scorer] = None, verbose: int = 0):
30 | self.spark_session = spark_session
31 | self.table_checkpointer = table_checkpointer
32 | if not self.table_checkpointer:
33 | if checkpoint_dir:
34 | self.table_checkpointer = ParquetCheckPointer(self.spark_session, checkpoint_dir,
35 | "checkpoint_deduplicator")
36 | else:
37 | warnings.warn(
38 | 'Either `table_checkpointer` or `checkpoint_dir` should be provided. This instance can only be used'
39 | ' when loading a previously saved instance.')
40 |
41 | if col_names:
42 | self.col_names = col_names
43 | self.field_info = {col_name: [token_set_ratio, token_sort_ratio] for col_name in
44 | self.col_names}
45 | elif field_info:
46 | self.col_names = list(field_info.keys())
47 | self.field_info = field_info
48 | else:
49 | warnings.warn(
50 | 'Either `col_names` or `field_info` should be provided. This instance can only be used when loading a '
51 | 'previously saved instance.')
52 | self.col_names = [''] # needed for instantiating ScoringLearner
53 |
54 | self.n_train_samples = n_train_samples
55 | self.ratio_hashed_samples = ratio_hashed_samples
56 | self.n_perfect_train_matches = n_perfect_train_matches
57 | self.verbose = verbose
58 |
59 | if not scorer:
60 | scorer = Scorer(self.spark_session)
61 | self.scoring_learner = ScoringLearner(self.col_names, scorer, verbose=self.verbose)
62 |
63 | self.blocking_rules = blocking_rules
64 | if not self.blocking_rules:
65 | self.blocking_rules = [blocking_rule(col) for col in self.col_names for blocking_rule in
66 | default_blocking_rules]
67 | self.blocker = BlockLearner(blocking_rules=self.blocking_rules, recall=blocking_recall,
68 | table_checkpointer=self.table_checkpointer, verbose=self.verbose)
69 |
70 | def save(self, path: str) -> None:
71 | """
72 | Save the current instance to a pickle file.
73 |
74 | Args:
75 | path: Path and file name of pickle file
76 |
77 | """
78 | to_save = self.__dict__
79 |
80 | # the spark sessions need to be removed as they cannot be saved and re-used later
81 | to_save['spark_session'] = None
82 | setattr(to_save['scoring_learner'].learner.estimator, 'spark_session', None)
83 | setattr(to_save['table_checkpointer'], 'spark_session', None)
84 |
85 | with open(path, 'wb') as f:
86 | dill.dump(to_save, f)
87 |
88 | def load(self, path: str) -> None:
89 | """
90 | Load a previously trained and saved Matcher instance.
91 |
92 | Args:
93 | path: Path and file name of pickle file
94 |
95 | """
96 | with open(path, 'rb') as f:
97 | loaded_obj = dill.load(f)
98 |
99 | # the spark session that was removed before saving needs to be filled with the spark session of this instance
100 | loaded_obj['spark_session'] = self.spark_session
101 | setattr(loaded_obj['scoring_learner'].learner.estimator, 'spark_session', self.spark_session)
102 | setattr(loaded_obj['table_checkpointer'], 'spark_session', self.spark_session)
103 |
104 | self.__dict__.update(loaded_obj)
105 |
106 | def _simplify_dataframe_for_matching(self, sdf: DataFrame) -> DataFrame:
107 | """
108 | Select only the columns used for matching and drop duplicates.
109 | """
110 | return sdf.select(*self.col_names).drop_duplicates()
111 |
112 | def _create_train_pairs_table(self, sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None) -> DataFrame:
113 | """
114 | Create pairs that are used for training. Based on the given 'ratio_hashed_samples', part of the sample
115 | pairs are generated using MinHashLSH technique to create pairs that are more likely to be a match. Rest is
116 | sampled using random selection.
117 |
118 | Args:
119 | sdf_1: Spark dataframe containing the first table to with the input should be matched
120 | sdf_2: Optional: Spark dataframe containing the second table that should be matched to the first table
121 |
122 | Returns:
123 | Spark dataframe with sampled pairs that should be compared during training
124 | """
125 | n_hashed_samples = int(self.n_train_samples * self.ratio_hashed_samples)
126 |
127 | h_sampler = HashSampler(self.table_checkpointer, self.col_names, n_hashed_samples)
128 | hashed_pairs_table = self.table_checkpointer.checkpoint_table(h_sampler.create_pairs_table(sdf_1, sdf_2),
129 | checkpoint_name='minhash_pairs_table')
130 |
131 | # creating additional random samples to ensure that the exact number of n_train_samples
132 | # is obtained after dropping duplicated records
133 | n_random_samples = int(1.5 * (self.n_train_samples - hashed_pairs_table.count()))
134 |
135 | r_sampler = RandomSampler(self.table_checkpointer, self.col_names, n_random_samples)
136 | random_pairs_table = self.table_checkpointer.checkpoint_table(r_sampler.create_pairs_table(sdf_1, sdf_2),
137 | checkpoint_name='random_pairs_table')
138 |
139 | pairs_table = (
140 | random_pairs_table
141 | .unionByName(hashed_pairs_table)
142 | .withColumn('perfect_train_match', F.lit(False))
143 | .drop_duplicates()
144 | .limit(self.n_train_samples)
145 | )
146 |
147 | # add some perfect matches to assure there are two labels in the train data and the classifier can be trained
148 | perfect_matches = (pairs_table.withColumn('perfect_train_match', F.lit(True))
149 | .limit(self.n_perfect_train_matches))
150 | for col in self.col_names:
151 | perfect_matches = perfect_matches.withColumn(col + "_1", F.col(col + "_2"))
152 |
153 | pairs_table = (perfect_matches.unionByName(pairs_table
154 | .limit(self.n_train_samples - self.n_perfect_train_matches)))
155 |
156 | return pairs_table
157 |
158 | def _calculate_metrics(self, pairs_table: DataFrame) -> DataFrame:
159 | """
160 | Method to apply similarity metrics to pairs table.
161 |
162 | Args:
163 | pairs_table: Spark dataframe containing pairs table
164 |
165 | Returns:
166 | Spark dataframe with pairs table and newly created `similarity_metrics` column
167 |
168 | """
169 | similarity_metrics = SimilarityMetrics(self.field_info)
170 | return similarity_metrics.transform(pairs_table)
171 |
172 | def _create_blocklearning_input(self, metrics_table: pd.DataFrame, threshold: int = 0.5) -> DataFrame:
173 | """
174 | Method to collect data used for block learning. This data consists of the manually labelled pairs with label 1
175 | and the train pairs with a score above the `threshold`
176 |
177 | Args:
178 | metrics_table: Pandas dataframe containing the similarity metrics
179 | threshold: scoring threshold
180 |
181 | Returns:
182 | Spark dataframe with data for block learning
183 |
184 | """
185 | metrics_table['score'] = self.scoring_learner.predict_proba(np.array(metrics_table['similarity_metrics'].tolist()))[:, 1]
186 | metrics_table.loc[metrics_table['score'] > threshold, 'label'] = '1'
187 |
188 | # get labelled positives from activelearner
189 | positive_train_labels = (self.scoring_learner.train_samples[self.scoring_learner.train_samples['y'] == '1']
190 | .rename(columns={'y': 'label'}))
191 | metrics_table = pd.concat([metrics_table, positive_train_labels]).drop_duplicates(
192 | subset=[col + "_1" for col in self.col_names] + [col + "_2" for col in self.col_names], keep='last')
193 |
194 | metrics_table = metrics_table[metrics_table.label == '1']
195 |
196 | metrics_table['row_id'] = np.arange(len(metrics_table))
197 | return self.spark_session.createDataFrame(metrics_table)
198 |
199 | def _add_suffix_to_col_names(self, sdf: DataFrame, suffix: int):
200 | """
201 | This method adds a suffix to the columns that are used in the algorithm.
202 | This is done in order to do a join (with two dataframes with the same schema) to create the pairs table.
203 | """
204 | for col in self.col_names:
205 | sdf = sdf.withColumnRenamed(col, f"{col}_{suffix}")
206 | return sdf
207 |
208 | def fit(self, sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None) -> 'MatchingBase':
209 | """
210 | Fit the MatchingBase instance on the two dataframes `sdf_1` and `sdf_2` using active learning. You will be
211 | prompted to enter whether the presented pairs are a match or not. Note that `sdf_2` is an optional argument.
212 | `sdf_2` is used for Matcher (i.e. matching one table to another). In the case of Deduplication, only providing
213 | `sdf_1` is sufficient, in that case `sdf_1` will be deduplicated.
214 |
215 | Args:
216 | sdf_1: Spark dataframe
217 | sdf_2: Optional: Spark dataframe
218 |
219 | Returns:
220 | Fitted MatchingBase instance
221 |
222 | """
223 | sdf_1 = self._simplify_dataframe_for_matching(sdf_1)
224 | if sdf_2:
225 | sdf_2 = self._simplify_dataframe_for_matching(sdf_2)
226 |
227 | pairs_table = self._create_train_pairs_table(sdf_1, sdf_2)
228 | metrics_table = self._calculate_metrics(pairs_table)
229 | metrics_table_pdf = metrics_table.toPandas()
230 | self.scoring_learner.fit(metrics_table_pdf)
231 | block_learning_input = self._create_blocklearning_input(metrics_table_pdf)
232 | self.blocker.fit(block_learning_input)
233 | self.fitted_ = True
234 | return self
235 |
--------------------------------------------------------------------------------
/spark_matcher/sampler/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['HashSampler', 'RandomSampler']
2 |
3 | from .training_sampler import HashSampler, RandomSampler
--------------------------------------------------------------------------------
/spark_matcher/sampler/training_sampler.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | import abc
6 | from typing import List, Tuple, Optional, Union
7 |
8 | import numpy as np
9 | from pyspark.ml.feature import CountVectorizer, VectorAssembler, MinHashLSH
10 | from pyspark.ml.linalg import DenseVector, SparseVector
11 | from pyspark.sql import DataFrame, functions as F, types as T
12 |
13 | from spark_matcher.config import MINHASHING_MAXDF, MINHASHING_VOCABSIZE
14 | from spark_matcher.table_checkpointer import TableCheckpointer
15 |
16 |
17 | class Sampler:
18 | """
19 | Sampler base class to generate pairs for training.
20 |
21 | Args:
22 | col_names: list of column names to use for matching
23 | n_train_samples: nr of pair samples to be created for training
24 | table_checkpointer: table_checkpointer object to store cached tables
25 | """
26 | def __init__(self, col_names: List[str], n_train_samples: int, table_checkpointer: TableCheckpointer) -> DataFrame:
27 | self.col_names = col_names
28 | self.n_train_samples = n_train_samples
29 | self.table_checkpointer = table_checkpointer
30 |
31 | @abc.abstractmethod
32 | def create_pairs_table(self, sdf_1: DataFrame, sdf_2: DataFrame) -> DataFrame:
33 | pass
34 |
35 |
36 | class HashSampler(Sampler):
37 | def __init__(self, table_checkpointer: TableCheckpointer, col_names: List[str], n_train_samples: int,
38 | threshold: float = 0.5, num_hash_tables: int = 10) -> DataFrame:
39 | """
40 | Sampler class to generate pairs using MinHashLSH method by selecting pairs that are more likely to be a match.
41 |
42 | Args:
43 | threshold: for the distance of hashed pairs, values below threshold will be returned
44 | (threshold being equal to 1.0 will return all pairs with at least one common shingle)
45 | num_hash_tables: number of hashes to be applied
46 | table_checkpointer: table_checkpointer object to store cached tables
47 | Returns:
48 | A spark dataframe which contains all selected pairs for training
49 | """
50 | super().__init__(col_names, n_train_samples, table_checkpointer)
51 | self.threshold = threshold
52 | self.num_hash_tables = num_hash_tables
53 |
54 | @staticmethod
55 | @F.udf(T.BooleanType())
56 | def is_non_zero_vector(vector):
57 | """
58 | Check if a vector has at least 1 non-zero entry. This function can deal with dense or sparse vectors. This is
59 | needed as the VectorAssembler can return dense or sparse vectors, dependent on what is more memory efficient.
60 |
61 | Args:
62 | vector: vector
63 |
64 | Returns:
65 | boolean whether a vector has at least 1 non-zero entry
66 |
67 | """
68 | if isinstance(vector, DenseVector):
69 | return bool(len(vector))
70 | if isinstance(vector, SparseVector):
71 | return vector.indices.size > 0
72 |
73 | def _vectorize(self, col_names: List[str], sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None,
74 | max_df: Union[float, int] = MINHASHING_MAXDF, vocab_size: int = MINHASHING_VOCABSIZE) -> Union[
75 | Tuple[DataFrame, DataFrame], DataFrame]:
76 | """
77 | This function is used to vectorize word based features. First, given dataframes are united to cover all feature
78 | space. Given columns are converted into a vector. `sdf_2` is only required for record matching, for
79 | deduplication only `sdf_1` is required. To prevent too many (redundant) matches on frequently occurring tokens,
80 | the maximum document frequency for vectorization is limited by maxDF.
81 |
82 | Args:
83 | col_names: list of column names to create feature vectors
84 | sdf_1: spark dataframe
85 | sdf_2: Optional: spark dataframe
86 | max_df: max document frequency to use for count vectorization
87 | vocab_size: vocabulary size of vectorizer
88 |
89 | Returns:
90 | Adds a 'features' column to given spark dataframes
91 |
92 | """
93 | if sdf_2:
94 | # creating a source column to separate tables after vectorization is completed
95 | sdf_1 = sdf_1.withColumn('source', F.lit('sdf_1'))
96 | sdf_2 = sdf_2.withColumn('source', F.lit('sdf_2'))
97 | sdf_merged = sdf_1.unionByName(sdf_2)
98 | else:
99 | sdf_merged = sdf_1
100 |
101 | for col in col_names:
102 | # converting strings into an array
103 | sdf_merged = sdf_merged.withColumn(f'{col}_array', F.split(F.col(col), ' '))
104 |
105 | # setting CountVectorizer
106 | # minDF is set to 2 because if a token occurs only once there is no other occurrence to match it to
107 | # the maximum document frequency may not be smaller than the minimum document frequency which is assured
108 | # below
109 | n_rows = sdf_merged.count()
110 | if isinstance(max_df, float) and (n_rows * max_df < 2):
111 | max_df = 10
112 | cv = CountVectorizer(binary=True, minDF=2, maxDF=max_df, vocabSize=vocab_size, inputCol=f"{col}_array",
113 | outputCol=f"{col}_vector")
114 |
115 | # creating vectors
116 | model = cv.fit(sdf_merged)
117 | sdf_merged = model.transform(sdf_merged)
118 | sdf_merged = self.table_checkpointer(sdf_merged, "cached_vectorized_table")
119 |
120 | # in case of col_names contains multiple columns, merge all vectors into one
121 | vec_assembler = VectorAssembler(outputCol="features")
122 | vec_assembler.setInputCols([f'{col}_vector' for col in col_names])
123 |
124 | sdf_merged = vec_assembler.transform(sdf_merged)
125 |
126 | # breaking the lineage is found to be required below
127 | sdf_merged = self.table_checkpointer(
128 | sdf_merged.filter(HashSampler.is_non_zero_vector('features')), 'minhash_vectorized')
129 |
130 | if sdf_2:
131 | sdf_1_vectorized = sdf_merged.filter(F.col('source') == 'sdf_1').select(*col_names, 'features')
132 | sdf_2_vectorized = sdf_merged.filter(F.col('source') == 'sdf_2').select(*col_names, 'features')
133 | return sdf_1_vectorized, sdf_2_vectorized
134 | sdf_1_vectorized = sdf_merged.select(*col_names, 'features')
135 | return sdf_1_vectorized
136 |
137 | @staticmethod
138 | def _apply_min_hash(sdf_1: DataFrame, sdf_2: DataFrame, col_names: List[str],
139 | threshold: float, num_hash_tables: int):
140 | """
141 | This function is used to apply MinHasLSH technique to calculate the Jaccard distance between feature vectors.
142 |
143 | Args:
144 | sdf_1: spark dataframe with the `features` column and the column `row_number_1`
145 | sdf_2: spark dataframe with the `features` column and the column `row_number_2`
146 | col_names: list of column names to create feature vectors
147 |
148 | Returns:
149 | Creates one spark dataframe that contains all pairs that has a smaller Jaccard distance than threshold
150 | """
151 | mh = MinHashLSH(inputCol="features", outputCol="hashes", seed=42, numHashTables=num_hash_tables)
152 |
153 | model = mh.fit(sdf_1)
154 | model.transform(sdf_1)
155 |
156 | sdf_distances = (
157 | model
158 | .approxSimilarityJoin(sdf_1, sdf_2, threshold, distCol="JaccardDistance")
159 | .select(*[F.col(f'datasetA.{col}').alias(f'{col}_1') for col in col_names],
160 | *[F.col(f'datasetB.{col}').alias(f'{col}_2') for col in col_names],
161 | F.col('JaccardDistance').alias('distance'),
162 | F.col('datasetA.row_number_1').alias('row_number_1'),
163 | F.col('datasetB.row_number_2').alias('row_number_2'))
164 | )
165 | return sdf_distances
166 |
167 | def create_pairs_table(self, sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None) -> DataFrame:
168 | """
169 | Create hashed pairs that are used for training. `sdf_2` is only required for record matching, for deduplication
170 | only `sdf_1` is required.
171 |
172 | Args:
173 | sdf_1: Spark dataframe containing the first table to with the input should be matched
174 | sdf_2: Optional: Spark dataframe containing the second table that should be matched to the first table
175 |
176 | Returns:
177 | Spark dataframe that contains sampled pairs selected with MinHashLSH technique
178 |
179 | """
180 | if sdf_2:
181 | sdf_1_vectorized, sdf_2_vectorized = self._vectorize(self.col_names, sdf_1, sdf_2)
182 | sdf_1_vectorized = self.table_checkpointer(
183 | sdf_1_vectorized.withColumn('row_number_1', F.monotonically_increasing_id()),
184 | checkpoint_name='sdf_1_vectorized')
185 | sdf_2_vectorized = self.table_checkpointer(
186 | sdf_2_vectorized.withColumn('row_number_2', F.monotonically_increasing_id()),
187 | checkpoint_name='sdf_2_vectorized')
188 | else:
189 | sdf_1_vectorized = self._vectorize(self.col_names, sdf_1)
190 | sdf_1_vectorized = self.table_checkpointer(
191 | sdf_1_vectorized.withColumn('row_number_1', F.monotonically_increasing_id()),
192 | checkpoint_name='sdf_1_vectorized')
193 | sdf_2_vectorized = sdf_1_vectorized.alias('sdf_2_vectorized') # matched with itself for deduplication
194 | sdf_2_vectorized = sdf_2_vectorized.withColumnRenamed('row_number_1', 'row_number_2')
195 |
196 | sdf_distances = self._apply_min_hash(sdf_1_vectorized, sdf_2_vectorized, self.col_names,
197 | self.threshold, self.num_hash_tables)
198 |
199 | # for deduplication we remove identical pairs like (a,a) and duplicates of pairs like (a,b) and (b,a)
200 | if not sdf_2:
201 | sdf_distances = sdf_distances.filter(F.col('row_number_1') < F.col('row_number_2'))
202 |
203 | hashed_pairs_table = (
204 | sdf_distances
205 | .sort('distance')
206 | .limit(self.n_train_samples)
207 | .drop('distance', 'row_number_1', 'row_number_2')
208 | )
209 |
210 | return hashed_pairs_table
211 |
212 |
213 | class RandomSampler(Sampler):
214 | def __init__(self, table_checkpointer: TableCheckpointer, col_names: List[str], n_train_samples: int) -> DataFrame:
215 | """
216 | Sampler class to generate randomly selected pairs
217 |
218 | Returns:
219 | A spark dataframe which contains randomly selected pairs for training
220 | """
221 | super().__init__(col_names, n_train_samples, table_checkpointer)
222 |
223 | def create_pairs_table(self, sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None) -> DataFrame:
224 | """
225 | Create random pairs that are used for training. `sdf_2` is only required for record matching, for deduplication
226 | only `sdf_1` is required.
227 |
228 | Args:
229 | sdf_1: Spark dataframe containing the first table to with the input should be matched
230 | sdf_2: Optional: Spark dataframe containing the second table that should be matched to the first table
231 |
232 | Returns:
233 | Spark dataframe that contains randomly sampled pairs
234 |
235 | """
236 | if sdf_2:
237 | return self._create_pairs_table_matcher(sdf_1, sdf_2)
238 | return self._create_pairs_table_deduplicator(sdf_1)
239 |
240 | def _create_pairs_table_matcher(self, sdf_1: DataFrame, sdf_2: DataFrame) -> DataFrame:
241 | """
242 | Create random pairs that are used for training.
243 |
244 | If the first table and the second table are equally large, we take the square root of `n_train_samples` from
245 | both tables to create a pairs table containing `n_train_samples` rows. If one of the table is much smaller than
246 | the other, all rows of the smallest table are taken and the number of the sample from the larger table is chosen
247 | such that the total number of pairs is `n_train_samples`.
248 |
249 | Args:
250 | sdf_1: Spark dataframe containing the first table to with the input should be matched
251 | sdf_2: Spark dataframe containing the second table that should be matched to the first table
252 |
253 | Returns:
254 | Spark dataframe with randomly sampled pairs to compared during training
255 | """
256 | sdf_1_count, sdf_2_count = sdf_1.count(), sdf_2.count()
257 |
258 | sample_size_small_table = min([int(self.n_train_samples ** 0.5), min(sdf_1_count, sdf_2_count)])
259 | sample_size_large_table = self.n_train_samples // sample_size_small_table
260 |
261 | if sdf_1_count > sdf_2_count:
262 | fraction_sdf_1 = sample_size_large_table / sdf_1_count
263 | fraction_sdf_2 = sample_size_small_table / sdf_2_count
264 | # below the `fraction` is slightly increased (but capped at 1.) and a `limit()` is applied to assure that
265 | # the exact number of required samples is obtained
266 | sdf_1_sample = (sdf_1.sample(withReplacement=False, fraction=min([1., 1.5 * fraction_sdf_1]))
267 | .limit(sample_size_large_table))
268 | sdf_2_sample = (sdf_2.sample(withReplacement=False, fraction=min([1., 1.5 * fraction_sdf_2]))
269 | .limit(sample_size_small_table))
270 | else:
271 | fraction_sdf_1 = sample_size_small_table / sdf_1_count
272 | fraction_sdf_2 = sample_size_large_table / sdf_2_count
273 | sdf_1_sample = (sdf_1.sample(withReplacement=False, fraction=min([1., 1.5 * fraction_sdf_1]))
274 | .limit(sample_size_small_table))
275 | sdf_2_sample = (sdf_2.sample(withReplacement=False, fraction=min([1., 1.5 * fraction_sdf_2]))
276 | .limit(sample_size_large_table))
277 |
278 | for col in self.col_names:
279 | sdf_1_sample = sdf_1_sample.withColumnRenamed(col, col + "_1")
280 | sdf_2_sample = sdf_2_sample.withColumnRenamed(col, col + "_2")
281 |
282 | random_pairs_table = sdf_1_sample.crossJoin(sdf_2_sample)
283 |
284 | return random_pairs_table
285 |
286 | def _create_pairs_table_deduplicator(self, sdf: DataFrame) -> DataFrame:
287 | """
288 | Create randomly selected pairs for deduplication.
289 |
290 | Args:
291 | sdf: Spark dataframe containing rows from which pairs need to be created
292 |
293 | Returns:
294 | Spark dataframe with the randomly selected pairs
295 |
296 | """
297 | n_samples_required = int(1.5 * np.ceil((self.n_train_samples * 2) ** 0.5))
298 | fraction_samples_required = min([n_samples_required / sdf.count(), 1.])
299 | sample = sdf.sample(withReplacement=False, fraction=fraction_samples_required)
300 | sample = self.table_checkpointer(sample.withColumn('row_id', F.monotonically_increasing_id()),
301 | checkpoint_name='random_pairs_deduplicator')
302 | sample_1, sample_2 = sample, sample
303 | for col in self.col_names + ['row_id']:
304 | sample_1 = sample_1.withColumnRenamed(col, col + "_1")
305 | sample_2 = sample_2.withColumnRenamed(col, col + "_2")
306 |
307 | pairs_table = (sample_1
308 | .crossJoin(sample_2)
309 | .filter(F.col('row_id_1') < F.col('row_id_2'))
310 | .limit(self.n_train_samples)
311 | .drop('row_id_1', 'row_id_2'))
312 |
313 | return pairs_table
314 |
--------------------------------------------------------------------------------
/spark_matcher/scorer/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['Scorer']
2 |
3 | from .scorer import Scorer
--------------------------------------------------------------------------------
/spark_matcher/scorer/scorer.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | from typing import Union, Optional
6 |
7 | import numpy as np
8 | from pyspark.sql import SparkSession, types as T, functions as F, Column
9 | from sklearn.base import BaseEstimator
10 | from sklearn.pipeline import make_pipeline
11 | from sklearn.preprocessing import StandardScaler
12 | from sklearn.linear_model import LogisticRegression
13 |
14 |
15 | class Scorer:
16 | def __init__(self, spark_session: SparkSession, binary_clf: Optional[BaseEstimator] = None):
17 | self.spark_session = spark_session
18 | self.binary_clf = binary_clf
19 | if not self.binary_clf:
20 | self._create_default_clf()
21 | self.fitted_ = False
22 |
23 | def _create_default_clf(self) -> None:
24 | """
25 | This method creates a Sklearn Pipeline with a Standard Scaler and a Logistic Regression classifier.
26 | """
27 | self.binary_clf = (
28 | make_pipeline(
29 | StandardScaler(),
30 | LogisticRegression(class_weight='balanced')
31 | )
32 | )
33 |
34 | def fit(self, X: np.ndarray, y: np.ndarray) -> 'Scorer':
35 | """
36 | This method fits a clf model on input data `X` nd the binary targets `y`.
37 |
38 | Args:
39 | X: training data
40 | y: training targets, containing binary values
41 |
42 | Returns:
43 | The object itself
44 | """
45 |
46 | if len(set(y)) == 1: # in case active learning only resulted in labels from one class
47 | return self
48 |
49 | self.binary_clf.fit(X, y)
50 | self.fitted_ = True
51 | return self
52 |
53 | def _predict_proba(self, X: np.ndarray) -> np.ndarray:
54 | """
55 | This method implements the code for predict_proba on a numpy array.
56 | This method is used to score all the pairs during training.
57 | """
58 | return self.binary_clf.predict_proba(X)
59 |
60 | def _predict_proba_spark(self, X: Column) -> Column:
61 | """
62 | This method implements the code for predict_proba on a spark column.
63 | This method is used to score all the pairs during inference time.
64 | """
65 | broadcasted_clf = self.spark_session.sparkContext.broadcast(self.binary_clf)
66 |
67 | @F.pandas_udf(T.FloatType())
68 | def _distributed_predict_proba(array):
69 | """
70 | This inner function defines the Pandas UDF for predict_proba on a spark cluster
71 | """
72 | return array.apply(lambda x: broadcasted_clf.value.predict_proba([x])[0][1])
73 |
74 | return _distributed_predict_proba(X)
75 |
76 | def predict_proba(self, X: Union[Column, np.ndarray]) -> Union[Column, np.ndarray]:
77 | """
78 | This method implements the abstract predict_proba method. It predicts the 'probabilities' of the target class
79 | for given input data `X`.
80 |
81 | Args:
82 | X: input data
83 |
84 | Returns:
85 | the predicted probabilities
86 | """
87 | if isinstance(X, Column):
88 | return self._predict_proba_spark(X)
89 |
90 | if isinstance(X, np.ndarray):
91 | return self._predict_proba(X)
92 |
93 | raise ValueError(f"{type(X)} is an unsupported datatype for X")
94 |
--------------------------------------------------------------------------------
/spark_matcher/similarity_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['SimilarityMetrics']
2 |
3 | from .similarity_metrics import SimilarityMetrics
4 |
--------------------------------------------------------------------------------
/spark_matcher/similarity_metrics/similarity_metrics.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | from typing import List, Callable, Dict
6 |
7 | import pandas as pd
8 |
9 | from pyspark.sql import DataFrame, Column
10 | from pyspark.sql import functions as F
11 | from pyspark.sql import types as T
12 |
13 |
14 | class SimilarityMetrics:
15 | """
16 | Class to calculate similarity metrics for pairs of records. The `field_info` dict contains column names as keys
17 | and lists of similarity functions as values. E.g.
18 |
19 | field_info = {'name': [token_set_ratio, token_sort_ratio],
20 | 'postcode': [ratio]}
21 |
22 | where `token_set_ratio`, `token_sort_ratio` and `ratio` are string similarity functions that take two strings
23 | as arguments and return a numeric value
24 |
25 | Attributes:
26 | field_info: dict containing column names as keys and lists of similarity functions as values
27 | """
28 | def __init__(self, field_info: Dict):
29 | self.field_info = field_info
30 |
31 | @staticmethod
32 | def _distance_measure_pandas(strings_1: List[str], strings_2: List[str], func: Callable) -> pd.Series:
33 | """
34 | Helper function to apply a string similarity metric to two arrays of strings. To be used in a Pandas UDF.
35 |
36 | Args:
37 | strings_1: array containing strings
38 | strings_2: array containing strings
39 | func: string similarity function to be applied
40 |
41 | Returns:
42 | Pandas series containing string similarities
43 |
44 | """
45 | df = pd.DataFrame({'x': strings_1, 'y': strings_2})
46 | return df.apply(lambda row: func(row['x'], row['y']), axis=1)
47 |
48 | @staticmethod
49 | def _create_similarity_metric_udf(similarity_metric_function: Callable):
50 | """
51 | Function that created Pandas UDF for a given similarity_metric_function
52 |
53 | Args:
54 | similarity_metric_function: function that takes two strings and returns a number
55 |
56 | Returns:
57 | Pandas UDF
58 | """
59 |
60 | @F.pandas_udf(T.FloatType())
61 | def similarity_udf(strings_1: pd.Series, strings_2: pd.Series) -> pd.Series:
62 | # some similarity metrics cannot deal with empty strings, therefore these are replaced with " "
63 | strings_1 = [x if x != "" else " " for x in strings_1]
64 | strings_2 = [x if x != "" else " " for x in strings_2]
65 | return SimilarityMetrics._distance_measure_pandas(strings_1, strings_2, similarity_metric_function)
66 |
67 | return similarity_udf
68 |
69 | def _apply_distance_metrics(self) -> Column:
70 | """
71 | Function to apply all distance metrics in the right order to string pairs and returns them in an array. This
72 | array is used as input to the (logistic regression) scoring function.
73 |
74 | Returns:
75 | array with the string distance metrics
76 |
77 | """
78 | distance_metrics_list = []
79 | for field_name in self.field_info.keys():
80 | field_name_1, field_name_2 = field_name + "_1", field_name + "_2"
81 | for similarity_function in self.field_info[field_name]:
82 | similarity_metric = (
83 | SimilarityMetrics._create_similarity_metric_udf(similarity_function)(F.col(field_name_1),
84 | F.col(field_name_2)))
85 | distance_metrics_list.append(similarity_metric)
86 |
87 | array = F.array(*distance_metrics_list)
88 |
89 | return array
90 |
91 | def transform(self, pairs_table: DataFrame) -> DataFrame:
92 | """
93 | Method to apply similarity metrics to pairs table. Method makes use of method dispatching to facilitate both
94 | Pandas and Spark dataframes
95 |
96 | Args:
97 | pairs_table: Spark or Pandas dataframe containing pairs table
98 |
99 | Returns:
100 | Pandas or Spark dataframe with pairs table and newly created `similarity_metrics` column
101 |
102 | """
103 | return pairs_table.withColumn('similarity_metrics', self._apply_distance_metrics())
104 |
--------------------------------------------------------------------------------
/spark_matcher/table_checkpointer.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | import abc
6 | import os
7 |
8 | from pyspark.sql import SparkSession, DataFrame
9 |
10 |
11 | class TableCheckpointer(abc.ABC):
12 | """
13 | Args:
14 | spark_session: a spark session
15 | database: a name of a database or storage system where the tables can be saved
16 | checkpoint_prefix: a prefix of the name that can be used to save tables
17 | """
18 |
19 | def __init__(self, spark_session: SparkSession, database: str, checkpoint_prefix: str = "checkpoint_spark_matcher"):
20 | self.spark_session = spark_session
21 | self.database = database
22 | self.checkpoint_prefix = checkpoint_prefix
23 |
24 | def __call__(self, sdf: DataFrame, checkpoint_name: str):
25 | return self.checkpoint_table(sdf, checkpoint_name)
26 |
27 | @abc.abstractmethod
28 | def checkpoint_table(self, sdf: DataFrame, checkpoint_name: str):
29 | pass
30 |
31 |
32 | class HiveCheckpointer(TableCheckpointer):
33 | """
34 | Args:
35 | spark_session: a spark session
36 | database: a name of a database or storage system where the tables can be saved
37 | checkpoint_prefix: a prefix of the name that can be used to save tables
38 | """
39 | def __init__(self, spark_session: SparkSession, database: str, checkpoint_prefix: str = "checkpoint_spark_matcher"):
40 | super().__init__(spark_session, database, checkpoint_prefix)
41 |
42 | def checkpoint_table(self, sdf: DataFrame, checkpoint_name: str):
43 | """
44 | This method saves the input dataframes as checkpoints of the algorithm. This checkpointing can be
45 | used to store intermediary results that are needed throughout the algorithm. The tables are stored using the
46 | following name convention: `{checkpoint_prefix}_{checkpoint_name}`.
47 |
48 | Args:
49 | sdf: a Spark dataframe that needs to be saved as a checkpoint
50 | checkpoint_name: name of the table
51 |
52 | Returns:
53 | the same, unchanged, spark dataframe as the input dataframe. With the only difference that the
54 | dataframe is now read from disk as a checkpoint.
55 | """
56 | sdf.write.saveAsTable(f"{self.database}.{self.checkpoint_prefix}_{checkpoint_name}",
57 | mode="overwrite")
58 | sdf = self.spark_session.table(f"{self.database}.{self.checkpoint_prefix}_{checkpoint_name}")
59 | return sdf
60 |
61 |
62 | class ParquetCheckPointer(TableCheckpointer):
63 | """
64 | Args:
65 | spark_session: a spark session
66 | checkpoint_dir: directory where the tables can be saved
67 | checkpoint_prefix: a prefix of the name that can be used to save tables
68 | """
69 | def __init__(self, spark_session: SparkSession, checkpoint_dir: str,
70 | checkpoint_prefix: str = "checkpoint_spark_matcher"):
71 | super().__init__(spark_session, checkpoint_dir, checkpoint_prefix)
72 |
73 | def checkpoint_table(self, sdf: DataFrame, checkpoint_name: str):
74 | """
75 | This method saves the input dataframes as checkpoints of the algorithm. This checkpointing can be
76 | used to store intermediary results that are needed throughout the algorithm. The tables are stored
77 | using the
78 | following name convention: `{checkpoint_prefix}_{checkpoint_name}`.
79 |
80 | Args:
81 | sdf: a Spark dataframe that needs to be saved as a checkpoint
82 | checkpoint_name: name of the table
83 |
84 | Returns:
85 | the same, unchanged, spark dataframe as the input dataframe. With the only difference that the
86 | dataframe is now read from disk as a checkpoint.
87 | """
88 | file_name = os.path.join(f'{self.database}', f'{self.checkpoint_prefix}_{checkpoint_name}')
89 | sdf.write.parquet(file_name, mode='overwrite')
90 | return self.spark_session.read.parquet(file_name)
91 |
--------------------------------------------------------------------------------
/spark_matcher/utils.py:
--------------------------------------------------------------------------------
1 | # Authors: Ahmet Bayraktar
2 | # Stan Leisink
3 | # Frits Hermans
4 |
5 | from typing import List
6 |
7 | import numpy as np
8 | import pandas as pd
9 | from pyspark.ml.feature import StopWordsRemover
10 | from pyspark.sql import DataFrame
11 | from pyspark.sql import functions as F
12 |
13 |
14 | def get_most_frequent_words(sdf: DataFrame, col_name: str, min_df=2, top_n_words=1_000) -> pd.DataFrame:
15 | """
16 | Count word frequencies in a Spark dataframe `sdf` column named `col_name` and return a Pandas dataframe containing
17 | the document frequencies of the `top_n_words`. This function is intended to be used to create a list of stopwords.
18 |
19 | Args:
20 | sdf: Spark dataframe
21 | col_name: column name to get most frequent words from
22 | min_df: minimal document frequency for a word to be included as stopword, int for number and float for fraction
23 | top_n_words: number of most occurring words to include
24 |
25 | Returns:
26 | pandas dataframe containing most occurring words, their counts and document frequencies
27 |
28 | """
29 | sdf_col_splitted = sdf.withColumn(f'{col_name}_array', F.split(F.col(col_name), pattern=' '))
30 | word_count = sdf_col_splitted.select(F.explode(f'{col_name}_array').alias('words')).groupBy('words').count()
31 | doc_count = sdf.count()
32 | word_count = word_count.withColumn('df', F.col('count') / doc_count)
33 | if isinstance(min_df, int):
34 | min_count = min_df
35 | elif isinstance(min_df, float):
36 | min_count = np.ceil(min_df * doc_count)
37 | word_count_pdf = word_count.filter(F.col('count') >= min_count).sort(F.desc('df')).limit(top_n_words).toPandas()
38 | return word_count_pdf
39 |
40 |
41 | def remove_stopwords(sdf: DataFrame, col_name: str, stopwords: List[str], case=False,
42 | suffix: str = '_wo_stopwords') -> DataFrame:
43 | """
44 | Remove stopwords `stopwords` from a column `col_name` in a Spark dataframe `sdf`. The result will be written to a
45 | new column, named with the concatenation of `col_name` and `suffix`.
46 |
47 | Args:
48 | sdf: Spark dataframe
49 | col_name: column name to remove stopwords from
50 | stopwords: list of stopwords to remove
51 | case: whether to check for stopwords including lower- or uppercase
52 | suffix: suffix for the newly created column
53 |
54 | Returns:
55 | Spark dataframe with column added without stopwords
56 |
57 | """
58 | sdf = sdf.withColumn(f'{col_name}_array', F.split(F.col(col_name), pattern=' '))
59 | sw_remover = StopWordsRemover(inputCol=f'{col_name}_array', outputCol=f'{col_name}_array_wo_stopwords',
60 | stopWords=stopwords, caseSensitive=case)
61 | sdf = (sw_remover.transform(sdf)
62 | .withColumn(f'{col_name}{suffix}', F.concat_ws(' ', F.col(f'{col_name}_array_wo_stopwords')))
63 | .fillna({f'{col_name}{suffix}': ''})
64 | .drop(f'{col_name}_array', f'{col_name}_array_wo_stopwords')
65 | )
66 | return sdf
67 |
--------------------------------------------------------------------------------
/spark_requirements.txt:
--------------------------------------------------------------------------------
1 | thefuzz
2 | pandas
3 | scikit-learn
4 | pyarrow
5 |
--------------------------------------------------------------------------------
/test/test_active_learner/test_active_learner.py:
--------------------------------------------------------------------------------
1 | from spark_matcher.activelearner.active_learner import ScoringLearner
2 | from spark_matcher.scorer.scorer import Scorer
3 |
4 |
5 | def test__get_uncertainty_improvement(spark_session):
6 | scorer = Scorer(spark_session)
7 | myScoringLearner = ScoringLearner(col_names=[''], scorer=scorer, n_uncertainty_improvement=5)
8 | myScoringLearner.uncertainties = [0.4, 0.2, 0.1, 0.08, 0.05, 0.03]
9 | assert myScoringLearner._get_uncertainty_improvement() == 0.2
10 |
11 |
12 | def test__is_converged(spark_session):
13 | scorer = Scorer(spark_session)
14 | myScoringLearner = ScoringLearner(col_names=[''], scorer=scorer, min_nr_samples=5, uncertainty_threshold=0.1,
15 | uncertainty_improvement_threshold=0.01, n_uncertainty_improvement=5)
16 |
17 | # insufficient labelled samples
18 | myScoringLearner.uncertainties = [0.4, 0.39, 0.395]
19 | myScoringLearner.counter_total = len(myScoringLearner.uncertainties)
20 | assert not myScoringLearner._is_converged()
21 |
22 | # too large improvement in last 5 iterations
23 | myScoringLearner.uncertainties = [0.4, 0.2, 0.19, 0.18, 0.17, 0.16]
24 | myScoringLearner.counter_total = len(myScoringLearner.uncertainties)
25 | assert not myScoringLearner._is_converged()
26 |
27 | # improvement in last 5 iterations below threshold and sufficient labelled cases
28 | myScoringLearner.uncertainties = [0.19, 0.1, 0.08, 0.05, 0.03, 0.02]
29 | myScoringLearner.counter_total = len(myScoringLearner.uncertainties)
30 | assert myScoringLearner._is_converged()
31 |
--------------------------------------------------------------------------------
/test/test_blocker/test_block_learner.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import pandas as pd
3 | from pandas.testing import assert_frame_equal
4 | from pyspark.sql import functions as F, DataFrame
5 |
6 | from spark_matcher.blocker.blocking_rules import BlockingRule
7 | from spark_matcher.blocker.block_learner import BlockLearner
8 |
9 |
10 | @pytest.fixture
11 | def blocking_rule():
12 | class DummyBlockingRule(BlockingRule):
13 | def __repr__(self):
14 | return 'dummy_blocking_rule'
15 |
16 | def _blocking_rule(self, _):
17 | return F.lit('block_key')
18 | return DummyBlockingRule('blocking_column')
19 |
20 |
21 | @pytest.fixture
22 | def block_learner(blocking_rule, monkeypatch, table_checkpointer):
23 | monkeypatch.setattr(blocking_rule, "create_block_key", lambda _sdf: _sdf)
24 | return BlockLearner([blocking_rule, blocking_rule], 1.0, table_checkpointer)
25 |
26 |
27 | def test__create_blocks(spark_session, blocking_rule, block_learner, monkeypatch):
28 | monkeypatch.setattr(block_learner, "cover_blocking_rules", [blocking_rule, blocking_rule])
29 |
30 |
31 | sdf = spark_session.createDataFrame(
32 | pd.DataFrame({
33 | 'blocking_column': ['b', 'bb', 'bbb']
34 | })
35 | )
36 |
37 | result = block_learner._create_blocks(sdf)
38 |
39 | assert isinstance(result, list)
40 | assert isinstance(result[0], DataFrame)
41 |
42 |
43 | def test__create_block_table(spark_session, block_learner):
44 |
45 | block_1 = spark_session.createDataFrame(
46 | pd.DataFrame({'column_1': [1, 2, 3], 'column_2': [4, 5, 6]})
47 | )
48 | block_2 = spark_session.createDataFrame(
49 | pd.DataFrame({'column_1': [7, 8, 9], 'column_2': [10, 11, 12]})
50 | )
51 |
52 | input_blocks = [block_1, block_2]
53 |
54 | expected_result = pd.DataFrame({
55 | 'column_1': [1, 2, 3, 7, 8, 9], 'column_2': [4, 5, 6, 10, 11, 12]
56 | })
57 |
58 | result = (
59 | block_learner
60 | ._create_block_table(input_blocks)
61 | .toPandas()
62 | .sort_values(by='column_1')
63 | .reset_index(drop=True)
64 | )
65 | assert_frame_equal(result, expected_result)
66 |
67 |
68 |
69 | def test__greedy_set_coverage(blocking_rule, block_learner):
70 | import copy
71 |
72 | full_set = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
73 | block_learner.full_set = full_set
74 | block_learner.full_set_size = 10
75 |
76 | bl1 = copy.copy(blocking_rule)
77 | bl1.training_coverage = {0, 1, 2, 3}
78 | bl1.training_coverage_size = 4
79 |
80 | bl2 = copy.copy(blocking_rule)
81 | bl2.training_coverage = {3, 4, 5, 6, 7, 8}
82 | bl2.training_coverage_size = 6
83 |
84 | bl3 = copy.copy(blocking_rule)
85 | bl3.training_coverage = {3, 4, 5, 6, 7, 8}
86 | bl3.training_coverage_size = 5
87 |
88 | bl4 = copy.copy(blocking_rule)
89 | bl4.training_coverage = {8, 9}
90 | bl4.training_coverage_size = 2
91 |
92 | block_learner.blocking_rules = [bl1, bl2, bl3, bl4]
93 |
94 | # case 1: recall = 0.9
95 | block_learner.recall = 0.9
96 | block_learner._greedy_set_coverage()
97 |
98 | assert block_learner.cover_set == {0, 1, 2, 3, 4, 5, 6, 7, 8}
99 |
100 | # case 2: recall = 1.0
101 | block_learner.recall = 1.0
102 | block_learner._greedy_set_coverage()
103 |
104 | assert block_learner.cover_set == full_set
105 |
106 | # case 2: recall = 0.5
107 | block_learner.recall = 0.5
108 | block_learner._greedy_set_coverage()
109 |
110 | assert block_learner.cover_set == {3, 4, 5, 6, 7, 8}
111 |
--------------------------------------------------------------------------------
/test/test_blocker/test_blocking_rules.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import pandas as pd
3 | from pandas.testing import assert_frame_equal
4 | from pyspark.sql import functions as F
5 |
6 | from spark_matcher.blocker.blocking_rules import BlockingRule, FirstNChars, LastNChars, WholeField, FirstNWords, \
7 | FirstNLettersNoSpace, SortedIntegers, FirstInteger, LastInteger, LargestInteger, NLetterAbbreviation, \
8 | FirstNCharsLastWord, FirstNCharactersFirstTokenSorted
9 |
10 |
11 | @pytest.fixture
12 | def blocking_rule():
13 | class DummyBlockingRule(BlockingRule):
14 | def __repr__(self):
15 | return 'dummy_blocking_rule'
16 |
17 | def _blocking_rule(self, _):
18 | return F.lit('block_key')
19 | return DummyBlockingRule('blocking_column')
20 |
21 |
22 | def test_create_block_key(spark_session, blocking_rule):
23 | input_sdf = spark_session.createDataFrame(
24 | pd.DataFrame({
25 | 'sort_column': [1, 2, 3],
26 | 'blocking_column': ['a', 'aa', 'aaa']}))
27 |
28 | expected_result = pd.DataFrame({
29 | 'sort_column': [1, 2, 3],
30 | 'blocking_column': ['a', 'aa', 'aaa'],
31 | 'block_key': ['dummy_blocking_rule:block_key', 'dummy_blocking_rule:block_key', 'dummy_blocking_rule:block_key']
32 | })
33 |
34 | result = (
35 | blocking_rule
36 | .create_block_key(input_sdf)
37 | .toPandas()
38 | .sort_values(by='sort_column')
39 | .reset_index(drop=True)
40 | )
41 | assert_frame_equal(result, expected_result)
42 |
43 |
44 | def test__create_training_block_keys(spark_session, blocking_rule):
45 | input_sdf = spark_session.createDataFrame(
46 | pd.DataFrame({
47 | 'sort_column': [1, 2, 3],
48 | 'blocking_column_1': ['a', 'aa', 'aaa'],
49 | 'blocking_column_2': ['b', 'bb', 'bbb']}))
50 |
51 | expected_result = pd.DataFrame({
52 | 'sort_column': [1, 2, 3],
53 | 'blocking_column_1': ['a', 'aa', 'aaa'],
54 | 'blocking_column_2': ['b', 'bb', 'bbb'],
55 | 'block_key_1': ['dummy_blocking_rule:block_key', 'dummy_blocking_rule:block_key',
56 | 'dummy_blocking_rule:block_key'],
57 | 'block_key_2': ['dummy_blocking_rule:block_key', 'dummy_blocking_rule:block_key',
58 | 'dummy_blocking_rule:block_key']
59 | })
60 |
61 | result = (
62 | blocking_rule
63 | ._create_training_block_keys(input_sdf)
64 | .toPandas()
65 | .sort_values(by='sort_column')
66 | .reset_index(drop=True)
67 | )
68 | assert_frame_equal(result, expected_result)
69 |
70 |
71 | def test__compare_and_filter_keys(spark_session, blocking_rule):
72 | input_sdf = spark_session.createDataFrame(
73 | pd.DataFrame({
74 | 'sort_column': [1, 2, 3, 4, 5, 6],
75 | 'block_key_1': ['blocking_rule:a', 'blocking_rule:a:a', 'blocking_rule:a:', 'blocking_rule:aaa',
76 | 'blocking_rule:', 'blocking_rule:'],
77 | 'block_key_2': ['blocking_rule:a', 'blocking_rule:a:a', 'blocking_rule:a:', 'blocking_rule:bbb',
78 | 'blocking_rule:', 'blocking_rule:bb']}))
79 |
80 | expected_result = pd.DataFrame({
81 | 'sort_column': [1, 2, 3],
82 | 'block_key_1': ['blocking_rule:a', 'blocking_rule:a:a', 'blocking_rule:a:'],
83 | 'block_key_2': ['blocking_rule:a', 'blocking_rule:a:a', 'blocking_rule:a:']})
84 |
85 | result = (
86 | blocking_rule
87 | ._compare_and_filter_keys(input_sdf)
88 | .toPandas()
89 | .sort_values(by='sort_column')
90 | .reset_index(drop=True)
91 | )
92 | print(input_sdf.show())
93 | print('res')
94 | print(result)
95 | print('exp')
96 | print(expected_result)
97 | assert_frame_equal(result, expected_result)
98 |
99 |
100 | def test_calculate_training_set_coverage(spark_session, blocking_rule, monkeypatch):
101 | # monkey_patch the inner functions of `calculate_training_set_coverage`, since they are tested separately
102 | monkeypatch.setattr(blocking_rule, "_create_training_block_keys", lambda _sdf: _sdf)
103 | monkeypatch.setattr(blocking_rule, "_compare_and_filter_keys", lambda _sdf: _sdf)
104 |
105 | row_ids = [0, 1, 2, 3, 4, 5]
106 | input_sdf = spark_session.createDataFrame(
107 | pd.DataFrame({'row_id': row_ids}))
108 |
109 | expected_result = set(row_ids)
110 |
111 | blocking_rule.calculate_training_set_coverage(input_sdf)
112 |
113 | assert blocking_rule.training_coverage == expected_result
114 | assert blocking_rule.training_coverage_size == len(expected_result)
115 |
116 |
117 | @pytest.mark.parametrize('n', [2, 3])
118 | def test_firstnchars(spark_session, n):
119 | if n == 2:
120 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark', 'aa bb'],
121 | 'block_key': ["first_2_characters_name:aa",
122 | "first_2_characters_name:sp",
123 | "first_2_characters_name:aa"]})
124 | elif n == 3:
125 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark', 'aa bb'],
126 | 'block_key': [None,
127 | "first_3_characters_name:spa",
128 | "first_3_characters_name:aa "]})
129 |
130 | sdf = spark_session.createDataFrame(expected_result_pdf)
131 | result_pdf = FirstNChars('name', n).create_block_key(sdf.select('name')).toPandas()
132 | assert_frame_equal(expected_result_pdf, result_pdf)
133 |
134 |
135 | @pytest.mark.parametrize('n', [2, 3])
136 | def test_firstncharslastword(spark_session, n):
137 | if n == 2:
138 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'aa spark', 'aa bb'],
139 | 'block_key': ["first_2_characters_last_word_name:aa",
140 | "first_2_characters_last_word_name:sp",
141 | "first_2_characters_last_word_name:bb"]})
142 | elif n == 3:
143 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'aa spark', 'aa bbb'],
144 | 'block_key': [None,
145 | "first_3_characters_last_word_name:spa",
146 | "first_3_characters_last_word_name:bbb"]})
147 |
148 | sdf = spark_session.createDataFrame(expected_result_pdf)
149 | result_pdf = FirstNCharsLastWord('name', n).create_block_key(sdf.select('name')).toPandas()
150 | assert_frame_equal(expected_result_pdf, result_pdf)
151 |
152 |
153 | @pytest.mark.parametrize('n', [2, 3])
154 | def test_firstncharactersfirsttokensorted(spark_session, n):
155 | if n == 2:
156 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark aa', 'aa bb', 'a bb'],
157 | 'block_key': ["first_2_characters_first_token_sorted_name:aa",
158 | "first_2_characters_first_token_sorted_name:aa",
159 | "first_2_characters_first_token_sorted_name:aa",
160 | "first_2_characters_first_token_sorted_name:bb"]})
161 | elif n == 3:
162 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark aaa', 'bbb aa'],
163 | 'block_key': [None,
164 | "first_3_characters_first_token_sorted_name:aaa",
165 | "first_3_characters_first_token_sorted_name:bbb"]})
166 |
167 | sdf = spark_session.createDataFrame(expected_result_pdf)
168 | result_pdf = FirstNCharactersFirstTokenSorted('name', n).create_block_key(sdf.select('name')).toPandas()
169 | assert_frame_equal(expected_result_pdf, result_pdf)
170 |
171 |
172 | @pytest.mark.parametrize('n', [2, 3])
173 | def test_lastnchars(spark_session, n):
174 | if n == 2:
175 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark', 'aa bb'],
176 | 'block_key': ["last_2_characters_name:aa",
177 | "last_2_characters_name:rk",
178 | "last_2_characters_name:bb"]})
179 | elif n == 3:
180 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark', 'aa bb'],
181 | 'block_key': [None,
182 | "last_3_characters_name:ark",
183 | "last_3_characters_name: bb"]})
184 |
185 | sdf = spark_session.createDataFrame(expected_result_pdf)
186 | result_pdf = LastNChars('name', n).create_block_key(sdf.select('name')).toPandas()
187 | assert_frame_equal(expected_result_pdf, result_pdf)
188 |
189 |
190 | def test_wholefield(spark_session):
191 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python pyspark'],
192 | 'block_key': [f"whole_field_name:python",
193 | f"whole_field_name:python pyspark"]})
194 |
195 | sdf = spark_session.createDataFrame(expected_result_pdf)
196 | result_pdf = WholeField('name').create_block_key(sdf.select('name')).toPandas()
197 | assert_frame_equal(expected_result_pdf, result_pdf)
198 |
199 |
200 | @pytest.mark.parametrize('n', [2, 3])
201 | def test_firstnwords(spark_session, n):
202 | if n == 2:
203 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python pyspark'],
204 | 'block_key': [None,
205 | "first_2_words_name:python pyspark"]})
206 | elif n == 3:
207 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python py spark'],
208 | 'block_key': [None, "first_3_words_name:python py spark"]})
209 |
210 | sdf = spark_session.createDataFrame(expected_result_pdf)
211 | result_pdf = FirstNWords('name', n).create_block_key(sdf.select('name')).toPandas()
212 | assert_frame_equal(expected_result_pdf, result_pdf)
213 |
214 |
215 | @pytest.mark.parametrize('n', [2, 3])
216 | def test_firstnlettersnospace(spark_session, n):
217 | if n == 2:
218 | expected_result_pdf = pd.DataFrame({'name': ['python', 'p ython'],
219 | 'block_key': ["first_2_letters_name_no_space:py",
220 | "first_2_letters_name_no_space:py"]})
221 | elif n == 3:
222 | expected_result_pdf = pd.DataFrame({'name': ['p y', 'python', 'p y thon'],
223 | 'block_key': [None,
224 | "first_3_letters_name_no_space:pyt",
225 | "first_3_letters_name_no_space:pyt"]})
226 |
227 | sdf = spark_session.createDataFrame(expected_result_pdf)
228 | result_pdf = FirstNLettersNoSpace('name', n).create_block_key(sdf.select('name')).toPandas()
229 | assert_frame_equal(expected_result_pdf, result_pdf)
230 |
231 |
232 | def test_sortedintegers(spark_session):
233 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python 2 1'],
234 | 'block_key': [None,
235 | "sorted_integers_name:1 2"]})
236 |
237 | sdf = spark_session.createDataFrame(expected_result_pdf)
238 | result_pdf = SortedIntegers('name').create_block_key(sdf.select('name')).toPandas()
239 | assert_frame_equal(expected_result_pdf, result_pdf)
240 |
241 |
242 | def test_firstinteger(spark_session):
243 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python 2 1'],
244 | 'block_key': [None,
245 | "first_integer_name:2"]})
246 |
247 | sdf = spark_session.createDataFrame(expected_result_pdf)
248 | result_pdf = FirstInteger('name').create_block_key(sdf.select('name')).toPandas()
249 | assert_frame_equal(expected_result_pdf, result_pdf)
250 |
251 |
252 | def test_lastinteger(spark_session):
253 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python 2 1'],
254 | 'block_key': [None,
255 | "last_integer_name:1"]})
256 |
257 | sdf = spark_session.createDataFrame(expected_result_pdf)
258 | result_pdf = LastInteger('name').create_block_key(sdf.select('name')).toPandas()
259 | assert_frame_equal(expected_result_pdf, result_pdf)
260 |
261 |
262 | def test_largestinteger(spark_session):
263 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python1', 'python 2 1'],
264 | 'block_key': [None,
265 | 'largest_integer_name:1',
266 | "largest_integer_name:2"]})
267 |
268 | sdf = spark_session.createDataFrame(expected_result_pdf)
269 | result_pdf = LargestInteger('name').create_block_key(sdf.select('name')).toPandas()
270 | assert_frame_equal(expected_result_pdf, result_pdf)
271 |
272 |
273 | @pytest.mark.parametrize('n', [2, 3])
274 | def test_nletterabbreviation(spark_session, n):
275 | if n == 2:
276 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python pyspark'],
277 | 'block_key': [None,
278 | "2_letter_abbreviation_name:pp"]})
279 | elif n == 3:
280 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python pyspark', 'python apache pyspark'],
281 | 'block_key': [None,
282 | None,
283 | "3_letter_abbreviation_name:pap"]})
284 | sdf = spark_session.createDataFrame(expected_result_pdf)
285 | result_pdf = NLetterAbbreviation('name', n).create_block_key(sdf.select('name')).toPandas()
286 | assert_frame_equal(expected_result_pdf, result_pdf)
287 |
--------------------------------------------------------------------------------
/test/test_deduplicator/test_deduplicator.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import pandas as pd
3 | from pandas.testing import assert_frame_equal
4 | from pyspark.sql import functions as F
5 |
6 | from spark_matcher.deduplicator import Deduplicator
7 |
8 |
9 | @pytest.fixture()
10 | def deduplicator(spark_session):
11 | return Deduplicator(spark_session, col_names=[], checkpoint_dir='mock_db')
12 |
13 |
14 | def test__create_predict_pairs_table(spark_session, deduplicator):
15 | deduplicator.col_names = ['name', 'address']
16 |
17 | input_df = pd.DataFrame(
18 | {
19 | 'name': ['John Doe', 'Jane Doe', 'Daniel Jacks', 'Jack Sparrow', 'Donald Trump'],
20 | 'address': ['1 Square', '2 Square', '3 Main street', '4 Harbour', '5 white house'],
21 | 'block_key': ['J', 'J', 'D', 'J', 'D'],
22 | 'row_number': [1, 2, 3, 4, 5],
23 | }
24 | )
25 |
26 | expected_df = pd.DataFrame({
27 | 'block_key': ['J', 'J', 'J', 'D'],
28 | 'name_1': ['John Doe', 'John Doe', 'Jane Doe', 'Daniel Jacks'],
29 | 'address_1': ['1 Square', '1 Square', '2 Square', '3 Main street'],
30 | 'row_number_1': [1, 1, 2, 3],
31 | 'name_2': ['Jane Doe', 'Jack Sparrow', 'Jack Sparrow', 'Donald Trump'],
32 | 'address_2': ['2 Square', '4 Harbour', '4 Harbour', '5 white house'],
33 | 'row_number_2': [2, 4, 4, 5]
34 | })
35 |
36 | result = (
37 | deduplicator
38 | ._create_predict_pairs_table(spark_session.createDataFrame(input_df))
39 | .toPandas()
40 | .sort_values(by=['address_1', 'address_2'])
41 | .reset_index(drop=True)
42 | )
43 | assert_frame_equal(result, expected_df)
44 |
45 |
46 | def test__add_singletons_entity_identifiers(spark_session, deduplicator) -> None:
47 | df = pd.DataFrame(
48 | data={
49 | "entity_identifier": [0, 1, None, 2, 3, 4, 5, 6],
50 | "row_number": [10, 11, 12, 13, 14, 15, 16, 17]})
51 |
52 | input_data = (
53 | spark_session
54 | .createDataFrame(df)
55 | .withColumn('entity_identifier',
56 | F.when(F.isnan('entity_identifier'), F.lit(None)).otherwise(F.col('entity_identifier')))
57 | .repartition(6))
58 |
59 | expected_df = pd.DataFrame(
60 | data={
61 | "row_number": [10, 11, 12, 13, 14, 15, 16, 17],
62 | "entity_identifier": [0.0, 1.0, 7.0, 2.0, 3.0, 4.0, 5.0, 6.0]})
63 |
64 | result = (deduplicator._add_singletons_entity_identifiers(input_data)
65 | .toPandas()
66 | .sort_values(by='row_number')
67 | .reset_index(drop=True))
68 | assert_frame_equal(result, expected_df, check_dtype=False)
69 |
--------------------------------------------------------------------------------
/test/test_deduplicator/test_hierarchical_clustering.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from pandas.util.testing import assert_frame_equal
4 | from pyspark.sql import SparkSession
5 |
6 |
7 | def test__convert_data_to_adjacency_matrix():
8 | from spark_matcher.deduplicator.hierarchical_clustering import _convert_data_to_adjacency_matrix
9 |
10 | input_df = pd.DataFrame({
11 | 'row_number_1': [1, 1],
12 | 'row_number_2': [2, 3],
13 | 'score': [0.8, 0.3],
14 | 'component_id': [1, 1]})
15 |
16 | expected_indexes = np.array([1, 2, 3])
17 | expected_result = np.array([[0., 0.8, 0.3],
18 | [0.8, 0., 0.],
19 | [0.3, 0., 0.]])
20 |
21 | indexes, result = _convert_data_to_adjacency_matrix(input_df)
22 |
23 | np.testing.assert_array_equal(result, expected_result)
24 | np.testing.assert_array_equal(np.array(indexes), expected_indexes)
25 |
26 |
27 | def test__get_condensed_distances():
28 | from spark_matcher.deduplicator.hierarchical_clustering import _get_condensed_distances
29 |
30 | input_df = pd.DataFrame({
31 | 'row_number_1': [1, 1],
32 | 'row_number_2': [2, 3],
33 | 'score': [0.8, 0.3],
34 | 'component_id': [1, 1]})
35 |
36 | expected_dict = {0: 1, 1: 2, 2: 3}
37 | expected_result = np.array([0.2, 0.7, 1.])
38 |
39 | result_dict, result = _get_condensed_distances(input_df)
40 |
41 | np.testing.assert_array_almost_equal_nulp(result, expected_result, nulp=2)
42 | np.testing.assert_array_equal(result_dict, expected_dict)
43 |
44 |
45 | def test__convert_dedupe_result_to_pandas_dataframe(spark_session: SparkSession) -> None:
46 | from spark_matcher.deduplicator.hierarchical_clustering import _convert_dedupe_result_to_pandas_dataframe
47 |
48 | # case 1: with no empty data
49 | inputs = [
50 | ((1, 3), np.array([0.96, 0.96])),
51 | ((4, 7, 8), np.array([0.95, 0.95, 0.95])),
52 | ((5, 6), np.array([0.98, 0.98]))]
53 | component_id = 112233
54 |
55 | expected_df = pd.DataFrame(data={
56 | 'row_number': [1, 3, 4, 7, 8, 5, 6],
57 | 'entity_identifier': [f"{component_id}_0", f"{component_id}_0",
58 | f"{component_id}_1", f"{component_id}_1", f"{component_id}_1",
59 | f"{component_id}_2", f"{component_id}_2"]})
60 | result = _convert_dedupe_result_to_pandas_dataframe(inputs, component_id).reset_index(drop=True)
61 |
62 | assert_frame_equal(result, expected_df, check_dtype=False)
63 |
64 | # case 2: with empty data
65 | inputs = []
66 | component_id = 12345
67 |
68 | expected_df = pd.DataFrame(data={}, columns=['row_number', 'entity_identifier']).reset_index(drop=True)
69 | result = _convert_dedupe_result_to_pandas_dataframe(inputs, component_id).reset_index(drop=True)
70 |
71 | assert_frame_equal(result, expected_df, check_dtype=False)
72 |
--------------------------------------------------------------------------------
/test/test_matcher/test_matcher.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import pandas as pd
3 | from thefuzz.fuzz import ratio
4 | from pyspark.sql import DataFrame
5 |
6 | from spark_matcher.blocker.block_learner import BlockLearner
7 | from spark_matcher.blocker.blocking_rules import FirstNChars
8 | from spark_matcher.scorer.scorer import Scorer
9 | from spark_matcher.matcher import Matcher
10 |
11 |
12 | @pytest.fixture
13 | def matcher(spark_session, table_checkpointer):
14 | block_name = FirstNChars('name', 3)
15 | block_address = FirstNChars('address', 3)
16 | blocklearner = BlockLearner([block_name, block_address], recall=1, table_checkpointer=table_checkpointer)
17 | blocklearner.fitted = True
18 | blocklearner.cover_blocking_rules = [block_name, block_address]
19 |
20 | myScorer = Scorer(spark_session)
21 | fit_df = pd.DataFrame({'similarity_metrics': [[100, 80], [50, 50]], 'y': [1, 0]})
22 | myScorer.fit(fit_df['similarity_metrics'].values.tolist(), fit_df['y'])
23 |
24 | myMatcher = Matcher(spark_session, table_checkpointer=table_checkpointer,
25 | field_info={'name': [ratio], 'address': [ratio]})
26 | myMatcher.blocker = blocklearner
27 | myMatcher.scorer = myScorer
28 | myMatcher.fitted_ = True
29 |
30 | return myMatcher
31 |
32 |
33 | def test__create_predict_pairs_table(spark_session, table_checkpointer):
34 | sdf_1_blocked = spark_session.createDataFrame(
35 | pd.DataFrame({'name': ['frits', 'stan', 'ahmet'], 'address': ['dam 1', 'leidseplein 2', 'rokin 3'],
36 | 'block_key': ['fr', 'st', 'ah']}))
37 | sdf_2_blocked = spark_session.createDataFrame(
38 | pd.DataFrame({'name': ['frits', 'fred', 'glenn'],
39 | 'address': ['amsterdam', 'waterlooplein 4', 'rembrandtplein 5'],
40 | 'block_key': ['fr', 'fr', 'gl']}))
41 | myMatcher = Matcher(spark_session, table_checkpointer=table_checkpointer, col_names=['name', 'address'])
42 | result = myMatcher._create_predict_pairs_table(sdf_1_blocked, sdf_2_blocked)
43 | assert isinstance(result, DataFrame)
44 | assert result.count() == 2
45 | assert result.select('block_key').drop_duplicates().count() == 1
46 | assert set(result.columns) == {'block_key', 'name_2', 'address_2', 'name_1', 'address_1'}
47 |
48 | def test_predict(spark_session, matcher):
49 | sdf_1 = spark_session.createDataFrame(pd.DataFrame({'name': ['frits', 'stan', 'ahmet', 'ahmet', 'ahmet'],
50 | 'address': ['damrak', 'leidseplein', 'waterlooplein',
51 | 'amstel', 'amstel 3']}))
52 |
53 | sdf_2 = spark_session.createDataFrame(
54 | pd.DataFrame({'name': ['frits h', 'stan l', 'bayraktar', 'ahmet', 'ahmet'],
55 | 'address': ['damrak 1', 'leidseplein 2', 'amstel 3', 'amstel 3', 'waterlooplein 12324']}))
56 |
57 | result = matcher.predict(sdf_1, sdf_2)
58 | assert result
59 |
--------------------------------------------------------------------------------
/test/test_matching_base/test_matching_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | from string import ascii_lowercase
3 |
4 | import numpy as np
5 | import pandas as pd
6 | import pytest
7 | from pandas.testing import assert_frame_equal
8 | from thefuzz.fuzz import token_set_ratio
9 |
10 | from spark_matcher.blocker.blocking_rules import BlockingRule
11 | from spark_matcher.matching_base.matching_base import MatchingBase
12 |
13 |
14 | def create_fake_string(length=2):
15 | return "".join(np.random.choice(list(ascii_lowercase), size=length))
16 |
17 |
18 | def create_fake_df(size=10):
19 | return pd.DataFrame.from_dict({'name': [create_fake_string(1) + ' ' + create_fake_string(2) for _ in range(size)],
20 | 'address': [create_fake_string(2) + ' ' + create_fake_string(3) for _ in range(size)]})
21 |
22 |
23 | @pytest.fixture()
24 | def matching_base(spark_session, table_checkpointer):
25 | return MatchingBase(spark_session, table_checkpointer=table_checkpointer, col_names=['name', 'address'])
26 |
27 |
28 | def test__create_train_pairs_table(spark_session, matching_base):
29 | n_1, n_2, n_train_samples, n_perfect_train_matches = (1000, 100, 5000, 1)
30 | sdf_1 = spark_session.createDataFrame(create_fake_df(size=n_1))
31 | sdf_2 = spark_session.createDataFrame(create_fake_df(size=n_2))
32 |
33 | matching_base.n_train_samples = n_train_samples
34 | matching_base.n_perfect_matches = n_perfect_train_matches
35 |
36 | result = matching_base._create_train_pairs_table(sdf_1, sdf_2)
37 |
38 | assert set(result.columns) == {'name_2', 'address_2', 'name_1', 'address_1', 'perfect_train_match'}
39 | assert result.count() == pytest.approx(n_train_samples, abs=n_perfect_train_matches)
40 |
41 |
42 | def test_default_blocking_rules(matching_base):
43 | assert isinstance(matching_base.blocking_rules, list)
44 | assert isinstance(matching_base.blocking_rules[0], BlockingRule)
45 |
46 |
47 | def test_load_save(spark_session, tmpdir, matching_base, table_checkpointer):
48 | def my_metric(x, y):
49 | return float(x == y)
50 |
51 | matching_base.field_info = {'name': [token_set_ratio, my_metric]}
52 | matching_base.n_perfect_train_matches = 5
53 | matching_base.save(os.path.join(tmpdir, 'matcher.pkl'))
54 |
55 | myMatcherLoaded = MatchingBase(spark_session, table_checkpointer=table_checkpointer, col_names=['nothing'])
56 | myMatcherLoaded.load(os.path.join(tmpdir, 'matcher.pkl'))
57 | setattr(table_checkpointer, 'spark_session', spark_session) # needed to be able to continue with other unit tests
58 |
59 | assert matching_base.col_names == myMatcherLoaded.col_names
60 | assert [x.__name__ for x in matching_base.field_info['name']] == [x.__name__ for x in
61 | myMatcherLoaded.field_info['name']]
62 | assert matching_base.n_perfect_train_matches == myMatcherLoaded.n_perfect_train_matches
63 |
64 |
65 | @pytest.mark.parametrize('suffix', [1, 2, 3])
66 | def test__add_suffix_to_col_names(spark_session, suffix, matching_base):
67 | input_df = pd.DataFrame({
68 | 'name': ['John Doe', 'Jane Doe', 'Chris Rock', 'Jack Sparrow'],
69 | 'address': ['Square 1', 'Square 2', 'Main street', 'Harbour 123'],
70 | 'block_key': ['1', '2', '3', '3']
71 | })
72 |
73 | result = (
74 | matching_base
75 | ._add_suffix_to_col_names(spark_session.createDataFrame(input_df), suffix)
76 | .toPandas()
77 | .reset_index(drop=True)
78 | )
79 | assert_frame_equal(result, input_df.rename(columns={'name': f'name_{suffix}', 'address': f'address_{suffix}'}))
80 |
--------------------------------------------------------------------------------
/test/test_sampler/test_training_sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import pytest
4 | from pandas.testing import assert_frame_equal
5 | from string import ascii_lowercase
6 |
7 | from spark_matcher.sampler.training_sampler import HashSampler, RandomSampler
8 |
9 |
10 | def create_fake_string(length=2):
11 | return "".join(np.random.choice(list(ascii_lowercase), size=length))
12 |
13 |
14 | def create_fake_df(size=10):
15 | return pd.DataFrame.from_dict({'name': [create_fake_string(1) + ' ' + create_fake_string(2) for _ in range(size)],
16 | 'surname': [create_fake_string(2) + ' ' + create_fake_string(3) for _ in range(size)]})
17 |
18 |
19 | @pytest.mark.parametrize('test_case', [(100_000, 10, 10_000), (100, 100_000, 10_000), (100_000, 100_000, 10_000)])
20 | def test__create_random_pairs_table_matcher(spark_session, test_case, table_checkpointer):
21 | n_1, n_2, n_train_samples = test_case
22 | sdf_1 = spark_session.createDataFrame(create_fake_df(size=n_1))
23 | sdf_2 = spark_session.createDataFrame(create_fake_df(size=n_2))
24 |
25 | rSampler = RandomSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=n_train_samples)
26 |
27 | result = rSampler._create_pairs_table_matcher(sdf_1, sdf_2)
28 |
29 | assert set(result.columns) == {'name_2', 'surname_2', 'name_1', 'surname_1'}
30 | assert result.count() == pytest.approx(n_train_samples, abs=1)
31 | if n_1 > n_2:
32 | assert result.select('name_1', 'surname_1').drop_duplicates().count() == pytest.approx(
33 | n_train_samples / n_2, abs=1)
34 | assert result.select('name_2', 'surname_2').drop_duplicates().count() == n_2
35 | elif n_1 < n_2:
36 | assert result.select('name_1', 'surname_1').drop_duplicates().count() == pytest.approx(n_1, abs=1)
37 | assert result.select('name_2', 'surname_2').drop_duplicates().count() == pytest.approx(
38 | n_train_samples / n_1, abs=1)
39 | elif n_1 == n_2:
40 | assert result.select('name_2', 'surname_2').drop_duplicates().count() == pytest.approx(
41 | int(n_train_samples ** 0.5), abs=1)
42 | assert result.select('name_1', 'surname_1').drop_duplicates().count() == pytest.approx(
43 | int(n_train_samples ** 0.5), abs=1)
44 |
45 |
46 | @pytest.mark.parametrize('test_case', [(100_000, 10_000), (1_000, 10_000), (1_000, 1_000)])
47 | def test__create_random_pairs_table_deduplicator(spark_session, test_case, table_checkpointer):
48 | n_rows, n_train_samples = test_case
49 | sdf = spark_session.createDataFrame(create_fake_df(size=n_rows))
50 |
51 | rSampler = RandomSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=n_train_samples)
52 |
53 | result = rSampler._create_pairs_table_deduplicator(sdf)
54 |
55 | assert set(result.columns) == {'name_2', 'surname_2', 'name_1', 'surname_1'}
56 | assert result.count() == n_train_samples
57 |
58 |
59 | def test__create_hashed_pairs_table_matcher(spark_session, table_checkpointer):
60 | n_1, n_2, n_train_samples = (1000, 1000, 1000)
61 | sdf_1 = spark_session.createDataFrame(create_fake_df(size=n_1))
62 | sdf_2 = spark_session.createDataFrame(create_fake_df(size=n_2))
63 |
64 | hSampler = HashSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=n_train_samples, threshold=1)
65 |
66 | result = hSampler.create_pairs_table(sdf_1, sdf_2)
67 |
68 | assert set(result.columns) == {'name_2', 'surname_2', 'name_1', 'surname_1'}
69 | assert result.count() == n_train_samples
70 |
71 |
72 | def test__create_hashed_pairs_table_deduplicator(spark_session, table_checkpointer):
73 | n_1, n_train_samples = (1000, 1000)
74 | sdf_1 = spark_session.createDataFrame(create_fake_df(size=n_1))
75 |
76 | hSampler = HashSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=n_train_samples, threshold=1)
77 |
78 | result = hSampler.create_pairs_table(sdf_1)
79 |
80 | assert set(result.columns) == {'name_2', 'surname_2', 'name_1', 'surname_1'}
81 | assert result.count() == n_train_samples
82 |
83 |
84 | def test__vectorize_matcher(spark_session, table_checkpointer):
85 | input_sdf_1 = spark_session.createDataFrame(
86 | pd.DataFrame({
87 | 'name': ['aa bb cc', 'cc dd'],
88 | 'surname': ['xx yy', 'yy zz xx']
89 | }))
90 |
91 | input_sdf_2 = spark_session.createDataFrame(
92 | pd.DataFrame({
93 | 'name': ['bb cc'],
94 | 'surname': ['yy']
95 | }))
96 |
97 | hSampler = HashSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=999)
98 |
99 | # set max_df to a large value because otherwise nothing would be tokenized
100 | result_sdf_1, result_sdf_2 = hSampler._vectorize(['name', 'surname'], input_sdf_1, input_sdf_2, max_df=10)
101 |
102 | result_df_1 = result_sdf_1.toPandas()
103 | result_df_2 = result_sdf_2.toPandas()
104 |
105 | assert set(result_df_1.columns) == {'name', 'surname', 'features'}
106 | assert set(result_df_2.columns) == {'name', 'surname', 'features'}
107 | assert result_df_1.shape[0] == 2
108 | assert result_df_2.shape[0] == 1
109 | assert len(result_df_1.features[0]) == 4 # note that not all tokens occur at least twice (minDF=2)
110 | assert len(result_df_2.features[0]) == 4 # note that not all tokens occur at least twice (minDF=2)
111 |
112 |
113 | def test__vectorize_deduplicator(spark_session, table_checkpointer):
114 | input_sdf = spark_session.createDataFrame(
115 | pd.DataFrame({
116 | 'name': ['aa bb cc', 'cc dd', 'bb cc'],
117 | 'surname': ['xx yy', 'yy zz xx', 'yy zz']
118 | }))
119 |
120 | hSampler = HashSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=999)
121 |
122 | # set max_df to a large value because otherwise nothing would be tokenized
123 | result_sdf = hSampler._vectorize(['name', 'surname'], input_sdf, max_df=10)
124 |
125 | result_df = result_sdf.toPandas()
126 |
127 | assert set(result_df.columns) == {'name', 'surname', 'features'}
128 | assert result_df.shape[0] == 3
129 | assert len(result_df.features[0]) == 5 # note that not all tokens occur at least twice (minDF=2)
130 |
--------------------------------------------------------------------------------
/test/test_scorer/test_scorer.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import pandas as pd
3 | import numpy as np
4 | from pyspark.sql import Column
5 | from sklearn.base import BaseEstimator
6 |
7 | from spark_matcher.scorer.scorer import Scorer
8 |
9 |
10 | @pytest.fixture
11 | def scorer(spark_session):
12 | class DummyScorer(Scorer):
13 | def __init__(self, spark_session=spark_session):
14 | super().__init__(spark_session)
15 |
16 | def _predict_proba(self, X):
17 | return np.array([[0, 1]])
18 |
19 | def _predict_proba_spark(self, X):
20 | return spark_session.createDataFrame(pd.DataFrame({'1': [1]}))['1']
21 |
22 | return DummyScorer()
23 |
24 |
25 | def test_fit(scorer):
26 | # case 1, the scorer should be able to be 'fitted' without execptions even if there is only one class:
27 | X = np.array([[1, 2, 3]])
28 | y = pd.Series([0])
29 | scorer.fit(X, y)
30 |
31 | # case 2:
32 | X = np.array([[1, 2, 3], [4, 5, 6]])
33 | y = pd.Series([0, 1])
34 | scorer.fit(X, y)
35 |
36 |
37 | def test_predict_proba(spark_session, scorer):
38 | X = np.ndarray([0])
39 | preds = scorer.predict_proba(X)
40 | assert isinstance(preds, np.ndarray)
41 |
42 | X = spark_session.createDataFrame(pd.DataFrame({'c': [0]}))
43 | preds = scorer.predict_proba(X['c'])
44 | assert isinstance(preds, Column)
45 |
46 | X = pd.DataFrame({})
47 | with pytest.raises(ValueError) as e:
48 | scorer.predict_proba(X)
49 | assert f"{type(X)} is an unsupported datatype for X" == str(e.value)
50 |
51 |
52 | def test__create_default_clf(scorer):
53 | clf = scorer.binary_clf
54 | assert isinstance(clf, BaseEstimator)
55 | assert hasattr(clf, 'fit')
56 | assert hasattr(clf, 'predict_proba')
--------------------------------------------------------------------------------
/test/test_similarity_metrics/test_similarity_metrics.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from pandas.testing import assert_series_equal
3 | from thefuzz.fuzz import token_set_ratio
4 |
5 | from spark_matcher.similarity_metrics import SimilarityMetrics
6 |
7 |
8 | def test__distance_measure_pandas():
9 | df = pd.DataFrame({'name_1': ['aa', 'bb', 'cc'],
10 | 'name_2': ['aa', 'bb c', 'dd'],
11 | 'similarity_metrics': [100, 100, 0]
12 | })
13 | strings_1 = df['name_1'].tolist()
14 | strings_2 = df['name_2'].tolist()
15 | result = SimilarityMetrics._distance_measure_pandas(strings_1, strings_2, token_set_ratio)
16 | assert_series_equal(result, df['similarity_metrics'], check_names=False)
17 |
--------------------------------------------------------------------------------
/test/test_table_checkpointer/test_table_checkpointer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 |
4 |
5 | def test_parquetcheckpointer(spark_session):
6 | from spark_matcher.table_checkpointer import ParquetCheckPointer
7 |
8 | checkpointer = ParquetCheckPointer(spark_session, 'temp_database', 'checkpoint_name')
9 |
10 | pdf = pd.DataFrame({'col_1': ['1', '2', '3'],
11 | 'col_2': ['a', 'b', 'c'],
12 | })
13 | sdf = spark_session.createDataFrame(pdf)
14 |
15 | returned_sdf = checkpointer(sdf, 'checkpoint_name')
16 | returned_pdf = returned_sdf.toPandas().sort_values(['col_1', 'col_2'])
17 | np.array_equal(pdf.values, returned_pdf.values)
18 |
--------------------------------------------------------------------------------
/test/test_utils/test_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import pandas as pd
4 |
5 |
6 | @pytest.mark.parametrize('min_df', [1, 2, 0.2])
7 | def test_get_most_frequent_words(spark_session, min_df):
8 | from spark_matcher.utils import get_most_frequent_words
9 |
10 | sdf = spark_session.createDataFrame(pd.DataFrame({'name': ['Company A ltd', 'Company B ltd', 'Company C']}))
11 |
12 | if (min_df == 1) or (min_df == 0.2):
13 | expected_result = pd.DataFrame({'words': ['Company', 'ltd', 'A', 'B', 'C'],
14 | 'count': [3, 2, 1, 1, 1],
15 | 'df': [1, 2 / 3, 1 / 3, 1 / 3, 1 / 3]})
16 | elif min_df == 2:
17 | expected_result = pd.DataFrame({'words': ['Company', 'ltd'],
18 | 'count': [3, 2],
19 | 'df': [1, 2 / 3]})
20 |
21 | result = get_most_frequent_words(sdf, col_name='name', min_df=min_df, top_n_words=10).sort_values(['count', \
22 | 'words'], ascending=[False, True]).reset_index(drop=True)
23 | pd.testing.assert_frame_equal(expected_result, result)
24 |
25 |
26 | def test_remove_stop_words(spark_session):
27 | from spark_matcher.utils import remove_stopwords
28 |
29 | stopwords = ['ltd', 'bv']
30 |
31 | expected_sdf = spark_session.createDataFrame(
32 | pd.DataFrame({'name': ['Company A ltd', 'Company B ltd', 'Company C bv', 'Company D'],
33 | 'name_wo_stopwords': ['Company A', 'Company B', 'Company C', 'Company D']}))
34 | input_cols = ['name']
35 |
36 | result = remove_stopwords(expected_sdf.select(*input_cols), col_name='name', stopwords=stopwords)
37 | pd.testing.assert_frame_equal(result.toPandas(), expected_sdf.toPandas())
38 |
--------------------------------------------------------------------------------