├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── Makefile ├── README.md ├── conftest.py ├── docs ├── Makefile ├── docs-requirements.txt ├── make.bat └── source │ ├── _static │ └── spark_matcher_logo.png │ ├── api │ ├── activelearner.rst │ ├── blocker.rst │ ├── data.rst │ ├── deduplicator.rst │ ├── matcher.rst │ ├── matching_base.rst │ ├── modules.rst │ ├── sampler.rst │ ├── scorer.rst │ └── similarity_metrics.rst │ ├── conf.py │ ├── example.ipynb │ ├── index.rst │ └── installation_guide.rst ├── examples ├── example_deduplicator.ipynb ├── example_matcher.ipynb ├── example_matcher_advanced.ipynb └── example_stopword_removal.ipynb ├── external_dependencies └── .gitkeep ├── licenses_bundled ├── pyproject.toml ├── setup.py ├── spark_matcher ├── __init__.py ├── activelearner │ ├── __init__.py │ └── active_learner.py ├── blocker │ ├── __init__.py │ ├── block_learner.py │ └── blocking_rules.py ├── config.py ├── data │ ├── .gitkeep │ ├── __init__.py │ ├── acm.csv │ ├── datasets.py │ ├── dblp.csv │ ├── stoxx50.csv │ ├── voters_1.csv │ └── voters_2.csv ├── deduplicator │ ├── __init__.py │ ├── connected_components_calculator.py │ ├── deduplicator.py │ └── hierarchical_clustering.py ├── matcher │ ├── __init__.py │ └── matcher.py ├── matching_base │ ├── __init__.py │ └── matching_base.py ├── sampler │ ├── __init__.py │ └── training_sampler.py ├── scorer │ ├── __init__.py │ └── scorer.py ├── similarity_metrics │ ├── __init__.py │ └── similarity_metrics.py ├── table_checkpointer.py └── utils.py ├── spark_requirements.txt └── test ├── test_active_learner └── test_active_learner.py ├── test_blocker ├── test_block_learner.py └── test_blocking_rules.py ├── test_deduplicator ├── test_deduplicator.py └── test_hierarchical_clustering.py ├── test_matcher └── test_matcher.py ├── test_matching_base └── test_matching_base.py ├── test_sampler └── test_training_sampler.py ├── test_scorer └── test_scorer.py ├── test_similarity_metrics └── test_similarity_metrics.py ├── test_table_checkpointer └── test_table_checkpointer.py └── test_utils └── test_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data-folder 2 | /data/** 3 | !/data/**/ 4 | 5 | # VSCode 6 | .vscode/ 7 | 8 | # Generated documentation 9 | docs/build/* 10 | .pdf 11 | *.pdf 12 | 13 | # To make sure that the folder structure is tracked 14 | !/**/.gitkeep 15 | 16 | # General 17 | *~ 18 | *.log 19 | .env 20 | *.swp 21 | .cache/* 22 | 23 | # Python pickle files 24 | *.pickle 25 | *.pkl 26 | 27 | # Byte-compiled / optimized / DLL files 28 | __pycache__/ 29 | *.py[cod] 30 | *$py.class 31 | 32 | # Jupyter Notebook 33 | .ipynb_checkpoints 34 | 35 | # Pycharm 36 | .idea/ 37 | 38 | # Pytest 39 | *.pytest_cache/ 40 | 41 | *.DS_Store 42 | 43 | 44 | 45 | # Created by https://www.gitignore.io/api/python 46 | # Edit at https://www.gitignore.io/?templates=python 47 | 48 | ### Python ### 49 | # Byte-compiled / optimized / DLL files 50 | __pycache__/ 51 | *.py[cod] 52 | *$py.class 53 | 54 | # C extensions 55 | *.so 56 | 57 | # Distribution / packaging 58 | .Python 59 | build/ 60 | develop-eggs/ 61 | dist/ 62 | downloads/ 63 | eggs/ 64 | .eggs/ 65 | lib/ 66 | lib64/ 67 | parts/ 68 | sdist/ 69 | var/ 70 | wheels/ 71 | pip-wheel-metadata/ 72 | share/python-wheels/ 73 | *.egg-info/ 74 | .installed.cfg 75 | *.egg 76 | MANIFEST 77 | 78 | # PyInstaller 79 | # Usually these files are written by a python script from a template 80 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 81 | *.manifest 82 | *.spec 83 | 84 | # Installer logs 85 | pip-log.txt 86 | pip-delete-this-directory.txt 87 | 88 | # Unit test / coverage reports 89 | htmlcov/ 90 | .tox/ 91 | .nox/ 92 | .coverage 93 | .coverage.* 94 | .cache 95 | nosetests.xml 96 | coverage.xml 97 | *.cover 98 | .hypothesis/ 99 | .pytest_cache/ 100 | 101 | # Translations 102 | *.mo 103 | *.pot 104 | 105 | # Django stuff: 106 | *.log 107 | local_settings.py 108 | db.sqlite3 109 | 110 | # Flask stuff: 111 | instance/ 112 | .webassets-cache 113 | 114 | # Scrapy stuff: 115 | .scrapy 116 | 117 | # Sphinx documentation 118 | docs/_build/ 119 | 120 | # PyBuilder 121 | target/ 122 | 123 | # Jupyter Notebook 124 | .ipynb_checkpoints 125 | 126 | # IPython 127 | profile_default/ 128 | ipython_config.py 129 | 130 | # pyenv 131 | .python-version 132 | 133 | # pipenv 134 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 135 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 136 | # having no cross-platform support, pipenv may install dependencies that don’t work, or not 137 | # install all needed dependencies. 138 | #Pipfile.lock 139 | 140 | # celery beat schedule file 141 | celerybeat-schedule 142 | 143 | # SageMath parsed files 144 | *.sage.py 145 | 146 | # Environments 147 | .env 148 | .venv 149 | env/ 150 | venv/ 151 | ENV/ 152 | env.bak/ 153 | venv.bak/ 154 | 155 | # Spyder project settings 156 | .spyderproject 157 | .spyproject 158 | 159 | # Rope project settings 160 | .ropeproject 161 | 162 | # mkdocs documentation 163 | /site 164 | 165 | # mypy 166 | .mypy_cache/ 167 | .dmypy.json 168 | dmypy.json 169 | 170 | # Pyre type checker 171 | .pyre/ 172 | 173 | # End of https://www.gitignore.io/api/python 174 | /external_dependencies/graphframes*.jar 175 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | version: 2 6 | 7 | sphinx: 8 | configuration: docs/source/conf.py 9 | 10 | python: 11 | version: 3.7 12 | install: 13 | - requirements: docs/docs-requirements.txt 14 | - method: pip 15 | path: . 16 | system_packages: true -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | create_documentation: 2 | make clean 3 | make html 4 | 5 | deploy: 6 | rm -rf dist 7 | rm -rf build 8 | python3 -m build 9 | python3 -m twine upload dist/* 10 | 11 | clean: 12 | rm -f *.o prog3 13 | 14 | html: 15 | sphinx-build -b html docs/source docs/build 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | [![Version](https://img.shields.io/pypi/v/spark-matcher)](https://pypi.org/project/spark-matcher/) 3 | [![Downloads](https://pepy.tech/badge/spark-matcher)](https://pepy.tech/project/spark-matcher) 4 | ![](https://img.shields.io/github/license/ing-bank/spark-matcher) 5 | [![Docs - GitHub.io](https://img.shields.io/static/v1?logo=readthdocs&style=flat&color=pink&label=docs&message=spark-matcher)][#docs-package] 6 | 7 | [#docs-package]: https://spark-matcher.readthedocs.io/en/latest/ 8 | 9 | 10 | ![spark_matcher_logo](https://spark-matcher.readthedocs.io/en/latest/_images/spark_matcher_logo.png) 11 | 12 | # Spark-Matcher 13 | 14 | Spark-Matcher is a scalable entity matching algorithm implemented in PySpark. With Spark-Matcher the user can easily 15 | train an algorithm to solve a custom matching problem. Spark Matcher uses active learning (modAL) to train a 16 | classifier (Scikit-learn) to match entities. In order to deal with the N^2 complexity of matching large tables, blocking is 17 | implemented to reduce the number of pairs. Since the implementation is done in PySpark, Spark-Matcher can deal with 18 | extremely large tables. 19 | 20 | Documentation with examples can be found [here](https://spark-matcher.readthedocs.io/en/latest/). 21 | 22 | Developed by data scientists at ING Analytics, www.ing.com. 23 | 24 | ## Installation 25 | 26 | ### Normal installation 27 | 28 | As Spark-Matcher is intended to be used with large datasets on a Spark cluster, it is assumed that Spark is already 29 | installed. If that is not the case, first install Spark and PyArrow (`pip install pyspark pyarrow`). 30 | 31 | Install Spark-Matcher using PyPi: 32 | 33 | ``` 34 | pip install spark-matcher 35 | ``` 36 | 37 | ### Install with possibility to create documentation 38 | 39 | Pandoc, the general markup converter needs to be available. You may follow the official [Pandoc installations instructions](https://pandoc.org/installing.html) or use conda: 40 | 41 | ``` 42 | conda install -c conda-forge pandoc 43 | ``` 44 | 45 | Then clone the Spark-Matcher repository and add `[doc]` like this: 46 | 47 | ``` 48 | pip install ".[doc]" 49 | ``` 50 | 51 | ### Install to contribute 52 | 53 | Clone this repo and install in editable mode. This also installs PySpark and Jupyterlab: 54 | 55 | ``` 56 | python -m pip install -e ".[dev]" 57 | python setup.py develop 58 | ``` 59 | 60 | ## Documentation 61 | 62 | Documentation can be created using the following command: 63 | 64 | ``` 65 | make create_documentation 66 | ``` 67 | 68 | ## Dependencies 69 | 70 | The usage examples in the `examples` directory contain notebooks that run in local mode. 71 | Using the SparkMatcher in cluster mode, requires sending the SparkMatcher package and several other python packages (see spark_requirements.txt) to the executors. 72 | How to send these dependencies, depends on the cluster. 73 | Please read the instructions and examples of Apache Spark on how to do this: https://spark.apache.org/docs/latest/api/python/user_guide/python_packaging.html. 74 | 75 | SparkMatcher uses `graphframes` under to hood. 76 | Therefore, depending on the spark version, the correct version of `graphframes` needs to be added to the `external_dependencies` directory and to the configuration of the spark session. 77 | As a default, `graphframes` for spark 3.0 is used in the spark sessions in the notebooks in the `examples` directory. 78 | For a different version, see: https://spark-packages.org/package/graphframes/graphframes. 79 | 80 | ## Usage 81 | 82 | Example notebooks are provided in the `examples` directory. 83 | Using the SparkMatcher to find matches between Spark 84 | dataframes `a` and `b` goes as follows: 85 | 86 | ```python 87 | from spark_matcher.matcher import Matching 88 | 89 | myMatcher = Matcher(spark_session, col_names=['name', 'suburb', 'postcode']) 90 | ``` 91 | 92 | Now we are ready for fitting the Matcher object using 'active learning'; this means that the user has to enter whether a 93 | pair is a match or not. You enter 'y' if a pair is a match or 'n' when a pair is not a match. You will be notified when 94 | the model has converged and you can stop training by pressing 'f'. 95 | 96 | ```python 97 | myMatcher.fit(a, b) 98 | ``` 99 | 100 | The Matcher is now trained and can be used to predict on all data. This can be the data used for training or new data 101 | that was not seen by the model yet. 102 | 103 | ```python 104 | result = myMatcher.predict(a, b) 105 | ``` 106 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pyspark.sql import SparkSession 3 | 4 | from spark_matcher.table_checkpointer import ParquetCheckPointer 5 | 6 | 7 | @pytest.fixture(scope="session") 8 | def spark_session(): 9 | spark = (SparkSession 10 | .builder 11 | .appName(str(__file__)) 12 | .getOrCreate() 13 | ) 14 | yield spark 15 | spark.stop() 16 | 17 | 18 | @pytest.fixture(scope="session") 19 | def table_checkpointer(spark_session): 20 | return ParquetCheckPointer(spark_session, 'temp_database') 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/docs-requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==3.5.4 2 | nbsphinx 3 | sphinx_rtd_theme 4 | Jinja2<3.1 5 | pyspark -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/_static/spark_matcher_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/spark-matcher/526fd7983c5b841158bf62d70adaab383d69f0af/docs/source/_static/spark_matcher_logo.png -------------------------------------------------------------------------------- /docs/source/api/activelearner.rst: -------------------------------------------------------------------------------- 1 | activelearner module 2 | ==================================== 3 | 4 | .. automodule:: spark_matcher.activelearner 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Contents: 12 | -------------------------------------------------------------------------------- /docs/source/api/blocker.rst: -------------------------------------------------------------------------------- 1 | blocker module 2 | ============================== 3 | 4 | .. automodule:: spark_matcher.blocker 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Contents: 12 | -------------------------------------------------------------------------------- /docs/source/api/data.rst: -------------------------------------------------------------------------------- 1 | data module 2 | =========================== 3 | 4 | .. automodule:: spark_matcher.data.datasets 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Contents: 12 | -------------------------------------------------------------------------------- /docs/source/api/deduplicator.rst: -------------------------------------------------------------------------------- 1 | deduplicator module 2 | =================================== 3 | 4 | .. automodule:: spark_matcher.deduplicator 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Contents: 12 | -------------------------------------------------------------------------------- /docs/source/api/matcher.rst: -------------------------------------------------------------------------------- 1 | matcher module 2 | ============================== 3 | 4 | .. automodule:: spark_matcher.matcher 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Contents: 12 | -------------------------------------------------------------------------------- /docs/source/api/matching_base.rst: -------------------------------------------------------------------------------- 1 | matching_base module 2 | ===================================== 3 | 4 | .. automodule:: spark_matcher.matching_base 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Contents: -------------------------------------------------------------------------------- /docs/source/api/modules.rst: -------------------------------------------------------------------------------- 1 | spark_matcher 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | activelearner 8 | blocker 9 | data 10 | deduplicator 11 | matcher 12 | matching_base 13 | sampler 14 | scorer 15 | similarity_metrics 16 | -------------------------------------------------------------------------------- /docs/source/api/sampler.rst: -------------------------------------------------------------------------------- 1 | sampler module 2 | ============================== 3 | 4 | .. automodule:: spark_matcher.sampler 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Contents: -------------------------------------------------------------------------------- /docs/source/api/scorer.rst: -------------------------------------------------------------------------------- 1 | scorer module 2 | ============================= 3 | 4 | .. automodule:: spark_matcher.scorer 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Contents: -------------------------------------------------------------------------------- /docs/source/api/similarity_metrics.rst: -------------------------------------------------------------------------------- 1 | similarity_metrics module 2 | ========================================== 3 | 4 | .. automodule:: spark_matcher.similarity_metrics 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | 9 | .. toctree:: 10 | :maxdepth: 4 11 | :caption: Contents: -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('./../..')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'Spark Matcher' 21 | copyright = '2021, Ahmet Bayraktar, Frits Hermans, Stan Leisink' 22 | author = 'Ahmet Bayraktar, Frits Hermans, Stan Leisink' 23 | 24 | # The full version, including alpha/beta/rc tags 25 | release = '0.3.2' 26 | 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = ['nbsphinx', 'sphinx.ext.autodoc', 'sphinx.ext.napoleon'] 34 | 35 | # Add any paths that contain templates here, relative to this directory. 36 | templates_path = ['_templates'] 37 | 38 | # List of patterns, relative to source directory, that match files and 39 | # directories to ignore when looking for source files. 40 | # This pattern also affects html_static_path and html_extra_path. 41 | exclude_patterns = [] 42 | 43 | 44 | # -- Options for HTML output ------------------------------------------------- 45 | 46 | # The theme to use for HTML and HTML Help pages. See the documentation for 47 | # a list of builtin themes. 48 | # 49 | html_theme = 'sphinx_rtd_theme' 50 | 51 | # Add any paths that contain custom static files (such as style sheets) here, 52 | # relative to this directory. They are copied after the builtin static files, 53 | # so a file named "default.css" will overwrite the builtin "default.css". 54 | html_static_path = ['_static'] 55 | 56 | # FKH: 57 | nbsphinx_allow_errors = True -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. Spark Matcher documentation master file, created by 2 | sphinx-quickstart on Tue Nov 23 10:39:10 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | .. image:: _static/spark_matcher_logo.png 7 | 8 | Welcome to Spark Matcher's documentation! 9 | ========================================= 10 | 11 | Spark Matcher is a scalable entity matching algorithm implemented in PySpark. 12 | With Spark Matcher the user can easily train an algorithm to solve a custom matching problem. 13 | Spark Matcher uses active learning (modAL) to train a classifier (Sklearn) to match entities. 14 | In order to deal with the N^2 complexity of matching large tables, blocking is implemented to reduce the number of pairs. 15 | Since the implementation is done in PySpark, Spark Matcher can deal with extremely large tables. 16 | 17 | 18 | .. toctree:: 19 | :maxdepth: 2 20 | :caption: Contents: 21 | 22 | installation_guide 23 | example.ipynb 24 | api/modules 25 | 26 | 27 | 28 | Indices and tables 29 | ================== 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | -------------------------------------------------------------------------------- /docs/source/installation_guide.rst: -------------------------------------------------------------------------------- 1 | How to Install 2 | ************** 3 | 4 | As Spark-Matcher is intended to be used with large datasets on a Spark cluster, it is assumed that Spark is already 5 | installed. If that is not the case, first install Spark and PyArrow (:code:`pip install pyspark pyarrow`). 6 | 7 | Install Spark-Matcher from PyPi: 8 | 9 | .. code-block:: bash 10 | 11 | pip install spark-matcher 12 | -------------------------------------------------------------------------------- /examples/example_matcher_advanced.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7601f10c", 6 | "metadata": {}, 7 | "source": [ 8 | "# Spark-Matcher advanced Matcher example " 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "71300ce4", 14 | "metadata": {}, 15 | "source": [ 16 | "This notebook shows how to use the `spark_matcher` for matching entities with more customized settings. First we create a Spark session:" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "b29a36e5", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "%config Completer.use_jedi = False # for proper autocompletion\n", 27 | "from pyspark.sql import SparkSession" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "b5d8664e", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "spark = (SparkSession\n", 38 | " .builder\n", 39 | " .master(\"local\")\n", 40 | " .enableHiveSupport()\n", 41 | " .getOrCreate())" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "id": "0f72cb2c", 47 | "metadata": {}, 48 | "source": [ 49 | "Load the example data:" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "fc90dcba", 56 | "metadata": {}, 57 | "outputs": [], 58 | "source": [ 59 | "from spark_matcher.data import load_data" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "26d826fd", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "a, b = load_data(spark)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "eec05ac1", 75 | "metadata": {}, 76 | "source": [ 77 | "We now create a `Matcher` object with our own string similarity metric and blocking rules:" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "e4c9b675", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "from spark_matcher.matcher import Matcher" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "9935a8a6", 93 | "metadata": {}, 94 | "source": [ 95 | "First create a string similarity metric that checks if the first word is a perfect match:" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "8b1c0150", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "def first_word(string_1, string_2):\n", 106 | " return float(string_1.split()[0]==string_2.split()[0])" 107 | ] 108 | }, 109 | { 110 | "cell_type": "markdown", 111 | "id": "98b9d059", 112 | "metadata": {}, 113 | "source": [ 114 | "We also want to use the `token_sort_ratio` from the `thefuzz` package. Note that this package should be available on the Spark worker nodes." 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "id": "04f54a91", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "from thefuzz.fuzz import token_sort_ratio" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "id": "6f1b85f8", 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "field_info={'name':[first_word, token_sort_ratio], 'suburb':[token_sort_ratio], 'postcode':[token_sort_ratio]}" 135 | ] 136 | }, 137 | { 138 | "cell_type": "markdown", 139 | "id": "77b1623b", 140 | "metadata": {}, 141 | "source": [ 142 | "Moreover, we want to limit blocking to the 'title' field only by looking at the first 3 character and the first 3 words:" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "c3f6c974", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "from spark_matcher.blocker.blocking_rules import FirstNChars, FirstNWords" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "id": "736c3643", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "blocking_rules=[FirstNChars('name', 3), FirstNWords('name', 3)]" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "id": "3b98dbb6", 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "myMatcher = Matcher(spark, field_info=field_info, blocking_rules=blocking_rules, checkpoint_dir='path_to_checkpoints')" 173 | ] 174 | }, 175 | { 176 | "cell_type": "markdown", 177 | "id": "22440b47", 178 | "metadata": {}, 179 | "source": [ 180 | "Now we are ready for fitting the `Matcher` object:" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": null, 186 | "id": "9a5999c6", 187 | "metadata": { 188 | "tags": [] 189 | }, 190 | "outputs": [], 191 | "source": [ 192 | "myMatcher.fit(a, b)" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "id": "a0cfadaa", 198 | "metadata": {}, 199 | "source": [ 200 | "This fitted model can now be use to predict:" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "id": "594fd2cd", 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "result = myMatcher.predict(a, b)" 211 | ] 212 | } 213 | ], 214 | "metadata": { 215 | "kernelspec": { 216 | "display_name": "Python 3 (ipykernel)", 217 | "language": "python", 218 | "name": "python3" 219 | }, 220 | "language_info": { 221 | "codemirror_mode": { 222 | "name": "ipython", 223 | "version": 3 224 | }, 225 | "file_extension": ".py", 226 | "mimetype": "text/x-python", 227 | "name": "python", 228 | "nbconvert_exporter": "python", 229 | "pygments_lexer": "ipython3", 230 | "version": "3.8.11" 231 | } 232 | }, 233 | "nbformat": 4, 234 | "nbformat_minor": 5 235 | } 236 | -------------------------------------------------------------------------------- /examples/example_stopword_removal.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "5287ca32", 6 | "metadata": { 7 | "tags": [ 8 | "keep_output" 9 | ] 10 | }, 11 | "source": [ 12 | "# Spark-Matcher advanced example " 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "15c0ec9a", 18 | "metadata": { 19 | "tags": [ 20 | "keep_output" 21 | ] 22 | }, 23 | "source": [ 24 | "This notebook shows how to use the `spark_matcher` with more customized settings. First we create a Spark session:" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "53d5be5d", 31 | "metadata": { 32 | "tags": [ 33 | "keep_output" 34 | ] 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "%config Completer.use_jedi = False # for proper autocompletion\n", 39 | "from pyspark.sql import SparkSession" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "7a6d7d03", 46 | "metadata": { 47 | "tags": [ 48 | "keep_output" 49 | ] 50 | }, 51 | "outputs": [], 52 | "source": [ 53 | "spark = (SparkSession\n", 54 | " .builder\n", 55 | " .master(\"local\")\n", 56 | " .enableHiveSupport()\n", 57 | " .getOrCreate())" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "fd1f147c", 63 | "metadata": {}, 64 | "source": [ 65 | "Load the example data:" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "43d1d5a4", 72 | "metadata": { 73 | "tags": [ 74 | "keep_output" 75 | ] 76 | }, 77 | "outputs": [], 78 | "source": [ 79 | "from spark_matcher.data import load_data" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "id": "a6c60c7c", 85 | "metadata": {}, 86 | "source": [ 87 | "We use the 'library' data and remove the (numeric) 'year' column:" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "ba4878ba", 94 | "metadata": { 95 | "tags": [ 96 | "keep_output" 97 | ] 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "a, b = load_data(spark, kind='library')\n", 102 | "a, b = a.drop('year'), b.drop('year')" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "f21332ee", 109 | "metadata": { 110 | "tags": [ 111 | "keep_output" 112 | ] 113 | }, 114 | "outputs": [ 115 | { 116 | "data": { 117 | "text/html": [ 118 | "
\n", 119 | "\n", 132 | "\n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | "
titleauthorsvenue
0The WASA2 object-oriented workflow management ...Gottfried Vossen, Mathias WeskeInternational Conference on Management of Data
1A user-centered interface for querying distrib...Isabel F. Cruz, Kimberly M. JamesInternational Conference on Management of Data
2World Wide Database-integrating the Web, CORBA...Athman Bouguettaya, Boualem Benatallah, Lily H...International Conference on Management of Data
\n", 162 | "
" 163 | ], 164 | "text/plain": [ 165 | " title \\\n", 166 | "0 The WASA2 object-oriented workflow management ... \n", 167 | "1 A user-centered interface for querying distrib... \n", 168 | "2 World Wide Database-integrating the Web, CORBA... \n", 169 | "\n", 170 | " authors \\\n", 171 | "0 Gottfried Vossen, Mathias Weske \n", 172 | "1 Isabel F. Cruz, Kimberly M. James \n", 173 | "2 Athman Bouguettaya, Boualem Benatallah, Lily H... \n", 174 | "\n", 175 | " venue \n", 176 | "0 International Conference on Management of Data \n", 177 | "1 International Conference on Management of Data \n", 178 | "2 International Conference on Management of Data " 179 | ] 180 | }, 181 | "execution_count": null, 182 | "metadata": {}, 183 | "output_type": "execute_result" 184 | } 185 | ], 186 | "source": [ 187 | "a.limit(3).toPandas()" 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "id": "962ef4b5", 193 | "metadata": {}, 194 | "source": [ 195 | "`spark_matcher` is shipped with a utility function to get the most frequenty occurring words in a Spark dataframe column. We apply this to the `venue` column:" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": null, 201 | "id": "0bdd312e", 202 | "metadata": { 203 | "tags": [ 204 | "keep_output" 205 | ] 206 | }, 207 | "outputs": [], 208 | "source": [ 209 | "from spark_matcher.utils import get_most_frequent_words" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "id": "81bc2861", 216 | "metadata": { 217 | "tags": [ 218 | "keep_output" 219 | ] 220 | }, 221 | "outputs": [ 222 | { 223 | "data": { 224 | "text/html": [ 225 | "
\n", 226 | "\n", 239 | "\n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | "
wordscountdf
0SIGMOD19170.390428
1Data16400.334012
2Conference16030.326477
3VLDB12890.262525
4on11350.231161
5Record11110.226273
6International10010.203870
78580.174745
8Large8430.171690
9Very8430.171690
\n", 311 | "
" 312 | ], 313 | "text/plain": [ 314 | " words count df\n", 315 | "0 SIGMOD 1917 0.390428\n", 316 | "1 Data 1640 0.334012\n", 317 | "2 Conference 1603 0.326477\n", 318 | "3 VLDB 1289 0.262525\n", 319 | "4 on 1135 0.231161\n", 320 | "5 Record 1111 0.226273\n", 321 | "6 International 1001 0.203870\n", 322 | "7 858 0.174745\n", 323 | "8 Large 843 0.171690\n", 324 | "9 Very 843 0.171690" 325 | ] 326 | }, 327 | "execution_count": null, 328 | "metadata": {}, 329 | "output_type": "execute_result" 330 | } 331 | ], 332 | "source": [ 333 | "frequent_words = get_most_frequent_words(a.unionByName(b), col_name='venue')\n", 334 | "frequent_words.head(10)" 335 | ] 336 | }, 337 | { 338 | "cell_type": "markdown", 339 | "id": "7e0a1379", 340 | "metadata": {}, 341 | "source": [ 342 | "Based on this list, we decide that we want to consider the words 'conference' and 'international' as stopwords. The utility function `remove_stopwords` does this job:" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "id": "00cb0b30", 349 | "metadata": { 350 | "tags": [ 351 | "keep_output" 352 | ] 353 | }, 354 | "outputs": [], 355 | "source": [ 356 | "from spark_matcher.utils import remove_stopwords" 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": null, 362 | "id": "a386ead6", 363 | "metadata": { 364 | "tags": [ 365 | "keep_output" 366 | ] 367 | }, 368 | "outputs": [], 369 | "source": [ 370 | "stopwords = ['conference', 'international']" 371 | ] 372 | }, 373 | { 374 | "cell_type": "code", 375 | "execution_count": null, 376 | "id": "18f3322e", 377 | "metadata": { 378 | "tags": [ 379 | "keep_output" 380 | ] 381 | }, 382 | "outputs": [], 383 | "source": [ 384 | "a = remove_stopwords(a, col_name='venue', stopwords=stopwords).drop('venue')\n", 385 | "b = remove_stopwords(b, col_name='venue', stopwords=stopwords).drop('venue')" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "id": "816e6201", 391 | "metadata": {}, 392 | "source": [ 393 | "A new column `venue_wo_stopwords` is created in which the stopwords are removed:" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "id": "d2c8d0c9", 400 | "metadata": { 401 | "tags": [ 402 | "keep_output" 403 | ] 404 | }, 405 | "outputs": [ 406 | { 407 | "data": { 408 | "text/html": [ 409 | "
\n", 410 | "\n", 423 | "\n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | "
titleauthorsvenue_wo_stopwords
0The WASA2 object-oriented workflow management ...Gottfried Vossen, Mathias Weskeon Management of Data
1A user-centered interface for querying distrib...Isabel F. Cruz, Kimberly M. Jameson Management of Data
2World Wide Database-integrating the Web, CORBA...Athman Bouguettaya, Boualem Benatallah, Lily H...on Management of Data
\n", 453 | "
" 454 | ], 455 | "text/plain": [ 456 | " title \\\n", 457 | "0 The WASA2 object-oriented workflow management ... \n", 458 | "1 A user-centered interface for querying distrib... \n", 459 | "2 World Wide Database-integrating the Web, CORBA... \n", 460 | "\n", 461 | " authors venue_wo_stopwords \n", 462 | "0 Gottfried Vossen, Mathias Weske on Management of Data \n", 463 | "1 Isabel F. Cruz, Kimberly M. James on Management of Data \n", 464 | "2 Athman Bouguettaya, Boualem Benatallah, Lily H... on Management of Data " 465 | ] 466 | }, 467 | "execution_count": null, 468 | "metadata": {}, 469 | "output_type": "execute_result" 470 | } 471 | ], 472 | "source": [ 473 | "a.limit(3).toPandas()" 474 | ] 475 | }, 476 | { 477 | "cell_type": "markdown", 478 | "id": "f1c71962", 479 | "metadata": {}, 480 | "source": [ 481 | "We use the `spark_matcher` to link the records in dataframe `a` with the records in dataframe `b`. Instead of the `venue` column, we now use the newly created `venue_wo_stopwords` column." 482 | ] 483 | }, 484 | { 485 | "cell_type": "code", 486 | "execution_count": null, 487 | "id": "34d71374", 488 | "metadata": { 489 | "tags": [ 490 | "keep_output" 491 | ] 492 | }, 493 | "outputs": [], 494 | "source": [ 495 | "from spark_matcher.matcher import Matcher" 496 | ] 497 | }, 498 | { 499 | "cell_type": "code", 500 | "execution_count": null, 501 | "id": "833caee8", 502 | "metadata": { 503 | "tags": [ 504 | "keep_output" 505 | ] 506 | }, 507 | "outputs": [], 508 | "source": [ 509 | "myMatcher = Matcher(spark, col_names=['title', 'authors', 'venue_wo_stopwords'], checkpoint_dir='path_to_checkpoints')" 510 | ] 511 | }, 512 | { 513 | "cell_type": "markdown", 514 | "id": "4cf920a9", 515 | "metadata": { 516 | "tags": [ 517 | "keep_output" 518 | ] 519 | }, 520 | "source": [ 521 | "Now we are ready for fitting the `Matcher` object." 522 | ] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "id": "ef8e304a", 528 | "metadata": { 529 | "tags": [ 530 | "keep_output" 531 | ] 532 | }, 533 | "outputs": [], 534 | "source": [ 535 | "myMatcher.fit(a, b)" 536 | ] 537 | }, 538 | { 539 | "cell_type": "markdown", 540 | "id": "5b0e4425", 541 | "metadata": {}, 542 | "source": [ 543 | "The `Matcher` is now trained and can be used to predict on all data as usual:" 544 | ] 545 | }, 546 | { 547 | "cell_type": "code", 548 | "execution_count": null, 549 | "id": "f342c880", 550 | "metadata": { 551 | "tags": [ 552 | "keep_output" 553 | ] 554 | }, 555 | "outputs": [], 556 | "source": [ 557 | "result = myMatcher.predict(a, b, threshold=0.5, top_n=3)" 558 | ] 559 | }, 560 | { 561 | "cell_type": "markdown", 562 | "id": "2a68b63e", 563 | "metadata": {}, 564 | "source": [ 565 | "Now let's have a look at the results:" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": null, 571 | "id": "b3057fd0", 572 | "metadata": { 573 | "tags": [ 574 | "keep_output" 575 | ] 576 | }, 577 | "outputs": [], 578 | "source": [ 579 | "result_pdf = result.toPandas()" 580 | ] 581 | }, 582 | { 583 | "cell_type": "code", 584 | "execution_count": null, 585 | "id": "2aae3ea3", 586 | "metadata": { 587 | "tags": [ 588 | "keep_output" 589 | ] 590 | }, 591 | "outputs": [], 592 | "source": [ 593 | "result_pdf.sort_values('score')" 594 | ] 595 | } 596 | ], 597 | "metadata": { 598 | "kernelspec": { 599 | "display_name": "Python 3 (ipykernel)", 600 | "language": "python", 601 | "name": "python3" 602 | }, 603 | "language_info": { 604 | "codemirror_mode": { 605 | "name": "ipython", 606 | "version": 3 607 | }, 608 | "file_extension": ".py", 609 | "mimetype": "text/x-python", 610 | "name": "python", 611 | "nbconvert_exporter": "python", 612 | "pygments_lexer": "ipython3", 613 | "version": "3.8.13" 614 | } 615 | }, 616 | "nbformat": 4, 617 | "nbformat_minor": 5 618 | } 619 | -------------------------------------------------------------------------------- /external_dependencies/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/spark-matcher/526fd7983c5b841158bf62d70adaab383d69f0af/external_dependencies/.gitkeep -------------------------------------------------------------------------------- /licenses_bundled: -------------------------------------------------------------------------------- 1 | The Spark-Matcher repository and source distributions bundle several libraries that are 2 | compatibly licensed. We list these here. 3 | 4 | Name: dill 5 | License: BSD License (3-clause BSD) 6 | 7 | Name: graphframes 8 | License: MIT License 9 | 10 | Name: jupyterlab 11 | License: BSD License 12 | 13 | Name: modAL 14 | License: MIT License 15 | 16 | Name: multipledispatch 17 | License: BSD License 18 | 19 | Name: nbsphinx 20 | License: MIT License 21 | 22 | Name: numpy 23 | License: BSD License (3-clause BSD) 24 | 25 | Name: pandas 26 | License: BSD License (3-clause BSD) 27 | 28 | Name: pyarrow 29 | License: Apache Software License version 2 30 | 31 | Name: pyspark 32 | License: Apache Software License version 2 33 | 34 | Name: pytest 35 | License: MIT License 36 | 37 | Name: python-Levenshtein 38 | License: GNU General Public License v2 39 | 40 | Name: scikit-learn 41 | License: OSI Approved (new BSD) 42 | 43 | Name: scipy 44 | License: BSD License (3-clause BSD) 45 | 46 | Name: sphinx 47 | License: BSD License (2-clause BSD) 48 | 49 | Name: sphinx_rtd_theme 50 | License: MIT License 51 | 52 | Name: thefuzz 53 | License: GNU General Public License v2 54 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42"] 3 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | base_packages = [ 4 | 'pandas', 5 | 'numpy', 6 | 'scikit-learn', 7 | 'python-Levenshtein', 8 | 'thefuzz', 9 | 'modAL-python', 10 | 'pytest', 11 | 'multipledispatch', 12 | 'dill', 13 | 'graphframes', 14 | 'scipy' 15 | ] 16 | 17 | doc_packages = [ 18 | 'sphinx', 19 | 'nbsphinx', 20 | 'sphinx_rtd_theme' 21 | ] 22 | 23 | util_packages = [ 24 | 'pyspark', 25 | 'pyarrow', 26 | 'jupyterlab' 27 | ] 28 | 29 | base_doc_packages = base_packages + doc_packages 30 | dev_packages = base_packages + doc_packages + util_packages 31 | 32 | with open("README.md", "r", encoding="utf-8") as fh: 33 | long_description = fh.read() 34 | 35 | setup(name='Spark-Matcher', 36 | version='0.3.2', 37 | author="Ahmet Bayraktar, Stan Leisink, Frits Hermans", 38 | description="Record matching and entity resolution at scale in Spark", 39 | long_description=long_description, 40 | long_description_content_type="text/markdown", 41 | classifiers=[ 42 | "Programming Language :: Python :: 3", 43 | "License :: OSI Approved :: MIT License", 44 | "Operating System :: OS Independent", 45 | ], 46 | packages=find_packages(exclude=['examples']), 47 | package_data={"spark_matcher": ["data/*.csv"]}, 48 | install_requires=base_packages, 49 | extras_require={ 50 | "base": base_packages, 51 | "doc": base_doc_packages, 52 | "dev": dev_packages, 53 | }, 54 | python_requires=">=3.7", 55 | ) 56 | -------------------------------------------------------------------------------- /spark_matcher/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.2" 2 | -------------------------------------------------------------------------------- /spark_matcher/activelearner/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['ScoringLearner'] 2 | 3 | from .active_learner import ScoringLearner -------------------------------------------------------------------------------- /spark_matcher/activelearner/active_learner.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | from typing import List, Optional, Union 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from modAL.models import ActiveLearner 10 | from modAL.uncertainty import uncertainty_sampling 11 | from pyspark.sql import DataFrame 12 | from sklearn.base import BaseEstimator 13 | 14 | 15 | class ScoringLearner: 16 | """ 17 | Class to train a string matching model using active learning. 18 | Attributes: 19 | col_names: column names used for matching 20 | scorer: the scorer to be used in the active learning loop 21 | min_nr_samples: minimum number of responses required before classifier convergence is tested 22 | uncertainty_threshold: threshold on the uncertainty of the classifier during active learning, 23 | used for determining if the model has converged 24 | uncertainty_improvement_threshold: threshold on the uncertainty improvement of classifier during active 25 | learning, used for determining if the model has converged 26 | n_uncertainty_improvement: span of iterations to check for largest difference between uncertainties 27 | n_queries: maximum number of iterations to be done for the active learning session 28 | sampling_method: sampling method to be used for the active learning session 29 | verbose: sets verbosity 30 | """ 31 | def __init__(self, col_names: List[str], scorer: BaseEstimator, min_nr_samples: int = 10, 32 | uncertainty_threshold: float = 0.1, uncertainty_improvement_threshold: float = 0.01, 33 | n_uncertainty_improvement: int = 5, n_queries: int = 9999, sampling_method=uncertainty_sampling, 34 | verbose: int = 0): 35 | self.col_names = col_names 36 | self.learner = ActiveLearner( 37 | estimator=scorer, 38 | query_strategy=sampling_method 39 | ) 40 | self.counter_total = 0 41 | self.counter_positive = 0 42 | self.counter_negative = 0 43 | self.min_nr_samples = min_nr_samples 44 | self.uncertainty_threshold = uncertainty_threshold 45 | self.uncertainty_improvement_threshold = uncertainty_improvement_threshold 46 | self.n_uncertainty_improvement = n_uncertainty_improvement 47 | self.uncertainties = [] 48 | self.n_queries = n_queries 49 | self.verbose = verbose 50 | 51 | def _input_assert(self, message: str, choices: List[str]) -> str: 52 | """ 53 | Adds functionality to the python function `input` to limit the choices that can be returned 54 | Args: 55 | message: message to user 56 | choices: list containing possible choices that can be returned 57 | Returns: 58 | input returned by user 59 | """ 60 | output = input(message).lower() 61 | if output not in choices: 62 | print(f"Wrong input! Your input should be one of the following: {', '.join(choices)}") 63 | return self._input_assert(message, choices) 64 | else: 65 | return output 66 | 67 | def _get_uncertainty_improvement(self) -> Optional[float]: 68 | """ 69 | Calculates the uncertainty differences during active learning. The largest difference over the `last_n` 70 | iterations is returned. The aim of this function is to suggest early stopping of active learning. 71 | 72 | Returns: largest uncertainty update in `last_n` iterations 73 | 74 | """ 75 | uncertainties = np.asarray(self.uncertainties) 76 | abs_differences = abs(uncertainties[1:] - uncertainties[:-1]) 77 | return max(abs_differences[-self.n_uncertainty_improvement:]) 78 | 79 | def _is_converged(self) -> bool: 80 | """ 81 | Checks whether the model is converged by comparing the last uncertainty value with the `uncertainty_threshold` 82 | and comparing the `last_n` uncertainty improvements with the `uncertainty_improvement_threshold`. These checks 83 | are only performed if at least `min_nr_samples` are labelled. 84 | 85 | Returns: 86 | boolean indicating whether the model is converged 87 | 88 | """ 89 | if (self.counter_total >= self.min_nr_samples) and ( 90 | len(self.uncertainties) >= self.n_uncertainty_improvement + 1): 91 | uncertainty_improvement = self._get_uncertainty_improvement() 92 | if (self.uncertainties[-1] <= self.uncertainty_threshold) or ( 93 | uncertainty_improvement <= self.uncertainty_improvement_threshold): 94 | return True 95 | else: 96 | return False 97 | 98 | def _get_active_learning_input(self, query_inst: pd.DataFrame) -> np.ndarray: 99 | """ 100 | Obtain user input for a query during active learning. 101 | Args: 102 | query_inst: query as provided by the ActiveLearner instance 103 | Returns: label of user input '1' or '0' as yes or no 104 | 'p' to go to previous 105 | 'f' to finish 106 | 's' to skip the query 107 | """ 108 | print(f'\nNr. {self.counter_total + 1} ({self.counter_positive}+/{self.counter_negative}-)') 109 | print("Is this a match? (y)es, (n)o, (p)revious, (s)kip, (f)inish") 110 | print('') 111 | for element in [1, 2]: 112 | for col_name in self.col_names: 113 | print(f'{col_name}_{element}' + ': ' + query_inst[f'{col_name}_{element}'].iloc[0]) 114 | print('') 115 | user_input = self._input_assert("", ['y', 'n', 'p', 'f', 's']) 116 | # replace 'y' and 'n' with '1' and '0' to make them valid y labels 117 | user_input = user_input.replace('y', '1').replace('n', '0') 118 | 119 | y_new = np.array([user_input]) 120 | return y_new 121 | 122 | def _calculate_uncertainty(self, x) -> None: 123 | # take the maximum probability of the predicted classes as proxy of the confidence of the classifier 124 | confidence = self.predict_proba(x).max(axis=1)[0] 125 | if self.verbose: 126 | print('uncertainty:', 1 - confidence) 127 | self.uncertainties.append(1 - confidence) 128 | 129 | def _show_min_max_scores(self, X: pd.DataFrame) -> None: 130 | """ 131 | Prints the lowest and the highest logistic regression scores on train data during active learning. 132 | 133 | Args: 134 | X: Pandas dataframe containing train data that is available for labelling duringg active learning 135 | """ 136 | X_all = pd.concat((X, self.train_samples)) 137 | pred_max = self.learner.predict_proba(np.array(X_all['similarity_metrics'].tolist())).max(axis=0) 138 | print(f'lowest score: {1 - pred_max[0]:.3f}') 139 | print(f'highest score: {pred_max[1]:.3f}') 140 | 141 | def _label_perfect_train_matches(self, identical_records: pd.DataFrame) -> None: 142 | """ 143 | To prevent asking labels for the perfect matches that were created by setting `n_perfect_train_matches`, these 144 | are provided to the active learner upfront. 145 | 146 | Args: 147 | identical_records: Pandas dataframe containing perfect matches 148 | 149 | """ 150 | identical_records['y'] = '1' 151 | self.learner.teach(np.array(identical_records['similarity_metrics'].values.tolist()), 152 | identical_records['y'].values) 153 | self.train_samples = pd.concat([self.train_samples, identical_records]) 154 | 155 | def fit(self, X: pd.DataFrame) -> 'ScoringLearner': 156 | """ 157 | Fit ScoringLearner instance on pairs of strings 158 | Args: 159 | X: Pandas dataframe containing pairs of strings and distance metrics of paired strings 160 | """ 161 | self.train_samples = pd.DataFrame([]) 162 | query_inst_prev = None 163 | 164 | # automatically label all perfect train matches: 165 | identical_records = X[X['perfect_train_match']].copy() 166 | self._label_perfect_train_matches(identical_records) 167 | X = X.drop(identical_records.index).reset_index(drop=True) # remove identical records to avoid double labelling 168 | 169 | for i in range(self.n_queries): 170 | query_idx, query_inst = self.learner.query(np.array(X['similarity_metrics'].tolist())) 171 | 172 | if self.learner.estimator.fitted_: 173 | # the uncertainty calculations need a fitted estimator 174 | # however it can occur that the estimator can only be fit after a couple rounds of querying 175 | self._calculate_uncertainty(query_inst) 176 | if self.verbose >= 2: 177 | self._show_min_max_scores(X) 178 | 179 | y_new = self._get_active_learning_input(X.iloc[query_idx]) 180 | if y_new == 'p': # use previous (input is 'p') 181 | y_new = self._get_active_learning_input(query_inst_prev) 182 | elif y_new == 'f': # finish labelling (input is 'f') 183 | break 184 | query_inst_prev = X.iloc[query_idx] 185 | if y_new != 's': # skip case (input is 's') 186 | self.learner.teach(np.asarray([X.iloc[query_idx]['similarity_metrics'].iloc[0]]), np.asarray(y_new)) 187 | train_sample_to_add = X.iloc[query_idx].copy() 188 | train_sample_to_add['y'] = y_new 189 | self.train_samples = pd.concat([self.train_samples, train_sample_to_add]) 190 | 191 | X = X.drop(query_idx).reset_index(drop=True) 192 | 193 | if self._is_converged(): 194 | print("Classifier converged, enter 'f' to stop training") 195 | 196 | if y_new == '1': 197 | self.counter_positive += 1 198 | elif y_new == '0': 199 | self.counter_negative += 1 200 | self.counter_total += 1 201 | return self 202 | 203 | def predict_proba(self, X: Union[DataFrame, pd.DataFrame]) -> Union[DataFrame, pd.DataFrame]: 204 | """ 205 | Predict probabilities on new data whether the pairs are a match or not 206 | Args: 207 | X: Pandas or Spark dataframe to predict on 208 | Returns: match probabilities 209 | """ 210 | return self.learner.estimator.predict_proba(X) 211 | -------------------------------------------------------------------------------- /spark_matcher/blocker/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['BlockLearner', 'BlockingRule'] 2 | 3 | from .block_learner import BlockLearner 4 | from .blocking_rules import BlockingRule -------------------------------------------------------------------------------- /spark_matcher/blocker/block_learner.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | from typing import List, Tuple, Optional, Union 6 | 7 | from pyspark.sql import DataFrame, functions as F 8 | 9 | from spark_matcher.blocker.blocking_rules import BlockingRule 10 | from spark_matcher.table_checkpointer import TableCheckpointer 11 | 12 | 13 | class BlockLearner: 14 | """ 15 | Class to learn blocking rules from training data. 16 | 17 | Attributes: 18 | blocking_rules: list of `BlockingRule` objects that are taken into account during block learning 19 | recall: the minimum required percentage of training pairs that are covered by the learned blocking rules 20 | verbose: set verbosity 21 | """ 22 | def __init__(self, blocking_rules: List[BlockingRule], recall: float, 23 | table_checkpointer: Optional[TableCheckpointer] = None, verbose=0): 24 | self.table_checkpointer = table_checkpointer 25 | self.blocking_rules = blocking_rules 26 | self.recall = recall 27 | self.cover_blocking_rules = None 28 | self.fitted = False 29 | self.full_set = None 30 | self.full_set_size = None 31 | self.verbose = verbose 32 | 33 | def _greedy_set_coverage(self): 34 | """ 35 | This method solves the `set cover problem` with a greedy algorithm. It identifies a subset of blocking-rules 36 | that cover `recall` * |full_set| (|full_set| stands for cardinality of the full_set)percent of all the elements 37 | in the original (full) set. 38 | """ 39 | # sort the blocking rules to start the algorithm with the ones that have most coverage: 40 | _sorted_blocking_rules = sorted(self.blocking_rules, key=lambda bl: bl.training_coverage_size, reverse=True) 41 | 42 | self.cover_blocking_rules = [_sorted_blocking_rules.pop(0)] 43 | self.cover_set = self.cover_blocking_rules[0].training_coverage 44 | 45 | for blocking_rule in _sorted_blocking_rules: 46 | # check if the required recall is already reached: 47 | if len(self.cover_set) >= int(self.recall * self.full_set_size): 48 | break 49 | # check if subset is dominated by the cover_set: 50 | if blocking_rule.training_coverage.issubset(self.cover_set): 51 | continue 52 | self.cover_set = self.cover_set.union(blocking_rule.training_coverage) 53 | self.cover_blocking_rules.append(blocking_rule) 54 | 55 | def fit(self, sdf: DataFrame) -> 'BlockLearner': 56 | """ 57 | This method fits, i.e. learns, the blocking rules that are needed to cover `recall` percent of the training 58 | set pairs. The fitting is done by solving the set-cover problem. It is solved by using a greedy algorithm. 59 | 60 | Args: 61 | sdf: a labelled training set containing pairs. 62 | 63 | Returns: 64 | the object itself 65 | """ 66 | # verify whether `row_id` and `label` are columns of `sdf_1` 67 | if 'row_id' not in sdf.columns: 68 | raise AssertionError('`row_id` is not present as a column of sdf_1') 69 | 70 | if 'label' not in sdf.columns: 71 | raise AssertionError('`label` is not present as a column of sdf_1') 72 | 73 | 74 | # determine the full set of pairs in the training data that have positive labels from active learning: 75 | sdf = ( 76 | sdf 77 | .filter(F.col('label') == 1) 78 | .persist() # break the lineage to avoid recomputing since sdf_1 is used many times during the fitting 79 | ) 80 | self.full_set = set(sdf.select('row_id').toPandas()['row_id']) 81 | # determine the cardinality of the full_set, i.e. |full_set|: 82 | self.full_set_size = len(self.full_set) 83 | 84 | # calculate the training coverage for each blocking rule: 85 | self.blocking_rules = ( 86 | list( 87 | map( 88 | lambda x: x.calculate_training_set_coverage(sdf), 89 | self.blocking_rules 90 | ) 91 | ) 92 | ) 93 | 94 | # use a greedy set cover algorithm to select a subset of the blocking rules that cover `recall` * |full_set| 95 | self._greedy_set_coverage() 96 | if self.verbose: 97 | print('Blocking rules:', ", ".join([x.__repr__() for x in self.cover_blocking_rules])) 98 | self.fitted = True 99 | return self 100 | 101 | def _create_blocks(self, sdf: DataFrame) -> List[DataFrame]: 102 | """ 103 | This method creates a list of blocked data. Blocked data is created by applying the learned blocking rules on 104 | the input dataframe. 105 | 106 | Args: 107 | sdf: dataframe containing records 108 | 109 | Returns: 110 | A list of blocked dataframes, i.e. list of dataframes containing block-keys 111 | """ 112 | sdf_blocks = [] 113 | for blocking_rule in self.cover_blocking_rules: 114 | sdf_blocks.append(blocking_rule.create_block_key(sdf)) 115 | return sdf_blocks 116 | 117 | @staticmethod 118 | def _create_block_table(blocks: List[DataFrame]) -> DataFrame: 119 | """ 120 | This method unifies the blocked data into a single dataframe. 121 | Args: 122 | blocks: containing blocked dataframes, i.e. dataframes containing block-keys 123 | 124 | Returns: 125 | a unified dataframe with all the block-keys 126 | """ 127 | block_table = blocks[0] 128 | for block in blocks[1:]: 129 | block_table = block_table.unionByName(block) 130 | return block_table 131 | 132 | def transform(self, sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None) -> Union[DataFrame, 133 | Tuple[DataFrame, DataFrame]]: 134 | """ 135 | This method adds the block-keys to the input dataframes. It applies all the learned blocking rules on the 136 | input data and unifies the results. The result of this method is/are the input dataframe(s) containing the 137 | block-keys from the learned blocking rules. 138 | 139 | Args: 140 | sdf_1: dataframe containing records 141 | sdf_2: dataframe containing records 142 | 143 | Returns: 144 | dataframe(s) containing block-keys from the learned blocking-rules 145 | """ 146 | if not self.fitted: 147 | raise ValueError('BlockLearner is not yet fitted') 148 | 149 | sdf_1_blocks = self._create_blocks(sdf_1) 150 | sdf_1_blocks = self._create_block_table(sdf_1_blocks) 151 | 152 | if sdf_2: 153 | sdf_1_blocks = self.table_checkpointer(sdf_1_blocks, checkpoint_name='sdf_1_blocks') 154 | sdf_2_blocks = self._create_blocks(sdf_2) 155 | sdf_2_blocks = self._create_block_table(sdf_2_blocks) 156 | 157 | return sdf_1_blocks, sdf_2_blocks 158 | 159 | return sdf_1_blocks 160 | -------------------------------------------------------------------------------- /spark_matcher/blocker/blocking_rules.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | import abc 6 | 7 | from pyspark.sql import Column, DataFrame, functions as F, types as T 8 | 9 | 10 | class BlockingRule(abc.ABC): 11 | """ 12 | Abstract class for blocking rules. This class contains all the base functionality for blocking rules. 13 | 14 | Attributes: 15 | blocking_column: the column on which the `BlockingRule` is applied 16 | """ 17 | 18 | def __init__(self, blocking_column: str): 19 | self.blocking_column = blocking_column 20 | self.training_coverage = None 21 | self.training_coverage_size = None 22 | 23 | @abc.abstractmethod 24 | def _blocking_rule(self, c: Column) -> Column: 25 | """ 26 | Abstract method for a blocking-rule 27 | 28 | Args: 29 | c: a Column 30 | 31 | Returns: 32 | A Column with a blocking-rule result 33 | """ 34 | pass 35 | 36 | @abc.abstractmethod 37 | def __repr__(self) -> str: 38 | """ 39 | Abstract method for a class representation 40 | 41 | Returns: 42 | return a string that represents the blocking-rule 43 | 44 | """ 45 | pass 46 | 47 | def _apply_blocking_rule(self, c: Column) -> Column: 48 | """ 49 | This method applies the blocking rule on an input column and adds a unique representation id to make sure the 50 | block-key is unique. Uniqueness is important to avoid collisions between block-keys, e.g. a blocking-rule that 51 | captures the first string character can return the same as a blocking-rule that captures the last character. To 52 | avoid this, the blocking-rule results is concatenated with the __repr__ method. 53 | 54 | Args: 55 | c: a Column 56 | 57 | Returns: 58 | a Column with a block_key 59 | 60 | """ 61 | return F.concat(F.lit(f'{self.__repr__()}:'), self._blocking_rule(c)) 62 | 63 | def create_block_key(self, sdf: DataFrame) -> DataFrame: 64 | """ 65 | This method calculates and adds the block-key column to the input dataframe 66 | 67 | Args: 68 | sdf: a dataframe with records that need to be matched 69 | 70 | Returns: 71 | the dataframe with the block-key column 72 | """ 73 | return sdf.withColumn('block_key', self._apply_blocking_rule(sdf[f'{self.blocking_column}'])) 74 | 75 | def _create_training_block_keys(self, sdf: DataFrame) -> DataFrame: 76 | """ 77 | This method is used to create block-keys on a training dataframe 78 | 79 | Args: 80 | sdf: a dataframe containing record pairs for training 81 | 82 | Returns: 83 | the dataframe containing block-keys for the record pairs 84 | """ 85 | return ( 86 | sdf 87 | .withColumn('block_key_1', self._apply_blocking_rule(sdf[f'{self.blocking_column}_1'])) 88 | .withColumn('block_key_2', self._apply_blocking_rule(sdf[f'{self.blocking_column}_2'])) 89 | ) 90 | 91 | @staticmethod 92 | def _compare_and_filter_keys(sdf: DataFrame) -> DataFrame: 93 | """ 94 | This method is used to compare block-keys of record pairs and subsequently filter record pairs in the training 95 | dataframe that have identical block-keys 96 | 97 | Args: 98 | sdf: a dataframe containing record pairs for training with their block-keys 99 | 100 | Returns: 101 | the filtered dataframe containing only record pairs with identical block-keys 102 | """ 103 | return ( 104 | sdf 105 | # split on `:` since this separates the `__repr__` from the blocking_rule result 106 | .withColumn('_blocking_result_1', F.split(F.col('block_key_1'), ":").getItem(1)) 107 | .withColumn('_blocking_result_2', F.split(F.col('block_key_2'), ":").getItem(1)) 108 | .filter( 109 | (F.col('block_key_1') == F.col('block_key_2')) & 110 | ((F.col('_blocking_result_1') != '') & (F.col('_blocking_result_2') != '')) 111 | ) 112 | .drop('_blocking_result_1', '_blocking_result_2') 113 | ) 114 | 115 | @staticmethod 116 | def _length_check(c: Column, n: int, word_count: bool = False) -> Column: 117 | """ 118 | This method checks the length of the created block key. 119 | 120 | Args: 121 | c: block key to check 122 | n: given length of the string 123 | word_count: whether to check the string length or word count 124 | Returns: 125 | the block key if it is not shorter than the given length, otherwise returns None 126 | """ 127 | if word_count: 128 | return F.when(F.size(c) >= n, c).otherwise(None) 129 | 130 | return F.when(F.length(c) == n, c).otherwise(None) 131 | 132 | def calculate_training_set_coverage(self, sdf: DataFrame) -> 'BlockingRule': 133 | """ 134 | This method calculate the set coverage of the blocking rule on the training pairs. The set coverage of the rule 135 | is determined by looking at how many record pairs in the training set end up in the same block. This coverage 136 | is used in the BlockLearner to sort blocking rules in the greedy set_covering algorithm. 137 | Args: 138 | sdf: a dataframe containing record pairs for training 139 | 140 | Returns: 141 | The object itself 142 | """ 143 | sdf = self._create_training_block_keys(sdf) 144 | 145 | sdf = self._compare_and_filter_keys(sdf) 146 | 147 | self.training_coverage = set( 148 | sdf 149 | .agg(F.collect_set('row_id').alias('training_coverage')) 150 | .collect()[0]['training_coverage'] 151 | ) 152 | 153 | self.training_coverage_size = len(self.training_coverage) 154 | return self 155 | 156 | 157 | # define the concrete blocking rule examples: 158 | 159 | class FirstNChars(BlockingRule): 160 | def __init__(self, blocking_column: str, n: int = 3): 161 | super().__init__(blocking_column) 162 | self.n = n 163 | 164 | def __repr__(self): 165 | return f"first_{self.n}_characters_{self.blocking_column}" 166 | 167 | def _blocking_rule(self, c: Column) -> Column: 168 | key = F.substring(c, 0, self.n) 169 | return self._length_check(key, self.n) 170 | 171 | 172 | class FirstNCharsLastWord(BlockingRule): 173 | def __init__(self, blocking_column: str, n: int = 3, remove_non_alphanumerical=False): 174 | super().__init__(blocking_column) 175 | self.n = n 176 | self.remove_non_alphanumerical = remove_non_alphanumerical 177 | 178 | def __repr__(self): 179 | return f"first_{self.n}_characters_last_word_{self.blocking_column}" 180 | 181 | def _blocking_rule(self, c: Column) -> Column: 182 | if self.remove_non_alphanumerical: 183 | c = F.regexp_replace(c, r'\W+', ' ') 184 | tokens = F.split(c, r'\s+') 185 | last_word = F.element_at(tokens, -1) 186 | key = F.substring(last_word, 1, self.n) 187 | return self._length_check(key, self.n, word_count=False) 188 | 189 | 190 | class FirstNCharactersFirstTokenSorted(BlockingRule): 191 | def __init__(self, blocking_column: str, n: int = 3, remove_non_alphanumerical=False): 192 | super().__init__(blocking_column) 193 | self.n = n 194 | self.remove_non_alphanumerical = remove_non_alphanumerical 195 | 196 | def __repr__(self): 197 | return f"first_{self.n}_characters_first_token_sorted_{self.blocking_column}" 198 | 199 | def _blocking_rule(self, c): 200 | if self.remove_non_alphanumerical: 201 | c = F.regexp_replace(c, r'\W+', ' ') 202 | tokens = F.split(c, r'\s+') 203 | sorted_tokens = F.sort_array(tokens) 204 | filtered_tokens = F.filter(sorted_tokens, lambda x: F.length(x) >= self.n) 205 | first_token = filtered_tokens.getItem(0) 206 | return F.substring(first_token, 1, self.n) 207 | 208 | 209 | class LastNChars(BlockingRule): 210 | def __init__(self, blocking_column: str, n: int = 3): 211 | super().__init__(blocking_column) 212 | self.n = n 213 | 214 | def __repr__(self): 215 | return f"last_{self.n}_characters_{self.blocking_column}" 216 | 217 | def _blocking_rule(self, c: Column) -> Column: 218 | key = F.substring(c, -self.n, self.n) 219 | return self._length_check(key, self.n) 220 | 221 | 222 | class WholeField(BlockingRule): 223 | def __init__(self, blocking_column: str): 224 | super().__init__(blocking_column) 225 | 226 | def __repr__(self): 227 | return f"whole_field_{self.blocking_column}" 228 | 229 | def _blocking_rule(self, c: Column) -> Column: 230 | return c 231 | 232 | 233 | class FirstNWords(BlockingRule): 234 | def __init__(self, blocking_column: str, n: int = 1, remove_non_alphanumerical=False): 235 | super().__init__(blocking_column) 236 | self.n = n 237 | self.remove_non_alphanumerical = remove_non_alphanumerical 238 | 239 | def __repr__(self): 240 | return f"first_{self.n}_words_{self.blocking_column}" 241 | 242 | def _blocking_rule(self, c: Column) -> Column: 243 | if self.remove_non_alphanumerical: 244 | c = F.regexp_replace(c, r'\W+', ' ') 245 | tokens = F.split(c, r'\s+') 246 | key = self._length_check(tokens, self.n, word_count=True) 247 | return F.array_join(F.slice(key, 1, self.n), ' ') 248 | 249 | 250 | class FirstNLettersNoSpace(BlockingRule): 251 | def __init__(self, blocking_column: str, n: int = 3): 252 | super().__init__(blocking_column) 253 | self.n = n 254 | 255 | def __repr__(self): 256 | return f"first_{self.n}_letters_{self.blocking_column}_no_space" 257 | 258 | def _blocking_rule(self, c: Column) -> Column: 259 | key = F.substring(F.regexp_replace(c, r'[^a-zA-Z]+', ''), 1, self.n) 260 | return self._length_check(key, self.n) 261 | 262 | 263 | class SortedIntegers(BlockingRule): 264 | def __init__(self, blocking_column: str): 265 | super().__init__(blocking_column) 266 | 267 | def __repr__(self): 268 | return f"sorted_integers_{self.blocking_column}" 269 | 270 | def _blocking_rule(self, c: Column) -> Column: 271 | number_string = F.trim(F.regexp_replace(c, r'[^0-9\s]+', '')) 272 | number_string_array = F.when(number_string != '', F.split(number_string, r'\s+')) 273 | number_int_array = F.transform(number_string_array, lambda x: x.cast(T.IntegerType())) 274 | number_sorted = F.array_sort(number_int_array) 275 | return F.array_join(number_sorted, " ") 276 | 277 | 278 | class FirstInteger(BlockingRule): 279 | def __init__(self, blocking_column: str): 280 | super().__init__(blocking_column) 281 | 282 | def __repr__(self): 283 | return f"first_integer_{self.blocking_column}" 284 | 285 | def _blocking_rule(self, c: Column) -> Column: 286 | number_string_array = F.split(F.trim(F.regexp_replace(c, r'[^0-9\s]+', '')), r'\s+') 287 | number_int_array = F.transform(number_string_array, lambda x: x.cast(T.IntegerType())) 288 | first_number = number_int_array.getItem(0) 289 | return first_number.cast(T.StringType()) 290 | 291 | 292 | class LastInteger(BlockingRule): 293 | def __init__(self, blocking_column: str): 294 | super().__init__(blocking_column) 295 | 296 | def __repr__(self): 297 | return f"last_integer_{self.blocking_column}" 298 | 299 | def _blocking_rule(self, c: Column) -> Column: 300 | number_string_array = F.split(F.trim(F.regexp_replace(c, r'[^0-9\s]+', '')), r'\s+') 301 | number_int_array = F.transform(number_string_array, lambda x: x.cast(T.IntegerType())) 302 | last_number = F.slice(number_int_array, -1, 1).getItem(0) 303 | return last_number.cast(T.StringType()) 304 | 305 | 306 | class LargestInteger(BlockingRule): 307 | def __init__(self, blocking_column: str): 308 | super().__init__(blocking_column) 309 | 310 | def __repr__(self): 311 | return f"largest_integer_{self.blocking_column}" 312 | 313 | def _blocking_rule(self, c: Column) -> Column: 314 | number_string_array = F.split(F.trim(F.regexp_replace(c, r'[^0-9\s]+', '')), r'\s+') 315 | number_int_array = F.transform(number_string_array, lambda x: x.cast(T.IntegerType())) 316 | largest_number = F.array_max(number_int_array) 317 | return largest_number.cast(T.StringType()) 318 | 319 | 320 | class NLetterAbbreviation(BlockingRule): 321 | def __init__(self, blocking_column: str, n: int = 3): 322 | super().__init__(blocking_column) 323 | self.n = n 324 | 325 | def __repr__(self): 326 | return f"{self.n}_letter_abbreviation_{self.blocking_column}" 327 | 328 | def _blocking_rule(self, c: Column) -> Column: 329 | words = F.split(F.trim(F.regexp_replace(c, r'[0-9]+', '')), r'\s+') 330 | first_letters = F.when(F.size(words) >= self.n, F.transform(words, lambda x: F.substring(x, 1, 1))) 331 | return F.array_join(first_letters, '') 332 | 333 | 334 | # this is an example of a blocking rule that contains a udf with plain python code: 335 | 336 | class UdfFirstNChar(BlockingRule): 337 | def __init__(self, blocking_column: str, n: int): 338 | super().__init__(blocking_column) 339 | self.n = n 340 | 341 | def __repr__(self): 342 | return f"udf_first_integer_{self.blocking_column}" 343 | 344 | def _blocking_rule(self, c: Column) -> Column: 345 | @F.udf 346 | def _rule(s: str) -> str: 347 | return s[:self.n] 348 | 349 | return _rule(c) 350 | 351 | 352 | default_blocking_rules = [FirstNChars, FirstNCharsLastWord, LastNChars, WholeField, FirstNWords, FirstNLettersNoSpace, 353 | SortedIntegers, FirstInteger, LastInteger, LargestInteger, NLetterAbbreviation] 354 | -------------------------------------------------------------------------------- /spark_matcher/config.py: -------------------------------------------------------------------------------- 1 | MINHASHING_MAXDF = 0.01 2 | MINHASHING_VOCABSIZE = 1_000_000 3 | -------------------------------------------------------------------------------- /spark_matcher/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/spark-matcher/526fd7983c5b841158bf62d70adaab383d69f0af/spark_matcher/data/.gitkeep -------------------------------------------------------------------------------- /spark_matcher/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import load_data 2 | -------------------------------------------------------------------------------- /spark_matcher/data/datasets.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | from typing import Tuple, Optional, Union 6 | from pkg_resources import resource_filename 7 | 8 | import pandas as pd 9 | 10 | from pyspark.sql import SparkSession, DataFrame 11 | 12 | 13 | def load_data(spark: SparkSession, kind: Optional[str] = 'voters') -> Union[Tuple[DataFrame, DataFrame], DataFrame]: 14 | """ 15 | Load examples datasets to be used to experiment with `spark-matcher`. For matching problems, set `kind` to `voters` 16 | for North Carolina voter registry data or `library` for bibliography data. For deduplication problems, set `kind` 17 | to `stoxx50` for EuroStoxx 50 company names and addresses. 18 | 19 | Voter data: 20 | - provided by Prof. Erhard Rahm 21 | https://dbs.uni-leipzig.de/research/projects/object_matching/benchmark_datasets_for_entity_resolution 22 | 23 | Library data: 24 | - DBLP bibliography, http://www.informatik.uni-trier.de/~ley/db/index.html 25 | - ACM Digital Library, http://portal.acm.org/portal.cfm 26 | 27 | Args: 28 | spark: Spark session 29 | kind: kind of data: `voters`, `library` or `stoxx50` 30 | 31 | Returns: 32 | two Spark dataframes for `voters` or `library`, a single dataframe for `stoxx50` 33 | 34 | """ 35 | if kind == 'library': 36 | return _load_data_library(spark) 37 | if kind == 'voters': 38 | return _load_data_voters(spark) 39 | if kind == 'stoxx50': 40 | return _load_data_stoxx50(spark) 41 | else: 42 | raise ValueError('`kind` must be `library`, `voters` or `stoxx50`') 43 | 44 | 45 | def _load_data_library(spark: SparkSession) -> Tuple[DataFrame, DataFrame]: 46 | """ 47 | Load examples datasets to be used to experiment with `spark-matcher`. Two Spark dataframe are returned with the 48 | same columns: 49 | 50 | - DBLP bibliography, http://www.informatik.uni-trier.de/~ley/db/index.html 51 | - ACM Digital Library, http://portal.acm.org/portal.cfm 52 | 53 | Args: 54 | spark: Spark session 55 | 56 | Returns: 57 | Spark dataframe for DBLP data and a Spark dataframe for ACM data 58 | 59 | """ 60 | file_path_acm = resource_filename('spark_matcher.data', 'acm.csv') 61 | file_path_dblp = resource_filename('spark_matcher.data', 'dblp.csv') 62 | acm_pdf = pd.read_csv(file_path_acm) 63 | dblp_pdf = pd.read_csv(file_path_dblp, encoding="ISO-8859-1") 64 | 65 | for col in acm_pdf.select_dtypes('object').columns: 66 | acm_pdf[col] = acm_pdf[col].fillna("") 67 | for col in dblp_pdf.select_dtypes('object').columns: 68 | dblp_pdf[col] = dblp_pdf[col].fillna("") 69 | 70 | acm_sdf = spark.createDataFrame(acm_pdf) 71 | dblp_sdf = spark.createDataFrame(dblp_pdf) 72 | return acm_sdf, dblp_sdf 73 | 74 | 75 | def _load_data_voters(spark: SparkSession) -> Tuple[DataFrame, DataFrame]: 76 | """ 77 | Voters data is based on the North Carolina voter registry and this dataset is provided by Prof. Erhard Rahm 78 | ('Comparative Evaluation of Distributed Clustering Schemes for Multi-source Entity Resolution'). Two Spark 79 | dataframe are returned with the same columns. 80 | 81 | Args: 82 | spark: Spark session 83 | 84 | Returns: 85 | two Spark dataframes containing voter data 86 | 87 | """ 88 | file_path_voters_1 = resource_filename('spark_matcher.data', 'voters_1.csv') 89 | file_path_voters_2 = resource_filename('spark_matcher.data', 'voters_2.csv') 90 | voters_1_pdf = pd.read_csv(file_path_voters_1) 91 | voters_2_pdf = pd.read_csv(file_path_voters_2) 92 | 93 | voters_1_sdf = spark.createDataFrame(voters_1_pdf) 94 | voters_2_sdf = spark.createDataFrame(voters_2_pdf) 95 | return voters_1_sdf, voters_2_sdf 96 | 97 | 98 | def _load_data_stoxx50(spark: SparkSession) -> DataFrame: 99 | """ 100 | The Stoxx50 dataset contains a single column containing the concatenation of Eurostoxx 50 company names and 101 | addresses. This dataset is created by the developers of spark_matcher. 102 | 103 | Args: 104 | spark: Spark session 105 | 106 | Returns: 107 | Spark dataframe containing Eurostoxx 50 names and addresses 108 | 109 | """ 110 | file_path_stoxx50 = resource_filename('spark_matcher.data', 'stoxx50.csv') 111 | stoxx50_pdf = pd.read_csv(file_path_stoxx50) 112 | 113 | stoxx50_sdf = spark.createDataFrame(stoxx50_pdf) 114 | return stoxx50_sdf 115 | -------------------------------------------------------------------------------- /spark_matcher/data/dblp.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ing-bank/spark-matcher/526fd7983c5b841158bf62d70adaab383d69f0af/spark_matcher/data/dblp.csv -------------------------------------------------------------------------------- /spark_matcher/data/stoxx50.csv: -------------------------------------------------------------------------------- 1 | name 2 | adidas ag adi dassler strasse 1 91074 germany 3 | adidas ag adi dassler strasse 1 91074 herzogenaurach 4 | adidas ag adi dassler strasse 1 91074 herzogenaurach germany 5 | airbus se 2333 cs leiden netherlands 6 | airbus se 2333 cs netherlands 7 | airbus se leiden netherlands 8 | allianz se 80802 munich germany 9 | allianz se koniginstrasse 28 munich germany 10 | allianz se munich germany 11 | amadeus it group s a salvador de madariaga 1 28027 12 | amadeus it group s a salvador de madariaga 1 28027 madrid 13 | amadeus it group s a salvador de madariaga 1 28027 madrid spain 14 | anheuser busch inbev sa nv 3000 leuven belgium 15 | anheuser busch inbev sa nv brouwerijplein 1 3000 leuven belgium 16 | anheuser busch inbev sa nv brouwerijplein 1 leuven belgium 17 | asml holding n v de run 6501 5504 dr veldhoven netherlands 18 | asml holding n v de run 6501 veldhoven netherlands 19 | axa sa 25 avenue matignon 75008 paris france 20 | axa sa 75008 paris 21 | axa sa 75008 paris france 22 | banco bilbao vizcaya argentaria s a 48005 bilbao spain 23 | banco bilbao vizcaya argentaria s a bilbao 24 | banco bilbao vizcaya argentaria s a plaza san nicolas 4 48005 spain 25 | banco santander s a 28660 26 | banco santander s a 28660 madrid 27 | banco santander s a 28660 madrid spain 28 | basf se 67056 ludwigshafen am rhein germany 29 | basf se carl bosch strasse 38 67056 germany 30 | bayer aktiengesellschaft 31 | bayer aktiengesellschaft 51368 leverkusen germany 32 | bayer aktiengesellschaft kaiser wilhelm allee 1 51368 leverkusen germany 33 | bayerische motoren werke aktiengesellschaft munich germany 34 | bayerische motoren werke aktiengesellschaft petuelring 130 80788 munich 35 | bayerische motoren werke aktiengesellschaft petuelring 130 munich germany 36 | bnp paribas sa 16 boulevard des italiens 75009 france 37 | bnp paribas sa 16 boulevard des italiens paris france 38 | bnp paribas sa paris 39 | crh plc stonemason s way 16 dublin ireland 40 | crh plc stonemason s way 16 ireland 41 | daimler ag 70372 stuttgart germany 42 | daimler ag germany 43 | daimler ag mercedesstrasse 120 70372 stuttgart germany 44 | danone s a 15 rue du helder 75439 paris france 45 | danone s a 15 rue du helder paris france 46 | danone s a 75439 paris 47 | deutsche boerse 48 | deutsche boerse 60485 frankfurt 49 | deutsche boerse frankfurt 50 | deutsche post ag platz der deutschen post 51 | deutsche post ag platz der deutschen post 53113 germany 52 | deutsche post ag platz der deutschen post bonn germany 53 | deutsche telekom ag 53113 bonn germany 54 | deutsche telekom ag 53113 germany 55 | enel spa rome italy 56 | enel spa viale regina margherita 137 rm 00198 italy 57 | engie sa 1 place samuel de champlain 92400 courbevoie 58 | engie sa 1 place samuel de champlain 92400 france 59 | engie sa 92400 courbevoie france 60 | eni s p a piazzale enrico mattei 1 rome italy 61 | eni s p a rm 00144 rome italy 62 | essilorluxottica 1 6 rue paul cezanne 75008 paris france 63 | essilorluxottica 1 6 rue paul cezanne paris france 64 | fresenius se co kgaa else kroner strasse 1 61352 bad homburg vor der hohe germany 65 | fresenius se co kgaa else kroner strasse 1 61352 germany 66 | fresenius se co kgaa else kroner strasse 1 bad homburg vor der hohe germany 67 | iberdrola s a plaza euskadi 5 48009 bilbao spain 68 | iberdrola s a plaza euskadi 5 bilbao spain 69 | industria de diseno textil s a avenida de la diputacion s n arteixo 15143 a coruna spain 70 | industria de diseno textil s a avenida de la diputacion s n arteixo a coruna spain 71 | ing groep n v bijlmerplein 888 1102 mg amsterdam 72 | ing groep n v bijlmerplein 888 1102 mg netherlands 73 | intesa sanpaolo s p a piazza san carlo 156 to 10121 italy 74 | intesa sanpaolo s p a piazza san carlo 156 to 10121 turin 75 | intesa sanpaolo s p a turin 76 | kering sa 40 rue de sevres 75007 paris 77 | kering sa 40 rue de sevres 75007 paris france 78 | koninklijke ahold delhaize n v 1506 ma zaandam netherlands 79 | koninklijke ahold delhaize n v provincialeweg 11 80 | koninklijke ahold delhaize n v provincialeweg 11 netherlands 81 | koninklijke philips n v amstelplein 2 1096 bc 82 | koninklijke philips n v amstelplein 2 1096 bc amsterdam 83 | koninklijke philips n v amstelplein 2 1096 bc netherlands 84 | l air liquide s a 75 quai d orsay 75007 france 85 | l air liquide s a 75007 paris france 86 | l air liquide s a paris france 87 | l oreal s a 41 rue martre 92117 clichy france 88 | l oreal s a 92117 france 89 | linde plc 10 priestley road surrey research park gu2 7xy guildford 90 | linde plc 10 priestley road surrey research park gu2 7xy united kingdom 91 | lvmh moet hennessy louis vuitton societe europeenne 22 avenue montaigne 75008 france 92 | lvmh moet hennessy louis vuitton societe europeenne 22 avenue montaigne paris 93 | lvmh moet hennessy louis vuitton societe europeenne 75008 france 94 | munchener ruckversicherungs gesellschaft aktiengesellschaft koniginstrasse 107 95 | munchener ruckversicherungs gesellschaft aktiengesellschaft koniginstrasse 107 80802 munich 96 | munchener ruckversicherungs gesellschaft aktiengesellschaft koniginstrasse 107 munich germany 97 | nokia corporation 2610 espoo finland 98 | nokia corporation finland 99 | nokia corporation karakaari 7 100 | orange s a 75015 paris france 101 | orange s a 78 rue olivier de serres 75015 paris france 102 | orange s a paris 103 | safran sa 2 boulevard du general martial valin paris france 104 | safran sa 75724 france 105 | safran sa paris france 106 | sanofi 54 rue la boetie 75008 france 107 | sanofi 54 rue la boetie 75008 paris france 108 | sanofi 75008 france 109 | sap se 69190 110 | sap se 69190 germany 111 | sap se dietmar hopp allee 16 112 | schneider electric s e 35 rue joseph monier 113 | schneider electric s e 35 rue joseph monier 92500 rueil malmaison 114 | schneider electric s e 92500 france 115 | siemens aktiengesellschaft munich germany 116 | siemens aktiengesellschaft werner von siemens strasse 1 80333 germany 117 | siemens aktiengesellschaft werner von siemens strasse 1 80333 munich germany 118 | societe generale societe 29 boulevard haussmann 75009 paris france 119 | societe generale societe 75009 paris france 120 | telefonica s a 28050 madrid 121 | telefonica s a ronda de la comunicacion 122 | telefonica s a ronda de la comunicacion 28050 madrid spain 123 | the unilever group weena 455 3000 dk netherlands 124 | the unilever group weena 455 3000 dk rotterdam netherlands 125 | total s a 2 place jean millier paris france 126 | total s a 92078 france 127 | total s a paris 128 | vinci sa 1 cours ferdinand de lesseps 92851 france 129 | vinci sa 1 cours ferdinand de lesseps france 130 | vinci sa 1 cours ferdinand de lesseps rueil malmaison france 131 | vivendi sa 42 avenue de friedland 75380 paris france 132 | vivendi sa 75380 paris 133 | volkswagen ag 38440 germany 134 | volkswagen ag berliner ring 2 38440 135 | volkswagen ag berliner ring 2 38440 wolfsburg -------------------------------------------------------------------------------- /spark_matcher/deduplicator/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['Deduplicator'] 2 | 3 | from .deduplicator import Deduplicator -------------------------------------------------------------------------------- /spark_matcher/deduplicator/connected_components_calculator.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | from typing import List, Tuple 6 | 7 | from graphframes import GraphFrame 8 | from pyspark.sql import functions as F, types as T, DataFrame, Window 9 | from pyspark.sql.utils import AnalysisException 10 | 11 | from spark_matcher.table_checkpointer import TableCheckpointer 12 | 13 | 14 | class ConnectedComponentsCalculator: 15 | 16 | def __init__(self, scored_pairs_table: DataFrame, max_edges_clustering: int, 17 | edge_filter_thresholds: List[float], table_checkpointer: TableCheckpointer): 18 | self.scored_pairs_table = scored_pairs_table 19 | self.max_edges_clustering = max_edges_clustering 20 | self.edge_filter_thresholds = edge_filter_thresholds 21 | self.table_checkpointer = table_checkpointer 22 | 23 | @staticmethod 24 | def _create_graph(scores_table: DataFrame) -> GraphFrame: 25 | """ 26 | This function creates a graph where each row-number is a vertex and each similarity score between a pair of 27 | row-numbers represents an edge. This Graph is used as input for connected components calculations to determine 28 | the subgraphs of linked pairs. 29 | 30 | Args: 31 | scores_table: a pairs table with similarity scores for each pair 32 | Returns: 33 | a GraphFrames graph object representing the row-numbers and the similarity scores between them as a graph 34 | """ 35 | 36 | vertices_1 = scores_table.select('row_number_1').withColumnRenamed('row_number_1', 'id') 37 | vertices_2 = scores_table.select('row_number_2').withColumnRenamed('row_number_2', 'id') 38 | vertices = vertices_1.unionByName(vertices_2).drop_duplicates() 39 | 40 | edges = ( 41 | scores_table 42 | .select('row_number_1', 'row_number_2', 'score') 43 | .withColumnRenamed('row_number_1', 'src') 44 | .withColumnRenamed('row_number_2', 'dst') 45 | ) 46 | return GraphFrame(vertices, edges) 47 | 48 | def _calculate_connected_components(self, scores_table: DataFrame, checkpoint_name: str) -> DataFrame: 49 | """ 50 | This function calculates the connected components (i.e. the subgraphs) of a graph of scored pairs. 51 | The result of the connected-components algorithm is cached and saved. 52 | 53 | Args: 54 | scores_table: a pairs table with similarity scores for each pair 55 | Returns: 56 | a spark dataframe containing the connected components of a graph of scored pairs. 57 | """ 58 | graph = self._create_graph(scores_table) 59 | # Since graphframes 0.3.0, checkpointing is used for the connected components algorithm. The Spark cluster might 60 | # not allow writing checkpoints. In such case we fall back to the graphx algorithm that doesn't require 61 | # checkpointing. 62 | try: 63 | connected_components = graph.connectedComponents() 64 | except AnalysisException: 65 | connected_components = graph.connectedComponents(algorithm='graphx') 66 | return self.table_checkpointer(connected_components, checkpoint_name) 67 | 68 | @staticmethod 69 | def _add_component_id_to_scores_table(scores_table: DataFrame, connected_components: DataFrame) -> DataFrame: 70 | """ 71 | This function joins the initial connected-component identifiers to the scored pairs table. For each scored pair, 72 | a component identifier indicating to which subgraph a pair belongs is added. 73 | 74 | Args: 75 | scores_table: a pairs table with similarity scores for each pair 76 | connected_components: a spark dataframe containing the result of the connected components algorithm 77 | Returns: 78 | the scores_table with an identifier that indicates to which component a pair belongs 79 | """ 80 | scores_table_with_component_ids = ( 81 | scores_table 82 | .join(connected_components, on=scores_table['row_number_1']==connected_components['id'], how='left') 83 | .withColumnRenamed('component', 'component_1') 84 | .drop('id') 85 | .join(connected_components, on=scores_table['row_number_2']==connected_components['id'], how='left') 86 | .withColumnRenamed('component', 'component_2') 87 | .drop('id') 88 | ) 89 | 90 | scores_table_with_component_ids = ( 91 | scores_table_with_component_ids 92 | .withColumnRenamed('component_1', 'component') 93 | .drop('component_2') 94 | ) 95 | return scores_table_with_component_ids 96 | 97 | @staticmethod 98 | def _add_big_components_info(scored_pairs_w_components: DataFrame, component_cols: List[str], max_size: int) -> \ 99 | Tuple[DataFrame, bool]: 100 | """ 101 | This function adds information if there are components in the scored_pairs table that contain more pairs than 102 | the `max_size'. If there is a component that is too big, the function will identify this and will indicate that 103 | filter iterations are required via the `continue_iteration` value. 104 | 105 | Args: 106 | scored_pairs_w_components: a pairs table with similarity scores for each pair and component identifiers 107 | component_cols: a list of columns that represent the composite key that indicate connected-components 108 | after one or multiple iterations 109 | max_size: a number indicating the maximum size of the number of pairs that are in a connected component 110 | in the scores table 111 | Returns: 112 | the scores table with a column indicating whether a pair belongs to a big component 113 | and a boolean indicating whether there are too big components and thus whether more iterations are required. 114 | """ 115 | window = Window.partitionBy(*component_cols) 116 | big_components_info = ( 117 | scored_pairs_w_components 118 | .withColumn('constant', F.lit(1)) 119 | .withColumn('component_size', F.count('constant').over(window)) 120 | .drop('constant') 121 | .withColumn('is_in_big_component', F.col('component_size') > max_size) 122 | .drop('component_size') 123 | ) 124 | continue_iteration = False 125 | if big_components_info.filter(F.col('is_in_big_component')).count() > 0: 126 | continue_iteration = True 127 | return big_components_info, continue_iteration 128 | 129 | @staticmethod 130 | def _add_low_score_info(scored_pairs: DataFrame, threshold) -> DataFrame: 131 | """ 132 | This function adds a column that indicates whether a similarity score between pairs is lower than a `threshold` 133 | value. 134 | 135 | Args: 136 | scored_pairs: a pairs table with similarity scores for each pair 137 | threshold: a value indicating the treshold value for a similarity score 138 | Returns: 139 | the scores table with an addiction column indicating whether a pair has a score lower than the `threshold` 140 | """ 141 | scored_pairs = ( 142 | scored_pairs 143 | .withColumn('is_lower_than_threshold', F.col('score') DataFrame: 150 | """ 151 | This function joins the the connected component results for each iteration to the scores table and prunes the 152 | scores table by removing edges that belong to a too big component and have a similarity score that is lower than 153 | the current threshold. 154 | 155 | Args: 156 | scored_pairs: a pairs table with similarity scores for each pair for an iteration 157 | connected_components: a spark dataframe containing the result of the connected components algorithm for 158 | an iteration 159 | iteration_number: a counter indicating which iteration it is 160 | Returns: 161 | a scores table with the component identifiers for an iteration and without edges that belonged to too big 162 | components 163 | that had scores that were below the current threshold. 164 | """ 165 | # join the connected components results to the scored-pairs on `row_number_1` 166 | result = ( 167 | scored_pairs 168 | .join(connected_components, on=scored_pairs['row_number_1'] == connected_components['id'], how='left') 169 | .drop('id') 170 | .withColumnRenamed('component', 'component_1') 171 | .persist()) 172 | 173 | # join the connected components results to the scored-pairs on `row_number_2` 174 | result = ( 175 | result 176 | .join(connected_components, on=scored_pairs['row_number_2'] == connected_components['id'], how='left') 177 | .drop('id') 178 | .withColumnRenamed('component', 'component_2') 179 | .filter((F.col('component_1') == F.col('component_2')) | (F.col('component_1').isNull() & F.col( 180 | 'component_2').isNull())) 181 | .withColumnRenamed('component_1', f'component_iteration_{iteration_number}') 182 | .drop('component_2') 183 | # add a default value for rows that were not part of the iteration, i.e. -1 since component ids' are 184 | # always positive integers 185 | .fillna(-1, subset=[f'component_iteration_{iteration_number}']) 186 | .filter((F.col('is_in_big_component') & ~F.col('is_lower_than_threshold')) | 187 | (~F.col('is_in_big_component'))) 188 | .drop('is_in_big_component', 'is_lower_than_threshold') 189 | ) 190 | return result 191 | 192 | @staticmethod 193 | def _add_component_id(sdf: DataFrame, iteration_count: int) -> DataFrame: 194 | """ 195 | This function adds a final `component_id` after all iterations that can be used as a groupby-key to 196 | distribute the deduplication. The `component_id` is a hashed value of the concatenation with '_' seperation 197 | of all component identifiers of the iterations. The sha2 hash function is used order to make the likelihood of 198 | hash collisions negligibly small. 199 | 200 | Args: 201 | sdf: a dataframe containing the scored pairs 202 | iteration_count: a number indicating the final iteration number that is used. 203 | Returns: 204 | the final scores-table with a component-id that can be used as groupby-key to distribute the deduplication. 205 | """ 206 | if iteration_count > 0: 207 | component_id_cols = ['initial_component'] + [f'component_iteration_{i}' for i in 208 | range(1, iteration_count + 1)] 209 | sdf = sdf.withColumn('component_id', F.sha2(F.concat_ws('_', *component_id_cols), numBits=256)) 210 | else: 211 | # use the initial component as component_id and cast the value to string to match with the return type of 212 | # the sha2 algorithm 213 | sdf = ( 214 | sdf 215 | .withColumn('component_id', F.col('initial_component').cast(T.StringType())) 216 | .drop('initial_component') 217 | ) 218 | return sdf 219 | 220 | def _create_component_ids(self) -> DataFrame: 221 | """ 222 | This function wraps the other methods in this class and performs the connected-component calculations and the 223 | edge filtering if required. 224 | 225 | Returns: 226 | a scored-pairs table with a component identifier that can be used as groupby-key to distribute the 227 | deduplication. 228 | """ 229 | # calculate the initial connected components 230 | connected_components = self._calculate_connected_components(self.scored_pairs_table, 231 | "cached_connected_components_table_initial") 232 | # add the component ids to the scored-pairs table 233 | self.scored_pairs_table = self._add_component_id_to_scores_table(self.scored_pairs_table, connected_components) 234 | self.scored_pairs_table = self.scored_pairs_table.withColumnRenamed('component', 'initial_component') 235 | 236 | #checkpoint the results 237 | self.scored_pairs_table = self.table_checkpointer(self.scored_pairs_table, 238 | "cached_scores_with_components_table_initial") 239 | 240 | component_columns = ['initial_component'] 241 | iteration_count = 0 242 | for i in range(1, len(self.edge_filter_thresholds) + 1): 243 | # set and update the threshold 244 | threshold = self.edge_filter_thresholds[i - 1] 245 | 246 | self.scored_pairs_table, continue_iter = self._add_big_components_info(self.scored_pairs_table, 247 | component_columns, 248 | self.max_edges_clustering) 249 | if not continue_iter: 250 | # no further iterations are required 251 | break 252 | 253 | self.scored_pairs_table = self._add_low_score_info(self.scored_pairs_table, threshold) 254 | # recalculate the connected_components 255 | iteration_input_sdf = ( 256 | self.scored_pairs_table 257 | .filter((F.col('is_in_big_component') & ~F.col('is_lower_than_threshold'))) 258 | ) 259 | 260 | print(f"calculate connected components iteration {i} with edge filter threshold {threshold}") 261 | connected_components_iteration = self._calculate_connected_components(iteration_input_sdf, 262 | f'cached_connected_components_table_iteration_{i}') 263 | 264 | self.scored_pairs_table = ( 265 | self 266 | ._join_connected_components_iteration_results(self.scored_pairs_table, 267 | connected_components_iteration, i) 268 | ) 269 | 270 | # checkpoint the result to break the lineage after each iteration and to inspect intermediate results 271 | self.scored_pairs_table = self.table_checkpointer(self.scored_pairs_table, 272 | f'cached_scores_with_components_table_iteration_{i}') 273 | 274 | component_columns.append(f'component_iteration_{i}') 275 | iteration_count += 1 276 | 277 | # final check to see if there were sufficient iterations to reduce the size of the big components if all 278 | # specified iterations are used to reduce the component size 279 | if iteration_count == len(self.edge_filter_thresholds): 280 | self.scored_pairs_table, continue_iter = self._add_big_components_info(self.scored_pairs_table, 281 | component_columns, 282 | self.max_edges_clustering) 283 | if continue_iter: 284 | print("THE EDGE-FILTER-THRESHOLDS ARE NOT ENOUGH TO SUFFICIENTLY REDUCE THE SIZE OF THE BIG COMPONENTS") 285 | 286 | # add the final component-id that can be used as groupby key for distributed deduplication 287 | self.scored_pairs_table = self._add_component_id(self.scored_pairs_table, iteration_count) 288 | 289 | return self.table_checkpointer(self.scored_pairs_table, 'cached_scores_with_components_table') 290 | -------------------------------------------------------------------------------- /spark_matcher/deduplicator/deduplicator.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | from typing import Optional, List, Dict 6 | 7 | from pyspark.sql import DataFrame, SparkSession, functions as F, types as T 8 | from sklearn.exceptions import NotFittedError 9 | from scipy.cluster import hierarchy 10 | 11 | from spark_matcher.blocker.blocking_rules import BlockingRule 12 | from spark_matcher.deduplicator.connected_components_calculator import ConnectedComponentsCalculator 13 | from spark_matcher.deduplicator.hierarchical_clustering import apply_deduplication 14 | from spark_matcher.matching_base.matching_base import MatchingBase 15 | from spark_matcher.scorer.scorer import Scorer 16 | from spark_matcher.table_checkpointer import TableCheckpointer 17 | 18 | 19 | class Deduplicator(MatchingBase): 20 | """ 21 | Deduplicator class to apply deduplication. Provide either the column names `col_names` using the default string 22 | similarity metrics or explicitly define the string similarity metrics in a dict `field_info` as in the example 23 | below. If `blocking_rules` is left empty, default blocking rules are used. Otherwise, provide blocking rules as 24 | a list containing `BlockingRule` instances (see example below). The number of perfect matches used during 25 | training is set by `n_perfect_train_matches`. 26 | 27 | E.g.: 28 | 29 | from spark_matcher.blocker.blocking_rules import FirstNChars 30 | 31 | myDeduplicator = Deduplicator(spark_session, field_info={'name':[metric_function_1, metric_function_2], 32 | 'address:[metric_function_1, metric_function_3]}, 33 | blocking_rules=[FirstNChars('name', 3)]) 34 | 35 | Args: 36 | spark_session: Spark session 37 | col_names: list of column names to use for matching 38 | field_info: dict of column names as keys and lists of string similarity metrics as values 39 | blocking_rules: list of `BlockingRule` instances 40 | table_checkpointer: pointer object to store cached tables 41 | checkpoint_dir: checkpoint directory if provided 42 | n_train_samples: nr of pair samples to be created for training 43 | ratio_hashed_samples: ratio of hashed samples to be created for training, rest is sampled randomly 44 | n_perfect_train_matches: nr of perfect matches used for training 45 | scorer: a Scorer object used for scoring pairs 46 | verbose: sets verbosity 47 | max_edges_clustering: max number of edges per component that enters clustering 48 | edge_filter_thresholds: list of score thresholds to use for filtering when components are too large 49 | cluster_score_threshold: threshold value between [0.0, 1.0], only pairs are put together in clusters if 50 | cluster similarity scores are >= cluster_score_threshold 51 | cluster_linkage_method: linkage method to be used within hierarchical clustering, can take values such as 52 | 'centroid', 'single', 'complete', 'average', 'weighted', 'median', 'ward' etc. 53 | """ 54 | def __init__(self, spark_session: SparkSession, col_names: Optional[List[str]] = None, 55 | field_info: Optional[Dict] = None, blocking_rules: Optional[List[BlockingRule]] = None, 56 | blocking_recall: float = 1.0, table_checkpointer: Optional[TableCheckpointer] = None, 57 | checkpoint_dir: Optional[str] = None, n_perfect_train_matches=1, n_train_samples: int = 100_000, 58 | ratio_hashed_samples: float = 0.5, scorer: Optional[Scorer] = None, verbose: int = 0, 59 | max_edges_clustering: int = 500_000, 60 | edge_filter_thresholds: List[float] = [0.45, 0.55, 0.65, 0.75, 0.85, 0.95], 61 | cluster_score_threshold: float = 0.5, cluster_linkage_method: str = "centroid"): 62 | 63 | super().__init__(spark_session, table_checkpointer, checkpoint_dir, col_names, field_info, blocking_rules, 64 | blocking_recall, n_perfect_train_matches, n_train_samples, ratio_hashed_samples, scorer, 65 | verbose) 66 | 67 | self.fitted_ = False 68 | self.max_edges_clustering = max_edges_clustering 69 | self.edge_filter_thresholds = edge_filter_thresholds 70 | self.cluster_score_threshold = cluster_score_threshold 71 | if cluster_linkage_method not in list(hierarchy._LINKAGE_METHODS.keys()): 72 | raise ValueError(f"Invalid cluster_linkage_method: {cluster_linkage_method}") 73 | self.cluster_linkage_method = cluster_linkage_method 74 | # set the checkpoints directory for graphframes 75 | self.spark_session.sparkContext.setCheckpointDir('/tmp/checkpoints') 76 | 77 | def _create_predict_pairs_table(self, sdf_blocked: DataFrame) -> DataFrame: 78 | """ 79 | This method performs an alias self-join on `sdf` to create the pairs table for prediction. `sdf` is joined to 80 | itself based on the `block_key` column. The result of this join is a pairs table. 81 | """ 82 | sdf_blocked_1 = self._add_suffix_to_col_names(sdf_blocked, 1) 83 | sdf_blocked_1 = sdf_blocked_1.withColumnRenamed('row_number', 'row_number_1') 84 | sdf_blocked_2 = self._add_suffix_to_col_names(sdf_blocked, 2) 85 | sdf_blocked_2 = sdf_blocked_2.withColumnRenamed('row_number', 'row_number_2') 86 | 87 | pairs = ( 88 | sdf_blocked_1 89 | .join(sdf_blocked_2, on='block_key', how='inner') 90 | .filter(F.col("row_number_1") < F.col("row_number_2")) 91 | .drop_duplicates(subset=[col + "_1" for col in self.col_names] + [col + "_2" for col in self.col_names]) 92 | ) 93 | return pairs 94 | 95 | def _map_distributed_identifiers_to_long(self, clustered_results: DataFrame) -> DataFrame: 96 | """ 97 | Method to add a unique `entity_identifier` to the results from clustering 98 | 99 | Args: 100 | clustered_results: results from clustering 101 | Returns: 102 | Spark dataframe containing `row_number` and `entity_identifier` 103 | """ 104 | long_entity_ids = ( 105 | clustered_results 106 | .select('entity_identifier') 107 | .drop_duplicates() 108 | .withColumn('long_entity_identifier', F.monotonically_increasing_id()) 109 | ) 110 | long_entity_ids = self.table_checkpointer(long_entity_ids, 'cached_long_ids_table') 111 | 112 | clustered_results = ( 113 | clustered_results 114 | .join(long_entity_ids, on='entity_identifier', how='left') 115 | .drop('entity_identifier') 116 | .withColumnRenamed('long_entity_identifier', 'entity_identifier') 117 | ) 118 | return clustered_results 119 | 120 | def _distributed_deduplication(self, scored_pairs_with_components: DataFrame) -> DataFrame: 121 | schema = T.StructType([T.StructField('row_number', T.LongType(), True), 122 | T.StructField('entity_identifier', T.StringType(), True)]) 123 | 124 | clustered_results = ( 125 | scored_pairs_with_components 126 | .select('row_number_1', 'row_number_2', 'score', 'component_id') 127 | .groupby('component_id') 128 | .applyInPandas(apply_deduplication(self.cluster_score_threshold, self.cluster_linkage_method) , schema=schema) 129 | ) 130 | return self.table_checkpointer(clustered_results, "cached_clustered_results_table") 131 | 132 | @staticmethod 133 | def _add_singletons_entity_identifiers(result_sdf: DataFrame) -> DataFrame: 134 | """ 135 | Function to add entity_identifier to entities (singletons) that are not combined with other entities into a 136 | deduplicated entity. If there are no singletons, the input table will be returned as it is. 137 | 138 | Args: 139 | result_sdf: Spark dataframe containing the result of deduplication where entities that are not deduplicated 140 | have a missing value in the `entity_identifier` column. 141 | 142 | Returns: 143 | Spark dataframe with `entity_identifier` values for all entities 144 | 145 | """ 146 | if result_sdf.filter(F.col('entity_identifier').isNull()).count() == 0: 147 | return result_sdf 148 | start_cluster_id = ( 149 | result_sdf 150 | .filter(F.col('entity_identifier').isNotNull()) 151 | .select(F.max('entity_identifier')) 152 | .first()[0]) 153 | 154 | singletons_entity_identifiers = ( 155 | result_sdf 156 | .filter(F.col('entity_identifier').isNull()) 157 | .select('row_number') 158 | .rdd 159 | .zipWithIndex() 160 | .toDF() 161 | .withColumn('row_number', F.col('_1').getItem("row_number")) 162 | .drop("_1") 163 | .withColumn('entity_identifier_singletons', F.col('_2') + start_cluster_id + 1) 164 | .drop("_2")) 165 | 166 | result_sdf = ( 167 | result_sdf 168 | .join(singletons_entity_identifiers, on='row_number', how='left') 169 | .withColumn('entity_identifier', 170 | F.when(F.col('entity_identifier').isNull(), 171 | F.col('entity_identifier_singletons')) 172 | .otherwise(F.col('entity_identifier'))) 173 | .drop('entity_identifier_singletons')) 174 | 175 | return result_sdf 176 | 177 | def _create_deduplication_results(self, clustered_results: DataFrame, entities: DataFrame) -> DataFrame: 178 | """ 179 | Joins deduplication results back to entity_table and adds identifiers to rows that are not deduplicated with 180 | other rows 181 | """ 182 | deduplication_results = ( 183 | entities 184 | .join(clustered_results, on='row_number', how='left') 185 | ) 186 | 187 | deduplication_results = ( 188 | self._add_singletons_entity_identifiers(deduplication_results) 189 | .drop('row_number', 'block_key') 190 | ) 191 | return deduplication_results 192 | 193 | def _get_large_clusters(self, pairs_with_components: DataFrame) -> DataFrame: 194 | """ 195 | Components that are too large after iteratively removing edges with a similarity score lower than the 196 | thresholds in `edge_filter_thresholds`, are considered to be one entity. For these the `component_id` is 197 | temporarily used as an `entity_identifier`. 198 | """ 199 | large_component = (pairs_with_components.filter(F.col('is_in_big_component')) 200 | .withColumn('entity_identifier', F.col('component_id'))) 201 | 202 | large_cluster = (large_component.select('row_number_1', 'entity_identifier') 203 | .withColumnRenamed('row_number_1', 'row_number') 204 | .unionByName(large_component.select('row_number_2', 'entity_identifier') 205 | .withColumnRenamed('row_number_2', 'row_number')) 206 | .drop_duplicates() 207 | .persist()) 208 | return large_cluster 209 | 210 | def predict(self, sdf: DataFrame, threshold: float = 0.5): 211 | """ 212 | Method to predict on data used for training or new data. 213 | 214 | Args: 215 | sdf: table to be applied entity deduplication 216 | threshold: probability threshold for similarity score 217 | 218 | Returns: 219 | Spark dataframe with the deduplication result 220 | 221 | """ 222 | if not self.fitted_: 223 | raise NotFittedError('The Deduplicator instance is not fitted yet. Call `fit` and train the instance.') 224 | 225 | sdf = self._simplify_dataframe_for_matching(sdf) 226 | sdf = sdf.withColumn('row_number', F.monotonically_increasing_id()) 227 | # make sure the `row_number` is a fixed value. See the docs of `monotonically_increasing_id`, the function 228 | # depends on the partitioning of the table. Not fixing/storing the result will cause problems when the table is 229 | # recalculated under the hood during other operations. Saving it to disk and reading it in, breaks lineage and 230 | # forces the result to be deterministic in subsequent operations. 231 | sdf = self.table_checkpointer(sdf, "cached_entities_numbered_table") 232 | 233 | sdf_blocked = self.table_checkpointer(self.blocker.transform(sdf), "cached_block_table") 234 | pairs_table = self.table_checkpointer(self._create_predict_pairs_table(sdf_blocked), "cached_pairs_table") 235 | metrics_table = self.table_checkpointer(self._calculate_metrics(pairs_table), "cached_metrics_table") 236 | 237 | scores_table = ( 238 | metrics_table 239 | .withColumn('score', self.scoring_learner.predict_proba(metrics_table['similarity_metrics'])) 240 | ) 241 | scores_table = self.table_checkpointer(scores_table, "cached_scores_table") 242 | 243 | scores_table_filtered = ( 244 | scores_table 245 | .filter(F.col('score') >= threshold) 246 | .drop('block_key', 'similarity_metrics') 247 | ) 248 | 249 | # create the subgraphs by using the connected components algorithm 250 | connected_components_calculator = ConnectedComponentsCalculator(scores_table_filtered, 251 | self.max_edges_clustering, 252 | self.edge_filter_thresholds, 253 | self.table_checkpointer) 254 | 255 | # calculate the component id for each scored-pairs 256 | scored_pairs_with_component_ids = connected_components_calculator._create_component_ids() 257 | 258 | large_cluster = self._get_large_clusters(scored_pairs_with_component_ids) 259 | 260 | # apply the distributed deduplication to components that are sufficiently small 261 | clustered_results = self._distributed_deduplication( 262 | scored_pairs_with_component_ids.filter(~F.col('is_in_big_component'))) 263 | 264 | all_results = large_cluster.unionByName(clustered_results) 265 | 266 | # assign a unique entity_identifier to all components 267 | all_results_long = self._map_distributed_identifiers_to_long(all_results) 268 | all_results_long = self.table_checkpointer(all_results_long, "cached_clustered_results_with_long_ids_table") 269 | 270 | deduplication_results = self._create_deduplication_results(all_results_long, sdf) 271 | 272 | return deduplication_results 273 | -------------------------------------------------------------------------------- /spark_matcher/deduplicator/hierarchical_clustering.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | from collections import defaultdict 6 | from typing import List, Dict, Callable 7 | 8 | import numpy as np 9 | import pandas as pd 10 | import scipy.spatial.distance as ssd 11 | from scipy.cluster import hierarchy 12 | 13 | 14 | def _perform_clustering(component: pd.DataFrame, threshold: float, linkage_method: str): 15 | """ 16 | Apply hierarchical clustering to scored_pairs_table with component_ids 17 | Args: 18 | component: pandas dataframe containing all pairs and the similarity scores for a connected component 19 | threshold: threshold to apply in hierarchical clustering 20 | linkage_method: linkage method to apply in hierarchical clustering 21 | Returns: 22 | Generator that contains tuples of ids and scores 23 | """ 24 | 25 | distance_threshold = 1 - threshold 26 | if len(component) > 1: 27 | i_to_id, condensed_distances = _get_condensed_distances(component) 28 | 29 | linkage = hierarchy.linkage(condensed_distances, method=linkage_method) 30 | partition = hierarchy.fcluster(linkage, t=distance_threshold, criterion='distance') 31 | 32 | clusters: Dict[int, List[int]] = defaultdict(list) 33 | 34 | for i, cluster_id in enumerate(partition): 35 | clusters[cluster_id].append(i) 36 | 37 | for cluster in clusters.values(): 38 | if len(cluster) > 1: 39 | yield tuple(i_to_id[i] for i in cluster), None 40 | else: 41 | ids = np.array([int(component['row_number_1']), int(component['row_number_2'])]) 42 | score = float(component['score']) 43 | if score > threshold: 44 | yield tuple(ids), (score,) * 2 45 | 46 | 47 | def _convert_data_to_adjacency_matrix(component: pd.DataFrame): 48 | """ 49 | This function converts a pd.DataFrame to a numpy adjacency matrix 50 | Args: 51 | component: pd.DataFrame 52 | Returns: 53 | index of elements of the components and a numpy adjacency matrix 54 | """ 55 | def _get_adjacency_matrix(df, col1, col2, score_col:str = 'score'): 56 | df = pd.crosstab(df[col1], df[col2], values=df[score_col], aggfunc='max') 57 | idx = df.columns.union(df.index) 58 | df = df.reindex(index = idx, columns=idx, fill_value=0).fillna(0) 59 | return df 60 | 61 | a_to_b = _get_adjacency_matrix(component, "row_number_1", "row_number_2") 62 | b_to_a = _get_adjacency_matrix(component, "row_number_2", "row_number_1") 63 | 64 | symmetric_adjacency_matrix = a_to_b + b_to_a 65 | 66 | return symmetric_adjacency_matrix.index, np.array(symmetric_adjacency_matrix) 67 | 68 | 69 | def _get_condensed_distances(component: pd.DataFrame): 70 | """ 71 | Converts the pairwise list of distances to "condensed distance matrix" required by the hierarchical clustering 72 | algorithms. Also return a dictionary that maps the distance matrix to the ids. 73 | 74 | Args: 75 | component: pandas dataframe containing all pairs and the similarity scores for a connected component 76 | 77 | Returns: 78 | condensed distances and a dict with ids 79 | """ 80 | 81 | i_to_id, adj_matrix = _convert_data_to_adjacency_matrix(component) 82 | distances = (np.ones_like(adj_matrix) - np.eye(len(adj_matrix))) - adj_matrix 83 | return dict(enumerate(i_to_id)), ssd.squareform(distances) 84 | 85 | 86 | def _convert_dedupe_result_to_pandas_dataframe(dedupe_result: List, component_id: int) -> pd.DataFrame: 87 | """ 88 | Function to convert the dedupe result into a pandas dataframe. 89 | E.g. 90 | dedupe_result = [((1, 2), array([0.96, 0.96])), ((3, 4, 5), array([0.95, 0.95, 0.95]))] 91 | 92 | returns 93 | 94 | | row_number | entity_identifier | 95 | | ---------- | ------------------- | 96 | | 1 | 1 | 97 | | 2 | 1 | 98 | | 3 | 2 | 99 | | 4 | 2 | 100 | | 5 | 2 | 101 | 102 | Args: 103 | dedupe_result: the result with the deduplication results from the clustering 104 | Returns: 105 | pandas dataframe with row_number and entity_identifier 106 | """ 107 | if len(dedupe_result) == 0: 108 | return pd.DataFrame(data={'row_number': [], 'entity_identifier': []}) 109 | 110 | entity_identifier = 0 111 | df_list = [] 112 | 113 | for ids, _ in dedupe_result: 114 | df_ = pd.DataFrame(data={'row_number': list(ids), 'entity_identifier': f"{component_id}_{entity_identifier}"}) 115 | df_list.append(df_) 116 | entity_identifier += 1 117 | return pd.concat(df_list) 118 | 119 | 120 | def apply_deduplication(cluster_score_threshold: float, cluster_linkage_method: str) -> Callable: 121 | """ 122 | This function is a wrapper function to parameterize the _apply_deduplucation function with extra parameters for 123 | the cluster_score_threshold and the linkage method. 124 | Args: 125 | cluster_score_threshold: a float in [0,1] 126 | cluster_linkage_method: a string indicating the linkage method to be used for hierarchical clustering 127 | Returns: 128 | a function, i.e. _apply_deduplication, that can be called as a Pandas udf 129 | """ 130 | def _apply_deduplication(component: pd.DataFrame) -> pd.DataFrame: 131 | """ 132 | This function applies deduplication on a component, i.e. a subgraph calculated by the connected components 133 | algorithm. This function is applied to a spark dataframe in a pandas udf to distribute the deduplication 134 | over a Spark cluster, component by component. 135 | Args: 136 | component: a pandas Dataframe 137 | Returns: 138 | a pandas Dataframe with the results from the hierarchical clustering of the deduplication 139 | """ 140 | component_id = component['component_id'][0] 141 | 142 | # perform the clustering: 143 | component = list(_perform_clustering(component, cluster_score_threshold, cluster_linkage_method)) 144 | 145 | # convert the results to a dataframe: 146 | return _convert_dedupe_result_to_pandas_dataframe(component, component_id) 147 | 148 | return _apply_deduplication 149 | -------------------------------------------------------------------------------- /spark_matcher/matcher/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['Matcher'] 2 | 3 | from .matcher import Matcher -------------------------------------------------------------------------------- /spark_matcher/matcher/matcher.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | from typing import Optional, List, Dict 6 | 7 | from pyspark.sql import DataFrame, functions as F 8 | from pyspark.sql import SparkSession 9 | from pyspark.sql import Window 10 | from sklearn.exceptions import NotFittedError 11 | 12 | from spark_matcher.blocker.blocking_rules import BlockingRule 13 | from spark_matcher.matching_base.matching_base import MatchingBase 14 | from spark_matcher.scorer.scorer import Scorer 15 | from spark_matcher.table_checkpointer import TableCheckpointer 16 | 17 | 18 | class Matcher(MatchingBase): 19 | """ 20 | Matcher class to apply record linkage. Provide either the column names `col_names` using the default string 21 | similarity metrics or explicitly define the string similarity metrics in a dict `field_info` as in the example 22 | below. If `blocking_rules` is left empty, default blocking rules are used. Otherwise provide blocking rules as 23 | a list containing `BlockingRule` instances (see example below). The number of perfect matches used during 24 | training is set by `n_perfect_train_matches`. 25 | 26 | E.g.: 27 | 28 | from spark_matcher.blocker.blocking_rules import FirstNChars 29 | 30 | myMatcher = Matcher(spark_session, field_info={'name':[metric_function_1, metric_function_2], 31 | 'address:[metric_function_1, metric_function_3]}, 32 | blocking_rules=[FirstNChars('name', 3)]) 33 | 34 | Args: 35 | spark_session: Spark session 36 | col_names: list of column names to use for matching 37 | field_info: dict of column names as keys and lists of string similarity metrics as values 38 | blocking_rules: list of `BlockingRule` instances 39 | n_train_samples: nr of pair samples to be created for training 40 | ratio_hashed_samples: ratio of hashed samples to be created for training, rest is sampled randomly 41 | n_perfect_train_matches: nr of perfect matches used for training 42 | scorer: a Scorer object used for scoring pairs 43 | verbose: sets verbosity 44 | """ 45 | def __init__(self, spark_session: SparkSession, table_checkpointer: Optional[TableCheckpointer]=None, 46 | checkpoint_dir: Optional[str]=None, col_names: Optional[List[str]] = None, 47 | field_info: Optional[Dict] = None, blocking_rules: Optional[List[BlockingRule]] = None, 48 | blocking_recall: float = 1.0, n_perfect_train_matches=1, n_train_samples: int = 100_000, 49 | ratio_hashed_samples: float = 0.5, scorer: Optional[Scorer] = None, verbose: int = 0): 50 | super().__init__(spark_session, table_checkpointer, checkpoint_dir, col_names, field_info, blocking_rules, 51 | blocking_recall, n_perfect_train_matches, n_train_samples, ratio_hashed_samples, scorer, 52 | verbose) 53 | self.fitted_ = False 54 | 55 | def _create_predict_pairs_table(self, sdf_1_blocked: DataFrame, sdf_2_blocked: DataFrame) -> DataFrame: 56 | """ 57 | Method to create pairs within blocks as provided by both input Spark dataframes in the column `block_key`. 58 | Note that an inner join is performed to create pairs: entries in `sdf_1_blocked` and `sdf_2_blocked` 59 | that don't share a `block_key` value in the other table will not appear in the resulting pairs table. 60 | 61 | Args: 62 | sdf_1_blocked: Spark dataframe containing all columns used for matching and the `block_key` 63 | sdf_2_blocked: Spark dataframe containing all columns used for matching and the `block_key` 64 | 65 | Returns: 66 | Spark dataframe containing all pairs to score 67 | 68 | """ 69 | sdf_1_blocked = self._add_suffix_to_col_names(sdf_1_blocked, 1) 70 | sdf_2_blocked = self._add_suffix_to_col_names(sdf_2_blocked, 2) 71 | 72 | predict_pairs_table = (sdf_1_blocked.join(sdf_2_blocked, on='block_key', how='inner') 73 | .drop_duplicates(subset=[col + "_1" for col in self.col_names] + 74 | [col + "_2" for col in self.col_names]) 75 | ) 76 | return predict_pairs_table 77 | 78 | def predict(self, sdf_1, sdf_2, threshold=0.5, top_n=None): 79 | """ 80 | Method to predict on data used for training or new data. 81 | 82 | Args: 83 | sdf_1: first table 84 | sdf_2: second table 85 | threshold: probability threshold 86 | top_n: only return best `top_n` matches above threshold 87 | 88 | Returns: 89 | Spark dataframe with the matching result 90 | 91 | """ 92 | if not self.fitted_: 93 | raise NotFittedError('The Matcher instance is not fitted yet. Call `fit` and train the instance.') 94 | 95 | sdf_1, sdf_2 = self._simplify_dataframe_for_matching(sdf_1), self._simplify_dataframe_for_matching(sdf_2) 96 | sdf_1_blocks, sdf_2_blocks = self.blocker.transform(sdf_1, sdf_2) 97 | pairs_table = self._create_predict_pairs_table(sdf_1_blocks, sdf_2_blocks) 98 | metrics_table = self._calculate_metrics(pairs_table) 99 | scores_table = ( 100 | metrics_table 101 | .withColumn('score', self.scoring_learner.predict_proba(metrics_table['similarity_metrics'])) 102 | ) 103 | scores_table_filtered = (scores_table.filter(F.col('score') >= threshold) 104 | .drop('block_key', 'similarity_metrics')) 105 | 106 | if top_n: 107 | # we add additional columns to order by to remove ties and return exactly n items 108 | window = (Window.partitionBy(*[x + "_1" for x in self.col_names]) 109 | .orderBy(F.desc('score'), *[F.asc(col) for col in [x + "_2" for x in self.col_names]])) 110 | result = (scores_table_filtered.withColumn('rank', F.rank().over(window)) 111 | .filter(F.col('rank') <= top_n) 112 | .drop('rank')) 113 | else: 114 | result = scores_table_filtered 115 | 116 | return result 117 | -------------------------------------------------------------------------------- /spark_matcher/matching_base/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['MatchingBase'] 2 | 3 | from .matching_base import MatchingBase -------------------------------------------------------------------------------- /spark_matcher/matching_base/matching_base.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | import warnings 6 | from typing import Optional, List, Dict 7 | 8 | import dill 9 | from pyspark.sql import DataFrame, functions as F, SparkSession 10 | import numpy as np 11 | import pandas as pd 12 | from thefuzz.fuzz import token_set_ratio, token_sort_ratio 13 | 14 | from spark_matcher.activelearner.active_learner import ScoringLearner 15 | from spark_matcher.blocker.block_learner import BlockLearner 16 | from spark_matcher.sampler.training_sampler import HashSampler, RandomSampler 17 | from spark_matcher.scorer.scorer import Scorer 18 | from spark_matcher.similarity_metrics import SimilarityMetrics 19 | from spark_matcher.blocker.blocking_rules import BlockingRule, default_blocking_rules 20 | from spark_matcher.table_checkpointer import TableCheckpointer, ParquetCheckPointer 21 | 22 | 23 | class MatchingBase: 24 | 25 | def __init__(self, spark_session: SparkSession, table_checkpointer: Optional[TableCheckpointer] = None, 26 | checkpoint_dir: Optional[str] = None, col_names: Optional[List[str]] = None, 27 | field_info: Optional[Dict] = None, blocking_rules: Optional[List[BlockingRule]] = None, 28 | blocking_recall: float = 1.0, n_perfect_train_matches=1, n_train_samples: int = 100_000, 29 | ratio_hashed_samples: float = 0.5, scorer: Optional[Scorer] = None, verbose: int = 0): 30 | self.spark_session = spark_session 31 | self.table_checkpointer = table_checkpointer 32 | if not self.table_checkpointer: 33 | if checkpoint_dir: 34 | self.table_checkpointer = ParquetCheckPointer(self.spark_session, checkpoint_dir, 35 | "checkpoint_deduplicator") 36 | else: 37 | warnings.warn( 38 | 'Either `table_checkpointer` or `checkpoint_dir` should be provided. This instance can only be used' 39 | ' when loading a previously saved instance.') 40 | 41 | if col_names: 42 | self.col_names = col_names 43 | self.field_info = {col_name: [token_set_ratio, token_sort_ratio] for col_name in 44 | self.col_names} 45 | elif field_info: 46 | self.col_names = list(field_info.keys()) 47 | self.field_info = field_info 48 | else: 49 | warnings.warn( 50 | 'Either `col_names` or `field_info` should be provided. This instance can only be used when loading a ' 51 | 'previously saved instance.') 52 | self.col_names = [''] # needed for instantiating ScoringLearner 53 | 54 | self.n_train_samples = n_train_samples 55 | self.ratio_hashed_samples = ratio_hashed_samples 56 | self.n_perfect_train_matches = n_perfect_train_matches 57 | self.verbose = verbose 58 | 59 | if not scorer: 60 | scorer = Scorer(self.spark_session) 61 | self.scoring_learner = ScoringLearner(self.col_names, scorer, verbose=self.verbose) 62 | 63 | self.blocking_rules = blocking_rules 64 | if not self.blocking_rules: 65 | self.blocking_rules = [blocking_rule(col) for col in self.col_names for blocking_rule in 66 | default_blocking_rules] 67 | self.blocker = BlockLearner(blocking_rules=self.blocking_rules, recall=blocking_recall, 68 | table_checkpointer=self.table_checkpointer, verbose=self.verbose) 69 | 70 | def save(self, path: str) -> None: 71 | """ 72 | Save the current instance to a pickle file. 73 | 74 | Args: 75 | path: Path and file name of pickle file 76 | 77 | """ 78 | to_save = self.__dict__ 79 | 80 | # the spark sessions need to be removed as they cannot be saved and re-used later 81 | to_save['spark_session'] = None 82 | setattr(to_save['scoring_learner'].learner.estimator, 'spark_session', None) 83 | setattr(to_save['table_checkpointer'], 'spark_session', None) 84 | 85 | with open(path, 'wb') as f: 86 | dill.dump(to_save, f) 87 | 88 | def load(self, path: str) -> None: 89 | """ 90 | Load a previously trained and saved Matcher instance. 91 | 92 | Args: 93 | path: Path and file name of pickle file 94 | 95 | """ 96 | with open(path, 'rb') as f: 97 | loaded_obj = dill.load(f) 98 | 99 | # the spark session that was removed before saving needs to be filled with the spark session of this instance 100 | loaded_obj['spark_session'] = self.spark_session 101 | setattr(loaded_obj['scoring_learner'].learner.estimator, 'spark_session', self.spark_session) 102 | setattr(loaded_obj['table_checkpointer'], 'spark_session', self.spark_session) 103 | 104 | self.__dict__.update(loaded_obj) 105 | 106 | def _simplify_dataframe_for_matching(self, sdf: DataFrame) -> DataFrame: 107 | """ 108 | Select only the columns used for matching and drop duplicates. 109 | """ 110 | return sdf.select(*self.col_names).drop_duplicates() 111 | 112 | def _create_train_pairs_table(self, sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None) -> DataFrame: 113 | """ 114 | Create pairs that are used for training. Based on the given 'ratio_hashed_samples', part of the sample 115 | pairs are generated using MinHashLSH technique to create pairs that are more likely to be a match. Rest is 116 | sampled using random selection. 117 | 118 | Args: 119 | sdf_1: Spark dataframe containing the first table to with the input should be matched 120 | sdf_2: Optional: Spark dataframe containing the second table that should be matched to the first table 121 | 122 | Returns: 123 | Spark dataframe with sampled pairs that should be compared during training 124 | """ 125 | n_hashed_samples = int(self.n_train_samples * self.ratio_hashed_samples) 126 | 127 | h_sampler = HashSampler(self.table_checkpointer, self.col_names, n_hashed_samples) 128 | hashed_pairs_table = self.table_checkpointer.checkpoint_table(h_sampler.create_pairs_table(sdf_1, sdf_2), 129 | checkpoint_name='minhash_pairs_table') 130 | 131 | # creating additional random samples to ensure that the exact number of n_train_samples 132 | # is obtained after dropping duplicated records 133 | n_random_samples = int(1.5 * (self.n_train_samples - hashed_pairs_table.count())) 134 | 135 | r_sampler = RandomSampler(self.table_checkpointer, self.col_names, n_random_samples) 136 | random_pairs_table = self.table_checkpointer.checkpoint_table(r_sampler.create_pairs_table(sdf_1, sdf_2), 137 | checkpoint_name='random_pairs_table') 138 | 139 | pairs_table = ( 140 | random_pairs_table 141 | .unionByName(hashed_pairs_table) 142 | .withColumn('perfect_train_match', F.lit(False)) 143 | .drop_duplicates() 144 | .limit(self.n_train_samples) 145 | ) 146 | 147 | # add some perfect matches to assure there are two labels in the train data and the classifier can be trained 148 | perfect_matches = (pairs_table.withColumn('perfect_train_match', F.lit(True)) 149 | .limit(self.n_perfect_train_matches)) 150 | for col in self.col_names: 151 | perfect_matches = perfect_matches.withColumn(col + "_1", F.col(col + "_2")) 152 | 153 | pairs_table = (perfect_matches.unionByName(pairs_table 154 | .limit(self.n_train_samples - self.n_perfect_train_matches))) 155 | 156 | return pairs_table 157 | 158 | def _calculate_metrics(self, pairs_table: DataFrame) -> DataFrame: 159 | """ 160 | Method to apply similarity metrics to pairs table. 161 | 162 | Args: 163 | pairs_table: Spark dataframe containing pairs table 164 | 165 | Returns: 166 | Spark dataframe with pairs table and newly created `similarity_metrics` column 167 | 168 | """ 169 | similarity_metrics = SimilarityMetrics(self.field_info) 170 | return similarity_metrics.transform(pairs_table) 171 | 172 | def _create_blocklearning_input(self, metrics_table: pd.DataFrame, threshold: int = 0.5) -> DataFrame: 173 | """ 174 | Method to collect data used for block learning. This data consists of the manually labelled pairs with label 1 175 | and the train pairs with a score above the `threshold` 176 | 177 | Args: 178 | metrics_table: Pandas dataframe containing the similarity metrics 179 | threshold: scoring threshold 180 | 181 | Returns: 182 | Spark dataframe with data for block learning 183 | 184 | """ 185 | metrics_table['score'] = self.scoring_learner.predict_proba(np.array(metrics_table['similarity_metrics'].tolist()))[:, 1] 186 | metrics_table.loc[metrics_table['score'] > threshold, 'label'] = '1' 187 | 188 | # get labelled positives from activelearner 189 | positive_train_labels = (self.scoring_learner.train_samples[self.scoring_learner.train_samples['y'] == '1'] 190 | .rename(columns={'y': 'label'})) 191 | metrics_table = pd.concat([metrics_table, positive_train_labels]).drop_duplicates( 192 | subset=[col + "_1" for col in self.col_names] + [col + "_2" for col in self.col_names], keep='last') 193 | 194 | metrics_table = metrics_table[metrics_table.label == '1'] 195 | 196 | metrics_table['row_id'] = np.arange(len(metrics_table)) 197 | return self.spark_session.createDataFrame(metrics_table) 198 | 199 | def _add_suffix_to_col_names(self, sdf: DataFrame, suffix: int): 200 | """ 201 | This method adds a suffix to the columns that are used in the algorithm. 202 | This is done in order to do a join (with two dataframes with the same schema) to create the pairs table. 203 | """ 204 | for col in self.col_names: 205 | sdf = sdf.withColumnRenamed(col, f"{col}_{suffix}") 206 | return sdf 207 | 208 | def fit(self, sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None) -> 'MatchingBase': 209 | """ 210 | Fit the MatchingBase instance on the two dataframes `sdf_1` and `sdf_2` using active learning. You will be 211 | prompted to enter whether the presented pairs are a match or not. Note that `sdf_2` is an optional argument. 212 | `sdf_2` is used for Matcher (i.e. matching one table to another). In the case of Deduplication, only providing 213 | `sdf_1` is sufficient, in that case `sdf_1` will be deduplicated. 214 | 215 | Args: 216 | sdf_1: Spark dataframe 217 | sdf_2: Optional: Spark dataframe 218 | 219 | Returns: 220 | Fitted MatchingBase instance 221 | 222 | """ 223 | sdf_1 = self._simplify_dataframe_for_matching(sdf_1) 224 | if sdf_2: 225 | sdf_2 = self._simplify_dataframe_for_matching(sdf_2) 226 | 227 | pairs_table = self._create_train_pairs_table(sdf_1, sdf_2) 228 | metrics_table = self._calculate_metrics(pairs_table) 229 | metrics_table_pdf = metrics_table.toPandas() 230 | self.scoring_learner.fit(metrics_table_pdf) 231 | block_learning_input = self._create_blocklearning_input(metrics_table_pdf) 232 | self.blocker.fit(block_learning_input) 233 | self.fitted_ = True 234 | return self 235 | -------------------------------------------------------------------------------- /spark_matcher/sampler/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['HashSampler', 'RandomSampler'] 2 | 3 | from .training_sampler import HashSampler, RandomSampler -------------------------------------------------------------------------------- /spark_matcher/sampler/training_sampler.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | import abc 6 | from typing import List, Tuple, Optional, Union 7 | 8 | import numpy as np 9 | from pyspark.ml.feature import CountVectorizer, VectorAssembler, MinHashLSH 10 | from pyspark.ml.linalg import DenseVector, SparseVector 11 | from pyspark.sql import DataFrame, functions as F, types as T 12 | 13 | from spark_matcher.config import MINHASHING_MAXDF, MINHASHING_VOCABSIZE 14 | from spark_matcher.table_checkpointer import TableCheckpointer 15 | 16 | 17 | class Sampler: 18 | """ 19 | Sampler base class to generate pairs for training. 20 | 21 | Args: 22 | col_names: list of column names to use for matching 23 | n_train_samples: nr of pair samples to be created for training 24 | table_checkpointer: table_checkpointer object to store cached tables 25 | """ 26 | def __init__(self, col_names: List[str], n_train_samples: int, table_checkpointer: TableCheckpointer) -> DataFrame: 27 | self.col_names = col_names 28 | self.n_train_samples = n_train_samples 29 | self.table_checkpointer = table_checkpointer 30 | 31 | @abc.abstractmethod 32 | def create_pairs_table(self, sdf_1: DataFrame, sdf_2: DataFrame) -> DataFrame: 33 | pass 34 | 35 | 36 | class HashSampler(Sampler): 37 | def __init__(self, table_checkpointer: TableCheckpointer, col_names: List[str], n_train_samples: int, 38 | threshold: float = 0.5, num_hash_tables: int = 10) -> DataFrame: 39 | """ 40 | Sampler class to generate pairs using MinHashLSH method by selecting pairs that are more likely to be a match. 41 | 42 | Args: 43 | threshold: for the distance of hashed pairs, values below threshold will be returned 44 | (threshold being equal to 1.0 will return all pairs with at least one common shingle) 45 | num_hash_tables: number of hashes to be applied 46 | table_checkpointer: table_checkpointer object to store cached tables 47 | Returns: 48 | A spark dataframe which contains all selected pairs for training 49 | """ 50 | super().__init__(col_names, n_train_samples, table_checkpointer) 51 | self.threshold = threshold 52 | self.num_hash_tables = num_hash_tables 53 | 54 | @staticmethod 55 | @F.udf(T.BooleanType()) 56 | def is_non_zero_vector(vector): 57 | """ 58 | Check if a vector has at least 1 non-zero entry. This function can deal with dense or sparse vectors. This is 59 | needed as the VectorAssembler can return dense or sparse vectors, dependent on what is more memory efficient. 60 | 61 | Args: 62 | vector: vector 63 | 64 | Returns: 65 | boolean whether a vector has at least 1 non-zero entry 66 | 67 | """ 68 | if isinstance(vector, DenseVector): 69 | return bool(len(vector)) 70 | if isinstance(vector, SparseVector): 71 | return vector.indices.size > 0 72 | 73 | def _vectorize(self, col_names: List[str], sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None, 74 | max_df: Union[float, int] = MINHASHING_MAXDF, vocab_size: int = MINHASHING_VOCABSIZE) -> Union[ 75 | Tuple[DataFrame, DataFrame], DataFrame]: 76 | """ 77 | This function is used to vectorize word based features. First, given dataframes are united to cover all feature 78 | space. Given columns are converted into a vector. `sdf_2` is only required for record matching, for 79 | deduplication only `sdf_1` is required. To prevent too many (redundant) matches on frequently occurring tokens, 80 | the maximum document frequency for vectorization is limited by maxDF. 81 | 82 | Args: 83 | col_names: list of column names to create feature vectors 84 | sdf_1: spark dataframe 85 | sdf_2: Optional: spark dataframe 86 | max_df: max document frequency to use for count vectorization 87 | vocab_size: vocabulary size of vectorizer 88 | 89 | Returns: 90 | Adds a 'features' column to given spark dataframes 91 | 92 | """ 93 | if sdf_2: 94 | # creating a source column to separate tables after vectorization is completed 95 | sdf_1 = sdf_1.withColumn('source', F.lit('sdf_1')) 96 | sdf_2 = sdf_2.withColumn('source', F.lit('sdf_2')) 97 | sdf_merged = sdf_1.unionByName(sdf_2) 98 | else: 99 | sdf_merged = sdf_1 100 | 101 | for col in col_names: 102 | # converting strings into an array 103 | sdf_merged = sdf_merged.withColumn(f'{col}_array', F.split(F.col(col), ' ')) 104 | 105 | # setting CountVectorizer 106 | # minDF is set to 2 because if a token occurs only once there is no other occurrence to match it to 107 | # the maximum document frequency may not be smaller than the minimum document frequency which is assured 108 | # below 109 | n_rows = sdf_merged.count() 110 | if isinstance(max_df, float) and (n_rows * max_df < 2): 111 | max_df = 10 112 | cv = CountVectorizer(binary=True, minDF=2, maxDF=max_df, vocabSize=vocab_size, inputCol=f"{col}_array", 113 | outputCol=f"{col}_vector") 114 | 115 | # creating vectors 116 | model = cv.fit(sdf_merged) 117 | sdf_merged = model.transform(sdf_merged) 118 | sdf_merged = self.table_checkpointer(sdf_merged, "cached_vectorized_table") 119 | 120 | # in case of col_names contains multiple columns, merge all vectors into one 121 | vec_assembler = VectorAssembler(outputCol="features") 122 | vec_assembler.setInputCols([f'{col}_vector' for col in col_names]) 123 | 124 | sdf_merged = vec_assembler.transform(sdf_merged) 125 | 126 | # breaking the lineage is found to be required below 127 | sdf_merged = self.table_checkpointer( 128 | sdf_merged.filter(HashSampler.is_non_zero_vector('features')), 'minhash_vectorized') 129 | 130 | if sdf_2: 131 | sdf_1_vectorized = sdf_merged.filter(F.col('source') == 'sdf_1').select(*col_names, 'features') 132 | sdf_2_vectorized = sdf_merged.filter(F.col('source') == 'sdf_2').select(*col_names, 'features') 133 | return sdf_1_vectorized, sdf_2_vectorized 134 | sdf_1_vectorized = sdf_merged.select(*col_names, 'features') 135 | return sdf_1_vectorized 136 | 137 | @staticmethod 138 | def _apply_min_hash(sdf_1: DataFrame, sdf_2: DataFrame, col_names: List[str], 139 | threshold: float, num_hash_tables: int): 140 | """ 141 | This function is used to apply MinHasLSH technique to calculate the Jaccard distance between feature vectors. 142 | 143 | Args: 144 | sdf_1: spark dataframe with the `features` column and the column `row_number_1` 145 | sdf_2: spark dataframe with the `features` column and the column `row_number_2` 146 | col_names: list of column names to create feature vectors 147 | 148 | Returns: 149 | Creates one spark dataframe that contains all pairs that has a smaller Jaccard distance than threshold 150 | """ 151 | mh = MinHashLSH(inputCol="features", outputCol="hashes", seed=42, numHashTables=num_hash_tables) 152 | 153 | model = mh.fit(sdf_1) 154 | model.transform(sdf_1) 155 | 156 | sdf_distances = ( 157 | model 158 | .approxSimilarityJoin(sdf_1, sdf_2, threshold, distCol="JaccardDistance") 159 | .select(*[F.col(f'datasetA.{col}').alias(f'{col}_1') for col in col_names], 160 | *[F.col(f'datasetB.{col}').alias(f'{col}_2') for col in col_names], 161 | F.col('JaccardDistance').alias('distance'), 162 | F.col('datasetA.row_number_1').alias('row_number_1'), 163 | F.col('datasetB.row_number_2').alias('row_number_2')) 164 | ) 165 | return sdf_distances 166 | 167 | def create_pairs_table(self, sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None) -> DataFrame: 168 | """ 169 | Create hashed pairs that are used for training. `sdf_2` is only required for record matching, for deduplication 170 | only `sdf_1` is required. 171 | 172 | Args: 173 | sdf_1: Spark dataframe containing the first table to with the input should be matched 174 | sdf_2: Optional: Spark dataframe containing the second table that should be matched to the first table 175 | 176 | Returns: 177 | Spark dataframe that contains sampled pairs selected with MinHashLSH technique 178 | 179 | """ 180 | if sdf_2: 181 | sdf_1_vectorized, sdf_2_vectorized = self._vectorize(self.col_names, sdf_1, sdf_2) 182 | sdf_1_vectorized = self.table_checkpointer( 183 | sdf_1_vectorized.withColumn('row_number_1', F.monotonically_increasing_id()), 184 | checkpoint_name='sdf_1_vectorized') 185 | sdf_2_vectorized = self.table_checkpointer( 186 | sdf_2_vectorized.withColumn('row_number_2', F.monotonically_increasing_id()), 187 | checkpoint_name='sdf_2_vectorized') 188 | else: 189 | sdf_1_vectorized = self._vectorize(self.col_names, sdf_1) 190 | sdf_1_vectorized = self.table_checkpointer( 191 | sdf_1_vectorized.withColumn('row_number_1', F.monotonically_increasing_id()), 192 | checkpoint_name='sdf_1_vectorized') 193 | sdf_2_vectorized = sdf_1_vectorized.alias('sdf_2_vectorized') # matched with itself for deduplication 194 | sdf_2_vectorized = sdf_2_vectorized.withColumnRenamed('row_number_1', 'row_number_2') 195 | 196 | sdf_distances = self._apply_min_hash(sdf_1_vectorized, sdf_2_vectorized, self.col_names, 197 | self.threshold, self.num_hash_tables) 198 | 199 | # for deduplication we remove identical pairs like (a,a) and duplicates of pairs like (a,b) and (b,a) 200 | if not sdf_2: 201 | sdf_distances = sdf_distances.filter(F.col('row_number_1') < F.col('row_number_2')) 202 | 203 | hashed_pairs_table = ( 204 | sdf_distances 205 | .sort('distance') 206 | .limit(self.n_train_samples) 207 | .drop('distance', 'row_number_1', 'row_number_2') 208 | ) 209 | 210 | return hashed_pairs_table 211 | 212 | 213 | class RandomSampler(Sampler): 214 | def __init__(self, table_checkpointer: TableCheckpointer, col_names: List[str], n_train_samples: int) -> DataFrame: 215 | """ 216 | Sampler class to generate randomly selected pairs 217 | 218 | Returns: 219 | A spark dataframe which contains randomly selected pairs for training 220 | """ 221 | super().__init__(col_names, n_train_samples, table_checkpointer) 222 | 223 | def create_pairs_table(self, sdf_1: DataFrame, sdf_2: Optional[DataFrame] = None) -> DataFrame: 224 | """ 225 | Create random pairs that are used for training. `sdf_2` is only required for record matching, for deduplication 226 | only `sdf_1` is required. 227 | 228 | Args: 229 | sdf_1: Spark dataframe containing the first table to with the input should be matched 230 | sdf_2: Optional: Spark dataframe containing the second table that should be matched to the first table 231 | 232 | Returns: 233 | Spark dataframe that contains randomly sampled pairs 234 | 235 | """ 236 | if sdf_2: 237 | return self._create_pairs_table_matcher(sdf_1, sdf_2) 238 | return self._create_pairs_table_deduplicator(sdf_1) 239 | 240 | def _create_pairs_table_matcher(self, sdf_1: DataFrame, sdf_2: DataFrame) -> DataFrame: 241 | """ 242 | Create random pairs that are used for training. 243 | 244 | If the first table and the second table are equally large, we take the square root of `n_train_samples` from 245 | both tables to create a pairs table containing `n_train_samples` rows. If one of the table is much smaller than 246 | the other, all rows of the smallest table are taken and the number of the sample from the larger table is chosen 247 | such that the total number of pairs is `n_train_samples`. 248 | 249 | Args: 250 | sdf_1: Spark dataframe containing the first table to with the input should be matched 251 | sdf_2: Spark dataframe containing the second table that should be matched to the first table 252 | 253 | Returns: 254 | Spark dataframe with randomly sampled pairs to compared during training 255 | """ 256 | sdf_1_count, sdf_2_count = sdf_1.count(), sdf_2.count() 257 | 258 | sample_size_small_table = min([int(self.n_train_samples ** 0.5), min(sdf_1_count, sdf_2_count)]) 259 | sample_size_large_table = self.n_train_samples // sample_size_small_table 260 | 261 | if sdf_1_count > sdf_2_count: 262 | fraction_sdf_1 = sample_size_large_table / sdf_1_count 263 | fraction_sdf_2 = sample_size_small_table / sdf_2_count 264 | # below the `fraction` is slightly increased (but capped at 1.) and a `limit()` is applied to assure that 265 | # the exact number of required samples is obtained 266 | sdf_1_sample = (sdf_1.sample(withReplacement=False, fraction=min([1., 1.5 * fraction_sdf_1])) 267 | .limit(sample_size_large_table)) 268 | sdf_2_sample = (sdf_2.sample(withReplacement=False, fraction=min([1., 1.5 * fraction_sdf_2])) 269 | .limit(sample_size_small_table)) 270 | else: 271 | fraction_sdf_1 = sample_size_small_table / sdf_1_count 272 | fraction_sdf_2 = sample_size_large_table / sdf_2_count 273 | sdf_1_sample = (sdf_1.sample(withReplacement=False, fraction=min([1., 1.5 * fraction_sdf_1])) 274 | .limit(sample_size_small_table)) 275 | sdf_2_sample = (sdf_2.sample(withReplacement=False, fraction=min([1., 1.5 * fraction_sdf_2])) 276 | .limit(sample_size_large_table)) 277 | 278 | for col in self.col_names: 279 | sdf_1_sample = sdf_1_sample.withColumnRenamed(col, col + "_1") 280 | sdf_2_sample = sdf_2_sample.withColumnRenamed(col, col + "_2") 281 | 282 | random_pairs_table = sdf_1_sample.crossJoin(sdf_2_sample) 283 | 284 | return random_pairs_table 285 | 286 | def _create_pairs_table_deduplicator(self, sdf: DataFrame) -> DataFrame: 287 | """ 288 | Create randomly selected pairs for deduplication. 289 | 290 | Args: 291 | sdf: Spark dataframe containing rows from which pairs need to be created 292 | 293 | Returns: 294 | Spark dataframe with the randomly selected pairs 295 | 296 | """ 297 | n_samples_required = int(1.5 * np.ceil((self.n_train_samples * 2) ** 0.5)) 298 | fraction_samples_required = min([n_samples_required / sdf.count(), 1.]) 299 | sample = sdf.sample(withReplacement=False, fraction=fraction_samples_required) 300 | sample = self.table_checkpointer(sample.withColumn('row_id', F.monotonically_increasing_id()), 301 | checkpoint_name='random_pairs_deduplicator') 302 | sample_1, sample_2 = sample, sample 303 | for col in self.col_names + ['row_id']: 304 | sample_1 = sample_1.withColumnRenamed(col, col + "_1") 305 | sample_2 = sample_2.withColumnRenamed(col, col + "_2") 306 | 307 | pairs_table = (sample_1 308 | .crossJoin(sample_2) 309 | .filter(F.col('row_id_1') < F.col('row_id_2')) 310 | .limit(self.n_train_samples) 311 | .drop('row_id_1', 'row_id_2')) 312 | 313 | return pairs_table 314 | -------------------------------------------------------------------------------- /spark_matcher/scorer/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['Scorer'] 2 | 3 | from .scorer import Scorer -------------------------------------------------------------------------------- /spark_matcher/scorer/scorer.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | from typing import Union, Optional 6 | 7 | import numpy as np 8 | from pyspark.sql import SparkSession, types as T, functions as F, Column 9 | from sklearn.base import BaseEstimator 10 | from sklearn.pipeline import make_pipeline 11 | from sklearn.preprocessing import StandardScaler 12 | from sklearn.linear_model import LogisticRegression 13 | 14 | 15 | class Scorer: 16 | def __init__(self, spark_session: SparkSession, binary_clf: Optional[BaseEstimator] = None): 17 | self.spark_session = spark_session 18 | self.binary_clf = binary_clf 19 | if not self.binary_clf: 20 | self._create_default_clf() 21 | self.fitted_ = False 22 | 23 | def _create_default_clf(self) -> None: 24 | """ 25 | This method creates a Sklearn Pipeline with a Standard Scaler and a Logistic Regression classifier. 26 | """ 27 | self.binary_clf = ( 28 | make_pipeline( 29 | StandardScaler(), 30 | LogisticRegression(class_weight='balanced') 31 | ) 32 | ) 33 | 34 | def fit(self, X: np.ndarray, y: np.ndarray) -> 'Scorer': 35 | """ 36 | This method fits a clf model on input data `X` nd the binary targets `y`. 37 | 38 | Args: 39 | X: training data 40 | y: training targets, containing binary values 41 | 42 | Returns: 43 | The object itself 44 | """ 45 | 46 | if len(set(y)) == 1: # in case active learning only resulted in labels from one class 47 | return self 48 | 49 | self.binary_clf.fit(X, y) 50 | self.fitted_ = True 51 | return self 52 | 53 | def _predict_proba(self, X: np.ndarray) -> np.ndarray: 54 | """ 55 | This method implements the code for predict_proba on a numpy array. 56 | This method is used to score all the pairs during training. 57 | """ 58 | return self.binary_clf.predict_proba(X) 59 | 60 | def _predict_proba_spark(self, X: Column) -> Column: 61 | """ 62 | This method implements the code for predict_proba on a spark column. 63 | This method is used to score all the pairs during inference time. 64 | """ 65 | broadcasted_clf = self.spark_session.sparkContext.broadcast(self.binary_clf) 66 | 67 | @F.pandas_udf(T.FloatType()) 68 | def _distributed_predict_proba(array): 69 | """ 70 | This inner function defines the Pandas UDF for predict_proba on a spark cluster 71 | """ 72 | return array.apply(lambda x: broadcasted_clf.value.predict_proba([x])[0][1]) 73 | 74 | return _distributed_predict_proba(X) 75 | 76 | def predict_proba(self, X: Union[Column, np.ndarray]) -> Union[Column, np.ndarray]: 77 | """ 78 | This method implements the abstract predict_proba method. It predicts the 'probabilities' of the target class 79 | for given input data `X`. 80 | 81 | Args: 82 | X: input data 83 | 84 | Returns: 85 | the predicted probabilities 86 | """ 87 | if isinstance(X, Column): 88 | return self._predict_proba_spark(X) 89 | 90 | if isinstance(X, np.ndarray): 91 | return self._predict_proba(X) 92 | 93 | raise ValueError(f"{type(X)} is an unsupported datatype for X") 94 | -------------------------------------------------------------------------------- /spark_matcher/similarity_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['SimilarityMetrics'] 2 | 3 | from .similarity_metrics import SimilarityMetrics 4 | -------------------------------------------------------------------------------- /spark_matcher/similarity_metrics/similarity_metrics.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | from typing import List, Callable, Dict 6 | 7 | import pandas as pd 8 | 9 | from pyspark.sql import DataFrame, Column 10 | from pyspark.sql import functions as F 11 | from pyspark.sql import types as T 12 | 13 | 14 | class SimilarityMetrics: 15 | """ 16 | Class to calculate similarity metrics for pairs of records. The `field_info` dict contains column names as keys 17 | and lists of similarity functions as values. E.g. 18 | 19 | field_info = {'name': [token_set_ratio, token_sort_ratio], 20 | 'postcode': [ratio]} 21 | 22 | where `token_set_ratio`, `token_sort_ratio` and `ratio` are string similarity functions that take two strings 23 | as arguments and return a numeric value 24 | 25 | Attributes: 26 | field_info: dict containing column names as keys and lists of similarity functions as values 27 | """ 28 | def __init__(self, field_info: Dict): 29 | self.field_info = field_info 30 | 31 | @staticmethod 32 | def _distance_measure_pandas(strings_1: List[str], strings_2: List[str], func: Callable) -> pd.Series: 33 | """ 34 | Helper function to apply a string similarity metric to two arrays of strings. To be used in a Pandas UDF. 35 | 36 | Args: 37 | strings_1: array containing strings 38 | strings_2: array containing strings 39 | func: string similarity function to be applied 40 | 41 | Returns: 42 | Pandas series containing string similarities 43 | 44 | """ 45 | df = pd.DataFrame({'x': strings_1, 'y': strings_2}) 46 | return df.apply(lambda row: func(row['x'], row['y']), axis=1) 47 | 48 | @staticmethod 49 | def _create_similarity_metric_udf(similarity_metric_function: Callable): 50 | """ 51 | Function that created Pandas UDF for a given similarity_metric_function 52 | 53 | Args: 54 | similarity_metric_function: function that takes two strings and returns a number 55 | 56 | Returns: 57 | Pandas UDF 58 | """ 59 | 60 | @F.pandas_udf(T.FloatType()) 61 | def similarity_udf(strings_1: pd.Series, strings_2: pd.Series) -> pd.Series: 62 | # some similarity metrics cannot deal with empty strings, therefore these are replaced with " " 63 | strings_1 = [x if x != "" else " " for x in strings_1] 64 | strings_2 = [x if x != "" else " " for x in strings_2] 65 | return SimilarityMetrics._distance_measure_pandas(strings_1, strings_2, similarity_metric_function) 66 | 67 | return similarity_udf 68 | 69 | def _apply_distance_metrics(self) -> Column: 70 | """ 71 | Function to apply all distance metrics in the right order to string pairs and returns them in an array. This 72 | array is used as input to the (logistic regression) scoring function. 73 | 74 | Returns: 75 | array with the string distance metrics 76 | 77 | """ 78 | distance_metrics_list = [] 79 | for field_name in self.field_info.keys(): 80 | field_name_1, field_name_2 = field_name + "_1", field_name + "_2" 81 | for similarity_function in self.field_info[field_name]: 82 | similarity_metric = ( 83 | SimilarityMetrics._create_similarity_metric_udf(similarity_function)(F.col(field_name_1), 84 | F.col(field_name_2))) 85 | distance_metrics_list.append(similarity_metric) 86 | 87 | array = F.array(*distance_metrics_list) 88 | 89 | return array 90 | 91 | def transform(self, pairs_table: DataFrame) -> DataFrame: 92 | """ 93 | Method to apply similarity metrics to pairs table. Method makes use of method dispatching to facilitate both 94 | Pandas and Spark dataframes 95 | 96 | Args: 97 | pairs_table: Spark or Pandas dataframe containing pairs table 98 | 99 | Returns: 100 | Pandas or Spark dataframe with pairs table and newly created `similarity_metrics` column 101 | 102 | """ 103 | return pairs_table.withColumn('similarity_metrics', self._apply_distance_metrics()) 104 | -------------------------------------------------------------------------------- /spark_matcher/table_checkpointer.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | import abc 6 | import os 7 | 8 | from pyspark.sql import SparkSession, DataFrame 9 | 10 | 11 | class TableCheckpointer(abc.ABC): 12 | """ 13 | Args: 14 | spark_session: a spark session 15 | database: a name of a database or storage system where the tables can be saved 16 | checkpoint_prefix: a prefix of the name that can be used to save tables 17 | """ 18 | 19 | def __init__(self, spark_session: SparkSession, database: str, checkpoint_prefix: str = "checkpoint_spark_matcher"): 20 | self.spark_session = spark_session 21 | self.database = database 22 | self.checkpoint_prefix = checkpoint_prefix 23 | 24 | def __call__(self, sdf: DataFrame, checkpoint_name: str): 25 | return self.checkpoint_table(sdf, checkpoint_name) 26 | 27 | @abc.abstractmethod 28 | def checkpoint_table(self, sdf: DataFrame, checkpoint_name: str): 29 | pass 30 | 31 | 32 | class HiveCheckpointer(TableCheckpointer): 33 | """ 34 | Args: 35 | spark_session: a spark session 36 | database: a name of a database or storage system where the tables can be saved 37 | checkpoint_prefix: a prefix of the name that can be used to save tables 38 | """ 39 | def __init__(self, spark_session: SparkSession, database: str, checkpoint_prefix: str = "checkpoint_spark_matcher"): 40 | super().__init__(spark_session, database, checkpoint_prefix) 41 | 42 | def checkpoint_table(self, sdf: DataFrame, checkpoint_name: str): 43 | """ 44 | This method saves the input dataframes as checkpoints of the algorithm. This checkpointing can be 45 | used to store intermediary results that are needed throughout the algorithm. The tables are stored using the 46 | following name convention: `{checkpoint_prefix}_{checkpoint_name}`. 47 | 48 | Args: 49 | sdf: a Spark dataframe that needs to be saved as a checkpoint 50 | checkpoint_name: name of the table 51 | 52 | Returns: 53 | the same, unchanged, spark dataframe as the input dataframe. With the only difference that the 54 | dataframe is now read from disk as a checkpoint. 55 | """ 56 | sdf.write.saveAsTable(f"{self.database}.{self.checkpoint_prefix}_{checkpoint_name}", 57 | mode="overwrite") 58 | sdf = self.spark_session.table(f"{self.database}.{self.checkpoint_prefix}_{checkpoint_name}") 59 | return sdf 60 | 61 | 62 | class ParquetCheckPointer(TableCheckpointer): 63 | """ 64 | Args: 65 | spark_session: a spark session 66 | checkpoint_dir: directory where the tables can be saved 67 | checkpoint_prefix: a prefix of the name that can be used to save tables 68 | """ 69 | def __init__(self, spark_session: SparkSession, checkpoint_dir: str, 70 | checkpoint_prefix: str = "checkpoint_spark_matcher"): 71 | super().__init__(spark_session, checkpoint_dir, checkpoint_prefix) 72 | 73 | def checkpoint_table(self, sdf: DataFrame, checkpoint_name: str): 74 | """ 75 | This method saves the input dataframes as checkpoints of the algorithm. This checkpointing can be 76 | used to store intermediary results that are needed throughout the algorithm. The tables are stored 77 | using the 78 | following name convention: `{checkpoint_prefix}_{checkpoint_name}`. 79 | 80 | Args: 81 | sdf: a Spark dataframe that needs to be saved as a checkpoint 82 | checkpoint_name: name of the table 83 | 84 | Returns: 85 | the same, unchanged, spark dataframe as the input dataframe. With the only difference that the 86 | dataframe is now read from disk as a checkpoint. 87 | """ 88 | file_name = os.path.join(f'{self.database}', f'{self.checkpoint_prefix}_{checkpoint_name}') 89 | sdf.write.parquet(file_name, mode='overwrite') 90 | return self.spark_session.read.parquet(file_name) 91 | -------------------------------------------------------------------------------- /spark_matcher/utils.py: -------------------------------------------------------------------------------- 1 | # Authors: Ahmet Bayraktar 2 | # Stan Leisink 3 | # Frits Hermans 4 | 5 | from typing import List 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from pyspark.ml.feature import StopWordsRemover 10 | from pyspark.sql import DataFrame 11 | from pyspark.sql import functions as F 12 | 13 | 14 | def get_most_frequent_words(sdf: DataFrame, col_name: str, min_df=2, top_n_words=1_000) -> pd.DataFrame: 15 | """ 16 | Count word frequencies in a Spark dataframe `sdf` column named `col_name` and return a Pandas dataframe containing 17 | the document frequencies of the `top_n_words`. This function is intended to be used to create a list of stopwords. 18 | 19 | Args: 20 | sdf: Spark dataframe 21 | col_name: column name to get most frequent words from 22 | min_df: minimal document frequency for a word to be included as stopword, int for number and float for fraction 23 | top_n_words: number of most occurring words to include 24 | 25 | Returns: 26 | pandas dataframe containing most occurring words, their counts and document frequencies 27 | 28 | """ 29 | sdf_col_splitted = sdf.withColumn(f'{col_name}_array', F.split(F.col(col_name), pattern=' ')) 30 | word_count = sdf_col_splitted.select(F.explode(f'{col_name}_array').alias('words')).groupBy('words').count() 31 | doc_count = sdf.count() 32 | word_count = word_count.withColumn('df', F.col('count') / doc_count) 33 | if isinstance(min_df, int): 34 | min_count = min_df 35 | elif isinstance(min_df, float): 36 | min_count = np.ceil(min_df * doc_count) 37 | word_count_pdf = word_count.filter(F.col('count') >= min_count).sort(F.desc('df')).limit(top_n_words).toPandas() 38 | return word_count_pdf 39 | 40 | 41 | def remove_stopwords(sdf: DataFrame, col_name: str, stopwords: List[str], case=False, 42 | suffix: str = '_wo_stopwords') -> DataFrame: 43 | """ 44 | Remove stopwords `stopwords` from a column `col_name` in a Spark dataframe `sdf`. The result will be written to a 45 | new column, named with the concatenation of `col_name` and `suffix`. 46 | 47 | Args: 48 | sdf: Spark dataframe 49 | col_name: column name to remove stopwords from 50 | stopwords: list of stopwords to remove 51 | case: whether to check for stopwords including lower- or uppercase 52 | suffix: suffix for the newly created column 53 | 54 | Returns: 55 | Spark dataframe with column added without stopwords 56 | 57 | """ 58 | sdf = sdf.withColumn(f'{col_name}_array', F.split(F.col(col_name), pattern=' ')) 59 | sw_remover = StopWordsRemover(inputCol=f'{col_name}_array', outputCol=f'{col_name}_array_wo_stopwords', 60 | stopWords=stopwords, caseSensitive=case) 61 | sdf = (sw_remover.transform(sdf) 62 | .withColumn(f'{col_name}{suffix}', F.concat_ws(' ', F.col(f'{col_name}_array_wo_stopwords'))) 63 | .fillna({f'{col_name}{suffix}': ''}) 64 | .drop(f'{col_name}_array', f'{col_name}_array_wo_stopwords') 65 | ) 66 | return sdf 67 | -------------------------------------------------------------------------------- /spark_requirements.txt: -------------------------------------------------------------------------------- 1 | thefuzz 2 | pandas 3 | scikit-learn 4 | pyarrow 5 | -------------------------------------------------------------------------------- /test/test_active_learner/test_active_learner.py: -------------------------------------------------------------------------------- 1 | from spark_matcher.activelearner.active_learner import ScoringLearner 2 | from spark_matcher.scorer.scorer import Scorer 3 | 4 | 5 | def test__get_uncertainty_improvement(spark_session): 6 | scorer = Scorer(spark_session) 7 | myScoringLearner = ScoringLearner(col_names=[''], scorer=scorer, n_uncertainty_improvement=5) 8 | myScoringLearner.uncertainties = [0.4, 0.2, 0.1, 0.08, 0.05, 0.03] 9 | assert myScoringLearner._get_uncertainty_improvement() == 0.2 10 | 11 | 12 | def test__is_converged(spark_session): 13 | scorer = Scorer(spark_session) 14 | myScoringLearner = ScoringLearner(col_names=[''], scorer=scorer, min_nr_samples=5, uncertainty_threshold=0.1, 15 | uncertainty_improvement_threshold=0.01, n_uncertainty_improvement=5) 16 | 17 | # insufficient labelled samples 18 | myScoringLearner.uncertainties = [0.4, 0.39, 0.395] 19 | myScoringLearner.counter_total = len(myScoringLearner.uncertainties) 20 | assert not myScoringLearner._is_converged() 21 | 22 | # too large improvement in last 5 iterations 23 | myScoringLearner.uncertainties = [0.4, 0.2, 0.19, 0.18, 0.17, 0.16] 24 | myScoringLearner.counter_total = len(myScoringLearner.uncertainties) 25 | assert not myScoringLearner._is_converged() 26 | 27 | # improvement in last 5 iterations below threshold and sufficient labelled cases 28 | myScoringLearner.uncertainties = [0.19, 0.1, 0.08, 0.05, 0.03, 0.02] 29 | myScoringLearner.counter_total = len(myScoringLearner.uncertainties) 30 | assert myScoringLearner._is_converged() 31 | -------------------------------------------------------------------------------- /test/test_blocker/test_block_learner.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | from pandas.testing import assert_frame_equal 4 | from pyspark.sql import functions as F, DataFrame 5 | 6 | from spark_matcher.blocker.blocking_rules import BlockingRule 7 | from spark_matcher.blocker.block_learner import BlockLearner 8 | 9 | 10 | @pytest.fixture 11 | def blocking_rule(): 12 | class DummyBlockingRule(BlockingRule): 13 | def __repr__(self): 14 | return 'dummy_blocking_rule' 15 | 16 | def _blocking_rule(self, _): 17 | return F.lit('block_key') 18 | return DummyBlockingRule('blocking_column') 19 | 20 | 21 | @pytest.fixture 22 | def block_learner(blocking_rule, monkeypatch, table_checkpointer): 23 | monkeypatch.setattr(blocking_rule, "create_block_key", lambda _sdf: _sdf) 24 | return BlockLearner([blocking_rule, blocking_rule], 1.0, table_checkpointer) 25 | 26 | 27 | def test__create_blocks(spark_session, blocking_rule, block_learner, monkeypatch): 28 | monkeypatch.setattr(block_learner, "cover_blocking_rules", [blocking_rule, blocking_rule]) 29 | 30 | 31 | sdf = spark_session.createDataFrame( 32 | pd.DataFrame({ 33 | 'blocking_column': ['b', 'bb', 'bbb'] 34 | }) 35 | ) 36 | 37 | result = block_learner._create_blocks(sdf) 38 | 39 | assert isinstance(result, list) 40 | assert isinstance(result[0], DataFrame) 41 | 42 | 43 | def test__create_block_table(spark_session, block_learner): 44 | 45 | block_1 = spark_session.createDataFrame( 46 | pd.DataFrame({'column_1': [1, 2, 3], 'column_2': [4, 5, 6]}) 47 | ) 48 | block_2 = spark_session.createDataFrame( 49 | pd.DataFrame({'column_1': [7, 8, 9], 'column_2': [10, 11, 12]}) 50 | ) 51 | 52 | input_blocks = [block_1, block_2] 53 | 54 | expected_result = pd.DataFrame({ 55 | 'column_1': [1, 2, 3, 7, 8, 9], 'column_2': [4, 5, 6, 10, 11, 12] 56 | }) 57 | 58 | result = ( 59 | block_learner 60 | ._create_block_table(input_blocks) 61 | .toPandas() 62 | .sort_values(by='column_1') 63 | .reset_index(drop=True) 64 | ) 65 | assert_frame_equal(result, expected_result) 66 | 67 | 68 | 69 | def test__greedy_set_coverage(blocking_rule, block_learner): 70 | import copy 71 | 72 | full_set = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} 73 | block_learner.full_set = full_set 74 | block_learner.full_set_size = 10 75 | 76 | bl1 = copy.copy(blocking_rule) 77 | bl1.training_coverage = {0, 1, 2, 3} 78 | bl1.training_coverage_size = 4 79 | 80 | bl2 = copy.copy(blocking_rule) 81 | bl2.training_coverage = {3, 4, 5, 6, 7, 8} 82 | bl2.training_coverage_size = 6 83 | 84 | bl3 = copy.copy(blocking_rule) 85 | bl3.training_coverage = {3, 4, 5, 6, 7, 8} 86 | bl3.training_coverage_size = 5 87 | 88 | bl4 = copy.copy(blocking_rule) 89 | bl4.training_coverage = {8, 9} 90 | bl4.training_coverage_size = 2 91 | 92 | block_learner.blocking_rules = [bl1, bl2, bl3, bl4] 93 | 94 | # case 1: recall = 0.9 95 | block_learner.recall = 0.9 96 | block_learner._greedy_set_coverage() 97 | 98 | assert block_learner.cover_set == {0, 1, 2, 3, 4, 5, 6, 7, 8} 99 | 100 | # case 2: recall = 1.0 101 | block_learner.recall = 1.0 102 | block_learner._greedy_set_coverage() 103 | 104 | assert block_learner.cover_set == full_set 105 | 106 | # case 2: recall = 0.5 107 | block_learner.recall = 0.5 108 | block_learner._greedy_set_coverage() 109 | 110 | assert block_learner.cover_set == {3, 4, 5, 6, 7, 8} 111 | -------------------------------------------------------------------------------- /test/test_blocker/test_blocking_rules.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | from pandas.testing import assert_frame_equal 4 | from pyspark.sql import functions as F 5 | 6 | from spark_matcher.blocker.blocking_rules import BlockingRule, FirstNChars, LastNChars, WholeField, FirstNWords, \ 7 | FirstNLettersNoSpace, SortedIntegers, FirstInteger, LastInteger, LargestInteger, NLetterAbbreviation, \ 8 | FirstNCharsLastWord, FirstNCharactersFirstTokenSorted 9 | 10 | 11 | @pytest.fixture 12 | def blocking_rule(): 13 | class DummyBlockingRule(BlockingRule): 14 | def __repr__(self): 15 | return 'dummy_blocking_rule' 16 | 17 | def _blocking_rule(self, _): 18 | return F.lit('block_key') 19 | return DummyBlockingRule('blocking_column') 20 | 21 | 22 | def test_create_block_key(spark_session, blocking_rule): 23 | input_sdf = spark_session.createDataFrame( 24 | pd.DataFrame({ 25 | 'sort_column': [1, 2, 3], 26 | 'blocking_column': ['a', 'aa', 'aaa']})) 27 | 28 | expected_result = pd.DataFrame({ 29 | 'sort_column': [1, 2, 3], 30 | 'blocking_column': ['a', 'aa', 'aaa'], 31 | 'block_key': ['dummy_blocking_rule:block_key', 'dummy_blocking_rule:block_key', 'dummy_blocking_rule:block_key'] 32 | }) 33 | 34 | result = ( 35 | blocking_rule 36 | .create_block_key(input_sdf) 37 | .toPandas() 38 | .sort_values(by='sort_column') 39 | .reset_index(drop=True) 40 | ) 41 | assert_frame_equal(result, expected_result) 42 | 43 | 44 | def test__create_training_block_keys(spark_session, blocking_rule): 45 | input_sdf = spark_session.createDataFrame( 46 | pd.DataFrame({ 47 | 'sort_column': [1, 2, 3], 48 | 'blocking_column_1': ['a', 'aa', 'aaa'], 49 | 'blocking_column_2': ['b', 'bb', 'bbb']})) 50 | 51 | expected_result = pd.DataFrame({ 52 | 'sort_column': [1, 2, 3], 53 | 'blocking_column_1': ['a', 'aa', 'aaa'], 54 | 'blocking_column_2': ['b', 'bb', 'bbb'], 55 | 'block_key_1': ['dummy_blocking_rule:block_key', 'dummy_blocking_rule:block_key', 56 | 'dummy_blocking_rule:block_key'], 57 | 'block_key_2': ['dummy_blocking_rule:block_key', 'dummy_blocking_rule:block_key', 58 | 'dummy_blocking_rule:block_key'] 59 | }) 60 | 61 | result = ( 62 | blocking_rule 63 | ._create_training_block_keys(input_sdf) 64 | .toPandas() 65 | .sort_values(by='sort_column') 66 | .reset_index(drop=True) 67 | ) 68 | assert_frame_equal(result, expected_result) 69 | 70 | 71 | def test__compare_and_filter_keys(spark_session, blocking_rule): 72 | input_sdf = spark_session.createDataFrame( 73 | pd.DataFrame({ 74 | 'sort_column': [1, 2, 3, 4, 5, 6], 75 | 'block_key_1': ['blocking_rule:a', 'blocking_rule:a:a', 'blocking_rule:a:', 'blocking_rule:aaa', 76 | 'blocking_rule:', 'blocking_rule:'], 77 | 'block_key_2': ['blocking_rule:a', 'blocking_rule:a:a', 'blocking_rule:a:', 'blocking_rule:bbb', 78 | 'blocking_rule:', 'blocking_rule:bb']})) 79 | 80 | expected_result = pd.DataFrame({ 81 | 'sort_column': [1, 2, 3], 82 | 'block_key_1': ['blocking_rule:a', 'blocking_rule:a:a', 'blocking_rule:a:'], 83 | 'block_key_2': ['blocking_rule:a', 'blocking_rule:a:a', 'blocking_rule:a:']}) 84 | 85 | result = ( 86 | blocking_rule 87 | ._compare_and_filter_keys(input_sdf) 88 | .toPandas() 89 | .sort_values(by='sort_column') 90 | .reset_index(drop=True) 91 | ) 92 | print(input_sdf.show()) 93 | print('res') 94 | print(result) 95 | print('exp') 96 | print(expected_result) 97 | assert_frame_equal(result, expected_result) 98 | 99 | 100 | def test_calculate_training_set_coverage(spark_session, blocking_rule, monkeypatch): 101 | # monkey_patch the inner functions of `calculate_training_set_coverage`, since they are tested separately 102 | monkeypatch.setattr(blocking_rule, "_create_training_block_keys", lambda _sdf: _sdf) 103 | monkeypatch.setattr(blocking_rule, "_compare_and_filter_keys", lambda _sdf: _sdf) 104 | 105 | row_ids = [0, 1, 2, 3, 4, 5] 106 | input_sdf = spark_session.createDataFrame( 107 | pd.DataFrame({'row_id': row_ids})) 108 | 109 | expected_result = set(row_ids) 110 | 111 | blocking_rule.calculate_training_set_coverage(input_sdf) 112 | 113 | assert blocking_rule.training_coverage == expected_result 114 | assert blocking_rule.training_coverage_size == len(expected_result) 115 | 116 | 117 | @pytest.mark.parametrize('n', [2, 3]) 118 | def test_firstnchars(spark_session, n): 119 | if n == 2: 120 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark', 'aa bb'], 121 | 'block_key': ["first_2_characters_name:aa", 122 | "first_2_characters_name:sp", 123 | "first_2_characters_name:aa"]}) 124 | elif n == 3: 125 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark', 'aa bb'], 126 | 'block_key': [None, 127 | "first_3_characters_name:spa", 128 | "first_3_characters_name:aa "]}) 129 | 130 | sdf = spark_session.createDataFrame(expected_result_pdf) 131 | result_pdf = FirstNChars('name', n).create_block_key(sdf.select('name')).toPandas() 132 | assert_frame_equal(expected_result_pdf, result_pdf) 133 | 134 | 135 | @pytest.mark.parametrize('n', [2, 3]) 136 | def test_firstncharslastword(spark_session, n): 137 | if n == 2: 138 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'aa spark', 'aa bb'], 139 | 'block_key': ["first_2_characters_last_word_name:aa", 140 | "first_2_characters_last_word_name:sp", 141 | "first_2_characters_last_word_name:bb"]}) 142 | elif n == 3: 143 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'aa spark', 'aa bbb'], 144 | 'block_key': [None, 145 | "first_3_characters_last_word_name:spa", 146 | "first_3_characters_last_word_name:bbb"]}) 147 | 148 | sdf = spark_session.createDataFrame(expected_result_pdf) 149 | result_pdf = FirstNCharsLastWord('name', n).create_block_key(sdf.select('name')).toPandas() 150 | assert_frame_equal(expected_result_pdf, result_pdf) 151 | 152 | 153 | @pytest.mark.parametrize('n', [2, 3]) 154 | def test_firstncharactersfirsttokensorted(spark_session, n): 155 | if n == 2: 156 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark aa', 'aa bb', 'a bb'], 157 | 'block_key': ["first_2_characters_first_token_sorted_name:aa", 158 | "first_2_characters_first_token_sorted_name:aa", 159 | "first_2_characters_first_token_sorted_name:aa", 160 | "first_2_characters_first_token_sorted_name:bb"]}) 161 | elif n == 3: 162 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark aaa', 'bbb aa'], 163 | 'block_key': [None, 164 | "first_3_characters_first_token_sorted_name:aaa", 165 | "first_3_characters_first_token_sorted_name:bbb"]}) 166 | 167 | sdf = spark_session.createDataFrame(expected_result_pdf) 168 | result_pdf = FirstNCharactersFirstTokenSorted('name', n).create_block_key(sdf.select('name')).toPandas() 169 | assert_frame_equal(expected_result_pdf, result_pdf) 170 | 171 | 172 | @pytest.mark.parametrize('n', [2, 3]) 173 | def test_lastnchars(spark_session, n): 174 | if n == 2: 175 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark', 'aa bb'], 176 | 'block_key': ["last_2_characters_name:aa", 177 | "last_2_characters_name:rk", 178 | "last_2_characters_name:bb"]}) 179 | elif n == 3: 180 | expected_result_pdf = pd.DataFrame({'name': ['aa', 'spark', 'aa bb'], 181 | 'block_key': [None, 182 | "last_3_characters_name:ark", 183 | "last_3_characters_name: bb"]}) 184 | 185 | sdf = spark_session.createDataFrame(expected_result_pdf) 186 | result_pdf = LastNChars('name', n).create_block_key(sdf.select('name')).toPandas() 187 | assert_frame_equal(expected_result_pdf, result_pdf) 188 | 189 | 190 | def test_wholefield(spark_session): 191 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python pyspark'], 192 | 'block_key': [f"whole_field_name:python", 193 | f"whole_field_name:python pyspark"]}) 194 | 195 | sdf = spark_session.createDataFrame(expected_result_pdf) 196 | result_pdf = WholeField('name').create_block_key(sdf.select('name')).toPandas() 197 | assert_frame_equal(expected_result_pdf, result_pdf) 198 | 199 | 200 | @pytest.mark.parametrize('n', [2, 3]) 201 | def test_firstnwords(spark_session, n): 202 | if n == 2: 203 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python pyspark'], 204 | 'block_key': [None, 205 | "first_2_words_name:python pyspark"]}) 206 | elif n == 3: 207 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python py spark'], 208 | 'block_key': [None, "first_3_words_name:python py spark"]}) 209 | 210 | sdf = spark_session.createDataFrame(expected_result_pdf) 211 | result_pdf = FirstNWords('name', n).create_block_key(sdf.select('name')).toPandas() 212 | assert_frame_equal(expected_result_pdf, result_pdf) 213 | 214 | 215 | @pytest.mark.parametrize('n', [2, 3]) 216 | def test_firstnlettersnospace(spark_session, n): 217 | if n == 2: 218 | expected_result_pdf = pd.DataFrame({'name': ['python', 'p ython'], 219 | 'block_key': ["first_2_letters_name_no_space:py", 220 | "first_2_letters_name_no_space:py"]}) 221 | elif n == 3: 222 | expected_result_pdf = pd.DataFrame({'name': ['p y', 'python', 'p y thon'], 223 | 'block_key': [None, 224 | "first_3_letters_name_no_space:pyt", 225 | "first_3_letters_name_no_space:pyt"]}) 226 | 227 | sdf = spark_session.createDataFrame(expected_result_pdf) 228 | result_pdf = FirstNLettersNoSpace('name', n).create_block_key(sdf.select('name')).toPandas() 229 | assert_frame_equal(expected_result_pdf, result_pdf) 230 | 231 | 232 | def test_sortedintegers(spark_session): 233 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python 2 1'], 234 | 'block_key': [None, 235 | "sorted_integers_name:1 2"]}) 236 | 237 | sdf = spark_session.createDataFrame(expected_result_pdf) 238 | result_pdf = SortedIntegers('name').create_block_key(sdf.select('name')).toPandas() 239 | assert_frame_equal(expected_result_pdf, result_pdf) 240 | 241 | 242 | def test_firstinteger(spark_session): 243 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python 2 1'], 244 | 'block_key': [None, 245 | "first_integer_name:2"]}) 246 | 247 | sdf = spark_session.createDataFrame(expected_result_pdf) 248 | result_pdf = FirstInteger('name').create_block_key(sdf.select('name')).toPandas() 249 | assert_frame_equal(expected_result_pdf, result_pdf) 250 | 251 | 252 | def test_lastinteger(spark_session): 253 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python 2 1'], 254 | 'block_key': [None, 255 | "last_integer_name:1"]}) 256 | 257 | sdf = spark_session.createDataFrame(expected_result_pdf) 258 | result_pdf = LastInteger('name').create_block_key(sdf.select('name')).toPandas() 259 | assert_frame_equal(expected_result_pdf, result_pdf) 260 | 261 | 262 | def test_largestinteger(spark_session): 263 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python1', 'python 2 1'], 264 | 'block_key': [None, 265 | 'largest_integer_name:1', 266 | "largest_integer_name:2"]}) 267 | 268 | sdf = spark_session.createDataFrame(expected_result_pdf) 269 | result_pdf = LargestInteger('name').create_block_key(sdf.select('name')).toPandas() 270 | assert_frame_equal(expected_result_pdf, result_pdf) 271 | 272 | 273 | @pytest.mark.parametrize('n', [2, 3]) 274 | def test_nletterabbreviation(spark_session, n): 275 | if n == 2: 276 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python pyspark'], 277 | 'block_key': [None, 278 | "2_letter_abbreviation_name:pp"]}) 279 | elif n == 3: 280 | expected_result_pdf = pd.DataFrame({'name': ['python', 'python pyspark', 'python apache pyspark'], 281 | 'block_key': [None, 282 | None, 283 | "3_letter_abbreviation_name:pap"]}) 284 | sdf = spark_session.createDataFrame(expected_result_pdf) 285 | result_pdf = NLetterAbbreviation('name', n).create_block_key(sdf.select('name')).toPandas() 286 | assert_frame_equal(expected_result_pdf, result_pdf) 287 | -------------------------------------------------------------------------------- /test/test_deduplicator/test_deduplicator.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | from pandas.testing import assert_frame_equal 4 | from pyspark.sql import functions as F 5 | 6 | from spark_matcher.deduplicator import Deduplicator 7 | 8 | 9 | @pytest.fixture() 10 | def deduplicator(spark_session): 11 | return Deduplicator(spark_session, col_names=[], checkpoint_dir='mock_db') 12 | 13 | 14 | def test__create_predict_pairs_table(spark_session, deduplicator): 15 | deduplicator.col_names = ['name', 'address'] 16 | 17 | input_df = pd.DataFrame( 18 | { 19 | 'name': ['John Doe', 'Jane Doe', 'Daniel Jacks', 'Jack Sparrow', 'Donald Trump'], 20 | 'address': ['1 Square', '2 Square', '3 Main street', '4 Harbour', '5 white house'], 21 | 'block_key': ['J', 'J', 'D', 'J', 'D'], 22 | 'row_number': [1, 2, 3, 4, 5], 23 | } 24 | ) 25 | 26 | expected_df = pd.DataFrame({ 27 | 'block_key': ['J', 'J', 'J', 'D'], 28 | 'name_1': ['John Doe', 'John Doe', 'Jane Doe', 'Daniel Jacks'], 29 | 'address_1': ['1 Square', '1 Square', '2 Square', '3 Main street'], 30 | 'row_number_1': [1, 1, 2, 3], 31 | 'name_2': ['Jane Doe', 'Jack Sparrow', 'Jack Sparrow', 'Donald Trump'], 32 | 'address_2': ['2 Square', '4 Harbour', '4 Harbour', '5 white house'], 33 | 'row_number_2': [2, 4, 4, 5] 34 | }) 35 | 36 | result = ( 37 | deduplicator 38 | ._create_predict_pairs_table(spark_session.createDataFrame(input_df)) 39 | .toPandas() 40 | .sort_values(by=['address_1', 'address_2']) 41 | .reset_index(drop=True) 42 | ) 43 | assert_frame_equal(result, expected_df) 44 | 45 | 46 | def test__add_singletons_entity_identifiers(spark_session, deduplicator) -> None: 47 | df = pd.DataFrame( 48 | data={ 49 | "entity_identifier": [0, 1, None, 2, 3, 4, 5, 6], 50 | "row_number": [10, 11, 12, 13, 14, 15, 16, 17]}) 51 | 52 | input_data = ( 53 | spark_session 54 | .createDataFrame(df) 55 | .withColumn('entity_identifier', 56 | F.when(F.isnan('entity_identifier'), F.lit(None)).otherwise(F.col('entity_identifier'))) 57 | .repartition(6)) 58 | 59 | expected_df = pd.DataFrame( 60 | data={ 61 | "row_number": [10, 11, 12, 13, 14, 15, 16, 17], 62 | "entity_identifier": [0.0, 1.0, 7.0, 2.0, 3.0, 4.0, 5.0, 6.0]}) 63 | 64 | result = (deduplicator._add_singletons_entity_identifiers(input_data) 65 | .toPandas() 66 | .sort_values(by='row_number') 67 | .reset_index(drop=True)) 68 | assert_frame_equal(result, expected_df, check_dtype=False) 69 | -------------------------------------------------------------------------------- /test/test_deduplicator/test_hierarchical_clustering.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from pandas.util.testing import assert_frame_equal 4 | from pyspark.sql import SparkSession 5 | 6 | 7 | def test__convert_data_to_adjacency_matrix(): 8 | from spark_matcher.deduplicator.hierarchical_clustering import _convert_data_to_adjacency_matrix 9 | 10 | input_df = pd.DataFrame({ 11 | 'row_number_1': [1, 1], 12 | 'row_number_2': [2, 3], 13 | 'score': [0.8, 0.3], 14 | 'component_id': [1, 1]}) 15 | 16 | expected_indexes = np.array([1, 2, 3]) 17 | expected_result = np.array([[0., 0.8, 0.3], 18 | [0.8, 0., 0.], 19 | [0.3, 0., 0.]]) 20 | 21 | indexes, result = _convert_data_to_adjacency_matrix(input_df) 22 | 23 | np.testing.assert_array_equal(result, expected_result) 24 | np.testing.assert_array_equal(np.array(indexes), expected_indexes) 25 | 26 | 27 | def test__get_condensed_distances(): 28 | from spark_matcher.deduplicator.hierarchical_clustering import _get_condensed_distances 29 | 30 | input_df = pd.DataFrame({ 31 | 'row_number_1': [1, 1], 32 | 'row_number_2': [2, 3], 33 | 'score': [0.8, 0.3], 34 | 'component_id': [1, 1]}) 35 | 36 | expected_dict = {0: 1, 1: 2, 2: 3} 37 | expected_result = np.array([0.2, 0.7, 1.]) 38 | 39 | result_dict, result = _get_condensed_distances(input_df) 40 | 41 | np.testing.assert_array_almost_equal_nulp(result, expected_result, nulp=2) 42 | np.testing.assert_array_equal(result_dict, expected_dict) 43 | 44 | 45 | def test__convert_dedupe_result_to_pandas_dataframe(spark_session: SparkSession) -> None: 46 | from spark_matcher.deduplicator.hierarchical_clustering import _convert_dedupe_result_to_pandas_dataframe 47 | 48 | # case 1: with no empty data 49 | inputs = [ 50 | ((1, 3), np.array([0.96, 0.96])), 51 | ((4, 7, 8), np.array([0.95, 0.95, 0.95])), 52 | ((5, 6), np.array([0.98, 0.98]))] 53 | component_id = 112233 54 | 55 | expected_df = pd.DataFrame(data={ 56 | 'row_number': [1, 3, 4, 7, 8, 5, 6], 57 | 'entity_identifier': [f"{component_id}_0", f"{component_id}_0", 58 | f"{component_id}_1", f"{component_id}_1", f"{component_id}_1", 59 | f"{component_id}_2", f"{component_id}_2"]}) 60 | result = _convert_dedupe_result_to_pandas_dataframe(inputs, component_id).reset_index(drop=True) 61 | 62 | assert_frame_equal(result, expected_df, check_dtype=False) 63 | 64 | # case 2: with empty data 65 | inputs = [] 66 | component_id = 12345 67 | 68 | expected_df = pd.DataFrame(data={}, columns=['row_number', 'entity_identifier']).reset_index(drop=True) 69 | result = _convert_dedupe_result_to_pandas_dataframe(inputs, component_id).reset_index(drop=True) 70 | 71 | assert_frame_equal(result, expected_df, check_dtype=False) 72 | -------------------------------------------------------------------------------- /test/test_matcher/test_matcher.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | from thefuzz.fuzz import ratio 4 | from pyspark.sql import DataFrame 5 | 6 | from spark_matcher.blocker.block_learner import BlockLearner 7 | from spark_matcher.blocker.blocking_rules import FirstNChars 8 | from spark_matcher.scorer.scorer import Scorer 9 | from spark_matcher.matcher import Matcher 10 | 11 | 12 | @pytest.fixture 13 | def matcher(spark_session, table_checkpointer): 14 | block_name = FirstNChars('name', 3) 15 | block_address = FirstNChars('address', 3) 16 | blocklearner = BlockLearner([block_name, block_address], recall=1, table_checkpointer=table_checkpointer) 17 | blocklearner.fitted = True 18 | blocklearner.cover_blocking_rules = [block_name, block_address] 19 | 20 | myScorer = Scorer(spark_session) 21 | fit_df = pd.DataFrame({'similarity_metrics': [[100, 80], [50, 50]], 'y': [1, 0]}) 22 | myScorer.fit(fit_df['similarity_metrics'].values.tolist(), fit_df['y']) 23 | 24 | myMatcher = Matcher(spark_session, table_checkpointer=table_checkpointer, 25 | field_info={'name': [ratio], 'address': [ratio]}) 26 | myMatcher.blocker = blocklearner 27 | myMatcher.scorer = myScorer 28 | myMatcher.fitted_ = True 29 | 30 | return myMatcher 31 | 32 | 33 | def test__create_predict_pairs_table(spark_session, table_checkpointer): 34 | sdf_1_blocked = spark_session.createDataFrame( 35 | pd.DataFrame({'name': ['frits', 'stan', 'ahmet'], 'address': ['dam 1', 'leidseplein 2', 'rokin 3'], 36 | 'block_key': ['fr', 'st', 'ah']})) 37 | sdf_2_blocked = spark_session.createDataFrame( 38 | pd.DataFrame({'name': ['frits', 'fred', 'glenn'], 39 | 'address': ['amsterdam', 'waterlooplein 4', 'rembrandtplein 5'], 40 | 'block_key': ['fr', 'fr', 'gl']})) 41 | myMatcher = Matcher(spark_session, table_checkpointer=table_checkpointer, col_names=['name', 'address']) 42 | result = myMatcher._create_predict_pairs_table(sdf_1_blocked, sdf_2_blocked) 43 | assert isinstance(result, DataFrame) 44 | assert result.count() == 2 45 | assert result.select('block_key').drop_duplicates().count() == 1 46 | assert set(result.columns) == {'block_key', 'name_2', 'address_2', 'name_1', 'address_1'} 47 | 48 | def test_predict(spark_session, matcher): 49 | sdf_1 = spark_session.createDataFrame(pd.DataFrame({'name': ['frits', 'stan', 'ahmet', 'ahmet', 'ahmet'], 50 | 'address': ['damrak', 'leidseplein', 'waterlooplein', 51 | 'amstel', 'amstel 3']})) 52 | 53 | sdf_2 = spark_session.createDataFrame( 54 | pd.DataFrame({'name': ['frits h', 'stan l', 'bayraktar', 'ahmet', 'ahmet'], 55 | 'address': ['damrak 1', 'leidseplein 2', 'amstel 3', 'amstel 3', 'waterlooplein 12324']})) 56 | 57 | result = matcher.predict(sdf_1, sdf_2) 58 | assert result 59 | -------------------------------------------------------------------------------- /test/test_matching_base/test_matching_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from string import ascii_lowercase 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import pytest 7 | from pandas.testing import assert_frame_equal 8 | from thefuzz.fuzz import token_set_ratio 9 | 10 | from spark_matcher.blocker.blocking_rules import BlockingRule 11 | from spark_matcher.matching_base.matching_base import MatchingBase 12 | 13 | 14 | def create_fake_string(length=2): 15 | return "".join(np.random.choice(list(ascii_lowercase), size=length)) 16 | 17 | 18 | def create_fake_df(size=10): 19 | return pd.DataFrame.from_dict({'name': [create_fake_string(1) + ' ' + create_fake_string(2) for _ in range(size)], 20 | 'address': [create_fake_string(2) + ' ' + create_fake_string(3) for _ in range(size)]}) 21 | 22 | 23 | @pytest.fixture() 24 | def matching_base(spark_session, table_checkpointer): 25 | return MatchingBase(spark_session, table_checkpointer=table_checkpointer, col_names=['name', 'address']) 26 | 27 | 28 | def test__create_train_pairs_table(spark_session, matching_base): 29 | n_1, n_2, n_train_samples, n_perfect_train_matches = (1000, 100, 5000, 1) 30 | sdf_1 = spark_session.createDataFrame(create_fake_df(size=n_1)) 31 | sdf_2 = spark_session.createDataFrame(create_fake_df(size=n_2)) 32 | 33 | matching_base.n_train_samples = n_train_samples 34 | matching_base.n_perfect_matches = n_perfect_train_matches 35 | 36 | result = matching_base._create_train_pairs_table(sdf_1, sdf_2) 37 | 38 | assert set(result.columns) == {'name_2', 'address_2', 'name_1', 'address_1', 'perfect_train_match'} 39 | assert result.count() == pytest.approx(n_train_samples, abs=n_perfect_train_matches) 40 | 41 | 42 | def test_default_blocking_rules(matching_base): 43 | assert isinstance(matching_base.blocking_rules, list) 44 | assert isinstance(matching_base.blocking_rules[0], BlockingRule) 45 | 46 | 47 | def test_load_save(spark_session, tmpdir, matching_base, table_checkpointer): 48 | def my_metric(x, y): 49 | return float(x == y) 50 | 51 | matching_base.field_info = {'name': [token_set_ratio, my_metric]} 52 | matching_base.n_perfect_train_matches = 5 53 | matching_base.save(os.path.join(tmpdir, 'matcher.pkl')) 54 | 55 | myMatcherLoaded = MatchingBase(spark_session, table_checkpointer=table_checkpointer, col_names=['nothing']) 56 | myMatcherLoaded.load(os.path.join(tmpdir, 'matcher.pkl')) 57 | setattr(table_checkpointer, 'spark_session', spark_session) # needed to be able to continue with other unit tests 58 | 59 | assert matching_base.col_names == myMatcherLoaded.col_names 60 | assert [x.__name__ for x in matching_base.field_info['name']] == [x.__name__ for x in 61 | myMatcherLoaded.field_info['name']] 62 | assert matching_base.n_perfect_train_matches == myMatcherLoaded.n_perfect_train_matches 63 | 64 | 65 | @pytest.mark.parametrize('suffix', [1, 2, 3]) 66 | def test__add_suffix_to_col_names(spark_session, suffix, matching_base): 67 | input_df = pd.DataFrame({ 68 | 'name': ['John Doe', 'Jane Doe', 'Chris Rock', 'Jack Sparrow'], 69 | 'address': ['Square 1', 'Square 2', 'Main street', 'Harbour 123'], 70 | 'block_key': ['1', '2', '3', '3'] 71 | }) 72 | 73 | result = ( 74 | matching_base 75 | ._add_suffix_to_col_names(spark_session.createDataFrame(input_df), suffix) 76 | .toPandas() 77 | .reset_index(drop=True) 78 | ) 79 | assert_frame_equal(result, input_df.rename(columns={'name': f'name_{suffix}', 'address': f'address_{suffix}'})) 80 | -------------------------------------------------------------------------------- /test/test_sampler/test_training_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from pandas.testing import assert_frame_equal 5 | from string import ascii_lowercase 6 | 7 | from spark_matcher.sampler.training_sampler import HashSampler, RandomSampler 8 | 9 | 10 | def create_fake_string(length=2): 11 | return "".join(np.random.choice(list(ascii_lowercase), size=length)) 12 | 13 | 14 | def create_fake_df(size=10): 15 | return pd.DataFrame.from_dict({'name': [create_fake_string(1) + ' ' + create_fake_string(2) for _ in range(size)], 16 | 'surname': [create_fake_string(2) + ' ' + create_fake_string(3) for _ in range(size)]}) 17 | 18 | 19 | @pytest.mark.parametrize('test_case', [(100_000, 10, 10_000), (100, 100_000, 10_000), (100_000, 100_000, 10_000)]) 20 | def test__create_random_pairs_table_matcher(spark_session, test_case, table_checkpointer): 21 | n_1, n_2, n_train_samples = test_case 22 | sdf_1 = spark_session.createDataFrame(create_fake_df(size=n_1)) 23 | sdf_2 = spark_session.createDataFrame(create_fake_df(size=n_2)) 24 | 25 | rSampler = RandomSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=n_train_samples) 26 | 27 | result = rSampler._create_pairs_table_matcher(sdf_1, sdf_2) 28 | 29 | assert set(result.columns) == {'name_2', 'surname_2', 'name_1', 'surname_1'} 30 | assert result.count() == pytest.approx(n_train_samples, abs=1) 31 | if n_1 > n_2: 32 | assert result.select('name_1', 'surname_1').drop_duplicates().count() == pytest.approx( 33 | n_train_samples / n_2, abs=1) 34 | assert result.select('name_2', 'surname_2').drop_duplicates().count() == n_2 35 | elif n_1 < n_2: 36 | assert result.select('name_1', 'surname_1').drop_duplicates().count() == pytest.approx(n_1, abs=1) 37 | assert result.select('name_2', 'surname_2').drop_duplicates().count() == pytest.approx( 38 | n_train_samples / n_1, abs=1) 39 | elif n_1 == n_2: 40 | assert result.select('name_2', 'surname_2').drop_duplicates().count() == pytest.approx( 41 | int(n_train_samples ** 0.5), abs=1) 42 | assert result.select('name_1', 'surname_1').drop_duplicates().count() == pytest.approx( 43 | int(n_train_samples ** 0.5), abs=1) 44 | 45 | 46 | @pytest.mark.parametrize('test_case', [(100_000, 10_000), (1_000, 10_000), (1_000, 1_000)]) 47 | def test__create_random_pairs_table_deduplicator(spark_session, test_case, table_checkpointer): 48 | n_rows, n_train_samples = test_case 49 | sdf = spark_session.createDataFrame(create_fake_df(size=n_rows)) 50 | 51 | rSampler = RandomSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=n_train_samples) 52 | 53 | result = rSampler._create_pairs_table_deduplicator(sdf) 54 | 55 | assert set(result.columns) == {'name_2', 'surname_2', 'name_1', 'surname_1'} 56 | assert result.count() == n_train_samples 57 | 58 | 59 | def test__create_hashed_pairs_table_matcher(spark_session, table_checkpointer): 60 | n_1, n_2, n_train_samples = (1000, 1000, 1000) 61 | sdf_1 = spark_session.createDataFrame(create_fake_df(size=n_1)) 62 | sdf_2 = spark_session.createDataFrame(create_fake_df(size=n_2)) 63 | 64 | hSampler = HashSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=n_train_samples, threshold=1) 65 | 66 | result = hSampler.create_pairs_table(sdf_1, sdf_2) 67 | 68 | assert set(result.columns) == {'name_2', 'surname_2', 'name_1', 'surname_1'} 69 | assert result.count() == n_train_samples 70 | 71 | 72 | def test__create_hashed_pairs_table_deduplicator(spark_session, table_checkpointer): 73 | n_1, n_train_samples = (1000, 1000) 74 | sdf_1 = spark_session.createDataFrame(create_fake_df(size=n_1)) 75 | 76 | hSampler = HashSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=n_train_samples, threshold=1) 77 | 78 | result = hSampler.create_pairs_table(sdf_1) 79 | 80 | assert set(result.columns) == {'name_2', 'surname_2', 'name_1', 'surname_1'} 81 | assert result.count() == n_train_samples 82 | 83 | 84 | def test__vectorize_matcher(spark_session, table_checkpointer): 85 | input_sdf_1 = spark_session.createDataFrame( 86 | pd.DataFrame({ 87 | 'name': ['aa bb cc', 'cc dd'], 88 | 'surname': ['xx yy', 'yy zz xx'] 89 | })) 90 | 91 | input_sdf_2 = spark_session.createDataFrame( 92 | pd.DataFrame({ 93 | 'name': ['bb cc'], 94 | 'surname': ['yy'] 95 | })) 96 | 97 | hSampler = HashSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=999) 98 | 99 | # set max_df to a large value because otherwise nothing would be tokenized 100 | result_sdf_1, result_sdf_2 = hSampler._vectorize(['name', 'surname'], input_sdf_1, input_sdf_2, max_df=10) 101 | 102 | result_df_1 = result_sdf_1.toPandas() 103 | result_df_2 = result_sdf_2.toPandas() 104 | 105 | assert set(result_df_1.columns) == {'name', 'surname', 'features'} 106 | assert set(result_df_2.columns) == {'name', 'surname', 'features'} 107 | assert result_df_1.shape[0] == 2 108 | assert result_df_2.shape[0] == 1 109 | assert len(result_df_1.features[0]) == 4 # note that not all tokens occur at least twice (minDF=2) 110 | assert len(result_df_2.features[0]) == 4 # note that not all tokens occur at least twice (minDF=2) 111 | 112 | 113 | def test__vectorize_deduplicator(spark_session, table_checkpointer): 114 | input_sdf = spark_session.createDataFrame( 115 | pd.DataFrame({ 116 | 'name': ['aa bb cc', 'cc dd', 'bb cc'], 117 | 'surname': ['xx yy', 'yy zz xx', 'yy zz'] 118 | })) 119 | 120 | hSampler = HashSampler(table_checkpointer, col_names=['name', 'surname'], n_train_samples=999) 121 | 122 | # set max_df to a large value because otherwise nothing would be tokenized 123 | result_sdf = hSampler._vectorize(['name', 'surname'], input_sdf, max_df=10) 124 | 125 | result_df = result_sdf.toPandas() 126 | 127 | assert set(result_df.columns) == {'name', 'surname', 'features'} 128 | assert result_df.shape[0] == 3 129 | assert len(result_df.features[0]) == 5 # note that not all tokens occur at least twice (minDF=2) 130 | -------------------------------------------------------------------------------- /test/test_scorer/test_scorer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pandas as pd 3 | import numpy as np 4 | from pyspark.sql import Column 5 | from sklearn.base import BaseEstimator 6 | 7 | from spark_matcher.scorer.scorer import Scorer 8 | 9 | 10 | @pytest.fixture 11 | def scorer(spark_session): 12 | class DummyScorer(Scorer): 13 | def __init__(self, spark_session=spark_session): 14 | super().__init__(spark_session) 15 | 16 | def _predict_proba(self, X): 17 | return np.array([[0, 1]]) 18 | 19 | def _predict_proba_spark(self, X): 20 | return spark_session.createDataFrame(pd.DataFrame({'1': [1]}))['1'] 21 | 22 | return DummyScorer() 23 | 24 | 25 | def test_fit(scorer): 26 | # case 1, the scorer should be able to be 'fitted' without execptions even if there is only one class: 27 | X = np.array([[1, 2, 3]]) 28 | y = pd.Series([0]) 29 | scorer.fit(X, y) 30 | 31 | # case 2: 32 | X = np.array([[1, 2, 3], [4, 5, 6]]) 33 | y = pd.Series([0, 1]) 34 | scorer.fit(X, y) 35 | 36 | 37 | def test_predict_proba(spark_session, scorer): 38 | X = np.ndarray([0]) 39 | preds = scorer.predict_proba(X) 40 | assert isinstance(preds, np.ndarray) 41 | 42 | X = spark_session.createDataFrame(pd.DataFrame({'c': [0]})) 43 | preds = scorer.predict_proba(X['c']) 44 | assert isinstance(preds, Column) 45 | 46 | X = pd.DataFrame({}) 47 | with pytest.raises(ValueError) as e: 48 | scorer.predict_proba(X) 49 | assert f"{type(X)} is an unsupported datatype for X" == str(e.value) 50 | 51 | 52 | def test__create_default_clf(scorer): 53 | clf = scorer.binary_clf 54 | assert isinstance(clf, BaseEstimator) 55 | assert hasattr(clf, 'fit') 56 | assert hasattr(clf, 'predict_proba') -------------------------------------------------------------------------------- /test/test_similarity_metrics/test_similarity_metrics.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pandas.testing import assert_series_equal 3 | from thefuzz.fuzz import token_set_ratio 4 | 5 | from spark_matcher.similarity_metrics import SimilarityMetrics 6 | 7 | 8 | def test__distance_measure_pandas(): 9 | df = pd.DataFrame({'name_1': ['aa', 'bb', 'cc'], 10 | 'name_2': ['aa', 'bb c', 'dd'], 11 | 'similarity_metrics': [100, 100, 0] 12 | }) 13 | strings_1 = df['name_1'].tolist() 14 | strings_2 = df['name_2'].tolist() 15 | result = SimilarityMetrics._distance_measure_pandas(strings_1, strings_2, token_set_ratio) 16 | assert_series_equal(result, df['similarity_metrics'], check_names=False) 17 | -------------------------------------------------------------------------------- /test/test_table_checkpointer/test_table_checkpointer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | def test_parquetcheckpointer(spark_session): 6 | from spark_matcher.table_checkpointer import ParquetCheckPointer 7 | 8 | checkpointer = ParquetCheckPointer(spark_session, 'temp_database', 'checkpoint_name') 9 | 10 | pdf = pd.DataFrame({'col_1': ['1', '2', '3'], 11 | 'col_2': ['a', 'b', 'c'], 12 | }) 13 | sdf = spark_session.createDataFrame(pdf) 14 | 15 | returned_sdf = checkpointer(sdf, 'checkpoint_name') 16 | returned_pdf = returned_sdf.toPandas().sort_values(['col_1', 'col_2']) 17 | np.array_equal(pdf.values, returned_pdf.values) 18 | -------------------------------------------------------------------------------- /test/test_utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import pandas as pd 4 | 5 | 6 | @pytest.mark.parametrize('min_df', [1, 2, 0.2]) 7 | def test_get_most_frequent_words(spark_session, min_df): 8 | from spark_matcher.utils import get_most_frequent_words 9 | 10 | sdf = spark_session.createDataFrame(pd.DataFrame({'name': ['Company A ltd', 'Company B ltd', 'Company C']})) 11 | 12 | if (min_df == 1) or (min_df == 0.2): 13 | expected_result = pd.DataFrame({'words': ['Company', 'ltd', 'A', 'B', 'C'], 14 | 'count': [3, 2, 1, 1, 1], 15 | 'df': [1, 2 / 3, 1 / 3, 1 / 3, 1 / 3]}) 16 | elif min_df == 2: 17 | expected_result = pd.DataFrame({'words': ['Company', 'ltd'], 18 | 'count': [3, 2], 19 | 'df': [1, 2 / 3]}) 20 | 21 | result = get_most_frequent_words(sdf, col_name='name', min_df=min_df, top_n_words=10).sort_values(['count', \ 22 | 'words'], ascending=[False, True]).reset_index(drop=True) 23 | pd.testing.assert_frame_equal(expected_result, result) 24 | 25 | 26 | def test_remove_stop_words(spark_session): 27 | from spark_matcher.utils import remove_stopwords 28 | 29 | stopwords = ['ltd', 'bv'] 30 | 31 | expected_sdf = spark_session.createDataFrame( 32 | pd.DataFrame({'name': ['Company A ltd', 'Company B ltd', 'Company C bv', 'Company D'], 33 | 'name_wo_stopwords': ['Company A', 'Company B', 'Company C', 'Company D']})) 34 | input_cols = ['name'] 35 | 36 | result = remove_stopwords(expected_sdf.select(*input_cols), col_name='name', stopwords=stopwords) 37 | pd.testing.assert_frame_equal(result.toPandas(), expected_sdf.toPandas()) 38 | --------------------------------------------------------------------------------