├── .gitignore ├── LICENSE ├── README.md ├── [Hands_On_Tutorial] Plagiarism Detection of Essays.ipynb ├── [Hands_On_Tutorial] Semantic Search and Visualization of Abstract Sections of USPTO Patents.ipynb ├── [Tutorial] Alignment.ipynb ├── [Tutorial] Distance.ipynb ├── [Tutorial] Search.ipynb ├── [Tutorial] Similarity.ipynb ├── data └── Plagiarism_Detection_Corpus.csv ├── docs ├── Makefile ├── alignment.rst ├── conf.py ├── distance.rst ├── embedding.rst ├── hupd_example.ipynb ├── index.rst ├── make.bat ├── matching.rst ├── metrics.rst ├── miscellaneous.rst ├── plagiarism_detection.ipynb ├── requirements.txt └── similarity.rst ├── fables ├── alignment-example.png ├── similarity-example.png ├── string2string-logo.pdf ├── string2string-logo.png └── string2string-overview.png ├── readthedocs.yaml ├── requirements.txt ├── setup.py ├── string2string ├── __init__.py ├── alignment │ ├── __init__.py │ └── classical.py ├── distance │ ├── __init__.py │ └── classical.py ├── metrics │ ├── __init__.py │ ├── exact_match.py │ ├── rouge.py │ └── sbleu.py ├── misc │ ├── __init__.py │ ├── basic_functions.py │ ├── default_tokenizer.py │ ├── hash_functions.py │ ├── model_embeddings.py │ ├── plotting_functions.py │ └── word_embeddings.py ├── search │ ├── __init__.py │ ├── classical.py │ └── faiss_search.py └── similarity │ ├── __init__.py │ ├── bartscore.py │ ├── bertscore.py │ ├── classical.py │ └── cosine_similarity.py └── tests ├── README.md ├── test_alignment.py ├── test_distance.py ├── test_rogue.py ├── test_sacrebleu.py └── test_search.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.DS_Store 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # Added by MS 133 | string2string/compression/ 134 | *DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mirac Suzgun 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/alignment.rst: -------------------------------------------------------------------------------- 1 | ######### 2 | Alignment 3 | ######### 4 | 5 | The page contains the documentation for the string-to-string alignment algorithms implemented in the package. 6 | 7 | 8 | ************************** 9 | Needleman-Wunsch Algorithm 10 | ************************** 11 | 12 | .. autoclass:: string2string.alignment.NeedlemanWunsch 13 | :special-members: __init__ 14 | :members: 15 | 16 | ********************** 17 | Hirschberg's Algorithm 18 | ********************** 19 | 20 | .. autoclass:: string2string.alignment.Hirschberg 21 | :special-members: __init__ 22 | :members: 23 | 24 | ************************ 25 | Smith-Waterman Algorithm 26 | ************************ 27 | .. autoclass:: string2string.alignment.SmithWaterman 28 | :special-members: __init__ 29 | :members: 30 | 31 | ************************** 32 | Dynamic Time Warping (DTW) 33 | ************************** 34 | .. autoclass:: string2string.alignment.DTW 35 | :special-members: __init__ 36 | :members: 37 | 38 | ************************** 39 | Longest Common Subsequence 40 | ************************** 41 | .. autoclass:: string2string.alignment.LongestCommonSubsequence 42 | :special-members: __init__ 43 | :members: 44 | 45 | ************************ 46 | Longest Common Substring 47 | ************************ 48 | .. autoclass:: string2string.alignment.LongestCommonSubstring 49 | :special-members: __init__ 50 | :members: 51 | 52 | ***************************** 53 | String Alignment Parent Class 54 | ***************************** 55 | .. autoclass:: string2string.alignment.StringAlignment 56 | :special-members: __init__ 57 | :members: -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('..')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | from recommonmark.parser import CommonMarkParser 20 | import sphinx_rtd_theme 21 | from recommonmark.transform import AutoStructify 22 | # import torch 23 | 24 | source_suffix = ['.rst', '.md'] 25 | extensions = ['sphinx.ext.autodoc', 26 | 'sphinx.ext.mathjax', 27 | 'sphinx.ext.viewcode', 28 | 'sphinx.ext.coverage', 29 | 'sphinx.ext.githubpages', 30 | 'sphinx.ext.napoleon', 31 | 'nbsphinx', 32 | ] 33 | 34 | 35 | project = 'string2string' 36 | copyright = '2023, Mirac Suzgun' 37 | author = 'Mirac Suzgun' 38 | 39 | # The full version, including alpha/beta/rc tags 40 | release = '0.4' 41 | 42 | 43 | # -- General configuration --------------------------------------------------- 44 | 45 | # Add any Sphinx extension module names here, as strings. They can be 46 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 47 | # ones. 48 | 49 | # Add any paths that contain templates here, relative to this directory. 50 | templates_path = ['_templates'] 51 | 52 | # List of patterns, relative to source directory, that match files and 53 | # directories to ignore when looking for source files. 54 | # This pattern also affects html_static_path and html_extra_path. 55 | exclude_patterns = [] 56 | 57 | 58 | # -- Options for HTML output ------------------------------------------------- 59 | 60 | # The theme to use for HTML and HTML Help pages. See the documentation for 61 | # a list of builtin themes. 62 | # 63 | html_theme = "sphinx_rtd_theme" 64 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 65 | 66 | # Add any paths that contain custom static files (such as style sheets) here, 67 | # relative to this directory. They are copied after the builtin static files, 68 | # so a file named "default.css" will overwrite the builtin "default.css". 69 | # html_static_path = ['_static'] 70 | html_static_path = [] 71 | -------------------------------------------------------------------------------- /docs/distance.rst: -------------------------------------------------------------------------------- 1 | ######## 2 | Distance 3 | ######## 4 | 5 | The page contains the documentation for the string distance algorithms implemented in the package. 6 | 7 | ************************* 8 | Levenshtein Edit Distance 9 | ************************* 10 | 11 | .. autoclass:: string2string.distance.LevenshteinEditDistance 12 | :special-members: 13 | :members: 14 | 15 | **************** 16 | Hamming Distance 17 | **************** 18 | .. autoclass:: string2string.distance.HammingDistance 19 | :special-members: 20 | :members: 21 | 22 | 23 | **************************** 24 | Damerau-Levenshtein Distance 25 | **************************** 26 | .. autoclass:: string2string.distance.DamerauLevenshteinDistance 27 | :special-members: __init__ 28 | :members: 29 | 30 | 31 | ************* 32 | Jaccard Index 33 | ************* 34 | .. autoclass:: string2string.distance.JaccardIndex 35 | :special-members: __init__ 36 | :members: 37 | 38 | -------------------------------------------------------------------------------- /docs/embedding.rst: -------------------------------------------------------------------------------- 1 | ############################ 2 | Word and Sentence Embeddings 3 | ############################ 4 | 5 | This page contains the documentation about the GloVe and fastText word embeddings, as well as the language model embeddings. 6 | 7 | ********************* 8 | GloVe Word Embeddings 9 | ********************* 10 | 11 | .. autoclass:: string2string.misc.word_embeddings.GloVeEmbeddings 12 | :special-members: 13 | :members: 14 | 15 | ************************ 16 | fastText Word Embeddings 17 | ************************ 18 | .. autoclass:: string2string.misc.word_embeddings.FastTextEmbeddings 19 | :special-members: 20 | :members: 21 | 22 | ************************* 23 | Language Model Embeddings 24 | ************************* 25 | .. autoclass:: string2string.misc.model_embeddings.ModelEmbeddings 26 | :special-members: 27 | :members: -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. string2string documentation master file, created by 2 | sphinx-quickstart on Thu Mar 16 16:10:42 2023. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to string2string's documentation! 7 | ========================================= 8 | 9 | The **string2string** library is an open-source tool that offers a comprehensive suite of efficient algorithms for a broad range of string-to-string problems. It includes both traditional algorithmic solutions and recent advanced neural approaches to address various problems in pairwise string alignment, distance measurement, lexical and semantic search, and similarity analysis. Additionally, the library provides several helpful visualization tools and metrics to facilitate the interpretation and analysis of these methods. 10 | 11 | The library features notable algorithms such as the `Smith-Waterman algorithm `_ for pairwise local alignment, the `Hirschberg algorithm `_ for global alignment, the `Wagner-Fisher algorithm `_ for `edit distance `_, `BARTScore `_ and `BERTScore `_ for similarity analysis, the `Knuth-Morris-Pratt `_ algorithm for lexical search, and `Faiss `_ for `semantic search `_. Moreover, it wraps existing highly efficient and widely-used implementations of certain frameworks and metrics, such as `sacreBLEU `_ and `ROUGE `_, whenever it is appropriate and suitable. 12 | 13 | In general, the **string2string** library seeks to provide extensive coverage and increased flexibility compared to existing libraries for strings. It can be used for many downstream applications, tasks, and problems in natural-language processing, bioinformatics, and computational social sciences. With its comprehensive suite of algorithms, visualization tools, and metrics, the string2string library is a valuable resource for researchers and practitioners in various fields. 14 | 15 | Getting Started 16 | --------------- 17 | 18 | Install the string2string library by running the following command in your terminal: 19 | 20 | .. code-block:: bash 21 | 22 | pip install string2string 23 | 24 | Once the installation is complete, you can import the library and start using its functionalities. 25 | 26 | **Remark**: We recommend using Python 3.7+ for the library. 27 | 28 | 29 | Tutorials 30 | --------- 31 | 32 | .. raw:: html 33 | 34 | 42 | 43 | .. toctree:: 44 | :maxdepth: 2 45 | :caption: Main Modules: 46 | 47 | alignment 48 | distance 49 | matching 50 | similarity 51 | embedding 52 | metrics 53 | miscellaneous 54 | 55 | .. toctree:: 56 | :maxdepth: 2 57 | :caption: Tutorials: 58 | 59 | hupd_example 60 | plagiarism_detection 61 | 62 | 63 | Citation 64 | -------- 65 | 66 | .. code-block:: bibtex 67 | 68 | @article{suzgun2023string2string, 69 | title={string2string: A Modern Python Library for String-to-String Algorithms}, 70 | author={Suzgun, Mirac and Shieber, Stuart M and Jurafsky, Dan}, 71 | journal={arXiv preprint arXiv:2304.14395}, 72 | year={2023} 73 | } 74 | 75 | 76 | 77 | Thanks 78 | ------ 79 | 80 | Our project owes a debt of gratitude to the following individuals for their contributions, comments, and feedback: Federico Bianchi, Corinna Coupette, Sebastian Gehrmann, Tayfun Gür, Şule Kahraman, Deniz Keleş, Luke Melas-Kyriazi, Christopher Manning, Tolúlopé Ògúnrèmí, Alexander "Sasha" Rush, Kyle Swanson, and Garrett Tanzer. -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/matching.rst: -------------------------------------------------------------------------------- 1 | ######################### 2 | Search (Pattern Matching) 3 | ######################### 4 | 5 | The page contains the documentation for the string-to-string search algorithms implemented in the package. 6 | 7 | ************************** 8 | Naive (Brute Force) Search 9 | ************************** 10 | 11 | .. autoclass:: string2string.search.NaiveSearch 12 | :special-members: __init__ 13 | :members: 14 | 15 | 16 | *************************** 17 | Rabin-Karp Search Algorithm 18 | *************************** 19 | 20 | .. autoclass:: string2string.search.RabinKarpSearch 21 | :special-members: __init__ 22 | :members: 23 | 24 | 25 | ***************************************** 26 | Knuth-Morris-Pratt (KMP) Search Algorithm 27 | ***************************************** 28 | 29 | .. autoclass:: string2string.search.KMPSearch 30 | :special-members: __init__ 31 | :members: 32 | 33 | 34 | **************************** 35 | Boyer-Moore Search Algorithm 36 | **************************** 37 | 38 | .. autoclass:: string2string.search.BoyerMooreSearch 39 | :special-members: __init__ 40 | :members: 41 | 42 | 43 | ********************* 44 | Faiss Semantic Search 45 | ********************* 46 | .. autoclass:: string2string.search.FaissSearch 47 | :special-members: __init__ 48 | :members: -------------------------------------------------------------------------------- /docs/metrics.rst: -------------------------------------------------------------------------------- 1 | ####### 2 | Metrics 3 | ####### 4 | 5 | This page contains the documentation about the string metrics used in the library. 6 | 7 | *********** 8 | Exact Match 9 | *********** 10 | 11 | .. autoclass:: string2string.metrics.ExactMatch 12 | :special-members: 13 | :members: 14 | 15 | ***************** 16 | sacreBLEU (sBLEU) 17 | ***************** 18 | .. autoclass:: string2string.metrics.sacreBLEU 19 | :special-members: 20 | :members: 21 | 22 | ***** 23 | ROUGE 24 | ***** 25 | .. autoclass:: string2string.metrics.ROUGE 26 | :special-members: 27 | :members: 28 | 29 | ********* 30 | BERTScore 31 | ********* 32 | .. autoclass:: string2string.similarity.BERTScore 33 | :special-members: __init__ 34 | :members: 35 | 36 | ********* 37 | BARTScore 38 | ********* 39 | .. autoclass:: string2string.similarity.BARTScore 40 | :special-members: __init__ 41 | :members: -------------------------------------------------------------------------------- /docs/miscellaneous.rst: -------------------------------------------------------------------------------- 1 | ############# 2 | Miscellaneous 3 | ############# 4 | 5 | This page contains the documentation about the miscellaneous classes and functions used in the library. 6 | 7 | ********* 8 | Tokenizer 9 | ********* 10 | 11 | .. autoclass:: string2string.misc.Tokenizer 12 | :special-members: 13 | :members: 14 | 15 | ******************************** 16 | Polynomial Rolling Hash Function 17 | ******************************** 18 | 19 | .. autoclass:: string2string.misc.PolynomialRollingHash 20 | :special-members: 21 | :members: 22 | 23 | 24 | *********************** 25 | Plot Pairwise Alignment 26 | *********************** 27 | 28 | .. autoclass:: string2string.misc.plotting_functions.plot_pairwise_alignment 29 | :special-members: 30 | :members: 31 | 32 | 33 | ************ 34 | Plot Heatmap 35 | ************ 36 | 37 | .. autoclass:: string2string.misc.plotting_functions.plot_heatmap 38 | :special-members: 39 | :members: 40 | 41 | 42 | ************************************************** 43 | Generate 2D-Scatter Plot of Embeddings with Plotly 44 | ************************************************** 45 | 46 | .. autoclass:: string2string.misc.plotting_functions.plot_corpus_embeds_with_plotly 47 | :special-members: 48 | :members: -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-jinja 3 | sphinxcontrib-bibtex 4 | sphinx-rtd-theme 5 | recommonmark 6 | nbsphinx 7 | pandoc 8 | ipython 9 | torch 10 | fasttext 11 | transformers 12 | datasets 13 | bert_score 14 | rouge_score 15 | sacrebleu 16 | joblib 17 | matplotlib 18 | plotly -------------------------------------------------------------------------------- /docs/similarity.rst: -------------------------------------------------------------------------------- 1 | ########## 2 | Similarity 3 | ########## 4 | 5 | The page contains the documentation for the string-to-string similarity algorithms implemented in the package. 6 | 7 | ************************* 8 | Cosine Similarity Measure 9 | ************************* 10 | 11 | .. autoclass:: string2string.similarity.CosineSimilarity 12 | :special-members: __init__ 13 | :members: 14 | 15 | ********* 16 | BERTScore 17 | ********* 18 | .. autoclass:: string2string.similarity.BERTScore 19 | :special-members: __init__ 20 | :members: 21 | 22 | ********* 23 | BARTScore 24 | ********* 25 | .. autoclass:: string2string.similarity.BARTScore 26 | :special-members: __init__ 27 | :members: 28 | 29 | 30 | ******************************** 31 | LCSubstringSimilarity Similarity 32 | ******************************** 33 | .. autoclass:: string2string.similarity.LCSubstringSimilarity 34 | :special-members: __init__ 35 | :members: 36 | 37 | 38 | ********************************** 39 | LCSubsequenceSimilarity Similarity 40 | ********************************** 41 | .. autoclass:: string2string.similarity.LCSubsequenceSimilarity 42 | :special-members: __init__ 43 | :members: 44 | 45 | 46 | *************** 47 | Jaro Similarity 48 | *************** 49 | .. autoclass:: string2string.similarity.JaroSimilarity 50 | :special-members: __init__ 51 | :members: -------------------------------------------------------------------------------- /fables/alignment-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/string2string/c4a72f59aafe8db42c4015709078064535dc4191/fables/alignment-example.png -------------------------------------------------------------------------------- /fables/similarity-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/string2string/c4a72f59aafe8db42c4015709078064535dc4191/fables/similarity-example.png -------------------------------------------------------------------------------- /fables/string2string-logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/string2string/c4a72f59aafe8db42c4015709078064535dc4191/fables/string2string-logo.pdf -------------------------------------------------------------------------------- /fables/string2string-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/string2string/c4a72f59aafe8db42c4015709078064535dc4191/fables/string2string-logo.png -------------------------------------------------------------------------------- /fables/string2string-overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/string2string/c4a72f59aafe8db42c4015709078064535dc4191/fables/string2string-overview.png -------------------------------------------------------------------------------- /readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.9" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/conf.py 17 | 18 | # Optionally declare the Python requirements required to build your docs 19 | python: 20 | install: 21 | - requirements: docs/requirements.txt 22 | - requirements: requirements.txt 23 | 24 | 25 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | transformers 4 | datasets 5 | fasttext 6 | bert_score 7 | networkx 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open('README.md') as f: 4 | readme = f.read() 5 | 6 | setup( 7 | name="string2string", 8 | version="0.0.150", 9 | description="String-to-String Algorithms for Natural Language Processing", 10 | long_description=readme, 11 | long_description_content_type='text/markdown', 12 | url="https://github.com/stanfordnlp/string2string", 13 | author="Mirac Suzgun", 14 | author_email="msuzgun@cs.stanford.edu", 15 | license="MIT", 16 | python_requires='>=3.7', 17 | packages=find_packages(), 18 | install_requires=[ 19 | "torch", 20 | "numpy", 21 | "transformers", 22 | "datasets", 23 | "faiss-cpu==1.7.3", 24 | "bert_score", 25 | "fasttext", 26 | "pandas", 27 | "joblib", 28 | ], 29 | tests_require=["pytest"], 30 | classifiers=[ 31 | "Programming Language :: Python :: 3.7", 32 | "Programming Language :: Python :: 3.8", 33 | "Programming Language :: Python :: 3.9", 34 | "Programming Language :: Python :: 3.10", 35 | "License :: OSI Approved :: MIT License", 36 | "Operating System :: OS Independent", 37 | "Typing :: Typed", 38 | ], 39 | keywords=[ 40 | "string matching", 41 | "pattern matching", 42 | "edit distance", 43 | "string to string correction", 44 | "string to string matching", 45 | "Levenshtein edit distance", 46 | "Hamming distance", 47 | "Damerau-Levenshtein distance", 48 | "Jaro-Winkler distance", 49 | "longest common subsequence", 50 | "longest common substring", 51 | "dynamic programming", 52 | "approximate string matching", 53 | "semantic similarity", 54 | "natural language processing", 55 | "NLP", 56 | "information retrieval", 57 | "rouge", 58 | "sacrebleu", 59 | "bertscore", 60 | "bartscore", 61 | "fasttext", 62 | "glove", 63 | "cosine similarity", 64 | "Smith-Waterman", 65 | "Needleman-Wunsch", 66 | "Hirschberg", 67 | "Karp-Rabin", 68 | "Knuth-Morris-Pratt", 69 | "Boyer-Moore", 70 | ], 71 | ) 72 | -------------------------------------------------------------------------------- /string2string/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/string2string/c4a72f59aafe8db42c4015709078064535dc4191/string2string/__init__.py -------------------------------------------------------------------------------- /string2string/alignment/__init__.py: -------------------------------------------------------------------------------- 1 | # The following trick allows us to import the classes directly from the alignment module: 2 | from .classical import ( 3 | StringAlignment, 4 | NeedlemanWunsch, 5 | Hirschberg, 6 | SmithWaterman, 7 | DTW, 8 | LongestCommonSubsequence, 9 | LongestCommonSubstring, 10 | ) 11 | -------------------------------------------------------------------------------- /string2string/distance/__init__.py: -------------------------------------------------------------------------------- 1 | # The following trick allows us to import the classes directly from the distance module: 2 | from .classical import ( 3 | StringAlgs, 4 | LevenshteinEditDistance, 5 | HammingDistance, 6 | DamerauLevenshteinDistance, 7 | JaccardIndex, 8 | ) -------------------------------------------------------------------------------- /string2string/distance/classical.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains functions for computing different distance metrics between two strings. 3 | 4 | The algorithms implemented in this module include the following: 5 | (a) Levenshtein edit distance ++ 6 | (b) Hamming edit distance ++ 7 | (c) Damerau–Levenshtein edit distance ++ 8 | (d) Jaccard distance ++ 9 | (e) Jaccard similarity ++ 10 | (f) Longest common substring ++ 11 | (g) Longest common subsequence ++ 12 | """ 13 | 14 | # Import relevant libraries and dependencies 15 | from typing import List, Union, Dict, Tuple 16 | import numpy as np 17 | 18 | # Parent class for all the string algorithms implemented in this module 19 | class StringAlgs: 20 | """ 21 | This class is the parent class for all the string algorithms implemented in this module. 22 | """ 23 | # Initialize the class 24 | def __init__(self, 25 | match_weight: float = 0.0, 26 | ) -> None: 27 | # Set the match weight 28 | self.match_weight = match_weight 29 | 30 | 31 | # Levenshtein edit distance class 32 | class LevenshteinEditDistance(StringAlgs): 33 | def __init__(self, 34 | match_weight: float = 0.0, 35 | insert_weight: float = 1.0, 36 | delete_weight: float = 1.0, 37 | substitute_weight: float = 1.0, 38 | ) -> None: 39 | r""" 40 | This class initializes the Levenshtein edit distance algorithm. Levenshtein edit distance represents the minimum number of edit distance operations (insertion, deletion, and substitution) required to convert one string to another. 41 | 42 | The Levenshtein edit distance (with unit cost for each edit distance operation) is given by the following recurrence relation: 43 | 44 | .. math:: 45 | :nowrap: 46 | 47 | \begin{align} 48 | d[i, j] := \min( & d[i-1, j-1] + \texttt{mismatch}(i, j), \\ 49 | & d[i-1, j] + 1, \\ 50 | & d[i, j-1] + 1), 51 | \end{align} 52 | 53 | where :math:`\texttt{mismatch}(i, j)` is 1 if the i-th element in str1 is not equal to the j-th element in str2, and 0 otherwise. 54 | 55 | Arguments: 56 | match_weight (float): The weight of a match (default: 0.0). 57 | insert_weight (float): The weight of an insertion (default: 1.0). 58 | delete_weight (float): The weight of a deletion (default: 1.0). 59 | substitute_weight (float): The weight of a substitution (default: 1.0). 60 | 61 | Raises: 62 | AssertionError: If any of the weights are negative. 63 | """ 64 | # Set the match weight 65 | super().__init__(match_weight=match_weight) 66 | 67 | # Set the insert, delete, and substite weights 68 | self.insert_weight = insert_weight 69 | self.delete_weight = delete_weight 70 | self.substitute_weight = substitute_weight 71 | 72 | # Assert that all the weights are non-negative 73 | assert min(match_weight, insert_weight, delete_weight, substitute_weight) >= 0.0 74 | 75 | 76 | # Compute the Levenshtein edit distance between two strings using recursion 77 | def compute_recursive(self, 78 | str1: Union[str, List[str]], 79 | str2: Union[str, List[str]], 80 | ) -> float: 81 | r""" 82 | This function computes the Levenshtein edit distance between two strings (or lists of strings) using recursion. 83 | 84 | Arguments: 85 | str1 (str or list of str): The first string (or list of strings). 86 | str2 (str or list of str): The second string (or list of strings). 87 | 88 | Returns: 89 | The Levenshtein edit distance between the two strings. 90 | 91 | .. note:: 92 | * The solution presented here utilizes recursion to compute the Levenshtein edit distance between two strings. It has an exponential time complexity and is not recommended for pairs of strings with a large length. 93 | * The time complexity of this function is :math:`O(3^{m+n})`, where :math:`m` and :math:`n` are the lengths of the two strings. 94 | """ 95 | # Base case 96 | if len(str1) == 0: 97 | return len(str2) * self.insert_weight 98 | elif len(str2) == 0: 99 | return len(str1) * self.delete_weight 100 | 101 | # Compute the mismatch 102 | mismatch = 0.0 if str1[-1] == str2[-1] else self.substitute_weight 103 | 104 | # Compute the Levenshtein edit distance 105 | return min( 106 | self.compute_recursive(str1[:-1], str2[:-1]) + mismatch, 107 | self.compute_recursive(str1[:-1], str2) + self.delete_weight, 108 | self.compute_recursive(str1, str2[:-1]) + self.insert_weight, 109 | ) 110 | 111 | 112 | # Compute the Levenshtein edit distance between two strings using memoization 113 | def compute_recursive_memoization(self, 114 | str1: Union[str, List[str]], 115 | str2: Union[str, List[str]], 116 | ) -> float: 117 | r""" 118 | This function computes the Levenshtein edit distance between two strings (or lists of strings) using memoization. 119 | 120 | Arguments: 121 | str1 (str or list of str): The first string (or list of strings). 122 | str2 (str or list of str): The second string (or list of strings). 123 | 124 | Returns: 125 | The Levenshtein edit distance between the two strings. 126 | 127 | .. note:: 128 | * The solution presented here utilizes memoization to compute the Levenshtein edit distance between two strings. 129 | * The time complexity of this function is :math:`\mathcal{O}(m n)`, where :math:`m` and :math:`n` are the lengths of the two strings. 130 | """ 131 | # Initialize the memoization dictionary 132 | memoization = {} 133 | 134 | # Compute the Levenshtein edit distance 135 | return self.compute_memoization_helper(str1, str2, memoization) 136 | 137 | 138 | # Compute the Levenshtein edit distance between two strings using memoization (helper function) 139 | def compute_memoization_helper(self, 140 | str1: Union[str, List[str]], 141 | str2: Union[str, List[str]], 142 | memoization: Dict[Tuple[str, str], float], 143 | ) -> float: 144 | r""" 145 | This is a helper function that computes the Levenshtein edit distance between two strings (or lists of strings) using memoization. 146 | 147 | Arguments: 148 | str1 (str or list of str): The first string (or list of strings). 149 | str2 (str or list of str): The second string (or list of strings). 150 | memoization (dict): The memoization dictionary. 151 | 152 | Returns: 153 | The Levenshtein edit distance between the two strings. 154 | 155 | .. note:: 156 | * The solution presented here utilizes memoization to compute the Levenshtein edit distance between two strings. 157 | * One can also use the :func:`functools.lru_cache` (@lru_cache()) decorator to memoize the function calls. However, for the sake of educational purposes, we have implemented memoization using a dictionary. 158 | * The time complexity of this function is quadratic, that is :math:`\mathcal{O}(nm)`, where m and n are the lengths of the two strings. 159 | """ 160 | # Base case 161 | if len(str1) == 0: 162 | return len(str2) * self.insert_weight 163 | elif len(str2) == 0: 164 | return len(str1) * self.delete_weight 165 | 166 | # Check if the Levenshtein edit distance has already been computed 167 | if (str1, str2) in memoization: 168 | return memoization[(str1, str2)] 169 | 170 | # Compute the mismatch 171 | mismatch = 0.0 if str1[-1] == str2[-1] else self.substitute_weight 172 | 173 | # Compute the Levenshtein edit distance 174 | memoization[(str1, str2)] = min( 175 | self.compute_memoization_helper(str1[:-1], str2[:-1], memoization) + mismatch, 176 | self.compute_memoization_helper(str1[:-1], str2, memoization) + self.delete_weight, 177 | self.compute_memoization_helper(str1, str2[:-1], memoization) + self.insert_weight, 178 | ) 179 | 180 | # Return the Levenshtein edit distance 181 | return memoization[(str1, str2)] 182 | 183 | 184 | # Compute the Levenshtein edit distance between two strings using dynamic programming 185 | def compute_dynamic_programming(self, 186 | str1: Union[str, List[str]], 187 | str2: Union[str, List[str]], 188 | ) -> float: 189 | r""" 190 | This function computes the Levenshtein edit distance between two strings (or lists of strings) using dynamic programming (Wagner-Fischer algorithm). 191 | 192 | Arguments: 193 | str1 (str or list of str): The first string (or list of strings). 194 | str2 (str or list of str): The second string (or list of strings). 195 | 196 | Returns: 197 | The Levenshtein edit distance between the two strings. 198 | 199 | .. note:: 200 | * The solution presented here utilizes dynamic programming principles to compute the Levenshtein edit distance between two strings. 201 | * This solution is also known as the Wagner-Fischer algorithm. [WF1974]_ 202 | * The time complexity of this dynamic-programming-based solution is :math:`\mathcal{O}(nm)`, and the space complexity is :math:`\mathcal{O}(nm)`, where n and m are the lengths of the two strings, respectively. 203 | * However, by using only two rows of the distance matrix at a time, the space complexity of the dynamic programming solution can be reduced to :math:`\mathcal{O}(min(n, m))`. 204 | * The time complexity cannot be made strongly subquadratic time unless SETH is false. [BI2015]_ 205 | * Finally, we note that this solution can be extended to cases where each edit distance operation has a non-unit cost. 206 | 207 | .. [WF1974] Wagner, R.A. and Fischer, M.J., 1974. The string-to-string correction problem. Journal of the ACM (JACM), 21(1), pp.168-173. 208 | .. [BI2015] Backurs, A. and Indyk, P., 2015, June. Edit distance cannot be computed in strongly subquadratic time (unless SETH is false). In Proceedings of the forty-seventh annual ACM symposium on Theory of computing (pp. 51-58). 209 | """ 210 | # Lengths of strings str1 and str2, respectively. 211 | n = len(str1) 212 | m = len(str2) 213 | 214 | # Initialize the distance matrix. 215 | dist = np.zeros((n + 1, m + 1)) 216 | for i in range(1, n + 1): 217 | dist[i, 0] = self.delete_weight * i 218 | for j in range(1, m + 1): 219 | dist[0, j] = self.insert_weight * j 220 | 221 | # Dynamic programming step, where each operation has a unit cost: 222 | # d[i, j] := min(d[i-1, j-1] + mismatch(i, j), d[i-1, j] + 1, d[i, j-1] + 1), 223 | # where mismatch(i, j) is 1 if str1[i] != str2[j] and 0 otherwise. 224 | for i in range(1, n + 1): 225 | for j in range(1, m + 1): 226 | # Compute the minimum edit distance between str1[:i] and str2[:j]. 227 | dist[i, j] = min( 228 | dist[i-1, j-1] + (self.substitute_weight if str1[i-1] != str2[j-1] else self.match_weight), 229 | dist[i-1, j] + self.delete_weight, 230 | dist[i, j-1] + self.insert_weight, 231 | ) 232 | 233 | # Return the Levenshtein edit distance between str1 and str2. 234 | return dist[n, m] 235 | 236 | 237 | # Compute the Levenshtein edit distance between two strings 238 | def compute(self, 239 | str1: Union[str, List[str]], 240 | str2: Union[str, List[str]], 241 | method: str = "dynamic-programming", 242 | ) -> float: 243 | r""" 244 | This function computes the Levenshtein edit distance between two strings (or lists of strings), using the method specified by the user. 245 | 246 | Arguments: 247 | str1 (str or list of str): The first string (or list of strings). 248 | str2 (str or list of str): The second string (or list of strings). 249 | method (str): The method to use to compute the Levenshtein edit distance (default: "dynamic-programming"). 250 | 251 | Returns: 252 | The Levenshtein edit distance between the two strings. 253 | 254 | .. note:: 255 | * The method can be one of the following: 256 | * "recursive": This method computes the Levenshtein edit distance using recursion. 257 | * "recursive-memoization": This method computes the Levenshtein edit distance using recursion with memoization. 258 | * "dynamic-programming": This method computes the Levenshtein edit distance using dynamic programming (Wagner-Fischer algorithm). 259 | * By default, the method is "dynamic-programming". 260 | 261 | """ 262 | # If the method is dynamic programming, then compute the Levenshtein edit distance using dynamic programming 263 | if method == "recursive": 264 | return self.compute_recursive(str1, str2) 265 | elif method == "recursive-memoization": 266 | return self.compute_recursive_memoization(str1, str2) 267 | return self.compute_dynamic_programming(str1, str2) 268 | 269 | 270 | # Hamming (edit) distance class 271 | class HammingDistance(StringAlgs): 272 | def __init__(self, 273 | match_weight: float = 0.0, 274 | substitute_weight: float = 1.0, 275 | ) -> None: 276 | r""" 277 | This function initializes the class variables of the Hamming distance. 278 | 279 | The Hamming distance is the number of positions at which the corresponding symbols are different. [H1950]_ 280 | 281 | Arguments: 282 | match_weight (float): The weight of a match (default: 0.0). 283 | substitute_weight (float): The weight of a substitution (default: 1.0). 284 | 285 | Raises: 286 | AssertionError: If the substite weight is negative. 287 | 288 | .. note:: 289 | * The Hamming distance has a time complexity of :math:`\mathcal{O}(n)`, where :math: `n` the length of the two strings. 290 | 291 | .. [H1950] Hamming, R.W., 1968. Error detecting and error correcting codes. Bell System Technical Journal, 29(2), pp.147-160. 292 | """ 293 | # Set the match weight 294 | super().__init__(match_weight=match_weight) 295 | 296 | # Set the substite weight 297 | self.substitute_weight = substitute_weight 298 | 299 | # Assert that the substite weight is non-negative 300 | assert substitute_weight >= 0.0 301 | 302 | 303 | # Compute the Hamming distance between two strings 304 | def compute(self, 305 | str1: Union[str, List[str]], 306 | str2: Union[str, List[str]], 307 | ) -> float: 308 | """ 309 | This function computes the Hamming distance between two strings (or lists of strings). 310 | 311 | Arguments: 312 | str1 (str or list of str): The first string (or list of strings). 313 | str2 (str or list of str): The second string (or list of strings). 314 | 315 | Returns: 316 | The Hamming distance between the two strings. 317 | 318 | Raises: 319 | ValueError: If the two strings (or lists of strings) have different lengths. 320 | """ 321 | 322 | # Lengths of strings str1 and str2, respectively. 323 | n = len(str1) 324 | m = len(str2) 325 | 326 | # Assert that the two strings have the same length 327 | if n != m: 328 | raise ValueError("The two strings (or lists of strings) must have the same length.") 329 | 330 | # Compute the Hamming edit distance between str1 and str2. 331 | return sum( 332 | self.substitute_weight if str1[i] != str2[i] else self.match_weight 333 | for i in range(n) 334 | ) 335 | 336 | 337 | # Damerau-Levenshtein edit distance class 338 | class DamerauLevenshteinDistance(LevenshteinEditDistance): 339 | def __init__(self, 340 | match_weight: float = 0.0, 341 | insert_weight: float = 1.0, 342 | delete_weight: float = 1.0, 343 | substitute_weight: float = 1.0, 344 | adjacent_transpose_weight: float = 1.0, 345 | ) -> None: 346 | r""" 347 | This function initializes the class variables of the Damerau-Levenshtein distance. 348 | 349 | The Damerau-Levenshtein distance is the minimum number of insertions, deletions, substitutions, and transpositions required to transform one string into the other. [D1964]_ 350 | 351 | Arguments: 352 | match_weight (float): The weight of a match (default: 0.0). 353 | insert_weight (float): The weight of an insertion (default: 1.0). 354 | delete_weight (float): The weight of a deletion (default: 1.0). 355 | substitute_weight (float): The weight of a substitution (default: 1.0). 356 | adjacent_transpose_weight (float): The weight of an adjacent transposition (default: 1.0). 357 | 358 | Raises: 359 | AssertionError: If the insert, delete, substite, or adjacent transpose weights are negative. 360 | 361 | .. [D1964] Damerau, F.J., 1964. A technique for computer detection and correction of spelling errors. Communications of the ACM, 7(3), pp.171-176. 362 | """ 363 | # Set the weights of the distance operations 364 | super().__init__( 365 | match_weight=match_weight, 366 | insert_weight=insert_weight, 367 | delete_weight=delete_weight, 368 | substitute_weight=substitute_weight, 369 | ) 370 | 371 | # Set the adjacent transpose weight 372 | self.adjacent_transpose_weight = adjacent_transpose_weight 373 | 374 | # Assert that the adjacent transpose weight is non-negative 375 | assert adjacent_transpose_weight >= 0.0 376 | 377 | 378 | # Compute the Damerau-Levenshtein edit distance between two strings 379 | def compute(self, 380 | str1: Union[str, List[str]], 381 | str2: Union[str, List[str]], 382 | ) -> float: 383 | """ 384 | This function computes the Damerau-Levenshtein edit distance between two strings (or lists of strings). 385 | 386 | Arguments: 387 | str1 (str or list of str): The first string (or list of strings). 388 | str2 (str or list of str): The second string (or list of strings). 389 | 390 | Returns: 391 | The Damerau-Levenshtein distance between the two strings. 392 | 393 | .. note:: 394 | * The Damerau-Levenshtein distance is a variant of the Levenshtein distance that allows for adjacent transpositions. 395 | * The dynamic programming solution to the Damerau-Levenshtein distance has a time complexity of :math:`\mathcal{O}(nm)`, where n and m are the lengths of the two strings. 396 | """ 397 | 398 | # Lengths of strings str1 and str2, respectively. 399 | n = len(str1) 400 | m = len(str2) 401 | 402 | # Initialize the distance matrix. 403 | dist = np.zeros((n + 1, m + 1)) 404 | for i in range(1, n + 1): 405 | dist[i, 0] = self.delete_weight * i 406 | for j in range(1, m + 1): 407 | dist[0, j] = self.insert_weight * j 408 | 409 | # Dynamic programming solution to the Damerau-Levenshtein edit distance is very similar to that of the Levenshtein edit distance. 410 | for i in range(1, n + 1): 411 | for j in range(1, m + 1): 412 | dist[i, j] = min( 413 | dist[i-1, j-1] + (self.substitute_weight if str1[i-1] != str2[j-1] else self.match_weight), 414 | dist[i-1, j] + self.delete_weight, 415 | dist[i, j-1] + self.insert_weight, 416 | ) 417 | # This is the only difference between the Damerau-Levenshtein edit distance and the Levenshtein edit distance. 418 | if i > 1 and j > 1 and str1[i-1] == str2[j-2] and str1[i-2] == str2[j-1]: 419 | dist[i, j] = min(dist[i, j], dist[i-2, j-2] + self.adjacent_transpose_weight) 420 | 421 | # Return the Damerau-Levenshtein edit distance between str1 and str2. 422 | return dist[n, m] 423 | 424 | 425 | # Jaccard index class 426 | class JaccardIndex: 427 | def __init__(self) -> None: 428 | r""" 429 | This function initializes the class variables of the Jaccard index. 430 | 431 | The Jaccard index is equal to 1.0 minus the Jaccard similarity coefficient. It is equal to 0.0 if and only if the two sets are equal. [J1938]_ 432 | 433 | .. [J1938] Jaccard, P., 1912. The Distribution of the Flora in the Alpine Zone. New Phytologist, 11(2), pp.37-50. 434 | """ 435 | pass 436 | 437 | # Compute the Jaccard index between two strings 438 | def compute(self, 439 | str1: Union[str, List[str]], 440 | str2: Union[str, List[str]], 441 | ) -> float: 442 | """ 443 | This function computes the Jaccard index between two strings (or lists of strings). 444 | 445 | Arguments: 446 | str1 (str or list of str): The first string (or list of strings). 447 | str2 (str or list of str): The second string (or list of strings). 448 | 449 | Returns: 450 | The Jaccard index between the two strings. 451 | """ 452 | # Compute the Jaccard index between str1 and str2. 453 | # The Jaccard index is, by definition, equal to 1.0 minus the Jaccard similarity coefficient. 454 | return 1. - len(set(str1).intersection(set(str2))) / len(set(str1).union(set(str2))) -------------------------------------------------------------------------------- /string2string/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # The following trick allows us to import the classes directly from the metrics module: 2 | from .exact_match import ExactMatch 3 | from .sbleu import sacreBLEU 4 | from .rouge import ROUGE -------------------------------------------------------------------------------- /string2string/metrics/exact_match.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains a class for the exact match metric. 3 | """ 4 | 5 | from typing import List, Dict 6 | 7 | # Exact match class 8 | class ExactMatch: 9 | def __init__(self) -> None: 10 | pass 11 | 12 | def compute(self, 13 | predictions: List[str], 14 | references: List[List[str]], 15 | lowercase: bool = True, 16 | ) -> Dict[str, float]: 17 | """ 18 | This function returns the exact match score between the predictions and the references. 19 | 20 | Arguments: 21 | predictions (List[str]): The list of predictions. 22 | references (List[List[str]]): The list of references. 23 | 24 | Returns: 25 | float: The exact match score. 26 | 27 | Raises: 28 | AssertionError: If the number of predictions does not match the number of references. 29 | """ 30 | 31 | # Check that the number of predictions and references are the same length and that the length is not 0 32 | assert len(predictions) == len(references) and len(predictions) > 0 33 | 34 | # Compute the exact match score 35 | num_correct = 0. 36 | for prediction, reference in zip(predictions, references): 37 | # Lowercase the prediction and reference 38 | if lowercase: 39 | prediction = prediction.lower() 40 | reference = [ref.lower() for ref in reference] 41 | 42 | # Check if the prediction is in the reference 43 | if prediction in reference: 44 | num_correct += 1 45 | 46 | # Summary of the final scores 47 | final_scores = { 48 | 'score': num_correct / len(predictions), 49 | 'num_correct': num_correct, 50 | 'num_total': len(predictions), 51 | } 52 | 53 | # Return the final scores 54 | return final_scores -------------------------------------------------------------------------------- /string2string/metrics/rouge.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains a wrapper class for the ROUGE metric. 3 | 4 | ROUGE (Recall-Oriented Understudy for Gisting Evaluation) is a set of metrics for evaluating the quality of summaries in machine translation, text summarization, and other natural language generation tasks. 5 | """ 6 | 7 | from typing import Union, List, Dict 8 | from rouge_score import rouge_scorer 9 | from rouge_score.scoring import BootstrapAggregator 10 | from string2string.misc.default_tokenizer import Tokenizer 11 | 12 | # ROUGE class 13 | class ROUGE: 14 | """ 15 | This class is a wrapper for the ROUGE metric from Google Research's rouge_score package. 16 | """ 17 | 18 | def __init__(self, 19 | tokenizer: Tokenizer = None, 20 | ) -> None: 21 | """ 22 | This function initializes the ROUGE class, which is a wrapper for the ROUGE metric from Google Research's rouge_score package. 23 | 24 | Arguments: 25 | tokenizer (Tokenizer): The tokenizer to use. Default is None. 26 | 27 | Returns: 28 | None 29 | """ 30 | # Set the tokenizer 31 | if tokenizer is None: 32 | self.tokenizer = Tokenizer(word_delimiter=' ') 33 | else: 34 | self.tokenizer = tokenizer 35 | 36 | # Compute the ROUGE score 37 | def compute(self, 38 | predictions: List[str], 39 | references: List[List[str]], 40 | rouge_types: Union[str, List[str]] = ["rouge1", "rouge2", "rougeL", "rougeLsum"], 41 | use_stemmer: bool = False, 42 | interval_name: str = 'mid', 43 | score_type: str = 'fmeasure', 44 | ) -> Dict[str, float]: 45 | """ 46 | This function returns the ROUGE score between a list of predictions and list of list of references. 47 | 48 | Arguments: 49 | predictions (List[str]): The predictions. 50 | references (List[List[str]]): The references (or ground truth strings). 51 | rouge_types (Union[str, List[str]]): The ROUGE types to use. Default is ["rouge1", "rouge2", "rougeL", "rougeLsum"]. 52 | use_stemmer (bool): Whether to use a stemmer. Default is False. 53 | interval_name (str): The interval name. Default is "mid". 54 | score_type (str): The score type. Default is "fmeasure". 55 | 56 | Returns: 57 | Dict[str, float]: The ROUGE score (between 0 and 1). 58 | 59 | Raises: 60 | ValueError: If the number of predictions does not match the number of references. 61 | ValueError: If the interval name, score type or ROUGE type is invalid. 62 | ValueError: If the prediction or reference is invalid. 63 | 64 | 65 | .. note:: 66 | * The ROUGE score is computed using the ROUGE metric from Google Research's rouge_score package. 67 | * By default, BootstrapAggregator is used to aggregate the scores. 68 | * By default, the interval name is "mid" and the score type is "fmeasure". 69 | """ 70 | 71 | # Check if the predictions and references are valid 72 | if len(predictions) != len(references): 73 | raise ValueError(f'Number of predictions ({len(predictions)}) does not match number of references ({len(references)})') 74 | 75 | # Check if the interval name is valid 76 | if interval_name not in ['low', 'mid', 'high']: 77 | raise ValueError(f'Invalid interval name: {interval_name}') 78 | 79 | # Check if the score type is valid 80 | if score_type not in ['precision', 'recall', 'fmeasure']: 81 | raise ValueError(f'Invalid score type: {score_type}') 82 | 83 | # Check if the ROUGE types are valid 84 | if not isinstance(rouge_types, list): 85 | rouge_types = [rouge_types] 86 | for rouge_type in rouge_types: 87 | if rouge_type not in ["rouge1", "rouge2", "rougeL", "rougeLsum"]: 88 | raise ValueError(f'Invalid ROUGE type: {rouge_type}') 89 | 90 | # Set the ROUGE scorer 91 | scorer = rouge_scorer.RougeScorer( 92 | rouge_types=rouge_types, 93 | use_stemmer=use_stemmer, 94 | tokenizer=self.tokenizer 95 | ) 96 | 97 | # Set the aggregator 98 | aggregator = BootstrapAggregator() 99 | 100 | # Compute the ROUGE score 101 | for prediction, reference in zip(predictions, references): 102 | # Check if the prediction and reference are valid 103 | if not isinstance(prediction, str): 104 | raise ValueError(f'Invalid prediction: {prediction}') 105 | if not isinstance(reference, list): 106 | raise ValueError(f'Invalid reference: {reference}') 107 | 108 | # Compute the ROUGE score 109 | scores = scorer.score_multi( 110 | targets=reference, 111 | prediction=prediction 112 | ) 113 | aggregator.add_scores(scores) 114 | 115 | # Aggregate the scores 116 | aggregate_score = aggregator.aggregate() 117 | 118 | # Get a summary of all the relevant BLEU score components 119 | final_scores = {rouge_type: getattr(aggregate_score[rouge_type], interval_name).__getattribute__(score_type) for rouge_type in rouge_types} 120 | 121 | # Return the final scores 122 | return final_scores -------------------------------------------------------------------------------- /string2string/metrics/sbleu.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains a wrapper class for the sacreBLEU metric from https://github.com/mjpost/sacreBLEU. 3 | """ 4 | 5 | from typing import Union, Optional, List, Dict 6 | from string2string.misc.default_tokenizer import Tokenizer 7 | from sacrebleu import corpus_bleu 8 | 9 | # Pre-defined tokenizers for sacreBLEU 10 | # This list taken from https://github.com/mjpost/sacrebleu/blob/4f4124642c4eb0b7120e50119c669f0570a326a7/sacrebleu/metrics/bleu.py#L18 11 | ALLOWED_TOKENIZERS = { 12 | 'none': 'tokenizer_none.NoneTokenizer', 13 | 'zh': 'tokenizer_zh.TokenizerZh', 14 | '13a': 'tokenizer_13a.Tokenizer13a', 15 | 'intl': 'tokenizer_intl.TokenizerV14International', 16 | 'char': 'tokenizer_char.TokenizerChar', 17 | 'ja-mecab': 'tokenizer_ja_mecab.TokenizerJaMecab', 18 | 'ko-mecab': 'tokenizer_ko_mecab.TokenizerKoMecab', 19 | 'spm': 'tokenizer_spm.TokenizerSPM', 20 | 'flores101': 'tokenizer_spm.Flores101Tokenizer', 21 | 'flores200': 'tokenizer_spm.Flores200Tokenizer', 22 | } 23 | 24 | 25 | class sacreBLEU: 26 | """ 27 | This class contains the sacreBLEU metric. 28 | """ 29 | 30 | def __init__(self) -> None: 31 | """ 32 | Initializes the BLEU class. 33 | """ 34 | pass 35 | 36 | 37 | def compute(self, 38 | predictions: List[str], 39 | references: List[List[str]], 40 | smooth_method: str = 'exp', 41 | smooth_value: Optional[float] = None, 42 | lowercase: bool = False, 43 | tokenizer_name: Optional[str] = 'none', 44 | use_effective_order: bool = False, 45 | return_only: List[str] = ['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len'] 46 | ): 47 | """ 48 | Returns the BLEU score between a list of predictions and list of list of references. 49 | 50 | Arguments: 51 | predictions (List[str]): The predictions. 52 | references (List[List[str]]): The references (or ground truth strings). 53 | smooth_method (str): The smoothing method. Default is "exp". Other options are "floor", "add-k" and "none". 54 | smooth_value (Optional[float]): The smoothing value for floor and add-k smoothing. Default is None. 55 | lowercase (bool): Whether to lowercase the text. Default is False. 56 | tokenizer_name (str): The tokenizer name. Default is "none". Other options are "zh", "13a", "intl", "char", "ja-mecab", "ko-mecab", "spm", "flores101" and "flores200". 57 | use_effective_order (bool): Whether to use the effective order. Default is False. 58 | return_only (Optional[List[str]]): The list of BLEU score components to return. Default is ['score', 'counts', 'totals', 'precisions', 'bp', 'sys_len', 'ref_len']. 59 | 60 | Returns: 61 | Dict[str, float]: The BLEU score (between 0 and 1). 62 | 63 | Raises: 64 | ValueError: If the number of predictions does not match the number of references. 65 | ValueError: If the tokenizer name is invalid. 66 | """ 67 | 68 | # Check that the number of predictions matches the number of references 69 | if len(predictions) != len(references): 70 | raise ValueError('The number of predictions does not match the number of references.') 71 | 72 | # Check that the tokenizer name is valid 73 | if tokenizer_name not in ALLOWED_TOKENIZERS: 74 | raise ValueError('The tokenizer name is invalid.') 75 | 76 | # Check that the size of each reference list is the same 77 | reference_size = len(references[0]) 78 | for reference in references: 79 | if len(reference) != reference_size: 80 | raise ValueError('The size of each reference list is not the same.') 81 | 82 | # Transform the references into a list of list of references. 83 | # This is necessary because sacrebleu.corpus_bleu expects a list of list of references. 84 | transformed_references = [[refs[i] for refs in references] for i in range(reference_size)] 85 | 86 | # Compute the BLEU score using sacrebleu.corpus_bleu 87 | # This function returns "BLEUScore(score, correct, total, precisions, bp, sys_len, ref_len)" 88 | bleu_score = corpus_bleu( 89 | hypotheses=predictions, 90 | references=transformed_references, 91 | smooth_method=smooth_method, 92 | smooth_value=smooth_value, 93 | lowercase=lowercase, 94 | use_effective_order=use_effective_order, 95 | **(dict(tokenize=ALLOWED_TOKENIZERS[tokenizer_name]) if tokenizer_name != 'none' else {}), 96 | ) 97 | 98 | # Get a summary of all the relevant BLEU score components 99 | final_scores = {k: getattr(bleu_score, k) for k in return_only} 100 | 101 | # Return the BLEU score 102 | return final_scores 103 | 104 | 105 | # predictions = ["hello there general kenobi", "foo bar foobar"] 106 | # references = [["hello there general kenobi", "hello there !"], ["foo bar foobar", "foo bar foobar"]] 107 | 108 | # sbleu = sacreBLEU() 109 | # bleu_score = sbleu.compute(predictions, references) 110 | # print(bleu_score) -------------------------------------------------------------------------------- /string2string/misc/__init__.py: -------------------------------------------------------------------------------- 1 | # The following trick allows us to import the classes directly from the util module: 2 | from .default_tokenizer import Tokenizer 3 | from .hash_functions import HashFunction, PolynomialRollingHash 4 | from .model_embeddings import ModelEmbeddings 5 | from .word_embeddings import GloVeEmbeddings, FastTextEmbeddings 6 | -------------------------------------------------------------------------------- /string2string/misc/basic_functions.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | # Take the Cartesian product of two lists of strings (or lists of lists of strings) 4 | def cartesian_product( 5 | lst1: Union[List[str], List[List[str]]], 6 | lst2: Union[List[str], List[List[str]]], 7 | boolList: bool = False, 8 | list_of_list_separator: str = " ## ", 9 | ) -> Union[List[str], List[List[str]]]: 10 | """ 11 | This function returns the Cartesian product of two lists of strings (or lists of lists of strings). 12 | 13 | Arguments: 14 | lst1: The first list of strings (or lists of lists of strings). 15 | lst2: The second list of strings (or lists of lists of strings). 16 | boolList: A boolean flag indicating whether the output should be a list of strings (or lists of lists of strings) (default: False). 17 | 18 | Returns: 19 | The Cartesian product of the two lists of strings (or lists of lists of strings). 20 | """ 21 | if lst1 == []: 22 | return lst2 23 | elif lst2 == []: 24 | return lst1 25 | return [ 26 | s1 + ("" if not (boolList) else list_of_list_separator) + s2 27 | for s1 in lst1 28 | for s2 in lst2 29 | ] -------------------------------------------------------------------------------- /string2string/misc/default_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the default tokenizer. 3 | """ 4 | 5 | from typing import List 6 | 7 | # Tokenizer class 8 | class Tokenizer: 9 | """ 10 | This class contains the tokenizer. 11 | """ 12 | 13 | def __init__(self, 14 | word_delimiter: str = " ", 15 | ): 16 | """ 17 | Initializes the Tokenizer class. 18 | 19 | Arguments: 20 | word_delimiter (str): The word delimiter. Default is " ". 21 | """ 22 | # Set the word delimiter 23 | self.word_delimiter = word_delimiter 24 | 25 | # Tokenize 26 | def tokenize(self, 27 | text: str, 28 | ) -> List[str]: 29 | """ 30 | Returns the tokens from a string. 31 | 32 | Arguments: 33 | text (str): The text to tokenize. 34 | 35 | Returns: 36 | List[str]: The tokens. 37 | """ 38 | return text.split(self.word_delimiter) 39 | 40 | # Detokenize 41 | def detokenize(self, 42 | tokens: List[str], 43 | ) -> str: 44 | """ 45 | Returns the string from a list of tokens. 46 | 47 | Arguments: 48 | tokens (List[str]): The tokens. 49 | 50 | Returns: 51 | str: The string. 52 | """ 53 | return self.word_delimiter.join(tokens) -------------------------------------------------------------------------------- /string2string/misc/hash_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the hash functions used in search algorithms. 3 | 4 | A hash function takes a string (or other object) and returns a number. 5 | The number is called the hash value, hash code, or simply the hash. The hash value is used to determine the location of the string in the hash table. 6 | - The hash function must be deterministic, meaning that the same string always produces the same hash value. 7 | - If two strings produce the same hash value, we say that the hash values collide. 8 | - The hash function must also be fast, so it is important to keep the number of operations to a minimum. 9 | """ 10 | 11 | from typing import List, Union, Tuple, Optional 12 | import numpy as np 13 | 14 | 15 | # A parent class for all hash functions 16 | class HashFunction: 17 | """ 18 | This class contains the parent class for all hash functions. 19 | """ 20 | def __init__(self): 21 | pass 22 | 23 | def compute(self, 24 | str1: str, 25 | ) -> int: 26 | """ 27 | Returns the hash value of a string. 28 | 29 | Arguments: 30 | str1 (str): The string. 31 | 32 | Returns: 33 | int: The hash value of the string. 34 | """ 35 | pass 36 | 37 | 38 | # Polynomial rolling hash function class 39 | class PolynomialRollingHash(HashFunction): 40 | """ 41 | This class contains the polynomial rolling hash function. 42 | """ 43 | 44 | def __init__(self, 45 | base: int = 10, # 256, 46 | modulus: int = 101, # 65537, 47 | ) -> None: 48 | """ 49 | Initializes the polynomial rolling hash function. 50 | 51 | Arguments: 52 | base (int): The base to use. Default is 256. 53 | modulus (int): The modulus to use. Default is 65537. 54 | 55 | Returns: 56 | None 57 | 58 | .. note:: 59 | * Why 65537? Because it is a Fermat prime. 60 | """ 61 | super().__init__() 62 | 63 | # Check the inputs 64 | assert base > 0, 'The base must be positive.' 65 | assert modulus > 0, 'The modulus must be positive.' 66 | 67 | # Set the attributes 68 | self.base = base 69 | self.modulus = modulus 70 | 71 | # Initialize the current hash value 72 | self.current_hash = 0 73 | 74 | 75 | def compute(self, 76 | str1: str, 77 | ) -> int: 78 | """ 79 | Returns the hash value of a string. 80 | 81 | Arguments: 82 | str1 (str): The string. 83 | 84 | Returns: 85 | int: The hash value of the string. 86 | """ 87 | # Compute the hash value of the string 88 | for char in str1: 89 | self.current_hash = (self.current_hash * self.base + ord(char)) % self.modulus 90 | 91 | # Return the hash value 92 | return self.current_hash 93 | 94 | 95 | def update(self, 96 | old_char: str, 97 | new_char: str, 98 | window_size: int, 99 | ) -> int: 100 | """ 101 | Updates the hash value of a string. 102 | 103 | Arguments: 104 | old_char (str): The old character. 105 | new_char (str): The new character. 106 | 107 | Returns: 108 | int: The hash value of the string. 109 | """ 110 | # Update the hash value of the string 111 | self.current_hash = (self.current_hash - ord(old_char) * (self.base ** (window_size - 1))) % self.modulus 112 | self.current_hash = (self.current_hash * self.base + ord(new_char)) % self.modulus 113 | 114 | # Return the hash value 115 | return self.current_hash 116 | 117 | 118 | def reset(self) -> None: 119 | """ 120 | Resets the hash value. 121 | 122 | Arguments: 123 | None 124 | 125 | Returns: 126 | None 127 | """ 128 | # Reset the current hash value 129 | self.current_hash = 0 -------------------------------------------------------------------------------- /string2string/misc/model_embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the ModelEmbeddings class. 3 | """ 4 | 5 | from typing import List, Union 6 | import os 7 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 8 | 9 | import torch 10 | from transformers import AutoTokenizer, AutoModel 11 | 12 | 13 | class ModelEmbeddings: 14 | """ 15 | This class is an abstract class for neural word embeddings. 16 | """ 17 | 18 | def __init__(self, 19 | model_name_or_path: str = 'facebook/bart-large', 20 | tokenizer_name_or_path: str = None, 21 | device: str = 'cpu', 22 | ) -> None: 23 | """ 24 | Constructor. 25 | 26 | Arguments: 27 | model_name_or_path (str): The name or path of the model to use (default: 'facebook/bart-large'). 28 | tokenizer (Tokenizer): The tokenizer to use (if None, the model name or path is used). 29 | device (str): The device to use (default: 'cpu'). 30 | 31 | Returns: 32 | None 33 | 34 | Raises: 35 | ValueError: If the model name or path is invalid. 36 | """ 37 | # Set the device 38 | self.device = device 39 | if self.device is None: 40 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 41 | 42 | # If the tokenizer is not specified, use the model name or path 43 | if tokenizer_name_or_path is None: 44 | tokenizer_name_or_path = model_name_or_path 45 | 46 | # Load the tokenizer 47 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) 48 | 49 | # Load the model 50 | self.model = AutoModel.from_pretrained(model_name_or_path).to(self.device) 51 | 52 | # Set the model to evaluation mode (since we do not need the gradients) 53 | self.model.eval() 54 | 55 | 56 | # Auxiliary function to get the last hidden state 57 | def get_last_hidden_state(self, 58 | embeddings: torch.Tensor, 59 | ) -> torch.Tensor: 60 | """ 61 | Returns the last hidden state (e.g., [CLS] token's) of the input embeddings. 62 | 63 | Arguments: 64 | embeddings (torch.Tensor): The input embeddings. 65 | 66 | Returns: 67 | torch.Tensor: The last hidden state. 68 | """ 69 | 70 | # Get the last hidden state 71 | last_hidden_state = embeddings.last_hidden_state 72 | 73 | # Return the last hidden state 74 | return last_hidden_state[:, 0, :] 75 | 76 | 77 | # Auxiliary function to get the mean pooling 78 | def get_mean_pooling(self, 79 | embeddings: torch.Tensor, 80 | ) -> torch.Tensor: 81 | """ 82 | Returns the mean pooling of the input embeddings. 83 | 84 | Arguments: 85 | embeddings (torch.Tensor): The input embeddings. 86 | 87 | Returns: 88 | torch.Tensor: The mean pooling. 89 | """ 90 | 91 | # Get the mean pooling 92 | mean_pooling = embeddings.last_hidden_state.mean(dim=1) 93 | 94 | # Return the mean pooling 95 | return mean_pooling 96 | 97 | # Get the embeddings 98 | def get_embeddings(self, 99 | text: Union[str, List[str]], 100 | embedding_type: str = 'last_hidden_state', 101 | ) -> torch.Tensor: 102 | """ 103 | Returns the embeddings of the input text. 104 | 105 | Arguments: 106 | text (Union[str, List[str]]): The input text. 107 | embedding_type (str, optional): The type of embedding to use. Defaults to 'last_hidden_state'. 108 | 109 | Returns: 110 | torch.Tensor: The embeddings. 111 | 112 | Raises: 113 | ValueError: If the embedding type is invalid. 114 | """ 115 | 116 | # Check if the embedding type is valid 117 | if embedding_type not in ['last_hidden_state', 'mean_pooling']: 118 | raise ValueError(f'Invalid embedding type: {embedding_type}. Only "last_hidden_state" and "mean_pooling" are supported.') 119 | 120 | # Tokenize the input text 121 | encoded_text = self.tokenizer( 122 | text, 123 | padding=True, 124 | truncation=True, 125 | return_tensors='pt', 126 | ) 127 | 128 | # Move the input text to the device 129 | encoded_text = encoded_text.to(self.device) 130 | 131 | # encoded_inputs = {k: v.to(self.device) for k, v in encoded_inputs.items()} 132 | 133 | # Get the embeddings 134 | with torch.no_grad(): 135 | embeddings = self.model(**encoded_text) 136 | 137 | # Get the proper embedding type 138 | if embedding_type == 'last_hidden_state': 139 | # Get the last hidden state 140 | embeddings = self.get_last_hidden_state(embeddings) 141 | elif embedding_type == 'mean_pooling': 142 | # Get the mean pooling 143 | embeddings = self.get_mean_pooling(embeddings) 144 | 145 | # Return the embeddings 146 | return embeddings -------------------------------------------------------------------------------- /string2string/misc/plotting_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the functions for plotting and visualizing the results. 3 | """ 4 | 5 | # Matplotlib 6 | import matplotlib 7 | import matplotlib.pyplot as plt 8 | 9 | # Plotly 10 | import plotly.graph_objects as go 11 | import plotly.express as px 12 | 13 | # Other necessary packages 14 | import numpy as np 15 | import torch 16 | from typing import List, Union, Tuple, Optional 17 | Coordinate = Union[int, float] 18 | 19 | # Plot the pairwise alignment between two strings (or lists of strings) 20 | def plot_pairwise_alignment( 21 | seq1_pieces: Union[str, List[Union[str, int, float]], np.ndarray], 22 | seq2_pieces: Union[str, List[Union[str, int, float]], np.ndarray], 23 | alignment: List[Tuple[int, int]] = [], 24 | str2colordict: Optional[dict] = None, 25 | padding_factor: float = 1.4, 26 | linewidth: float = 1.5, 27 | border_to_box: float = 0.2, 28 | title: str = 'Pairwise Alignment', 29 | seq1_name: str = 'Seq 1', 30 | seq2_name: str = 'Seq 2', 31 | show: bool = True, 32 | save: bool = False, 33 | save_path: str = 'pairwise_alignment.png', 34 | save_dpi: int = 300, 35 | save_bbox_inches: str = 'tight', 36 | ): 37 | """ 38 | This function is designed to generate a plot that displays the alignment between two given lists of characters, strings, integers, or floats (or a numpy array). To create this plot, the function takes in the two lists and a list of tuples that specifies the alignment between the two lists. 39 | 40 | Arguments: 41 | seq1_pieces (Union[str, List[Union[str, int, float], np.ndarray]]): The pieces of the first string or list of strings. 42 | seq2_pieces (Union[str, List[Union[str, int, float], np.ndarray]]): The pieces of the second string or list of strings. 43 | alignment (List[Tuple[int, int]]): The pairwise alignment between the two strings. 44 | str2colordict: Optional[dict] = None: A dictionary of colors for each character/string in the union of the two strings. 45 | padding_factor (float, optional): The factor to use for the padding (default is 1.4). 46 | linewidth (float, optional): The linewidth to use for the alignment (default is 1.5). 47 | border_to_box (float, optional): The gap between the border and the box (default is 0.2). 48 | title (str, optional): The title of the plot (default is 'Pairwise Alignment'). 49 | seq1_name (str, optional): The name of the first sequence (default is 'Seq 1'). 50 | seq2_name (str, optional): The name of the second sequence (default is 'Seq 2'). 51 | show (bool, optional): Whether to show the plot (default is True). 52 | save (bool, optional): Whether to save the plot (default is False). 53 | save_path (str, optional): The path to save the plot (default is 'pairwise_alignment.png'). 54 | save_dpi (int, optional): The dpi to use for the plot (default is 300). 55 | save_bbox_inches (str, optional): The bbox_inches to use for the plot (default is 'tight'). 56 | 57 | Returns: 58 | None 59 | 60 | .. note:: 61 | The pairwise alignment is a list of tuples of the form (i, j) where i is the index of the character in the first string and j is the index of the character in the second string. 62 | """ 63 | # Raise an error if str1 and seq2_pieces are not of the same type 64 | if type(seq1_pieces) != type(seq2_pieces): 65 | raise TypeError('seq1_pieces and seq2_pieces must be of the same type.') 66 | 67 | # Raise an error if save is True and save_path is None 68 | if save and save_path is None: 69 | raise ValueError('Save path is not specified.') 70 | 71 | # Get the length of the strings 72 | len1 = len(seq1_pieces) 73 | len2 = len(seq2_pieces) 74 | 75 | # Get the maximum length 76 | max_len = max(len1, len2) 77 | 78 | # Get the maximum length of the elements in the strings str1 and seq2_pieces 79 | max_len_chr1 = max([len(str(x)) for x in seq1_pieces]) 80 | max_len_chr2 = max([len(str(x)) for x in seq2_pieces]) 81 | max_len_chr = max(max_len_chr1, max_len_chr2) 82 | 83 | if max_len_chr > 20: 84 | raise ValueError('The maximum length of the characters in the strings must be less than 20.') 85 | 86 | # Get the scaling factor 87 | factor = 0.5 + (max_len_chr // 10.0) * 0.5 88 | 89 | # Get the x and y coordinates of the characters 90 | x_char = np.concatenate((np.arange(len1), np.arange(len2))) 91 | y_char = np.concatenate((np.zeros(len1), np.ones(len2))) 92 | 93 | # Get the characters 94 | chars = np.concatenate((np.array(list(seq1_pieces)), np.array(list(seq2_pieces)))) 95 | 96 | # Create the figure 97 | _, ax = plt.subplots(figsize=(2 * max_len * factor, 2 * 2)) 98 | 99 | # Get the alignment 100 | alignment = np.array(alignment) 101 | 102 | # Check if the alignment is not None and not empty 103 | if len(alignment) > 0: 104 | indices1 = alignment[:, 0] 105 | indices2 = alignment[:, 1] 106 | # Draw the alignment 107 | for i in range(len(indices1)): 108 | ax.plot([indices1[i], indices2[i]], [border_to_box, 1.-border_to_box], 'o-', color='#336699', linewidth=0.75, zorder=2) 109 | 110 | # Draw the characters/strings 111 | for i, char in enumerate(chars): 112 | # Get the color of the character if it is in the dictionary 113 | strip_char = char.strip() 114 | fc_color = str2colordict[strip_char] if (str2colordict is not None and strip_char in str2colordict) else (0.88, 0.94, 1.0, 1.0) 115 | ax.text( 116 | x_char[i], 117 | y_char[i], 118 | char, size=12, 119 | ha='center', va='center', 120 | bbox=dict(facecolor=fc_color, 121 | edgecolor='#336699', 122 | linewidth=linewidth, 123 | boxstyle=f'square,pad={padding_factor}', 124 | alpha=0.99, 125 | )) 126 | 127 | # Set the limits of the axes 128 | ax.set_xlim(-0.5, max_len - 0.5) 129 | ax.set_ylim(-0.5, 1.5) 130 | 131 | # Set the ticks of the axes 132 | ax.set_yticks([0, 1]) 133 | 134 | # Set the tick labels of the axes 135 | ax.set_yticklabels([seq1_name, seq2_name], fontsize=12)#, fontweight='bold') 136 | 137 | # Set the title of the axes 138 | ax.set_title(title, fontsize=14, fontweight='book') 139 | 140 | # Turn off the spines 141 | ax.spines[:].set_visible(False) 142 | 143 | # Turn off the ticks 144 | ax.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) 145 | 146 | # tight layout 147 | plt.tight_layout() 148 | 149 | # Show the plot 150 | if show: 151 | plt.show() 152 | 153 | # Save the plot 154 | if save: 155 | plt.savefig(save_path, dpi=save_dpi, bbox_inches=save_bbox_inches) 156 | 157 | 158 | # Plot a heatmap 159 | def plot_heatmap( 160 | data: Union[List[List[Union[str, int, float]]], np.ndarray], 161 | title: str = 'Heatmap', 162 | x_label: str = 'X', 163 | y_label: str = 'Y', 164 | x_ticks: List[str] = None, 165 | y_ticks: List[str] = None, 166 | colorbar_kwargs: dict = None, 167 | color_threshold: float = None, 168 | textcolors=("black", "white"), 169 | valfmt="{x:.1f}", 170 | legend: bool = False, 171 | show: bool = True, 172 | save: bool = False, 173 | save_path: str = 'heatmap.png', 174 | save_dpi: int = 300, 175 | save_bbox_inches: str = 'tight', 176 | **kwargs) -> None: 177 | """ 178 | This function creates a heatmap visualization based on a given 2D array of data. The input array can represent a variety of data structures, such as a confusion matrix or a correlation matrix, and can be represented as a list of lists or a numpy array. The resulting plot will visually represent the data in the input array using a color-coded grid. 179 | 180 | Arguments: 181 | data (Union[List[List[Union[str, int, float]]], np.ndarray]): The data to plot. 182 | title (str, optional): The title of the plot (default: 'Heatmap'). 183 | x_label (str, optional): The label of the x-axis (default: 'X'). 184 | y_label (str, optional): The label of the y-axis (default: 'Y'). 185 | x_ticks (List[str], optional): The ticks of the x-axis (default: None). 186 | y_ticks (List[str], optional): The ticks of the y-axis (default: None). 187 | colorbar_kwargs (dict, optional): The keyword arguments for the colorbar (default: None). 188 | color_threshold (float, optional): The threshold to use for the color (default: None). 189 | textcolors (tuple, optional): The colors to use for the text (default: ("black", "white")). 190 | valfmt (str, optional): The format to use for the values (default: "{x:.1f}"). 191 | legend (bool, optional): Whether to show the legend (default: False). 192 | show (bool, optional): Whether to show the plot (default: True). 193 | save (bool, optional): Whether to save the plot (default: False). 194 | save_path (str, optional): The path to save the plot (default: 'heatmap.png'). 195 | save_dpi (int, optional): The dpi to use for the plot (default: 300). 196 | save_bbox_inches (str, optional): The bbox_inches to use for the plot (default: 'tight'). 197 | **kwargs: The keyword arguments for the heatmap. 198 | """ 199 | # Create the figure and axes 200 | fig, ax = plt.subplots() 201 | 202 | # Create the heatmap 203 | im = ax.imshow(data, **kwargs) 204 | 205 | # Create the colorbar 206 | if colorbar_kwargs is None: 207 | colorbar_kwargs = {} 208 | 209 | # Create the colorbar 210 | if legend: 211 | ax.figure.colorbar(im, ax=ax, **colorbar_kwargs) 212 | 213 | # Set the x-axis label 214 | ax.set_xlabel(x_label, fontweight='medium') 215 | 216 | # Set the y-axis label 217 | ax.set_ylabel(y_label, fontweight='medium') 218 | 219 | # Set the x-axis ticks 220 | if x_ticks is not None: 221 | ax.set_xticks(np.arange(len(x_ticks))) 222 | ax.set_xticklabels(x_ticks) 223 | 224 | # Set the y-axis ticks 225 | if y_ticks is not None: 226 | ax.set_yticks(np.arange(len(y_ticks))) 227 | ax.set_yticklabels(y_ticks) 228 | 229 | # Set the tick parameters of the axes 230 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) 231 | 232 | # Rotate the tick labels and set their alignment. 233 | plt.setp(ax.get_xticklabels(), ha="center", rotation_mode="anchor") 234 | 235 | # Turn off the spines 236 | ax.spines[:].set_visible(False) 237 | 238 | # Turn off the ticks 239 | ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) 240 | ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) 241 | 242 | # Set the grid and tick parameters 243 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 244 | ax.tick_params(which="minor", bottom=False, left=False) 245 | ax.xaxis.set_label_position('top') 246 | 247 | # Color threshold for the heatmap 248 | if color_threshold is not None: 249 | color_threshold = im.norm(color_threshold) 250 | else: 251 | color_threshold = im.norm(data.max())/2. 252 | 253 | # Text annotations 254 | kw = dict(horizontalalignment="center", 255 | verticalalignment="center") 256 | 257 | # Value format for the text annotations 258 | if isinstance(valfmt, str): 259 | valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) 260 | 261 | # Loop over the data and create text annotations 262 | for i in range(data.shape[0]): 263 | for j in range(data.shape[1]): 264 | kw.update(color=textcolors[int(im.norm(data[i, j]) > color_threshold)]) 265 | im.axes.text(j, i, valfmt(data[i, j], None), **kw) 266 | 267 | # Set the title 268 | fig.suptitle(title, fontsize=14, fontweight='semibold') 269 | 270 | # Set the tight layout 271 | plt.tight_layout() 272 | 273 | # Show the plot 274 | if show: 275 | plt.show() 276 | 277 | # Save the plot 278 | if save: 279 | plt.savefig(save_path, dpi=save_dpi, bbox_inches=save_bbox_inches) 280 | 281 | 282 | def plot_corpus_embeds_with_plotly( 283 | corpus_embeddings: Union[List[List[Coordinate]], np.ndarray, torch.Tensor], 284 | corpus_labels: List[str], 285 | corpus_hover_texts: List[str], 286 | corpus_scatter_kwargs: Optional[dict] = {}, 287 | layoot_dict: Optional[dict] = None, 288 | query_embeddings: Optional[Union[List[List[Coordinate]], np.ndarray]] = None, 289 | query_labels: Optional[List[str]] = None, 290 | query_hover_texts: List[str] = None, 291 | query_modes: Optional[Union[List[str], str]] = 'markers', 292 | query_marker_dict: Optional[dict] = None, 293 | show_plot: bool = True, 294 | save_path: Optional[str] = None, 295 | ) -> go.Figure: 296 | """ 297 | The purpose of this function is to generate a 2D scatter plot using plotly, based on a given corpus of embeddings and their corresponding labels. The function takes in the embeddings and labels as input, and plots them in the scatter plot with each point represented by a particular color and shape based on its label. Additionally, the function can also take in a query embedding and its corresponding label as optional inputs, which will be plotted separately on the scatter plot with a distinct color and shape. 298 | 299 | Arguments: 300 | corpus_embeddings: A list of lists or a numpy array or a torch tensor of corpus embeddings (e.g. sentence embeddings). 301 | corpus_labels: A list of labels for the corpus embeddings. 302 | corpus_hover_texts: A list of hover texts for the corpus embeddings. 303 | corpus_scatter_kwargs: A dictionary of keyword arguments for the corpus scatter plot (e.g. marker size, marker color, etc.) (default: {}). 304 | layoot_dict: A dictionary of keyword arguments for the layout of the plot (e.g. title, x-axis title, y-axis title, etc.) (default: None). 305 | query_embeddings: A list of lists or a numpy array of query embeddings (e.g. sentence embeddings) (default: None). 306 | query_labels: A list of labels for the query embeddings (default: None). 307 | query_hover_texts: A list of hover texts for the query embeddings (default: None). 308 | query_modes: A list of modes for the query embeddings (default: 'markers'). 309 | query_marker_dict: A dictionary of keyword arguments for the query scatter plot (e.g. marker size, marker color, etc.) (default: None). 310 | show_plot: A boolean whether to show the plot (default: True). 311 | save_path: A string of the path to save the plot (e.g., 'corpus_embeddings.html') (default: None). 312 | 313 | Returns: 314 | go.Figure: A plotly figure object. 315 | 316 | .. note:: 317 | Please refer to the Hands-on Tutorial on Semantic Search with HUPD Patent Data for a good demonstration of how to use this function. 318 | """ 319 | 320 | # If the corpus_embeddings are a torch tensor or a list, we convert them to a numpy array 321 | if isinstance(corpus_embeddings, torch.Tensor): 322 | corpus_embeddings = corpus_embeddings.detach().cpu().numpy() 323 | elif isinstance(corpus_embeddings, list): 324 | corpus_embeddings = np.array(corpus_embeddings) 325 | 326 | # Let us plot the corpus embeddings 327 | fig = px.scatter(corpus_embeddings, x=0, y=1, color=corpus_labels, hover_name=corpus_hover_texts, **corpus_scatter_kwargs) 328 | 329 | # If we have query embeddings, we plot them as well 330 | if query_embeddings is not None: 331 | # If the query_embeddings are a torch tensor or a list, we convert them to a numpy array 332 | if isinstance(query_embeddings, torch.Tensor): 333 | query_embeddings = query_embeddings.detach().cpu().numpy() 334 | elif isinstance(query_embeddings, list): 335 | query_embeddings = np.array(query_embeddings) 336 | 337 | # Check if markers are specified for the query embeddings 338 | q_marker_dict = query_marker_dict if query_marker_dict is not None else dict(size=10, color='black') 339 | 340 | # If the query_modes is a string, we convert it to a list of the same length as the query_embeddings 341 | if isinstance(query_modes, str): 342 | query_modes = [query_modes] * len(query_embeddings) 343 | 344 | # Let us plot the query embeddings on top of the corpus embeddings, one by one 345 | for i, query_embedding in enumerate(query_embeddings): 346 | q_mode = query_modes[i] if query_modes is not None else 'markers' 347 | q_label = query_labels[i] if query_labels is not None else 'Query' 348 | q_hover_text = query_hover_texts[i] if query_hover_texts is not None else 'Query' 349 | 350 | fig.add_trace(go.Scatter(x=[query_embedding[0]], y=[query_embedding[1]], mode=q_mode, marker=q_marker_dict, name=q_label, hovertext=q_hover_text)) 351 | 352 | # If we have a layout dictionary, we update the figure layout with it 353 | if layoot_dict is not None: 354 | fig.update_layout(layoot_dict) 355 | 356 | # If we want to save the plot, we do it here 357 | if save_path is not None: 358 | fig.write_html(save_path) 359 | 360 | # If we want to show the plot, we do it here 361 | if show_plot: 362 | fig.show() 363 | 364 | return fig -------------------------------------------------------------------------------- /string2string/misc/word_embeddings.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements the word embeddings class. 3 | """ 4 | # from tqdm import tqdm 5 | import numpy as np 6 | from typing import List, Union 7 | import torch 8 | import os 9 | from torch import Tensor 10 | from torch.nn import functional as F 11 | import fasttext 12 | import fasttext.util 13 | from string2string.misc.default_tokenizer import Tokenizer 14 | 15 | 16 | class NeuralEmbeddings: 17 | """ 18 | This class is an abstract class for neural word embeddings. 19 | """ 20 | 21 | def __init__(self, 22 | tokenizer: Tokenizer = None, 23 | ) -> None: 24 | """ 25 | Constructor. 26 | 27 | Arguments: 28 | tokenizer (Tokenizer): The tokenizer to use. 29 | """ 30 | # Set the tokenizer 31 | if tokenizer is None: 32 | self.tokenizer = Tokenizer(word_delimiter=" ") 33 | 34 | 35 | 36 | def __call__(self, 37 | tokens: Union[List[str], str], 38 | ) -> Tensor: 39 | """ 40 | This function returns the embeddings of the given tokens. 41 | 42 | Arguments: 43 | tokens (Union[List[str], str]): The tokens to embed. 44 | 45 | Returns: 46 | Tensor: The embeddings of the given tokens. 47 | """ 48 | # Check the tokens 49 | if isinstance(tokens, str): 50 | tokens = self.tokenizer.tokenize(tokens) 51 | 52 | # Embed the tokens 53 | return self.embedding_layer(torch.tensor([self.vocabulary_dict[token] for token in tokens])) 54 | 55 | 56 | def get_embedding(self, 57 | tokens: Union[List[str], str] 58 | ) -> Tensor: 59 | """ 60 | This function returns the embeddings of the given tokens. 61 | 62 | Arguments: 63 | tokens (Union[List[str], str]): The tokens to embed. 64 | 65 | Returns: 66 | Tensor: The embeddings of the given tokens. 67 | """ 68 | return self.__call__(tokens) 69 | 70 | 71 | # GloVe embeddings class 72 | class GloVeEmbeddings(NeuralEmbeddings): 73 | """ 74 | This class implements the GloVe word embeddings. 75 | """ 76 | # Pre-trained GloVe embeddings 77 | # Source: https://github.com/stanfordnlp/GloVe#download-pre-trained-word-vectors 78 | MODEL_OPTIONS = { 79 | 'glove.6B.200d': { 80 | 'Description': 'Wikipedia 2014 + Gigaword 5 (6B tokens, 400K vocab, uncased, 300d vectors, 822 MB download)', 81 | 'URL': 'https://huggingface.co/stanfordnlp/glove/resolve/main/glove.6B.zip', 82 | }, 83 | 'glove.twitter.27B': { 84 | 'Description': 'Twitter (27B tokens, 1.2M vocab, uncased, 200d vectors, 1.42 GB download)', 85 | 'URL': 'https://huggingface.co/stanfordnlp/glove/resolve/main/glove.twitter.27B.zip', 86 | }, 87 | 'glove.42B.300d': { 88 | 'Description': 'Common Crawl (42B tokens, 1.9M vocab, uncased, 300d vectors, 1.75 GB download)', 89 | 'URL': 'https://huggingface.co/stanfordnlp/glove/resolve/main/glove.42B.300d.zip', 90 | }, 91 | 'glove.840B.300d': { 92 | 'Description': 'Common Crawl (840B tokens, 2.2M vocab, cased, 300d vectors, 2.03 GB download)', 93 | 'URL': 'https://huggingface.co/stanfordnlp/glove/resolve/main/glove.840B.300d.zip', 94 | }, 95 | } 96 | 97 | def __init__(self, 98 | model: str = 'glove.6B.200D', 99 | dim: int = 50, 100 | force_download: bool = False, 101 | dir = None, 102 | tokenizer: Tokenizer = None, 103 | ) -> None: 104 | r""" 105 | This function initializes the GloVe embeddings class. 106 | 107 | Arguments: 108 | model (str): The model to use. Default is 'glove.6B.200D'. (Options are: 'glove.6B.200D', 'glove.twitter.27B', 'glove.42B.300d', 'glove.840B.300d'.) 109 | dim (int): The dimension of the embeddings. Default is 300. 110 | force_download (bool): Whether to force download the model. Default is False. 111 | dir (str): The directory to save or load the model. Default is None. 112 | tokenizer (Tokenizer): The tokenizer to use. Default is None. 113 | 114 | Returns: 115 | None 116 | 117 | Raises: 118 | ValueError: If the model is not in the MODEL_OPTIONS [glove.6B.200D', 'glove.twitter.27B', 'glove.42B.300d', 'glove.840B.300d']. 119 | 120 | 121 | .. attention:: 122 | 123 | If you use this class, please make sure to cite the following paper: 124 | 125 | .. code-block:: latex 126 | 127 | @inproceedings{pennington2014glove, 128 | title={Glove: Global vectors for word representation}, 129 | author={Pennington, Jeffrey and Socher, Richard and Manning, Christopher D}, 130 | booktitle={Proceedings of the 2014 conference on empirical methods in natural language processing (EMNLP)}, 131 | pages={1532--1543}, 132 | year={2014} 133 | } 134 | 135 | 136 | .. note:: 137 | * If directory is None, the model will be saved in the torch hub directory. 138 | * If the model is not downloaded, it will be downloaded automatically. 139 | """ 140 | # Check model 141 | if model not in self.MODEL_OPTIONS: 142 | raise ValueError(f'Invalid model: {model}.') 143 | 144 | # Set the attributes 145 | self.model = model 146 | self.force_download = force_download 147 | self.dir = dir 148 | self.token_size = self.model.split('.')[1] 149 | self.dim = dim 150 | 151 | # Set the path 152 | if self.dir is None: 153 | self.dir = f'{torch.hub.get_dir()}/{self.model}' 154 | 155 | # Remove the trailing slash 156 | if self.dir[-1] == '/': 157 | self.dir = self.dir[:-1] 158 | 159 | # Download the embeddings if they do not exist or if force_download is True 160 | if not os.path.exists(self.dir) or self.force_download: 161 | 162 | # Create the directory if it does not exist 163 | if not (os.path.exists(self.dir)): 164 | os.system(f'mkdir {self.dir}') 165 | 166 | # Download the glove .zip file 167 | print(f'Downloading the {self.model} zip file...') 168 | torch.hub.download_url_to_file( 169 | url=self.MODEL_OPTIONS[self.model]['URL'], 170 | dst=f'{self.dir}/glove.zip', 171 | ) 172 | 173 | # Unzip the glove .txt files 174 | print(f'Unzipping the {self.model} zip file...') 175 | os.system(f'unzip {self.dir}/glove.zip -d {self.dir}') 176 | 177 | # Delete the zip file 178 | os.system(f'rm {self.dir}/glove.zip') 179 | 180 | # Process each glove .txt file and save it as a .pt file 181 | for file in os.listdir(self.dir): 182 | # Extract the words and the embeddings from the glove .txt file and save them as a .pt file 183 | 184 | # Example of a glove .txt file: 185 | # the 0.418 0.24968 -0.41242 0.1217 ... 186 | # ... 187 | # and 0.26818 0.14346 -0.27877 0.016257 ... 188 | # ... 189 | 190 | print(f'Processing {file}...') 191 | 192 | # Load the file 193 | with open(f'{self.dir}/{file}', 'r') as f: 194 | lines = f.readlines() 195 | 196 | # Extract the dimension of the embeddings from the file name (e.g. glove.6B.200d.txt -> 200) 197 | file_embed_dim = file.split('.')[2][:-1] 198 | 199 | # Extract the words and the embeddings 200 | words = [] 201 | embeddings = np.zeros((len(lines), int(file_embed_dim))) 202 | for i, line in enumerate(lines): 203 | line = line.split(' ') 204 | words.append(line[0]) 205 | embeddings[i] = np.array([float(x) for x in line[1:]]) 206 | 207 | # Convert the embeddings to a tensor 208 | embeddings = torch.from_numpy(embeddings) 209 | 210 | # Save the words and the embeddings as a .pt file 211 | torch.save(words, f'{self.dir}/{file[:-4]}.words.pt') 212 | torch.save(embeddings, f'{self.dir}/{file[:-4]}.embeddings.pt') 213 | 214 | # Delete the glove .txt files 215 | os.system(f'rm -r {self.dir}/*.txt') 216 | 217 | # Load the weights and the vocabulary 218 | weights = torch.load(f'{self.dir}/glove.{self.token_size}.{self.dim}d.embeddings.pt') 219 | vocabulary = torch.load(f'{self.dir}/glove.{self.token_size}.{self.dim}d.words.pt') 220 | 221 | # If the embeddings already exist 222 | else: 223 | # Load the weights and the vocabulary 224 | weights = torch.load(f'{self.dir}/glove.{self.token_size}.{self.dim}d.embeddings.pt') 225 | vocabulary = torch.load(f'{self.dir}/glove.{self.token_size}.{self.dim}d.words.pt') 226 | 227 | # Create the vocabulary dictionary to be fed to the embedding layer 228 | self.vocabulary_dict = {word: i for i, word in enumerate(vocabulary)} 229 | 230 | # Create the embedding layer 231 | self.embedding_layer = torch.nn.Embedding.from_pretrained( 232 | embeddings=weights, 233 | freeze=True, 234 | ) 235 | 236 | # Set the tokenizer 237 | if tokenizer is None: 238 | self.tokenizer = Tokenizer() 239 | else: 240 | self.tokenizer = tokenizer 241 | 242 | 243 | def __call__(self, 244 | tokens: Union[List[str], str], 245 | ) -> Tensor: 246 | """ 247 | This function returns the embeddings of the given tokens. 248 | 249 | Arguments: 250 | tokens (Union[List[str], str]): The tokens to embed. 251 | 252 | Returns: 253 | Tensor: The embeddings of the given tokens. 254 | """ 255 | return super().__call__(tokens) 256 | 257 | 258 | def get_embedding(self, 259 | tokens: Union[List[str], str] 260 | ) -> Tensor: 261 | r""" 262 | This function returns the embeddings of the given tokens. 263 | 264 | Arguments: 265 | tokens (Union[List[str], str]): The tokens to embed. 266 | 267 | Returns: 268 | Tensor: The embeddings of the given tokens. 269 | """ 270 | return self.__call__(tokens) 271 | 272 | 273 | # FastTextEmbeddings class 274 | class FastTextEmbeddings(NeuralEmbeddings): 275 | """ 276 | This class implements the FastText embeddings. 277 | """ 278 | def __init__(self, 279 | model: str = 'cc.en.300.bin', 280 | force_download: bool = True, 281 | dir: str = None, 282 | ) -> None: 283 | r""" 284 | This function initializes the FastTextEmbeddings class. 285 | 286 | Arguments: 287 | model (str): The model to use. Some of the available models are: 288 | 289 | - 'cc.en.300.bin': The English model trained on Common Crawl (300 dimensions) 290 | - 'cc.hi.300.bin': The Hindi model trained on Common Crawl (300 dimensions) 291 | - 'cc.fr.300.bin': The French model trained on Common Crawl (300 dimensions) 292 | - 'cc.yi.300.bin': The Yiddish model trained on Common Crawl (300 dimensions) 293 | - ... 294 | - 'wiki.en': The English model trained on Wikipedia (300 dimensions) 295 | - 'wiki.simple': The Simple English model trained on Wikipedia (300 dimensions) 296 | - 'wiki.ar': The Arabic model trained on Wikipedia (300 dimensions) 297 | - 'wiki.bg': The Bulgarian model trained on Wikipedia (300 dimensions) 298 | - 'wiki.ca': The Catalan model trained on Wikipedia (300 dimensions) 299 | - 'wiki.zh': The Chinese model trained on Wikipedia (300 dimensions) 300 | - 'wiki.sw': The Swahili model trained on Wikipedia (300 dimensions) 301 | - 'wiki.fr': The French model trained on Wikipedia (300 dimensions) 302 | - 'wiki.de': The German model trained on Wikipedia (300 dimensions) 303 | - 'wiki.es': The Spanish model trained on Wikipedia (300 dimensions) 304 | - 'wiki.it': The Italian model trained on Wikipedia (300 dimensions) 305 | - 'wiki.pt': The Portuguese model trained on Wikipedia (300 dimensions) 306 | - 'wiki.ru': The Russian model trained on Wikipedia (300 dimensions) 307 | - 'wiki.tr': The Turkish model trained on Wikipedia (300 dimensions) 308 | - 'wiki.uk': The Ukrainian model trained on Wikipedia (300 dimensions) 309 | - 'wiki.vi': The Vietnamese model trained on Wikipedia (300 dimensions) 310 | - 'wiki.id': The Indonesian model trained on Wikipedia (300 dimensions) 311 | - 'wiki.ja': The Japanese model trained on Wikipedia (300 dimensions) 312 | - ... 313 | 314 | force_download (bool): Whether to force the download of the model. Default: False. 315 | dir (str): The directory to save and load the model. 316 | 317 | Returns: 318 | None 319 | 320 | Raises: 321 | ValueError: If the given model is not available. 322 | 323 | .. attention:: 324 | 325 | If you make use of this code, please cite the following papers (depending on the model you use): 326 | 327 | .. code-block:: latex 328 | 329 | @inproceedings{mikolov2018advances, 330 | title={Advances in Pre-Training Distributed Word Representations}, 331 | author={Mikolov, Tomas and Grave, Edouard and Bojanowski, Piotr and Puhrsch, Christian and Joulin, Armand}, 332 | booktitle={Proceedings of the International Conference on Language Resources and Evaluation (LREC 2018)}, 333 | year={2018} 334 | } 335 | 336 | .. code-block:: latex 337 | 338 | @article{bojanowski2017enriching, 339 | title={Enriching Word Vectors with Subword Information}, 340 | author={Bojanowski, Piotr and Grave, Edouard and Joulin, Armand and Mikolov, Tomas}, 341 | journal={Transactions of the Association for Computational Linguistics}, 342 | volume={5}, 343 | year={2017}, 344 | issn={2307-387X}, 345 | pages={135--146} 346 | } 347 | 348 | .. code-block:: latex 349 | 350 | @article{joulin2016fasttext, 351 | title={FastText.zip: Compressing text classification models}, 352 | author={Joulin, Armand and Grave, Edouard and Bojanowski, Piotr and Douze, Matthijs and J{\'e}gou, H{\'e}rve and Mikolov, Tomas}, 353 | journal={arXiv preprint arXiv:1612.03651}, 354 | year={2016} 355 | } 356 | 357 | .. note:: 358 | 359 | * The models are downloaded from https://fasttext.cc/docs/en/english-vectors.html. 360 | * The models are saved in the torch hub directory, if no directory is specified. 361 | * 362 | """ 363 | 364 | # Set the attributes 365 | self.model = model 366 | self.dir = dir 367 | self.force_download = force_download 368 | 369 | # Set the path 370 | if self.dir is None: 371 | # For convenience, we save the model in the torch hub directory 372 | self.dir = f'{torch.hub.get_dir()}/{self.model}' 373 | 374 | # Remove the trailing slash 375 | if self.dir[-1] == '/': 376 | self.dir = self.dir[:-1] 377 | 378 | # Download the embeddings if they do not exist or if force_download is True 379 | if not os.path.exists(self.dir) or self.force_download: 380 | # Create the directory if it does not exist 381 | if not os.path.exists(self.dir): 382 | os.system(f'mkdir {self.dir}') 383 | 384 | # Download using wget 385 | if 'wiki' in model: 386 | # https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.en.zip 387 | os.system(f'wget https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/{model}.zip -P {self.dir}') 388 | os.system(f'unzip {self.dirl}.zip -d {self.dir}') 389 | os.system(f'rm {self.dir}.zip') 390 | else: 391 | # https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz 392 | os.system(f'wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/{model}.gz -P {self.dir}') 393 | os.system(f'gunzip {self.dir}.gz -d {self.dir}') 394 | os.system(f'rm {self.dir}.gz') 395 | 396 | # Load the model 397 | ft = fasttext.load_model(f'{self.dir}/{model}') 398 | 399 | # Get the vocabulary 400 | words = ft.get_words() 401 | 402 | # Convert the embeddings to a tensor 403 | embeddings =torch.tensor(ft.get_input_matrix()) 404 | 405 | # Save the words and the embeddings as a .pt file 406 | torch.save(words, f'{self.dir}/{model}.words.pt') 407 | torch.save(embeddings, f'{self.dir}/{model}.embeddings.pt') 408 | 409 | # Delete the model 410 | del ft 411 | 412 | else: 413 | try: 414 | # Load the words and the embeddings 415 | words = torch.load(f'{self.dir}/{model}.words.pt') 416 | embeddings = torch.load(f'{self.dir}/{model}.embeddings.pt') 417 | except: 418 | raise Exception(f'Please install the {model} model first by setting force_download to True.') 419 | 420 | # Create the vocabulary dictionary to be fed to the embedding layer 421 | self.vocabulary_dict = {word: i for i, word in enumerate(words)} 422 | 423 | # Create the embedding layer 424 | self.embedding_layer = torch.nn.Embedding.from_pretrained( 425 | embeddings=embeddings, 426 | freeze=True, 427 | ) 428 | 429 | def __call__(self, 430 | tokens: Union[List[str], str], 431 | ) -> Tensor: 432 | """ 433 | This function returns the embeddings of the given tokens. 434 | 435 | Arguments: 436 | tokens (Union[List[str], str]): The tokens to embed. 437 | 438 | Returns: 439 | Tensor: The embeddings of the given tokens. 440 | """ 441 | return super().__call__(tokens) 442 | 443 | 444 | def get_embedding(self, 445 | tokens: Union[List[str], str] 446 | ) -> Tensor: 447 | """ 448 | This function returns the embeddings of the given tokens. 449 | 450 | Arguments: 451 | tokens (Union[List[str], str]): The tokens to embed. 452 | 453 | Returns: 454 | Tensor: The embeddings of the given tokens. 455 | """ 456 | return self.__call__(tokens) -------------------------------------------------------------------------------- /string2string/search/__init__.py: -------------------------------------------------------------------------------- 1 | # The following trick allows us to import the classes directly from the search module: 2 | from .classical import ( 3 | SearchAlgorithm, 4 | NaiveSearch, 5 | RabinKarpSearch, 6 | KMPSearch, 7 | BoyerMooreSearch, 8 | ) 9 | from .faiss_search import FaissSearch -------------------------------------------------------------------------------- /string2string/search/classical.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the following algorithms: 3 | (-) Naive search algorithm ++ 4 | (a) Rabin-Karp algorithm ++ 5 | (b) Boyer-Moore algorithm ++ 6 | (c) Knuth-Morris-Pratt algorithm 7 | (d) Suffix Tree algorithm 8 | (e) Suffix Array algorithm 9 | (f) Suffix Automaton algorithm 10 | (g) Aho-Corasick algorithm (basis of fgrep/grep in Unix) ++ (not implemented) 11 | (h) Ukkonen's algorithm -- (not implemented) 12 | (i) Wu-Manber algorithm ++ (not implemented) 13 | (j) Z-Algorithm ++ (not implemented) 14 | """ 15 | 16 | from typing import List, Union, Tuple, Optional 17 | from string2string.misc.hash_functions import HashFunction, PolynomialRollingHash 18 | 19 | 20 | # Parent class for all search algorithms 21 | class SearchAlgorithm: 22 | """ 23 | This class contains the parent class for all search algorithms. 24 | """ 25 | 26 | def __init__(self) -> None: 27 | """ 28 | This function initializes the abstract class for all search algorithms. 29 | 30 | Returns: 31 | None 32 | """ 33 | pass 34 | 35 | def search(self, 36 | pattern: str, 37 | text: str, 38 | ) -> int: 39 | """ 40 | Searches for the pattern in a text. 41 | 42 | Arguments: 43 | pattern (str): The pattern to search for. 44 | text (str): The text to search in. 45 | 46 | Returns: 47 | int: The index of the pattern in the text. 48 | """ 49 | pass 50 | 51 | 52 | class NaiveSearch(SearchAlgorithm): 53 | """ 54 | This class contains the naive search algorithm. 55 | """ 56 | 57 | def __init__(self) -> None: 58 | """ 59 | Initializes the class. 60 | 61 | Returns: 62 | None 63 | """ 64 | super().__init__() 65 | 66 | 67 | def search(self, 68 | pattern: str, 69 | text: str, 70 | ) -> int: 71 | """ 72 | Searches for the pattern in the text. 73 | 74 | Arguments: 75 | text (str): The text to search in. 76 | 77 | Returns: 78 | int: The index of the pattern in the text (or -1 if the pattern is not found). 79 | 80 | Raises: 81 | AssertionError: If the inputs are invalid. 82 | """ 83 | # Check the inputs 84 | assert isinstance(pattern, str), 'The pattern must be a string.' 85 | assert isinstance(text, str), 'The text must be a string.' 86 | 87 | # Set the attributes 88 | self.pattern = pattern 89 | self.pattern_length = len(self.pattern) 90 | 91 | # Loop over the text 92 | for i in range(len(text) - self.pattern_length + 1): 93 | # Check if the strings match 94 | if text[i:i + self.pattern_length] == self.pattern: 95 | return i 96 | 97 | # Return -1 if the pattern is not found 98 | return -1 99 | 100 | 101 | # Rabin-Karp search algorithm class 102 | class RabinKarpSearch(SearchAlgorithm): 103 | """ 104 | This class contains the Rabin-Karp search algorithm. 105 | """ 106 | 107 | def __init__(self, 108 | hash_function: HashFunction = PolynomialRollingHash(), 109 | ) -> None: 110 | """ 111 | This function initializes the Rabin-Karp search algorithm class, which uses a hash function to search for a pattern in a text. [RK1987]_ 112 | 113 | Arguments: 114 | hash_function (HashFunction): The hash function to use. 115 | 116 | Returns: 117 | None 118 | 119 | Raises: 120 | AssertionError: If the inputs are invalid. 121 | 122 | .. [RK1987] Karp, R.M. and Rabin, M.O., 1987. Efficient randomized pattern-matching algorithms. IBM Journal of Research and Development, 31(2), pp.249-260. 123 | """ 124 | assert isinstance(hash_function, HashFunction), 'The hash function must be a HashFunction object.' 125 | 126 | # Set the attributes 127 | # self.pattern = pattern 128 | self.hash_function = hash_function 129 | 130 | # # Compute the hash value of the pattern 131 | # self.pattern_hash = self.hash_function.compute(self.pattern) 132 | 133 | # # Length of the pattern 134 | # self.pattern_length = len(self.pattern) 135 | 136 | def itialize_pattern_hash(self, 137 | pattern: str, 138 | ) -> None: 139 | """ 140 | This function initializes the pattern hash value. 141 | 142 | Arguments: 143 | pattern (str): The pattern to search for. 144 | 145 | Returns: 146 | None 147 | 148 | Raises: 149 | AssertionError: If the inputs are invalid. 150 | """ 151 | # Check the inputs 152 | assert isinstance(pattern, str), 'The pattern must be a string.' 153 | 154 | # Reset the hash function 155 | self.hash_function.reset() 156 | 157 | # Set the attributes 158 | self.pattern = pattern 159 | 160 | # Compute the hash value of the pattern 161 | self.pattern_hash = self.hash_function.compute(self.pattern) 162 | 163 | # Length of the pattern 164 | self.pattern_length = len(self.pattern) 165 | 166 | 167 | def search(self, 168 | pattern: str, 169 | text: str, 170 | ) -> int: 171 | """ 172 | This function searches for the pattern in the text. 173 | 174 | Arguments: 175 | pattern (str): The pattern to search for. 176 | text (str): The text to search in. 177 | 178 | Returns: 179 | int: The index of the pattern in the text (or -1 if the pattern is not found). 180 | 181 | Raises: 182 | AssertionError: If the inputs are invalid. 183 | 184 | 185 | """ 186 | # Check the inputs 187 | assert isinstance(text, str), 'The text must be a string.' 188 | 189 | # Initialize the pattern hash 190 | self.itialize_pattern_hash(pattern) 191 | 192 | # Reset the hash function (in case it was used before) [Important!] 193 | self.hash_function.reset() 194 | 195 | # Compute the hash value of the first window 196 | window_hash = self.hash_function.compute(text[:self.pattern_length]) 197 | 198 | # Loop over the text 199 | for i in range(len(text) - self.pattern_length + 1): 200 | # print('Window hash: {}'.format(window_hash)) 201 | 202 | # Check if the hash values match 203 | if window_hash == self.pattern_hash: 204 | # print('Hash values match at index {}.'.format(i)) 205 | j = 0 206 | # Check if the strings match 207 | while text[i + j] == self.pattern[j]: 208 | j += 1 209 | if j == self.pattern_length: 210 | return i 211 | # Update the hash value of the window 212 | if i < len(text) - self.pattern_length: 213 | window_hash = self.hash_function.update(text[i], text[i + self.pattern_length], self.pattern_length) 214 | 215 | # Return -1 if the pattern is not found 216 | return -1 217 | 218 | 219 | # Knuth-Morris-Pratt (KMP) search algorithm class 220 | class KMPSearch(SearchAlgorithm): 221 | """ 222 | This class contains the KMP search algorithm. 223 | """ 224 | 225 | def __init__(self) -> None: 226 | r""" 227 | This function initializes the Knuth-Morris-Pratt (KMP) search algorithm class. [KMP1977]_ 228 | 229 | Arguments: 230 | None 231 | 232 | Returns: 233 | None 234 | 235 | .. note:: 236 | * The current version of the KMP algorithm utilizes an auxiliary list called the lps_array, which stands for "longest proper prefix which is also a suffix". The lps_array is a list of integers where lps_array[i] represents the length of the longest proper prefix of the pattern that is also a suffix of the pattern[:i+1]. 237 | * By precomputing the lps_array, the KMP algorithm avoids unnecessary character comparisons while searching for the pattern in the text. The algorithm scans the text from left to right and compares characters in the pattern with characters in the text. When a mismatch occurs, the algorithm uses the values in the lps_array to determine the next character in the pattern to compare with the text. 238 | * An alternative implementation of the KMP algorithm exists, which uses a finite state automaton (FSA) instead of the lps_array, but this is not implemented in this version of the package. 239 | 240 | .. [KMP1977] Knuth, D.E., Morris, J.H. and Pratt, V.R., 1977. Fast pattern matching in strings. SIAM journal on computing, 6(2), pp.323-350. 241 | """ 242 | super().__init__() 243 | 244 | # Initialize_lps function 245 | def initialize_lps(self) -> None: 246 | r""" 247 | This function initializes the pongest proper prefix suffix (lps) array, which contains the length of the longest proper prefix that is also a suffix of the pattern. 248 | 249 | IOW: For each index i in the lps array, lps[i] is the length of the longest proper prefix that is also a suffix of the pattern[:i + 1]. In other words, if k = lps[i], then pattern[:k] is equal to pattern[i - k + 1:i + 1] (with the condition that pattern[:k+1] is not equal to pattern[i - k:i + 1]). The lps array is used in the Knuth-Morris-Pratt (KMP) algorithm to avoid unnecessary comparisons when searching for a pattern in a text. 250 | 251 | Arguments: 252 | pattern (str): The pattern to search for. 253 | 254 | Returns: 255 | None 256 | """ 257 | # Initialize the list of longest proper prefix which is also a suffix 258 | self.lps = [0] * self.pattern_length 259 | 260 | # Loop over the pattern 261 | i = 1 # denotes the index of the character in the pattern 262 | j = 0 # denotes the length of the longest proper prefix which is also a suffix of the pattern[:i] 263 | while i < self.pattern_length: 264 | # Check if the characters match 265 | if self.pattern[i] == self.pattern[j]: 266 | j += 1 267 | self.lps[i] = j 268 | i += 1 269 | else: 270 | if j != 0: 271 | j = self.lps[j - 1] 272 | else: 273 | self.lps[i] = 0 274 | i += 1 275 | 276 | # Search for the pattern in the text 277 | def search(self, 278 | pattern: str, 279 | text: str, 280 | ) -> int: 281 | """ 282 | This function searches for the pattern in the text. 283 | 284 | Arguments: 285 | pattern (str): The pattern to search for. 286 | text (str): The text to search in. 287 | 288 | Returns: 289 | int: The index of the pattern in the text (or -1 if the pattern is not found) 290 | 291 | Raises: 292 | AssertionError: If the text is not a string. 293 | 294 | .. note:: 295 | * This is the main function of the KMP search algorithm class. 296 | """ 297 | # Check the inputs 298 | assert isinstance(text, str), 'The text must be a string.' 299 | 300 | # Set the attributes 301 | self.pattern = pattern 302 | self.pattern_length = len(self.pattern) 303 | 304 | # Initialize the lps array 305 | self.initialize_lps() 306 | 307 | # Loop over the text 308 | i = 0 309 | j = 0 310 | while i < len(text): 311 | # Check if the characters match 312 | if self.pattern[j] == text[i]: 313 | i += 1 314 | j += 1 315 | # Check if the pattern is found 316 | if j == self.pattern_length: 317 | return i - j 318 | # Check if the characters do not match 319 | elif i < len(text) and self.pattern[j] != text[i]: 320 | if j != 0: 321 | j = self.lps[j - 1] 322 | else: 323 | i += 1 324 | 325 | # Return -1 if the pattern is not found 326 | return -1 327 | 328 | 329 | 330 | # Boyer-Moore search algorithm class 331 | class BoyerMooreSearch: 332 | """ 333 | This class contains the Boyer-Moore search algorithm. 334 | """ 335 | 336 | def __init__(self) -> None: 337 | """ 338 | This function initializes the Boyer-Moore search algorithm class. [BM1977]_ 339 | 340 | The Bayer-Moore search algorithm is a string searching algorithm that uses a heuristic to skip over large sections of the search string, resulting in faster search times than traditional algorithms such as brute-force or Knuth-Morris-Pratt. It is particularly useful for searching for patterns in large amounts of text. 341 | 342 | .. [BM1977] Boyer, RS and Moore, JS. "A fast string searching algorithm." Communications of the ACM 20.10 (1977): 762-772. 343 | 344 | A Correct Preprocessing Algorithm for Boyer–Moore String-Searching 345 | 346 | https://www.cs.jhu.edu/~langmea/resources/lecture_notes/strings_matching_boyer_moore.pdf 347 | 348 | """ 349 | super().__init__() 350 | 351 | 352 | # This is what we call the "prefix - suffix" match case of the good suffix rule 353 | def aux_get_suffix_prefix_length(self, 354 | i: int, 355 | ) -> int: 356 | """ 357 | This auxiliary function is used to compute the length of the longest suffix of pattern[i:] that matches a "prefix" of the pattern. 358 | 359 | Arguments: 360 | i (int): The index of the suffix. 361 | 362 | Returns: 363 | int: The length of the longest suffix of pattern[i:] that matches a "prefix" of the pattern. 364 | """ 365 | 366 | # pattern [ ....... i ................j] 367 | # Initialize j to the end of the pattern 368 | j = self.pattern_length - 1 369 | 370 | # pattern [ ....... i ....... j .......] 371 | # Move j to the left until we find a mismatch or until j == i 372 | while j >= i and self.pattern[j] == self.pattern[j - i]: 373 | # pattern [ ... j-i ..... i ... j .......] 374 | j -= 1 375 | 376 | return self.pattern_length - (j - 1) 377 | 378 | 379 | # This is what we call the "substring match" case of the good suffix rule 380 | def aux_get_matching_substring_length(self, 381 | j: int, 382 | ) -> int: 383 | """ 384 | This auxilary function is used to compute the length of the longess suffix of the patterm that matches a substring of the pattern that ends at the index j. 385 | 386 | It is used in the "substring match" case of the good suffix rule. More specifically, it is used to find when the suffix of the pattern does not match the text at all. Hence, we find the longest suffix of the pattern that matches a substring of the pattern that ends at the index j. 387 | 388 | Arguments: 389 | j (int): The end index of the substring. 390 | 391 | Returns: 392 | int: The length of the longess suffix of the patterm that matches a substring of the pattern that ends at the index j. 393 | 394 | """ 395 | # Loop over the suffixes of the pattern 396 | for i in range(j, -1, -1): 397 | # Check if the substring matches the suffix 398 | if self.pattern[i:i+(j+1)] == self.pattern[self.pattern_length-(j+1):]: 399 | return j - i + 1 400 | # Otherwise, if we get here, the substring does not match any suffix of the pattern 401 | return 0 402 | 403 | 404 | # Creates the "good suffix" skip table 405 | def create_skip_gs(self) -> None: 406 | """ 407 | This function creates the "good suffix" skip table. (It is used in the preprocessing step of the Boyer-Moore search algorithm.) 408 | 409 | Arguments: 410 | None 411 | 412 | Returns: 413 | None 414 | 415 | """ 416 | # Create the good suffix "skip" table 417 | # TODO(msuzgun): Has an error! 418 | self.skip_gs = [0] * self.pattern_length 419 | # skip_gs[i] denotes the number of cells to the right we need to skip if the current character is the i-th character of the pattern 420 | 421 | # First, we compute the length of the longest suffix of pattern [i:] that matches a prefix of the pattern 422 | for i in range(self.pattern_length - 1): 423 | self.skip_gs[i] = self.aux_get_suffix_prefix_length(i) 424 | 425 | # Set the default skip value to the pattern length 426 | self.skip_gs[-1] = 1 427 | 428 | # Second, we compute the length of the longest suffix of the pattern that matches a substring of the pattern that ends at the index j 429 | for j in range(self.pattern_length - 2): 430 | k = (self.pattern_length - 1) - self.aux_get_matching_substring_length(j) 431 | if self.skip_gs[k] == 0: 432 | self.skip_gs[k] = self.pattern_length - 1 - j 433 | 434 | 435 | # Creates the "bad character" skip table 436 | def create_skip_bc(self) -> None: 437 | """ 438 | This function creates the "bad character" skip table. (It is used in the preprocessing step of the Boyer-Moore search algorithm.) 439 | 440 | Arguments: 441 | None 442 | 443 | Returns: 444 | None 445 | """ 446 | # Create the bad character "skip" table 447 | self.last_occurence = {} 448 | 449 | # last_occurence[c] denotes the index of the last occurence of the character c in the pattern 450 | for j in range(self.pattern_length - 1): 451 | self.last_occurence[self.pattern[j]] = j 452 | 453 | # Set the default skip value to the pattern length 454 | self.last_occurence.setdefault(None, self.pattern_length) 455 | 456 | 457 | # Searches for the pattern in the text using the Boyer-Moore algorithm 458 | def search(self, 459 | pattern: str, 460 | text: str, 461 | ) -> int: 462 | """ 463 | This function searches for the pattern in the text using the Boyer-Moore algorithm. 464 | 465 | Arguments: 466 | pattern (str): The pattern to search for. 467 | text (str): The text to search in. 468 | 469 | Returns: 470 | int: The index of the pattern in the text (or -1 if the pattern is not found) 471 | 472 | Raises: 473 | AssertionError: If the text or the pattern is not a string. 474 | """ 475 | # Check both the pattern and the text 476 | assert isinstance(pattern, str), 'The pattern must be a string.' 477 | assert isinstance(text, str), 'The text must be a string.' 478 | 479 | # Set the attributes 480 | self.pattern = pattern 481 | 482 | # Length of the pattern 483 | self.pattern_length = len(self.pattern) 484 | 485 | # Preprocess the pattern by creating the skip tables for the bad character and good suffix rules, respectively. 486 | self.create_skip_bc() 487 | self.create_skip_gs() 488 | 489 | 490 | # Loop over the text 491 | i = 0 492 | while i <= len(text) - self.pattern_length: 493 | # Loop over the pattern 494 | j = self.pattern_length - 1 495 | while j >= 0 and text[i + j] == self.pattern[j]: 496 | j -= 1 497 | # Check if the pattern is found 498 | if j < 0: 499 | return i 500 | # Update i 501 | i += max(j - self.last_occurence.get(text[i + j], self.pattern_length), 1) 502 | 503 | # Return -1 if the pattern is not found 504 | return -1 -------------------------------------------------------------------------------- /string2string/search/faiss_search.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains a wrapper for the Faiss library by Facebook AI Research. 3 | """ 4 | 5 | from typing import List, Union, Optional, Dict, Any 6 | import os 7 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 8 | 9 | import torch 10 | from transformers import AutoTokenizer, AutoModel 11 | from datasets import Dataset 12 | 13 | import pandas as pd 14 | 15 | # FAISS library wrapper class 16 | class FaissSearch: 17 | def __init__(self, 18 | model_name_or_path: str = 'facebook/bart-large', 19 | tokenizer_name_or_path: str = 'facebook/bart-large', 20 | device: str = 'cpu', 21 | ) -> None: 22 | r""" 23 | This function initializes the wrapper for the FAISS library, which is used to perform semantic search. 24 | 25 | 26 | .. attention:: 27 | 28 | * If you use this class, please make sure to cite the following paper: 29 | 30 | .. code-block:: latex 31 | 32 | @article{johnson2019billion, 33 | title={Billion-scale similarity search with {GPUs}}, 34 | author={Johnson, Jeff and Douze, Matthijs and J{\'e}gou, Herv{\'e}}, 35 | journal={IEEE Transactions on Big Data}, 36 | volume={7}, 37 | number={3}, 38 | pages={535--547}, 39 | year={2019}, 40 | publisher={IEEE} 41 | } 42 | 43 | * The code is based on the following GitHub repository: 44 | https://github.com/facebookresearch/faiss 45 | 46 | Arguments: 47 | model_name_or_path (str, optional): The name or path of the model to use. Defaults to 'facebook/bart-large'. 48 | tokenizer_name_or_path (str, optional): The name or path of the tokenizer to use. Defaults to 'facebook/bart-large'. 49 | device (str, optional): The device to use. Defaults to 'cpu'. 50 | 51 | Returns: 52 | None 53 | """ 54 | 55 | # Set the device 56 | self.device = device 57 | 58 | # If the tokenizer is not specified, use the model name or path 59 | if tokenizer_name_or_path is None: 60 | tokenizer_name_or_path = model_name_or_path 61 | 62 | # Load the tokenizer 63 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) 64 | 65 | # Load the model 66 | self.model = AutoModel.from_pretrained(model_name_or_path).to(self.device) 67 | 68 | # Set the model to evaluation mode (since we do not need the gradients) 69 | self.model.eval() 70 | 71 | # Initialize the dataset 72 | self.dataset = None 73 | 74 | 75 | # Auxiliary function to get the last hidden state 76 | def get_last_hidden_state(self, 77 | embeddings: torch.Tensor, 78 | ) -> torch.Tensor: 79 | """ 80 | This function returns the last hidden state (e.g., [CLS] token's) of the input embeddings. 81 | 82 | Arguments: 83 | embeddings (torch.Tensor): The input embeddings. 84 | 85 | Returns: 86 | torch.Tensor: The last hidden state. 87 | """ 88 | 89 | # Get the last hidden state 90 | last_hidden_state = embeddings.last_hidden_state 91 | 92 | # Return the last hidden state 93 | return last_hidden_state[:, 0, :] 94 | 95 | 96 | # Auxiliary function to get the mean pooling 97 | def get_mean_pooling(self, 98 | embeddings: torch.Tensor, 99 | ) -> torch.Tensor: 100 | """ 101 | This function returns the mean pooling of the input embeddings. 102 | 103 | Arguments: 104 | embeddings (torch.Tensor): The input embeddings. 105 | 106 | Returns: 107 | torch.Tensor: The mean pooling. 108 | """ 109 | 110 | # Get the mean pooling 111 | mean_pooling = embeddings.last_hidden_state.mean(dim=1) 112 | 113 | # Return the mean pooling 114 | return mean_pooling 115 | 116 | 117 | 118 | # Get the embeddings 119 | def get_embeddings(self, 120 | text: Union[str, List[str]], 121 | embedding_type: str = 'last_hidden_state', 122 | batch_size: int = 8, 123 | num_workers: int = 4, 124 | ) -> torch.Tensor: 125 | """ 126 | This function returns the embeddings of the input text. 127 | 128 | Arguments: 129 | text (Union[str, List[str]]): The input text. 130 | embedding_type (str, optional): The type of embedding to use. Defaults to 'last_hidden_state'. 131 | batch_size (int, optional): The batch size to use. Defaults to 8. 132 | num_workers (int, optional): The number of workers to use. Defaults to 4. 133 | 134 | Returns: 135 | torch.Tensor: The embeddings. 136 | 137 | Raises: 138 | ValueError: If the embedding type is invalid. 139 | """ 140 | 141 | # Check if the embedding type is valid 142 | if embedding_type not in ['last_hidden_state', 'mean_pooling']: 143 | raise ValueError(f'Invalid embedding type: {embedding_type}. Only "last_hidden_state" and "mean_pooling" are supported.') 144 | 145 | # Tokenize the input text 146 | encoded_text = self.tokenizer( 147 | text, 148 | padding=True, 149 | truncation=True, 150 | return_tensors='pt', 151 | ) 152 | 153 | # Move the input text to the device 154 | encoded_text = encoded_text.to(self.device) 155 | 156 | # encoded_inputs = {k: v.to(self.device) for k, v in encoded_inputs.items()} 157 | 158 | # Get the embeddings 159 | with torch.no_grad(): 160 | embeddings = self.model(**encoded_text) 161 | 162 | # Get the proper embedding type 163 | if embedding_type == 'last_hidden_state': 164 | # Get the last hidden state 165 | embeddings = self.get_last_hidden_state(embeddings) 166 | elif embedding_type == 'mean_pooling': 167 | # Get the mean pooling 168 | embeddings = self.get_mean_pooling(embeddings) 169 | 170 | # Return the embeddings 171 | return embeddings 172 | 173 | 174 | # Add FAISS index 175 | def add_faiss_index(self, 176 | column_name: str = 'embeddings', 177 | metric_type: Optional[int] = None, 178 | batch_size: int = 8, 179 | **kwargs, 180 | ) -> None: 181 | """ 182 | This function adds a FAISS index to the dataset. 183 | 184 | Arguments: 185 | column_name (str, optional): The name of the column containing the embeddings. Defaults to 'embeddings'. 186 | index_type (str, optional): The index type to use. Defaults to 'Flat'. 187 | metric_type (str, optional): The metric type to use. Defaults to 'L2'. 188 | 189 | Returns: 190 | None 191 | 192 | Raises: 193 | ValueError: If the dataset is not initialized. 194 | """ 195 | 196 | # Check if the dataset is initialized 197 | if self.dataset is None: 198 | raise ValueError('The dataset is not initialized. Please initialize the dataset first.') 199 | 200 | print('Adding FAISS index...') 201 | self.dataset.add_faiss_index( 202 | column_name, 203 | # metric_type=metric_type, 204 | # device=self.device, 205 | # batch_size=batch_size, 206 | faiss_verbose=True, 207 | # **kwargs, 208 | ) 209 | 210 | def save_faiss_index(self, 211 | index_name: str, 212 | file_path: str, 213 | ) -> None: 214 | """ 215 | This function saves the FAISS index to the specified file path. 216 | * This is a wrapper function for the `save_faiss_index` function in the `Dataset` class. 217 | 218 | Arguments: 219 | index_name (str): The name of the FAISS index (e.g., "embeddings") 220 | file_path (str): The file path to save the FAISS index. 221 | 222 | Returns: 223 | None 224 | 225 | Raises: 226 | ValueError: If the dataset is not initialized. 227 | """ 228 | 229 | # Check if the dataset is initialized 230 | if self.dataset is None: 231 | raise ValueError('The dataset is not initialized. Please initialize the dataset first.') 232 | 233 | print('Saving FAISS index...') 234 | self.dataset.save_faiss_index(index_name=index_name, file=file_path) 235 | 236 | 237 | def load_faiss_index(self, 238 | index_name: str, 239 | file_path: str, 240 | device: str = 'cpu', 241 | ) -> None: 242 | """ 243 | This function loads the FAISS index from the specified file path. 244 | * This is a wrapper function for the `load_faiss_index` function in the `Dataset` class. 245 | 246 | Arguments: 247 | index_name (str): The name of the FAISS index (e.g., "embeddings") 248 | file_path (str): The file path to load the FAISS index from. 249 | device (str, optional): The device to use ("cpu" or "cuda") (default: "cpu"). 250 | 251 | Returns: 252 | None 253 | 254 | Raises: 255 | ValueError: If the dataset is not initialized. 256 | """ 257 | 258 | # Check if the dataset is initialized 259 | if self.dataset is None: 260 | raise ValueError('The dataset is not initialized. Please initialize the dataset first.') 261 | 262 | print('Loading FAISS index...') 263 | self.dataset.load_faiss_index(index_name=index_name, file=file_path, device=device) 264 | 265 | 266 | # Initialize the corpus using a dictionary or pandas DataFrame or HuggingFace Datasets object 267 | def initialize_corpus(self, 268 | corpus: Union[Dict[str, List[str]], pd.DataFrame, Dataset], 269 | section: str = 'text', 270 | index_column_name: str = 'embeddings', 271 | embedding_type: str = 'last_hidden_state', 272 | batch_size: Optional[int] = None, 273 | num_workers: Optional[int] = None, 274 | save_path: Optional[str] = None, 275 | ) -> Dataset: 276 | """ 277 | This function initializes a dataset using a dictionary or pandas DataFrame or HuggingFace Datasets object. 278 | 279 | Arguments: 280 | dataset_dict (Dict[str, List[str]]): The dataset dictionary. 281 | section (str): The section of the dataset to use whose embeddings will be used for semantic search (e.g., 'text', 'title', etc.) (default: 'text'). 282 | index_column_name (str): The name of the column containing the embeddings (default: 'embeddings') 283 | embedding_type (str): The type of embedding to use (default: 'last_hidden_state'). 284 | batch_size (int, optional): The batch size to use (default: 8). 285 | max_length (int, optional): The maximum length of the input sequences. 286 | num_workers (int, optional): The number of workers to use. 287 | save_path (Optional[str], optional): The path to save the dataset (default: None). 288 | 289 | Returns: 290 | Dataset: The dataset object (HuggingFace Datasets). 291 | 292 | Raises: 293 | ValueError: If the dataset is not a dictionary or pandas DataFrame or HuggingFace Datasets object. 294 | """ 295 | 296 | # Create the dataset 297 | if isinstance(corpus, dict): 298 | self.dataset = Dataset.from_dict(corpus) 299 | elif isinstance(corpus, pd.DataFrame): 300 | self.dataset = Dataset.from_pandas(corpus) 301 | elif isinstance(corpus, Dataset): 302 | self.dataset = corpus 303 | else: 304 | raise ValueError('The dataset must be a dictionary or pandas DataFrame.') 305 | 306 | # Set the embedding_type 307 | self.embedding_type = embedding_type 308 | 309 | 310 | # Tokenize the dataset 311 | # self.dataset = self.dataset.map( 312 | # lambda x: x[section], 313 | # batched=True, 314 | # batch_size=batch_size, 315 | # num_proc=num_workers, 316 | # ) 317 | 318 | # Map the section of the dataset to the embeddings 319 | self.dataset = self.dataset.map( 320 | lambda x: { 321 | index_column_name: self.get_embeddings(x[section], embedding_type=self.embedding_type).detach().cpu().numpy()[0] 322 | }, 323 | # batched=True, 324 | batch_size=batch_size, 325 | num_proc=num_workers, 326 | ) 327 | 328 | # Save the dataset 329 | if save_path is not None: 330 | self.dataset.to_json(save_path) 331 | 332 | # Add FAISS index 333 | self.add_faiss_index( 334 | column_name=index_column_name, 335 | ) 336 | 337 | # Return the dataset 338 | return self.dataset 339 | 340 | 341 | # Initialize the dataset using a JSON file 342 | def load_dataset_from_json(self, 343 | json_path: str, 344 | ) -> Dataset: 345 | """ 346 | This function loads a dataset from a JSON file. 347 | 348 | Arguments: 349 | json_path (str): The path to the JSON file. 350 | 351 | Returns: 352 | Dataset: The dataset. 353 | """ 354 | 355 | # Load the dataset 356 | self.dataset = Dataset.from_json(json_path) 357 | 358 | # Return the dataset 359 | return self.dataset 360 | 361 | 362 | # Search for the most similar elements in the dataset, given a query 363 | def search(self, 364 | query: str, 365 | k: int = 1, 366 | index_column_name: str = 'embeddings', 367 | ) -> pd.DataFrame: 368 | """ 369 | This function searches for the most similar elements in the dataset, given a query. 370 | 371 | Arguments: 372 | query (str): The query. 373 | k (int, optional): The number of elements to return (default: 1). 374 | index_column_name (str, optional): The name of the column containing the embeddings (default: 'embeddings') 375 | 376 | Returns: 377 | pd.DataFrame: The most similar elements in the dataset (text, score, etc.), sorted by score. 378 | 379 | Remarks: 380 | The returned elements are dictionaries containing the text and the score. 381 | """ 382 | 383 | # Get the embeddings of the query 384 | query_embeddings = self.get_embeddings([query], embedding_type=self.embedding_type).detach().cpu().numpy() 385 | 386 | # Search for the most similar elements in the dataset 387 | scores, similar_elts = self.dataset.get_nearest_examples( 388 | index_name=index_column_name, 389 | query=query_embeddings, 390 | k=k, 391 | ) 392 | 393 | # Convert the results to a pandas DataFrame 394 | results_df = pd.DataFrame.from_dict(similar_elts) 395 | 396 | # Add the scores 397 | results_df['score'] = scores 398 | 399 | # Sort the results by score 400 | results_df.sort_values("score", ascending=True, inplace=True) 401 | 402 | # Return the most similar elements 403 | return results_df -------------------------------------------------------------------------------- /string2string/similarity/__init__.py: -------------------------------------------------------------------------------- 1 | # The following trick allows us to import the classes directly from the similarity module: 2 | from .bertscore import BERTScore 3 | from .bartscore import BARTScore 4 | from .cosine_similarity import CosineSimilarity 5 | from .classical import LCSubstringSimilarity, LCSubsequenceSimilarity, JaroSimilarity 6 | -------------------------------------------------------------------------------- /string2string/similarity/bartscore.py: -------------------------------------------------------------------------------- 1 | """ 2 | This class contains the original implementation of the BARTScore algorithm by Yuan et al. (2021). 3 | 4 | BARTScore: BART-based Evaluation Metric for Text Generation 5 | 6 | @inproceedings{bartscore2021, 7 | author = {Yuan, Weizhe and Neubig, Graham and Liu, Pengfei}, 8 | booktitle = {Advances in Neural Information Processing Systems}, 9 | editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan}, 10 | pages = {27263--27277}, 11 | publisher = {Curran Associates, Inc.}, 12 | title = {BARTScore: Evaluating Generated Text as Text Generation}, 13 | url = {https://proceedings.neurips.cc/paper/2021/file/e4d2b6e6fdeca3e60e0f1a62fee3d9dd-Paper.pdf}, 14 | volume = {34}, 15 | year = {2021} 16 | } 17 | 18 | Disclaimer: 19 | This code is adapted from https://github.com/neulab/BARTScore/blob/main/bart_score.py 20 | """ 21 | 22 | import numpy as np 23 | from typing import List, Union, Dict 24 | import traceback 25 | 26 | import torch 27 | import torch.nn as nn 28 | from transformers import BartTokenizer, BartForConditionalGeneration 29 | 30 | 31 | # BARTScore class 32 | class BARTScore: 33 | """ 34 | This class implements the BARTScore algorithm. 35 | """ 36 | 37 | def __init__(self, 38 | model_name_or_path='facebook/bart-large-cnn', 39 | tokenizer_name_or_path: str = None, 40 | device: str = 'cpu', 41 | max_length=1024, 42 | ) -> None: 43 | r""" 44 | This function initializes the BARTScore class, which computes the BARTScore between two pieces of text. 45 | 46 | Arguments: 47 | model_name_or_path (str): The name or path of the model. Defaults to 'facebook/bart-large-cnn'. 48 | tokenizer_name_or_path (str): The name or path of the tokenizer. Defaults to None. 49 | device (str): The device to use. Defaults to 'cpu'. 50 | max_length (int): The maximum length of the input. Defaults to 1024. 51 | 52 | Returns: 53 | None 54 | 55 | Raises: 56 | ValueError: If the device is not 'cpu' or 'cuda'. 57 | 58 | .. attention:: 59 | 60 | If you use this class, please make sure to cite the following paper: 61 | 62 | .. code-block:: latex 63 | 64 | @inproceedings{bartscore2021, 65 | author = {Yuan, Weizhe and Neubig, Graham and Liu, Pengfei}, 66 | booktitle = {Advances in Neural Information Processing Systems}, 67 | editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan}, 68 | pages = {27263--27277}, 69 | publisher = {Curran Associates, Inc.}, 70 | title = {BARTScore: Evaluating Generated Text as Text Generation}, 71 | url = {https://proceedings.neurips.cc/paper/2021/file/e4d2b6e6fdeca3e60e0f1a62fee3d9dd-Paper.pdf}, 72 | volume = {34}, 73 | year = {2021} 74 | } 75 | 76 | .. note:: 77 | * The default model is the BART-large-cnn model. 78 | * If the tokenizer name or path is not specified, then the model name or path will be used. 79 | * If the device is 'cuda', then the model will be loaded onto the GPU. 80 | * If device is not specified, use the GPU if available, otherwise use the CPU. 81 | 82 | """ 83 | 84 | if tokenizer_name_or_path is None: 85 | tokenizer_name_or_path = model_name_or_path 86 | 87 | # Set the attributes 88 | self.device = device 89 | self.max_length = max_length 90 | 91 | # Load model and tokenizer 92 | self.tokenizer = BartTokenizer.from_pretrained(tokenizer_name_or_path) 93 | self.model = BartForConditionalGeneration.from_pretrained(model_name_or_path) 94 | self.model.eval() 95 | self.model.to(device) 96 | 97 | # Set up loss 98 | self.loss_fct = nn.NLLLoss(reduction='none', ignore_index=self.model.config.pad_token_id) 99 | self.lsm = nn.LogSoftmax(dim=1) 100 | 101 | 102 | # Loads the model weights from a specified path 103 | def load(self, 104 | weights_path=None, 105 | ) -> None: 106 | """ 107 | This function loads the model weights from a specified path. 108 | 109 | Arguments: 110 | weights_path (str): The path to the weights. 111 | 112 | Returns: 113 | None 114 | """ 115 | if weights_path is None: 116 | weights_path = 'models/bart.pth' 117 | 118 | self.model.load_state_dict(torch.load(weights_path, map_location=self.device)) 119 | 120 | 121 | # Compute the BARTScore between source sentences and target sentences 122 | def compute(self, 123 | source_sentences: List[str], 124 | target_sentences: Union[List[str], List[List[str]]], 125 | batch_size: int = 4, 126 | agg: str = 'mean', 127 | ) -> Dict[str, List[float]]: 128 | """ 129 | This function scores the target sentences against the source sentences using BARTScore. 130 | 131 | Arguments: 132 | source_sentences (List[str]): The source sentences. 133 | target_sentences (Union[List[str], List[List[str]]]): The target sentences. 134 | batch_size (int): The batch size to use (default: 4) 135 | agg (str): The aggregation method. Defaults to 'mean'; used only when target_sentences is a list of lists. 136 | 137 | Returns: 138 | Dict[str, List[float]]: The BARTScore for each example. 139 | 140 | Raises: 141 | ValueError: If the number of source sentences and target sentences do not match. 142 | """ 143 | # Check the number of source sentences and target sentences 144 | if len(source_sentences) != len(target_sentences): 145 | raise ValueError(f'Number of source sentences ({len(source_sentences)}) and number of target sentences ({len(target_sentences)}) do not match.') 146 | 147 | # If the target sentences are a list of lists, then call the multi_ref_score function 148 | if isinstance(target_sentences[0], list): 149 | return self.compute_multi_ref_score( 150 | source_sentences=source_sentences, 151 | target_sentences=target_sentences, 152 | batch_size=batch_size, 153 | agg=agg 154 | ) 155 | 156 | # Score for each example 157 | score_list = [] 158 | 159 | for i in range(0, len(source_sentences), batch_size): 160 | # Get the current batch 161 | src_batch = source_sentences[i: i + batch_size] 162 | tgt_batch = target_sentences[i: i + batch_size] 163 | try: 164 | with torch.no_grad(): 165 | # Encode the batch 166 | encoded_src = self.tokenizer( 167 | src_batch, 168 | max_length=self.max_length, 169 | truncation=True, 170 | padding=True, 171 | return_tensors='pt' 172 | ) 173 | encoded_tgt = self.tokenizer( 174 | tgt_batch, 175 | max_length=self.max_length, 176 | truncation=True, 177 | padding=True, 178 | return_tensors='pt' 179 | ) 180 | 181 | # Get the input ids and attention masks for the source and target sentences 182 | src_tokens = encoded_src['input_ids'].to(self.device) 183 | src_mask = encoded_src['attention_mask'].to(self.device) 184 | tgt_tokens = encoded_tgt['input_ids'].to(self.device) 185 | tgt_mask = encoded_tgt['attention_mask'] 186 | tgt_len = tgt_mask.sum(dim=1).to(self.device) 187 | 188 | # Feed the batch to the model and get the loss 189 | output = self.model( 190 | input_ids=src_tokens, 191 | attention_mask=src_mask, 192 | labels=tgt_tokens 193 | ) 194 | logits = output.logits.view(-1, self.model.config.vocab_size) 195 | # Compute the loss 196 | loss = self.loss_fct(self.lsm(logits), tgt_tokens.view(-1)) 197 | loss = loss.view(tgt_tokens.shape[0], -1) 198 | loss = loss.sum(dim=1) / tgt_len 199 | # Get the score 200 | curr_score_list = [-x.item() for x in loss] 201 | # Append the score to the list 202 | score_list += curr_score_list 203 | 204 | except: 205 | # If there is an error, print the traceback 206 | raise Exception(f'Error in scoring batch {i // batch_size}:\n{traceback.format_exc()}') 207 | return {'score': np.array(score_list)} 208 | 209 | 210 | # Score a batch of examples with multiple references 211 | def compute_multi_ref_score(self, 212 | source_sentences: List[str], 213 | target_sentences: List[List[str]], 214 | batch_size: int = 4, 215 | agg: str = "mean", 216 | ) -> Dict[str, List[float]]: 217 | """ 218 | Score a batch of examples with multiple references. 219 | 220 | Arguments: 221 | source_sentences (List[str]): The source sentences. 222 | target_sentences (List[List[str]]): The target sentences. 223 | agg (str): The aggregation method. Can be "mean" or "max". 224 | batch_size (int): The batch size. 225 | 226 | Returns: 227 | Dict[str, List[float]]: The BARTScore for each example. 228 | 229 | Raises: 230 | ValueError: If the number of source sentences and target sentences do not match. 231 | """ 232 | 233 | # Assert we have the same number of references 234 | ref_nums = [len(x) for x in target_sentences] 235 | if len(set(ref_nums)) > 1: 236 | raise Exception("You have different number of references per test sample.") 237 | 238 | ref_num = len(target_sentences[0]) 239 | score_matrix = [] 240 | for i in range(ref_num): 241 | curr_target_sentences = [x[i] for x in target_sentences] 242 | scores = self.compute(source_sentences, curr_target_sentences, batch_size) 243 | score_matrix.append(scores) 244 | if agg == "mean": 245 | score_list = np.mean(score_matrix, axis=0) 246 | elif agg == "max": 247 | score_list = np.max(score_matrix, axis=0) 248 | else: 249 | raise NotImplementedError(f"Aggregation method {agg} not implemented yet.") 250 | return {"score": score_list} -------------------------------------------------------------------------------- /string2string/similarity/bertscore.py: -------------------------------------------------------------------------------- 1 | """ 2 | This class contains the original implementation of the BERTScore algorithm by Zhang et al. (2020). 3 | 4 | BERTScore: Evaluating Text Generation with BERT 5 | 6 | @inproceedings{bertscore2020, 7 | title={BERTScore: Evaluating Text Generation with BERT}, 8 | author={Tianyi Zhang* and Varsha Kishore* and Felix Wu* and Kilian Q. Weinberger and Yoav Artzi}, 9 | booktitle={International Conference on Learning Representations}, 10 | year={2020}, 11 | url={https://openreview.net/forum?id=SkeHuCVFDr} 12 | } 13 | 14 | Disclaimer: 15 | This code is adapted from https://github.com/Tiiiger/bert_score 16 | """ 17 | 18 | from typing import List, Union, Optional, Tuple 19 | 20 | import os 21 | import sys 22 | import time 23 | import pandas as pd 24 | from collections import defaultdict 25 | import torch 26 | from bert_score.utils import (bert_cos_score_idf, get_hash, 27 | get_idf_dict, get_model, get_tokenizer, 28 | lang2model, model2layers) 29 | 30 | 31 | class BERTScore: 32 | """ 33 | This class implements the BERTScore algorithm. 34 | """ 35 | 36 | def __init__(self, 37 | model_name_or_path: str = None, 38 | lang: str = None, 39 | num_layers: int = None, 40 | all_layers: bool = False, 41 | use_fast_tokenizer: bool = False, 42 | device: str = 'cpu', 43 | baseline_path: str = None, 44 | ) -> None: 45 | r""" 46 | This function initializes the BERTScore class, which computes the BERTScore between two texts. 47 | 48 | Arguments: 49 | model_name_or_path (str): BERT model type to use (e.g., bert-base-uncased). 50 | lang (str): Language of the texts (e.g., en). 51 | num_layers (int): Number of layers to use. 52 | all_layers (bool): Whether to use all layers 53 | use_fast_tokenizer (bool): Whether to use the fast tokenizer. 54 | device (str): Device to use (e.g., cpu or cuda). 55 | baseline_path (str): Path to the baseline file. 56 | 57 | Returns: 58 | None 59 | 60 | Raises: 61 | ValueError: If model_name_or_path and lang are both None. 62 | 63 | .. attention:: 64 | 65 | If you use this class, please make sure to cite the following paper: 66 | 67 | .. code-block:: latex 68 | 69 | @inproceedings{bertscore2020, 70 | title={BERTScore: Evaluating Text Generation with BERT}, 71 | author={Tianyi Zhang* and Varsha Kishore* and Felix Wu* and Kilian Q. Weinberger and Yoav Artzi}, 72 | booktitle={International Conference on Learning Representations}, 73 | year={2020}, 74 | url={https://openreview.net/forum?id=SkeHuCVFDr} 75 | } 76 | 77 | 78 | .. note:: 79 | * If model_name_or_path is not specified, use the default model for the language. 80 | * If num_layers is not specified, use the default number of layers. 81 | * If device is not specified, use the GPU if available, otherwise use the CPU. 82 | * If baseline_path is not specified, use the default baseline file. 83 | """ 84 | 85 | # Check the arguments 86 | if model_name_or_path is None and lang is None: 87 | raise ValueError("You must specify either model_name_or_path or lang") 88 | 89 | # Set the attributes 90 | self.model_name_or_path = model_name_or_path 91 | self.lang = lang 92 | self.num_layers = num_layers 93 | self.all_layers = all_layers 94 | self.use_fast_tokenizer = use_fast_tokenizer 95 | self.baseline_path = baseline_path 96 | 97 | # If model_name_or_path is not specified, use the default model for the language 98 | if self.model_name_or_path is None: 99 | self.lang = lang.lower() 100 | self.model_name_or_path = lang2model[self.lang] 101 | 102 | # If num_layers is not specified, use the default number of layers 103 | if num_layers is None: 104 | self.num_layers = model2layers[self.model_name_or_path] 105 | 106 | # Set the device 107 | self.device = device 108 | if self.device is None: 109 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 110 | 111 | # Load model and tokenizer 112 | self.tokenizer = get_tokenizer(self.model_name_or_path, self.use_fast_tokenizer) 113 | self.model = get_model(self.model_name_or_path, self.num_layers, self.all_layers) 114 | self.model.eval() 115 | self.model.to(device) 116 | 117 | 118 | # Compute the BERTScore between source sentences and target sentences 119 | def compute(self, 120 | source_sentences: List[str], 121 | target_sentences: Union[List[str], List[List[str]]], 122 | batch_size: int = 4, 123 | idf: bool = False, 124 | nthreads: int = 4, 125 | return_hash: bool = False, 126 | rescale_with_baseline: bool = False, 127 | verbose: bool = False, 128 | ) -> Union[dict, Optional[str]]: 129 | """ 130 | This function scores the source sentences based on their similarity to the target sentences using BERTScore. 131 | 132 | Arguments: 133 | source_sentences (list of str): candidate sentences 134 | target_sentences (list of str or list of list of str): reference sentences 135 | batch_size (int): bert score processing batch size 136 | idf (bool or dict): use idf weighting, can also be a precomputed idf_dict 137 | nthreads (int): number of threads 138 | return_hash (bool): return hashcode of the setting 139 | rescale_with_baseline (bool): rescale bertscore with pre-computed baseline 140 | verbose (bool): turn on intermediate status update 141 | 142 | Returns: 143 | (Dict[str, Tensor], Optional[str]): A dictionary containing the precision, recall, and F1 score, and the hashcode (if return_hash is True). 144 | where the precision, recall, and F1 score are tensors of shape (len(source_sentences), 145 | 146 | Raises: 147 | ValueError: If the number of source sentences and target sentences do not match. 148 | """ 149 | 150 | # Check the arguments 151 | if len(source_sentences) != len(target_sentences): 152 | raise ValueError("The number of candidates and references do not match") 153 | 154 | # If the target sentences are grouped, flatten them 155 | ref_group_boundaries = None 156 | if not isinstance(target_sentences[0], str): 157 | ref_group_boundaries = [] 158 | ori_source_sentences, ori_target_sentences = source_sentences, target_sentences 159 | source_sentences, target_sentences = [], [] 160 | count = 0 161 | for cand, ref_group in zip(ori_source_sentences, ori_target_sentences): 162 | source_sentences += [cand] * len(ref_group) 163 | target_sentences += ref_group 164 | ref_group_boundaries.append((count, count + len(ref_group))) 165 | count += len(ref_group) 166 | 167 | if rescale_with_baseline and self.baseline_path is None: 168 | raise ValueError("Need to specify baseline_path when rescaling with baseline") 169 | 170 | # Get the IDF dict 171 | if not idf: 172 | idf_dict = defaultdict(lambda: 1.0) 173 | # set idf for [SEP] and [CLS] to 0 174 | idf_dict[self.tokenizer.sep_token_id] = 0 175 | idf_dict[self.tokenizer.cls_token_id] = 0 176 | elif isinstance(idf, dict): 177 | if verbose: 178 | print("using predefined IDF dict...") 179 | idf_dict = idf 180 | else: 181 | if verbose: 182 | print("preparing IDF dict...") 183 | start = time.perf_counter() 184 | idf_dict = get_idf_dict(target_sentences, self.tokenizer, nthreads=nthreads) 185 | if verbose: 186 | print("done in {:.2f} seconds".format(time.perf_counter() - start)) 187 | 188 | if verbose: 189 | print("calculating scores...") 190 | 191 | start = time.perf_counter() 192 | 193 | # Get all the predictions 194 | all_preds = bert_cos_score_idf( 195 | model = self.model, 196 | refs = target_sentences, 197 | hyps = source_sentences, 198 | tokenizer= self.tokenizer, 199 | idf_dict = idf_dict, 200 | verbose = verbose, 201 | device = self.device, 202 | batch_size=batch_size, 203 | all_layers=self.all_layers, 204 | ).cpu() 205 | 206 | # If the target sentences are grouped, take the max score 207 | if ref_group_boundaries is not None: 208 | max_preds = [] 209 | for beg, end in ref_group_boundaries: 210 | max_preds.append(all_preds[beg:end].max(dim=0)[0]) 211 | all_preds = torch.stack(max_preds, dim=0) 212 | 213 | # Rescale with baseline 214 | use_custom_baseline = self.baseline_path is not None 215 | if rescale_with_baseline: 216 | if self.baseline_path is None: 217 | self.baseline_path = os.path.join( 218 | os.path.dirname(__file__), f"rescale_baseline/{self.lang}/{self.model_name_or_path}.tsv" 219 | ) 220 | if os.path.isfile(self.baseline_path): 221 | if not self.all_layers: 222 | baselines = torch.from_numpy( 223 | pd.read_csv(self.baseline_path).iloc[self.num_layers].to_numpy() 224 | )[1:].float() 225 | else: 226 | baselines = ( 227 | torch.from_numpy(pd.read_csv(self.baseline_path).to_numpy())[:, 1:] 228 | .unsqueeze(1) 229 | .float() 230 | ) 231 | 232 | all_preds = (all_preds - baselines) / (1 - baselines) 233 | else: 234 | print( 235 | f"Warning: Baseline not Found for {self.model_name_or_path} on {self.lang} at {self.baseline_path}", 236 | file=sys.stderr, 237 | ) 238 | 239 | # Get the final output 240 | out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2] # P, R, F 241 | scores = { 242 | "precision": out[0].numpy(), 243 | "recall": out[1].numpy(), 244 | "f1": out[2].numpy(), 245 | } 246 | 247 | # Print the time 248 | if verbose: 249 | time_diff = time.perf_counter() - start 250 | print( 251 | f"done in {time_diff:.2f} seconds, {len(target_sentences) / time_diff:.2f} sentences/sec" 252 | ) 253 | 254 | # If return hash, return both the output and the hash 255 | if return_hash: 256 | return tuple( 257 | [ 258 | scores, 259 | get_hash( 260 | self.model_name_or_path, 261 | self.num_layers, 262 | idf, 263 | rescale_with_baseline, 264 | use_custom_baseline=use_custom_baseline, 265 | use_fast_tokenizer=self.use_fast_tokenizer, 266 | ), 267 | ] 268 | ) 269 | # Otherwise, just return the output 270 | return scores -------------------------------------------------------------------------------- /string2string/similarity/classical.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains the classes for the similarity metrics and functions. 3 | """ 4 | 5 | 6 | from typing import List, Union, Tuple, Optional 7 | import numpy as np 8 | 9 | # # Import the LongestCommonSubsequence class 10 | from string2string.alignment.classical import LongestCommonSubsequence, LongestCommonSubstring 11 | 12 | # Longest Common Subsequence based similarity class 13 | class LCSubsequenceSimilarity(LongestCommonSubsequence): 14 | """ 15 | This class contains the Longest Common Subsequence similarity metric. 16 | 17 | This class inherits from the LongestCommonSubsequence class. 18 | """ 19 | 20 | def __init__(self): 21 | super().__init__() 22 | 23 | def compute(self, 24 | str1: Union[str, List[str]], 25 | str2: Union[str, List[str]], 26 | denominator: str = 'max', 27 | ) -> float: 28 | """ 29 | Returns the LCS-similarity between two strings. 30 | 31 | Arguments: 32 | str1 (Union[str, List[str]]): The first string or list of strings. 33 | str2 (Union[str, List[str]]): The second string or list of strings. 34 | denominator (str): The denominator to use. Options are 'max' and 'sum'. Default is 'max'. 35 | 36 | Returns: 37 | float: The similarity between the two strings. 38 | 39 | Raises: 40 | ValueError: If the denominator is invalid. 41 | """ 42 | 43 | # Get the numerator 44 | numerator, _ = super().compute(str1, str2) 45 | 46 | if denominator == 'max': 47 | return (numerator / max(len(str1), len(str2))) 48 | elif denominator == 'sum': 49 | return (2. * numerator / (len(str1) + len(str2))) 50 | else: 51 | raise ValueError('Invalid denominator.') 52 | 53 | 54 | # Longest Common Substring based similarity class 55 | class LCSubstringSimilarity(LongestCommonSubstring): 56 | """ 57 | This class contains the Longest Common Substring similarity metric. 58 | 59 | This class inherits from the LongestCommonSubstring class. 60 | """ 61 | def __init__(self): 62 | super().__init__() 63 | 64 | def compute(self, 65 | str1: Union[str, List[str]], 66 | str2: Union[str, List[str]], 67 | denominator: str = 'max', 68 | ) -> float: 69 | """ 70 | Returns the LCS-similarity between two strings. 71 | 72 | Arguments: 73 | str1 (Union[str, List[str]]): The first string or list of strings. 74 | str2 (Union[str, List[str]]): The second string or list of strings. 75 | denominator (str): The denominator to use. Options are 'max' and 'sum'. Default is 'max'. 76 | 77 | Returns: 78 | float: The similarity between the two strings. 79 | 80 | Raises: 81 | ValueError: If the denominator is invalid. 82 | """ 83 | # Get the numerator 84 | numerator, _ = super().compute(str1, str2) 85 | 86 | if denominator == 'max': 87 | return (numerator / max(len(str1), len(str2))) 88 | elif denominator == 'sum': 89 | return (2. * numerator / (len(str1) + len(str2))) 90 | else: 91 | raise ValueError('Invalid denominator.') 92 | 93 | # Jaro similarity class 94 | class JaroSimilarity: 95 | """ 96 | This class contains the Jaro similarity metric. 97 | """ 98 | 99 | def __init__(self): 100 | pass 101 | 102 | def compute(self, 103 | str1: Union[str, List[str]], 104 | str2: Union[str, List[str]], 105 | ) -> float: 106 | """ 107 | This function returns the Jaro similarity between two strings. 108 | 109 | Arguments: 110 | str1 (Union[str, List[str]]): The first string or list of strings. 111 | str2 (Union[str, List[str]]): The second string or list of strings. 112 | 113 | Returns: 114 | float: The Jaro similarity between the two strings. 115 | """ 116 | # Get the length of the strings 117 | len1 = len(str1) 118 | len2 = len(str2) 119 | 120 | # Get the maximum distance, which we denote by k 121 | k = max(len1, len2) // 2 - 1 122 | 123 | # Initialize the number of matching characters and the number of transpositions 124 | num_matches = 0 125 | num_transpositions = 0 126 | 127 | # Initialize the list of matching flags for the strings 128 | matches1 = [False] * len1 129 | matches2 = [False] * len2 130 | 131 | # Loop through the characters in the first string and find the matching characters 132 | for i in range(len1): 133 | # Get the lower and upper bounds for the search 134 | lower_bound = max(0, i - k) 135 | upper_bound = min(len2, i + k + 1) 136 | 137 | # Loop through the characters in the second string 138 | for j in range(lower_bound, upper_bound): 139 | # Check if the characters match 140 | if not matches2[j] and str1[i] == str2[j]: 141 | # Increment the number of matches 142 | num_matches += 1 143 | 144 | # Set the matching flags 145 | matches1[i] = True 146 | matches2[j] = True 147 | 148 | # Break out of the loop 149 | break 150 | 151 | # Check if there are no matches 152 | if num_matches == 0: 153 | return 0. 154 | 155 | # Loop through again but this time find the number of transpositions 156 | # That is, the number of times where there are two matching characters but there is another "matched" character in between them 157 | moving_index = 0 158 | for i in range(len1): 159 | # Check if the character is a match 160 | if matches1[i]: 161 | # Find the next match 162 | for j in range(moving_index, len2): 163 | # Check if the character is a match 164 | if matches2[j]: 165 | # Set the moving index 166 | moving_index = j + 1 167 | 168 | # Check if the characters are not in the right order 169 | if str1[i] != str2[j]: 170 | # Increment the number of transpositions 171 | num_transpositions += 1 172 | 173 | # Break out of the loop 174 | break 175 | 176 | num_transpositions = num_transpositions // 2 177 | 178 | # Return the Jaro similarity 179 | return (num_matches / len1 + num_matches / len2 + (num_matches - num_transpositions) / num_matches) / 3.0 180 | -------------------------------------------------------------------------------- /string2string/similarity/cosine_similarity.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains an implementation of the cosine similarity algorithm (for embedding vectors). 3 | """ 4 | 5 | from typing import List, Union, Tuple 6 | import torch 7 | from torch import Tensor 8 | from torch.nn import functional as F 9 | import numpy as np 10 | 11 | from string2string.misc.word_embeddings import GloVeEmbeddings 12 | 13 | 14 | # Cosine similarity class 15 | class CosineSimilarity: 16 | def __init__(self) -> None: 17 | r""" 18 | This function initializes the CosineSimilarity class. 19 | """ 20 | pass 21 | 22 | # Compute (tensor) 23 | def _compute_tensor(self, 24 | x1: Tensor, 25 | x2: Tensor, 26 | dim: int = 1, 27 | eps: float = 1e-8 28 | ) -> Tensor: 29 | r""" 30 | Computes the cosine similarity between two tensors along a given dimension. 31 | 32 | Arguments: 33 | x1 (Tensor): First tensor. 34 | x2 (Tensor): Second tensor. 35 | dim (int): Dimension to compute cosine similarity. 36 | eps (float): Epsilon value. 37 | 38 | Returns: 39 | Tensor: Cosine similarity between two tensors along a given dimension. 40 | """ 41 | # Make sure that x1 and x2 are float tensors 42 | if x1.dtype != torch.float: 43 | x1 = x1.float() 44 | if x2.dtype != torch.float: 45 | x2 = x2.float() 46 | # Compute cosine similarity between two tensors 47 | return F.cosine_similarity(x1, x2, dim, eps) 48 | 49 | 50 | # Compute (numpy) 51 | def _compute_numpy(self, 52 | x1: np.ndarray, 53 | x2: np.ndarray, 54 | dim: int = 1, 55 | eps: float = 1e-8 56 | ) -> np.ndarray: 57 | r""" 58 | Computes the cosine similarity between two numpy arrays along a given dimension. 59 | 60 | Arguments: 61 | x1 (np.ndarray): First numpy array. 62 | x2 (np.ndarray): Second numpy array. 63 | dim (int): Dimension (or axis in the numpy realm) to compute cosine similarity. 64 | eps (float): Epsilon value (to prevent division by zero). 65 | 66 | Returns: 67 | np.ndarray: Cosine similarity between two numpy arrays along a given dimension. 68 | """ 69 | # Compute cosine similarity between two numpy arrays along a given dimension "dim" 70 | return np.sum(x1 * x2, axis=dim) / np.maximum(np.linalg.norm(x1, axis=dim) * np.linalg.norm(x2, axis=dim), eps) 71 | 72 | 73 | # Compute 74 | def compute(self, 75 | x1: Union[Tensor, np.ndarray], 76 | x2: Union[Tensor, np.ndarray], 77 | dim: int = 0, 78 | eps: float = 1e-8 79 | ) -> Union[Tensor, np.ndarray]: 80 | r""" 81 | Computes the cosine similarity between two tensors (or numpy arrays) along a given dimension. 82 | 83 | * For two (non-zero) vectors, :math:`x_1` and :math:`x_2`, the cosine similarity is defined as follows: 84 | 85 | .. math:: 86 | :nowrap: 87 | 88 | \begin{align} 89 | \texttt{cosine-similarity}(x_1, x_2) & = |x_1|| \ ||x_2|| \cos(\theta) \\ 90 | & = \frac{x_1 \cdot x_2}{||x_1|| \ ||x_2||} \\ 91 | & = \frac{\sum_{i=1}^n x_{1i} x_{2i}}{\sqrt{\sum_{i=1}^n x_{1i}^2} \sqrt{\sum_{i=1}^n x_{2i}^2}} 92 | \end{align} 93 | 94 | where :math:`\theta` denotes the angle between the vectors, :math:`\cdot` the dot product, and :math:`||\cdot||` the norm operator. 95 | 96 | * In practice, the cosine similarity is computed as follows: 97 | 98 | .. math:: 99 | :nowrap: 100 | 101 | \begin{align} 102 | \texttt{cosine-similarity}(x_1, x_2) & = \frac{x_1 \cdot x_2}{\max(||x_1|| ||x_2||, \epsilon)} 103 | \end{align} 104 | 105 | where :math:`\epsilon` is a small value to avoid division by zero. 106 | 107 | 108 | Arguments: 109 | x1 (Union[Tensor, np.ndarray]): First tensor (or numpy array). 110 | x2 (Union[Tensor, np.ndarray]): Second tensor (or numpy array). 111 | dim (int): Dimension to compute cosine similarity (default: 0). 112 | eps (float): Epsilon value (to avoid division by zero). 113 | 114 | Returns: 115 | Union[Tensor, np.ndarray]: Cosine similarity between two tensors (or numpy arrays) along a given dimension. 116 | 117 | Raises: 118 | TypeError: If x1 and x2 are not of the same type (either tensor or numpy array). 119 | TypeError: If x1 and x2 are not tensors or numpy arrays. 120 | """ 121 | # Check if x1 and x2 are of the same type (either tensor or numpy array) 122 | if type(x1) != type(x2): 123 | raise TypeError("x1 and x2 must be of the same type (either tensor or numpy array).") 124 | 125 | # If x1 and x2 are tensors 126 | if type(x1) == Tensor: 127 | # Compute cosine similarity 128 | return self._compute_tensor(x1, x2, dim, eps) 129 | # If x1 and x2 are numpy arrays 130 | elif type(x1) == np.ndarray: 131 | # Compute cosine similarity 132 | return self._compute_numpy(x1, x2, dim, eps) 133 | # If x1 and x2 are not tensors or numpy arrays 134 | else: 135 | raise TypeError("x1 and x2 must be either tensors or numpy arrays.") -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Unit Tests 2 | 3 | The purpose of this directory is to provide a set of functional tests that can be used to verify the accuracy of the algorithms and functions utilized within the library. 4 | 5 | To run these tests, one can utilize [pytest](https://docs.pytest.org/en/7.3.x/getting-started.html#run-multiple-tests) to execute all files with names of the form `test_*.py` or `*_test.py` located in the current directory and its subdirectories. 6 | 7 | To install `pytest`, please run the following command in your terminal: 8 | 9 | ```bash 10 | pip install -U pytest 11 | ``` 12 | 13 | Executing the `pytest` command in the current directory should generate an output similar to the following: 14 | 15 | ```python 16 | >>> pytest 17 | ============================================================================= test session starts ============================================================================= 18 | platform darwin -- Python 3.9.12, pytest-7.2.2, pluggy-1.0.0 19 | rootdir: /Users/machine/string2string 20 | collected 15 items 21 | 22 | test_alignment.py ....... [ 46%] 23 | test_distance.py ..... [ 80%] 24 | test_rogue.py . [ 86%] 25 | test_sacrebleu.py . [ 93%] 26 | test_search.py . [100%] 27 | 28 | ============================================================================= 15 passed in 6.05s ============================================================================== 29 | ``` 30 | -------------------------------------------------------------------------------- /tests/test_alignment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the distance module. 3 | """ 4 | import string 5 | import unittest 6 | from unittest import TestCase 7 | import random 8 | 9 | from string2string.alignment import ( 10 | NeedlemanWunsch, 11 | Hirschberg, 12 | SmithWaterman, 13 | DTW, 14 | LongestCommonSubsequence, 15 | LongestCommonSubstring, 16 | ) 17 | 18 | class AlignmentTestCase(TestCase): 19 | # Testing LngestCommonSubsequence 20 | def test_longest_common_subsequence(self): 21 | lcsubsequence = LongestCommonSubsequence() 22 | # Example 1 23 | length, candidates = lcsubsequence.compute( 24 | "aa", "aa", 25 | returnCandidates=True, 26 | ) 27 | self.assertEqual(length, 2.0) 28 | self.assertCountEqual(candidates, ["aa"]) 29 | # Example 2 30 | length, candidates = lcsubsequence.compute( 31 | "ab", "ba", returnCandidates=True 32 | ) 33 | self.assertEqual(length, 1.0) 34 | self.assertCountEqual(candidates, ["a", "b"]) 35 | # Example 3 36 | length, candidates = lcsubsequence.compute( 37 | "ab", "cd", returnCandidates=True 38 | ) 39 | self.assertEqual(length, 0.0) 40 | self.assertCountEqual(candidates, []) 41 | # Example 4 42 | length, candidates = lcsubsequence.compute( 43 | "ab", "xxaaabyy", returnCandidates=True 44 | ) 45 | self.assertEqual(length, 2.0) 46 | self.assertCountEqual(candidates, ["ab"]) 47 | # Example 5 48 | length, candidates = lcsubsequence.compute( 49 | "abcd", "xcxaaabydy", returnCandidates=True 50 | ) 51 | self.assertEqual(length, 3.0) 52 | self.assertCountEqual(candidates, ["abd"]) 53 | # Example 6 54 | length, candidates = lcsubsequence.compute( 55 | "aabbccdd", "dcdcbaba", returnCandidates=True 56 | ) 57 | self.assertEqual(length, 2.0) 58 | self.assertCountEqual(candidates, ["dd", "cc", "bb", "aa", "cd", "ab"]) 59 | # Example 7 60 | length, candidates = lcsubsequence.compute( 61 | ["abcd"], ["xcxaaabydy"], 62 | returnCandidates=True, 63 | ) 64 | self.assertEqual(length, 0.0) 65 | self.assertCountEqual(candidates,[]) 66 | # Example 8 67 | length, candidates = lcsubsequence.compute( 68 | ["a", "bb", "c"], 69 | ["a", "bb", "c"], 70 | returnCandidates=True, 71 | ) 72 | self.assertEqual(length, 3.0) 73 | self.assertCountEqual(candidates, [["a", "bb", "c"]]) 74 | # Example 9 75 | length, candidates = lcsubsequence.compute( 76 | ["a", "b", "c", "dd"], 77 | ["x", "c", "x", "a", "a", "a", "b", "y", "dd", "y"], 78 | returnCandidates=True, 79 | ) 80 | self.assertEqual(length, 3.0) 81 | self.assertCountEqual(candidates, [["a", "b", "dd"]]) 82 | # Example 10 83 | length, candidates = lcsubsequence.compute( 84 | ["a", "t", "b", "c", "y", "dd", "xyz"], 85 | ["x", "c", "x", "t", "a", "a", "a", "b", "y", "dd", "y", "xyz"], 86 | returnCandidates=True, 87 | ) 88 | self.assertEqual(length, 5.0) 89 | self.assertCountEqual( 90 | candidates, [["t", "b", "y", "dd", "xyz"], ["a", "b", "y", "dd", "xyz"]] 91 | ) 92 | 93 | 94 | # Testing LongestCommonSubstring 95 | def test_longest_common_subsubtring(self): 96 | lcsubstring = LongestCommonSubstring() 97 | # Example 1 98 | length, candidates = lcsubstring.compute( 99 | "aa", "aa", 100 | returnCandidates=True, 101 | ) 102 | self.assertEqual(length, 2) 103 | self.assertCountEqual(candidates, ["aa"]) 104 | # Example 2 105 | length, candidates = lcsubstring.compute( 106 | "aabb", "aa", 107 | returnCandidates=True, 108 | ) 109 | self.assertEqual(length, 2) 110 | self.assertCountEqual(candidates, ["aa"]) 111 | # Example 3 112 | length, candidates = lcsubstring.compute( 113 | "aabbaa", "aa", 114 | returnCandidates=True, 115 | ) 116 | self.assertEqual(length, 2) 117 | self.assertCountEqual(candidates, ["aa"]) 118 | # Example 4 119 | length, candidates = lcsubstring.compute( 120 | "xyxy", "yxyx", 121 | returnCandidates=True, 122 | ) 123 | self.assertEqual(length, 3) 124 | self.assertCountEqual(candidates, ["xyx", "yxy"]) 125 | # Example 4 126 | length, candidates = lcsubstring.compute( 127 | "xyxy", "yxyx", 128 | returnCandidates=True, 129 | ) 130 | self.assertEqual(length, 3) 131 | self.assertCountEqual(candidates, ["xyx", "yxy"]) 132 | # Example 5 133 | length, candidates = lcsubstring.compute( 134 | ["x", "y", "x", "y"], 135 | ["y", "x", "y", "x"], 136 | returnCandidates=True, 137 | ) 138 | self.assertEqual(length, 3) 139 | self.assertCountEqual( 140 | set(map(tuple, candidates)), 141 | set(map(tuple, [["x", "y", "x"], ["y", "x", "y"]])) 142 | ) 143 | # Example 6 144 | length, candidates = lcsubstring.compute( 145 | ["a", "a", "a", "a"], ["a"], 146 | returnCandidates=True, 147 | ) 148 | self.assertEqual(length, 1) 149 | self.assertCountEqual( 150 | set(map(tuple, candidates)), 151 | set(map(tuple, [["a"]])) 152 | ) 153 | # Example 7 154 | length, candidates = lcsubstring.compute( 155 | "x", "xxxx", 156 | returnCandidates=True, 157 | ) 158 | self.assertEqual(length, 1) 159 | self.assertCountEqual(candidates, ["x"]) 160 | # Example 8 161 | length, candidates = lcsubstring.compute( 162 | " julia ", " julie ", 163 | returnCandidates=True, 164 | ) 165 | self.assertEqual(length, 5) 166 | self.assertCountEqual(candidates, [" juli"]) 167 | 168 | 169 | # Testing NeedlemanWunsch 170 | def test_needleman_wunsch(self): 171 | # First set of examples 172 | needlemanwunsch = NeedlemanWunsch( 173 | match_weight=1, 174 | mismatch_weight=-1, 175 | gap_weight=-1, 176 | ) 177 | # Example 1 178 | aligned_str1, aligned_str2 = needlemanwunsch.get_alignment( 179 | str1 = ["a", "b", "bb"], 180 | str2 = ["a", "bb", "b", "bb"], 181 | ) 182 | self.assertEqual(aligned_str1, "a | - | b | bb") 183 | self.assertEqual(aligned_str2, "a | bb | b | bb") 184 | # Example 2 185 | aligned_str1, aligned_str2 = needlemanwunsch.get_alignment( 186 | str1 = "abcbd", 187 | str2 = "abcde", 188 | ) 189 | self.assertEqual(aligned_str1, "a | b | c | b | d | -") 190 | self.assertEqual(aligned_str2, "a | b | c | - | d | e") 191 | # Example 3 192 | aligned_str1, aligned_str2 = needlemanwunsch.get_alignment( 193 | str1 = "AATGCATGCGTT", 194 | str2 = "AATGATTACATT", 195 | ) 196 | self.assertTrue(aligned_str1 == 'A | A | T | G | C | A | T | G | - | C | G | T | T' or aligned_str1 == "A | A | T | G | C | A | - | T | G | C | G | T | T") 197 | self.assertEqual(aligned_str2, "A | A | T | G | - | A | T | T | A | C | A | T | T") 198 | 199 | # Another set of examples 200 | needlemanwunsch = NeedlemanWunsch( 201 | match_weight=2, 202 | mismatch_weight=-1, 203 | gap_weight=-2, 204 | ) 205 | # Example 4 206 | aligned_str1, aligned_str2 = needlemanwunsch.get_alignment( 207 | str1 = "AGTACGCA", 208 | str2 = "TATGC", 209 | ) 210 | self.assertEqual(aligned_str1, "A | G | T | A | C | G | C | A") 211 | self.assertEqual(aligned_str2, "- | - | T | A | T | G | C | -") 212 | # Example 5 213 | aligned_str1, aligned_str2 = needlemanwunsch.get_alignment( 214 | str1 = ['G', 'A', 'TWW', 'T', 'AWW', 'C', 'A'], 215 | str2 = ['G', 'CAA', 'A', 'T', 'XXG', 'C', 'U'], 216 | ) 217 | self.assertEqual(aligned_str1, "G | A | TWW | T | AWW | C | A") 218 | self.assertEqual(aligned_str2, "G | CAA | A | T | XXG | C | U") 219 | # Example 6 220 | aligned_str1, aligned_str2 = needlemanwunsch.get_alignment( 221 | str1 = "GATTACA", 222 | str2 = "GCATGCU", 223 | ) 224 | self.assertEqual(aligned_str1, "G | A | T | T | A | C | A") 225 | self.assertEqual(aligned_str2, "G | C | A | T | G | C | U") 226 | 227 | 228 | # Testing Hirschberg 229 | def test_hiirschberg(self): 230 | # First set of examples 231 | hirschberg = Hirschberg( 232 | match_weight=1, 233 | mismatch_weight=-1, 234 | gap_weight=-1, 235 | ) 236 | # Example 1 237 | aligned_str1, aligned_str2 = hirschberg.get_alignment( 238 | str1 = ["a", "b", "bb"], 239 | str2 = ["a", "bb", "b", "bb"], 240 | ) 241 | self.assertEqual(aligned_str1, "a | - | b | bb") 242 | self.assertEqual(aligned_str2, "a | bb | b | bb") 243 | # Example 2 244 | aligned_str1, aligned_str2 = hirschberg.get_alignment( 245 | str1 = "abcbd", 246 | str2 = "abcde", 247 | ) 248 | self.assertEqual(aligned_str1, "a | b | c | b | d | -") 249 | self.assertEqual(aligned_str2, "a | b | c | - | d | e") 250 | # Example 3 251 | aligned_str1, aligned_str2 = hirschberg.get_alignment( 252 | str1 = "AATGCATGCGTT", 253 | str2 = "AATGATTACATT", 254 | ) 255 | self.assertTrue(aligned_str1 == 'A | A | T | G | C | A | T | G | - | C | G | T | T' or aligned_str1 == "A | A | T | G | C | A | - | T | G | C | G | T | T") 256 | self.assertEqual(aligned_str2, "A | A | T | G | - | A | T | T | A | C | A | T | T") 257 | 258 | # Another set of examples 259 | hirschberg = Hirschberg( 260 | match_weight=2, 261 | mismatch_weight=-1, 262 | gap_weight=-2, 263 | ) 264 | # Example 4 265 | aligned_str1, aligned_str2 = hirschberg.get_alignment( 266 | str1 = "AGTACGCA", 267 | str2 = "TATGC", 268 | ) 269 | self.assertEqual(aligned_str1, "A | G | T | A | C | G | C | A") 270 | self.assertEqual(aligned_str2, "- | - | T | A | T | G | C | -") 271 | # Example 5 272 | aligned_str1, aligned_str2 = hirschberg.get_alignment( 273 | str1 = ['G', 'A', 'TWW', 'T', 'AWW', 'C', 'A'], 274 | str2 = ['G', 'CAA', 'A', 'T', 'XXG', 'C', 'U'], 275 | ) 276 | self.assertEqual(aligned_str1, "G | A | TWW | T | AWW | C | A") 277 | self.assertEqual(aligned_str2, "G | CAA | A | T | XXG | C | U") 278 | # Example 6 279 | aligned_str1, aligned_str2 = hirschberg.get_alignment( 280 | str1 = "GATTACA", 281 | str2 = "GCATGCU", 282 | ) 283 | self.assertEqual(aligned_str1, "G | A | T | T | A | C | A") 284 | self.assertEqual(aligned_str2, "G | C | A | T | G | C | U") 285 | 286 | 287 | # Testing SmithWaterman 288 | def test_smithwaterman(self): 289 | smithwaterman = SmithWaterman( 290 | match_weight=1, 291 | mismatch_weight=-1, 292 | gap_weight=-1, 293 | gap_char="-", 294 | ) 295 | # Example 1 296 | aligned_str1, aligned_str2 = smithwaterman.get_alignment( 297 | str1 = "abcbd", 298 | str2 = "abcde", 299 | ) 300 | self.assertEqual(aligned_str1, "a | b | c") 301 | self.assertEqual(aligned_str2, "a | b | c") 302 | # Example 2 303 | aligned_str1, aligned_str2 = smithwaterman.get_alignment( 304 | str1 = "GAATGCATGCGTT", 305 | str2 = "TAATGCATGCGGT", 306 | ) 307 | self.assertEqual(aligned_str1, "A | A | T | G | C | A | T | G | C | G") 308 | self.assertEqual(aligned_str2, "A | A | T | G | C | A | T | G | C | G") 309 | # Example 3 310 | aligned_str1, aligned_str2 = smithwaterman.get_alignment( 311 | str1 = "TACGGGCCCGCTAC", 312 | str2 = "TAGCCCTATCGGTCA", 313 | ) 314 | self.assertEqual(aligned_str1, "T | A | - | C | G | G") 315 | self.assertEqual(aligned_str2, "T | A | T | C | G | G") 316 | # Example 4 317 | aligned_str1, aligned_str2 = smithwaterman.get_alignment( 318 | str1 = "GAGTCGCTACGGGCCCGCTAC", 319 | str2 = "TAGCCTATGCACCTATCGGTCA", 320 | ) 321 | self.assertEqual(aligned_str1, "C | T | A | - | C | G | G") 322 | self.assertEqual(aligned_str2, "C | T | A | T | C | G | G") 323 | 324 | 325 | # Testing DTW 326 | def test_dtw(self): 327 | dtw = DTW() 328 | # Example 1 329 | alignment = dtw.get_alignment_path( 330 | sequence1=[1, 2, 3], 331 | sequence2=[1, 2, 3, 4], 332 | distance='absolute_difference', 333 | ) 334 | self.assertCountEqual(alignment, [(0, 0), (1, 1), (2, 2), (2, 3)]) 335 | # Example 2 336 | alignment = dtw.get_alignment_path( 337 | sequence1=[1, 2, 3], 338 | sequence2=[1, 2, 3], 339 | distance='absolute_difference', 340 | ) 341 | self.assertCountEqual(alignment, [(0, 0), (1, 1), (2, 2)]) 342 | # Example 3 343 | alignment = dtw.get_alignment_path( 344 | sequence1="abc", 345 | sequence2="abcd", 346 | distance='absolute_difference', 347 | ) 348 | self.assertCountEqual(alignment, [(0, 0), (1, 1), (2, 2), (2, 3)]) 349 | # Example 4 350 | alignment = dtw.get_alignment_path( 351 | sequence1=["a", "b", "c"], 352 | sequence2=["a", "b", "c", "d"], 353 | distance='absolute_difference', 354 | ) 355 | self.assertCountEqual(alignment, [(0, 0), (1, 1), (2, 2), (2, 3)]) 356 | # Example 5 357 | alignment = dtw.get_alignment_path( 358 | sequence1=[10, 20, 30], 359 | sequence2=[20, 50, 60, 30], 360 | distance='absolute_difference', 361 | ) 362 | self.assertCountEqual(alignment, [(0, 0), (1, 0), (2, 1), (2, 2), (2, 3)]) 363 | 364 | 365 | # Auxiliary function that generates a random string of length n 366 | def generate_random_string(self, bound: int): 367 | alphabet = string.ascii_lowercase 368 | n = random.randint(1, bound) 369 | return "".join(random.choice(alphabet) for i in range(n)) 370 | 371 | 372 | # Testing parallelization/multiprocessing 373 | def test_parallelization(self): 374 | # Generate 50 random strings of length 100 375 | random_test_pairs = [ 376 | (self.generate_random_string(100), self.generate_random_string(100)) 377 | for i in range(50) 378 | ] 379 | 380 | # Instantiate the LongestCommonSubstring class 381 | lcsubstring = LongestCommonSubstring() 382 | 383 | # Compute the results serially 384 | results_serial = [ 385 | lcsubstring.compute(str1=str1, str2=str2) for str1, str2 in random_test_pairs 386 | ] 387 | # Compute the results in parallel 388 | results_parallel = lcsubstring.compute_multiple_pairs(random_test_pairs) 389 | # Check that the results are the same 390 | self.assertEqual(results_serial, results_parallel) 391 | 392 | 393 | # Instantiate the LongestCommonSubsequence class 394 | lcsubsequence = LongestCommonSubsequence() 395 | # Compute the results serially 396 | results_serial = [ 397 | lcsubsequence.compute(str1=str1, str2=str2) for str1, str2 in random_test_pairs 398 | ] 399 | # Compute the results in parallel 400 | results_parallel = lcsubsequence.compute_multiple_pairs(random_test_pairs) 401 | # Check that the results are the same 402 | self.assertEqual(results_serial, results_parallel) 403 | 404 | if __name__ == "__main__": 405 | unittest.main() 406 | -------------------------------------------------------------------------------- /tests/test_distance.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the distance module. 3 | """ 4 | import unittest 5 | from unittest import TestCase 6 | 7 | from string2string.distance import LevenshteinEditDistance, HammingDistance, DamerauLevenshteinDistance, JaccardIndex 8 | 9 | class DistanceTestCase(TestCase): 10 | def test_levenshtein_edit_distance_unit_operations(self): 11 | ## Case 1: Costs of insertion, deletion, and substitution are all 1. 12 | edit_distance = LevenshteinEditDistance() 13 | # Example 0 14 | dist = edit_distance.compute("", "") 15 | self.assertEqual(dist, 0.0) 16 | # Example 1 17 | dist = edit_distance.compute("aa", "bb") 18 | self.assertEqual(dist, 2.0) 19 | # Example 2 20 | dist = edit_distance.compute("monty-python", "monty-python") 21 | self.assertEqual(dist, 0.0) 22 | # Example 3 23 | dist = edit_distance.compute("kitten", "sitting") 24 | self.assertEqual(dist, 3.0) 25 | # Example 4 26 | dist = edit_distance.compute("sitting", "kitten") 27 | self.assertEqual(dist, 3.0) 28 | # Example 5 29 | dist = edit_distance.compute("aaaaa", "a") 30 | self.assertEqual(dist, 4.0) 31 | # Example 6 32 | dist = edit_distance.compute("", "abcdef") 33 | self.assertEqual(dist, 6.0) 34 | # Example 7 35 | dist = edit_distance.compute("abcdef", "") 36 | self.assertEqual(dist, 6.0) 37 | # Example 8 38 | dist = edit_distance.compute("algorithm", "al-Khwarizmi") 39 | self.assertEqual(dist, 8.0) 40 | # Example 9 41 | dist = edit_distance.compute("qrrq", "rqqr") 42 | self.assertEqual(dist, 3.0) 43 | # Example 10 44 | dist = edit_distance.compute(["kurt", "godel"], ["godel", "kurt"]) 45 | self.assertEqual(dist, 2.0) 46 | # Example 11 47 | dist = edit_distance.compute( 48 | ["kurt", "godel", "kurt"], ["godel", "kurt"] 49 | ) 50 | self.assertEqual(dist, 1.0) 51 | 52 | def test_levenshtein_edit_distance_weighted_operations(self): 53 | ## Case 2: insertion = 2., deletion = 2., substitution = 1., match = 0. 54 | weighted_edit_distance = LevenshteinEditDistance( 55 | match_weight=0.0, 56 | insert_weight=2.0, 57 | delete_weight=2.0, 58 | substitute_weight=1.0, 59 | ) 60 | # Example 1 61 | dist = weighted_edit_distance.compute("aa", "bb") 62 | self.assertEqual(dist, 2.0) 63 | # Example 2 64 | dist = weighted_edit_distance.compute("aca", "bcb") 65 | self.assertEqual(dist, 2.0) 66 | # Example 3 67 | dist = weighted_edit_distance.compute("aa", "") 68 | self.assertEqual(dist, 4.0) 69 | # Example 4 70 | dist = weighted_edit_distance.compute("", "aa") 71 | self.assertEqual(dist, 4.0) 72 | # Example 5 73 | dist = weighted_edit_distance.compute("witty", "witty") 74 | self.assertEqual(dist, 0.0) 75 | # Example 6 76 | dist = weighted_edit_distance.compute("ttss", "stst") 77 | self.assertEqual(dist, 2.0) 78 | 79 | def test_damerau_levenshtein_edit_distance_unit_operations(self): 80 | ## Case 1: Costs of insertion, deletion, substitution, and transposition are all 1. 81 | dameraulevenshteindist = DamerauLevenshteinDistance() 82 | # Example 0 83 | dist = dameraulevenshteindist.compute("", "") 84 | self.assertEqual(dist, 0.0) 85 | # Example 1 86 | dist = dameraulevenshteindist.compute("aa", "bb") 87 | self.assertEqual(dist, 2.0) 88 | # # Example 2 89 | dist = dameraulevenshteindist.compute( 90 | "monty-python", "monty-python" 91 | ) 92 | self.assertEqual(dist, 0.0) 93 | # Example 3 94 | dist = dameraulevenshteindist.compute("ab", "ba") 95 | self.assertEqual(dist, 1.0) 96 | # Example 4 97 | dist = dameraulevenshteindist.compute("sitting", "kitten") 98 | self.assertEqual(dist, 3.0) 99 | # Example 5 100 | dist = dameraulevenshteindist.compute("baaaaa", "ab") 101 | self.assertEqual(dist, 5.0) 102 | # Example 6 103 | dist = dameraulevenshteindist.compute("ababab", "bababa") 104 | self.assertEqual(dist, 2.0) 105 | # Example 7 106 | dist = dameraulevenshteindist.compute("abxymn", "bayxnm") 107 | self.assertEqual(dist, 3.0) 108 | # Example 8 109 | dist = dameraulevenshteindist.compute("wikiepdia", "wikipedia") 110 | self.assertEqual(dist, 1.0) 111 | # Example 9 112 | dist = dameraulevenshteindist.compute("microaoft", "microsoft") 113 | self.assertEqual(dist, 1.0) 114 | # Example 10 115 | dist = dameraulevenshteindist.compute( 116 | ["kurt", "godel"], ["godel", "kurt"] 117 | ) 118 | self.assertEqual(dist, 1.0) 119 | # Example 11 120 | dist = dameraulevenshteindist.compute( 121 | ["kurt", "godel", "kurt"], ["godel", "kurt"] 122 | ) 123 | self.assertEqual(dist, 1.0) 124 | # Example 12 125 | dist = dameraulevenshteindist.compute("microaoft", "microsoft") 126 | self.assertEqual(dist, 1.0) 127 | 128 | def test_hamming_edit_distance(self): 129 | hamming_distance = HammingDistance() 130 | # Example 1 131 | dist = hamming_distance.compute("aa", "bb") 132 | self.assertEqual(dist, 2.0) 133 | # Example 2 134 | dist = hamming_distance.compute("aac", "abc") 135 | self.assertEqual(dist, 1.0) 136 | # Example 3 137 | dist = hamming_distance.compute("Turing1912", "during1921") 138 | self.assertEqual(dist, 3.0) 139 | # Example 4 140 | dist = hamming_distance.compute("John von Neumann", "John von Neumann") 141 | self.assertEqual(dist, 0.0) 142 | # Example 5 143 | dist = hamming_distance.compute("Earth", "earth") 144 | self.assertEqual(dist, 1.0) 145 | # Example 6 146 | with self.assertRaises(ValueError): 147 | dist = hamming_distance.compute(" ", "abc") 148 | # Example 7 149 | dist = hamming_distance.compute("", "") 150 | self.assertEqual(dist, 0.0) 151 | # Example 8 152 | dist = hamming_distance.compute( 153 | ["", "abc", "234", "#"], ["", "abc", "123", "#"] 154 | ) 155 | self.assertEqual(dist, 1.0) 156 | # Example 9 157 | dist = hamming_distance.compute( 158 | ["a", "ab", "abc", "abcd", "abc", "ab", "a"], 159 | ["a", "ab", "abc", "abcd", "abc", "ab", "a"], 160 | ) 161 | self.assertEqual(dist, 0.0) 162 | 163 | 164 | def test_jaccard_indexx(self): 165 | jaccard_index = JaccardIndex() 166 | # Example 1 167 | dist = jaccard_index.compute("aa", "bb") 168 | self.assertEqual(dist, 1.0) 169 | # Example 2 170 | dist = jaccard_index.compute("ab", "ba") 171 | self.assertEqual(dist, 0.0) 172 | # Example 3 173 | dist = jaccard_index.compute("ab", "baaaaab") 174 | self.assertEqual(dist, 0.0) 175 | # Example 4 176 | dist = jaccard_index.compute("ab", "bbbbaaaacd") 177 | self.assertEqual(dist, 0.5) 178 | # Example 5 179 | dist = jaccard_index.compute("ab", "cd") 180 | self.assertEqual(dist, 1.0) 181 | # Example 6 182 | dist = jaccard_index.compute( 183 | "The quick brown fox jumps over the lazy dog", 184 | "The quick brown cat jumps over the lazy dog" 185 | ) 186 | self.assertEqual(dist, 0.0714285714285714) 187 | # Example 7 188 | dist = jaccard_index.compute("apple", "banana") 189 | self.assertEqual(dist, 0.8333333333333334) 190 | # Example 8 191 | dist = jaccard_index.compute( 192 | ['a','p', 'p', 'l', 'e'], 193 | ['b', 'a', 'n', 'a', 'n', 'a'] 194 | ) 195 | self.assertEqual(dist, 0.8333333333333334) 196 | # Example 9 197 | dist = jaccard_index.compute( 198 | ['a','p', 'p', 'l', 'e'], 199 | ['a','p', 'p', 'p', 'l', 'e', 'e'], 200 | ) 201 | self.assertEqual(dist, 0.0) 202 | 203 | 204 | if __name__ == "__main__": 205 | unittest.main() -------------------------------------------------------------------------------- /tests/test_rogue.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the ROUGE module. 3 | """ 4 | import unittest 5 | from unittest import TestCase 6 | 7 | from string2string.metrics import ROUGE 8 | 9 | class ROUGE_TestCase(TestCase): 10 | def test_rogue(self): 11 | # Initialize the ROUGE metric 12 | rogue = ROUGE() 13 | # Example 1 14 | candidates = ["The cat is sitting on the mat.", "The dog is barking at the mailman.", "The bird is singing in the tree."] 15 | references = [["The cat is sitting on the mat."], ["The dog is barking at the postman."], ["The bird sings on the tree."]] 16 | # Compute the ROUGE score 17 | result = rogue.compute(candidates, references) 18 | print(result) 19 | r1, r2, rl, rlsum = result['rouge1'], result['rouge2'], result['rougeL'], result['rougeLsum'] 20 | # Check that the score is correct 21 | self.assertAlmostEqual(r1, 0.824, delta=0.01) 22 | self.assertAlmostEqual(r2, 0.732, delta=0.01) 23 | self.assertAlmostEqual(rl, 0.824, delta=0.01) 24 | self.assertAlmostEqual(rlsum, 0.824, delta=0.01) 25 | 26 | # Example 2 27 | candidates = ['The quick brown fox jumps over the lazy dog.', 'This is a test.'] 28 | references = [['The quick brown fox jumps over the lazy dog.'], ['This is only a test.']] 29 | # Compute the ROUGE score 30 | result = rogue.compute(candidates, references) 31 | r1, r2, rl, rlsum = result['rouge1'], result['rouge2'], result['rougeL'], result['rougeLsum'] 32 | # Check that the score is correct 33 | self.assertAlmostEqual(r1, 0.944, delta=0.01) 34 | self.assertAlmostEqual(r2, 0.786, delta=0.01) 35 | self.assertAlmostEqual(rl, 0.944, delta=0.01) 36 | self.assertAlmostEqual(rlsum, 0.944, delta=0.01) 37 | 38 | # Example 3 39 | candidates = ['I am eating lunch.', 'He is studying.'] 40 | references = [['I am having lunch.'], ['He is studying hard.']] 41 | # Compute the ROUGE score 42 | result = rogue.compute(candidates, references) 43 | r1, r2, rl, rlsum = result['rouge1'], result['rouge2'], result['rougeL'], result['rougeLsum'] 44 | # Check that the score is correct 45 | self.assertAlmostEqual(r1, 0.661, delta=0.01) 46 | self.assertAlmostEqual(r2, 0.367, delta=0.01) 47 | self.assertAlmostEqual(rl, 0.661, delta=0.01) 48 | self.assertAlmostEqual(rlsum, 0.661, delta=0.01) 49 | 50 | # Example 4 51 | candidates = ['Random sentence.', 'Random sentence.'] 52 | references = [['Sentence.'], ['Sentence.']] 53 | # Compute the ROUGE score 54 | result = rogue.compute(candidates, references) 55 | r1, r2, rl, rlsum = result['rouge1'], result['rouge2'], result['rougeL'], result['rougeLsum'] 56 | # Check that the score is correct 57 | self.assertAlmostEqual(r1, 0.0, delta=0.01) 58 | self.assertAlmostEqual(r2, 0.0, delta=0.01) 59 | self.assertAlmostEqual(rl, 0.0, delta=0.01) 60 | self.assertAlmostEqual(rlsum, 0.0, delta=0.01) 61 | 62 | # Example 5 63 | candidates = ['Random sentence 1.', 'Random sentence 2.'] 64 | references = [['Random sentence 1.'], ['Random sentence 2.']] 65 | # Compute the ROUGE score 66 | result = rogue.compute(candidates, references) 67 | r1, r2, rl, rlsum = result['rouge1'], result['rouge2'], result['rougeL'], result['rougeLsum'] 68 | # Check that the score is correct 69 | self.assertAlmostEqual(r1, 1.0, delta=0.01) 70 | self.assertAlmostEqual(r2, 1.0, delta=0.01) 71 | self.assertAlmostEqual(rl, 1.0, delta=0.01) 72 | self.assertAlmostEqual(rlsum, 1.0, delta=0.01) 73 | 74 | if __name__ == "__main__": 75 | unittest.main() -------------------------------------------------------------------------------- /tests/test_sacrebleu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the sacreBLEU module. 3 | """ 4 | import unittest 5 | from unittest import TestCase 6 | 7 | from string2string.metrics import sacreBLEU 8 | 9 | 10 | class SacreBLEUTestCase(TestCase): 11 | def test_sacrebleu(self): 12 | # Initialize the sacreBLEU metric 13 | sbleu = sacreBLEU() 14 | # Example 1 15 | candidates = ["The cat is sitting on the mat.", "The dog is barking at the mailman.", "The bird is singing in the tree."] 16 | references = [["The cat is sitting on the mat."], ["The dog is barking at the postman."], ["The bird sings on the tree."]] 17 | # Compute the sacreBLEU score 18 | result = sbleu.compute(candidates, references) 19 | score = result['score'] 20 | # Check that the score is correct 21 | self.assertAlmostEqual(score, 66.37, delta=0.01) 22 | 23 | # Example 2 24 | candidates = ['The quick brown fox jumps over the lazy dog.', 'This is a test.'] 25 | references = [['The quick brown fox jumps over the lazy dog.'], ['This is only a test.']] 26 | # Compute the sacreBLEU score 27 | result = sbleu.compute(candidates, references) 28 | score = result['score'] 29 | # Check that the score is correct 30 | self.assertAlmostEqual(score, 81.90, delta=0.01) 31 | 32 | # Example 3 33 | candidates = ['I am eating lunch.', 'He is studying.'] 34 | references = [['I am having lunch.'], ['He is studying hard.']] 35 | # Compute the sacreBLEU score 36 | result = sbleu.compute(candidates, references) 37 | score = result['score'] 38 | # Check that the score is correct 39 | self.assertAlmostEqual(score, 32.28, delta=0.01) 40 | 41 | # Example 4 42 | candidates = ['Random sentence.', 'Random sentence.'] 43 | references = [['Sentence.'], ['Sentence.']] 44 | # Compute the sacreBLEU score 45 | result = sbleu.compute(candidates, references) 46 | score = result['score'] 47 | # Check that the score is correct 48 | self.assertAlmostEqual(score, 0.0, delta=0.01) 49 | 50 | # Example 5 51 | candidates = ['Random sentence 1.', 'Random sentence 2.'] 52 | references = [['Random sentence 1.'], ['Random sentence 2.']] 53 | # Compute the sacreBLEU score 54 | result = sbleu.compute(candidates, references) 55 | score = result['score'] 56 | # Check that the score is correct 57 | self.assertAlmostEqual(score, 100., delta=0.01) 58 | 59 | candidates = ['The sun is shining.', 'The birds are chirping.', 'She is playing the guitar.', 'He is cooking dinner.'] 60 | references = [['The sun is shining.', 'The sun is bright.'], ['The birds are singing.', 'The harold is singing.'], ['Julie is playing the flute.', 'She is playing the piano.'], ['Chef is cooking dinner.', 'He is cooking lunch.']] 61 | # Compute the sacreBLEU score 62 | result = sbleu.compute(candidates, references) 63 | print(result) 64 | 65 | if __name__ == "__main__": 66 | unittest.main() -------------------------------------------------------------------------------- /tests/test_search.py: -------------------------------------------------------------------------------- 1 | """ 2 | Unit tests for the search module. 3 | """ 4 | import random 5 | import unittest 6 | from unittest import TestCase 7 | 8 | from string2string.misc import PolynomialRollingHash 9 | from string2string.search import RabinKarpSearch, KMPSearch, BoyerMooreSearch, NaiveSearch 10 | 11 | class SearcTestCase(TestCase): 12 | def test_lexical_search_algs(self): 13 | # Initialize the rolling hash function 14 | rolling_hash = PolynomialRollingHash( 15 | base=10, 16 | modulus=65537, 17 | ) 18 | # Initialize the Rabin-Karp search algorithm 19 | rabin_karp = RabinKarpSearch(hash_function=rolling_hash) 20 | # Initialize the KMP search algorithm 21 | knuth_morris_pratt = KMPSearch() 22 | # Initialize the Boyer-Moore search algorithm 23 | bayer_moore = BoyerMooreSearch() 24 | # Initialize the naive search algorithm 25 | naive = NaiveSearch() 26 | 27 | # Example 1 28 | pattern = 'Jane Austen' 29 | text = 'Sense and Sensibility, Pride and Prejudice, Emma, Mansfield Park, Northanger Abbey, Persuasion, and Lady Susan were written by Jane Austen and are important works of English literature.' 30 | # Search for the pattern in the text using all four algorithms 31 | idx_rabin_karp = rabin_karp.search(pattern, text) 32 | idx_knuth_morris_pratt = knuth_morris_pratt.search(pattern, text) 33 | idx_bayer_moore = bayer_moore.search(pattern, text) 34 | idx_naive = naive.search(pattern, text) 35 | # Check that all the four indices are the same 36 | self.assertEqual(idx_rabin_karp, idx_knuth_morris_pratt) 37 | self.assertEqual(idx_rabin_karp, idx_bayer_moore) 38 | self.assertEqual(idx_rabin_karp, idx_naive) 39 | 40 | # Example 2-11 (randomly generated) 41 | for _ in range(10): 42 | # Randomly generate a pattern and a text, using random strings of length 10 and 100, respectively 43 | pattern = ''.join(random.choices(['a', 'b', 'c'], k=5)) 44 | text = ''.join(random.choices(['a', 'b', 'c'], k=100)) 45 | # Search for the pattern in the text using all four algorithms 46 | idx_rabin_karp = rabin_karp.search(pattern, text) 47 | idx_knuth_morris_pratt = knuth_morris_pratt.search(pattern, text) 48 | idx_bayer_moore = bayer_moore.search(pattern, text) 49 | idx_naive = naive.search(pattern, text) 50 | # Check that all the four indices are the same 51 | self.assertEqual(idx_rabin_karp, idx_knuth_morris_pratt) 52 | self.assertEqual(idx_rabin_karp, idx_bayer_moore) 53 | self.assertEqual(idx_rabin_karp, idx_naive) 54 | 55 | 56 | if __name__ == "__main__": 57 | unittest.main() --------------------------------------------------------------------------------