├── test ├── __init__.py ├── test_vectors.py ├── test_utils.py ├── test_usif.py ├── test_sif.py ├── test_data │ └── test_sentences.txt ├── test_inputs.py ├── test_sentencevectors.py ├── test_average.py └── test_base_s2v.py ├── .isort.cfg ├── .dockerignore ├── .gitattributes ├── media └── fse.png ├── fse ├── models │ ├── __init__.py │ ├── voidptr.h │ ├── average_inner.pxd │ ├── utils.py │ ├── sif.py │ ├── usif.py │ ├── average.py │ ├── average_inner.pyx │ └── sentencevectors.py ├── __init__.py ├── vectors.py └── inputs.py ├── MANIFEST.in ├── release.sh ├── Dockerfile ├── .lgtm.yml ├── .travis.yml ├── tests.sh ├── .gitignore ├── setup.py ├── evaluation ├── readme.txt └── LICENSE.txt ├── notebooks └── Speed Comparision.ipynb └── README.md /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile = black -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | *.c 4 | *.so -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /media/fse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oborchers/Fast_Sentence_Embeddings/HEAD/media/fse.png -------------------------------------------------------------------------------- /fse/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .average import Average 2 | from .sif import SIF 3 | from .usif import uSIF 4 | from .sentencevectors import SentenceVectors 5 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include fse/test/test_data * 2 | include README.md 3 | 4 | include fse/models/voidptr.h 5 | 6 | include fse/models/average_inner.pyx 7 | include fse/models/average_inner.pxd -------------------------------------------------------------------------------- /release.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | docformatter --in-place **/*.py --wrap-summaries 88 --wrap-descriptions 88 3 | isort --atomic **/*.py 4 | black . 5 | 6 | pytest -v --cov=fse --cov-report=term-missing -------------------------------------------------------------------------------- /fse/models/voidptr.h: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #if PY_VERSION_HEX >= 0x03020000 4 | 5 | /* 6 | ** compatibility with python >= 3.2, which doesn't have CObject anymore 7 | */ 8 | static void * PyCObject_AsVoidPtr(PyObject *obj) 9 | { 10 | void *ret = PyCapsule_GetPointer(obj, NULL); 11 | if (ret == NULL) { 12 | PyErr_Clear(); 13 | } 14 | return ret; 15 | } 16 | 17 | #endif -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.6.0 2 | 3 | RUN pip install -U pip 4 | 5 | RUN pip install scipy \ 6 | smart_open \ 7 | scikit-learn \ 8 | wordfreq \ 9 | huggingface-hub \ 10 | psutil 11 | 12 | ARG gensim==4.0.0 13 | RUN pip install -U "gensim==$gensim" pytest coverage 14 | 15 | ADD . /home 16 | WORKDIR /home 17 | RUN rm -rf build dist 18 | 19 | RUN pip install -e . 20 | 21 | CMD [ "pytest", "-vv" ] -------------------------------------------------------------------------------- /.lgtm.yml: -------------------------------------------------------------------------------- 1 | path_classifiers: 2 | test: 3 | - test 4 | - exclude: "**/test_*" 5 | - exclude: "fse/test" 6 | - exclude: "test" 7 | 8 | extraction: 9 | python: 10 | python_setup: 11 | requirements: 12 | - cython>=0.29 13 | cpp: 14 | index: 15 | build_command: 16 | - python3 setup.py build 17 | after_prepare: 18 | - pip3 install --upgrade --user cython 19 | - export PATH="$HOME/.local/bin:$PATH" 20 | 21 | queries: 22 | - include: py/file-not-closed -------------------------------------------------------------------------------- /fse/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from fse import models 4 | from fse.models import SIF, Average, SentenceVectors, uSIF 5 | from fse.vectors import FTVectors, Vectors 6 | 7 | from .inputs import ( 8 | BaseIndexedList, 9 | CIndexedList, 10 | CSplitCIndexedList, 11 | CSplitIndexedList, 12 | IndexedLineDocument, 13 | IndexedList, 14 | SplitCIndexedList, 15 | SplitIndexedList, 16 | ) 17 | 18 | 19 | class NullHandler(logging.Handler): 20 | def emit(self, record): 21 | pass 22 | 23 | 24 | logger = logging.getLogger("fse") 25 | if len(logger.handlers) == 0: # To ensure reload() doesn't add another one 26 | logger.addHandler(NullHandler()) 27 | 28 | 29 | __version__ = "0.2.0" 30 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | if: (type = push AND branch IN (master, develop)) OR (type = pull_request AND NOT branch =~ /no-ci/) 2 | sudo: false 3 | 4 | cache: 5 | apt: true 6 | directories: 7 | - $HOME/.cache/pip 8 | - $HOME/.ccache 9 | - $HOME/.pip-cache 10 | 11 | language: python 12 | python: 13 | - "3.6" 14 | - "3.7" 15 | - "3.8" 16 | - "3.9" 17 | 18 | branches: 19 | only: 20 | - master 21 | - develop 22 | 23 | matrix: 24 | include: 25 | - name: "Python 3.8.0 on macOS" 26 | os: osx 27 | osx_image: xcode11.2 28 | language: shell 29 | 30 | install: 31 | - pip3 install -U pip coveralls 32 | - pip3 install -U psutil cython numpy 33 | - pip3 install . 34 | 35 | script: 36 | coverage run --source fse setup.py test 37 | 38 | after_success: 39 | coveralls 40 | -------------------------------------------------------------------------------- /tests.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | export GENSIM_VERSION=4.0.0 4 | DOCKER_BUILDKIT=1 docker build -t fse-$GENSIM_VERSION --build-arg gensim=$GENSIM_VERSION . 5 | docker run --rm "fse-$GENSIM_VERSION" 6 | 7 | export GENSIM_VERSION=4.0.1 8 | DOCKER_BUILDKIT=1 docker build -t fse-$GENSIM_VERSION --build-arg gensim=$GENSIM_VERSION . 9 | docker run --rm "fse-$GENSIM_VERSION" 10 | 11 | export GENSIM_VERSION=4.1.0 12 | DOCKER_BUILDKIT=1 docker build -t fse-$GENSIM_VERSION --build-arg gensim=$GENSIM_VERSION . 13 | docker run --rm "fse-$GENSIM_VERSION" 14 | 15 | export GENSIM_VERSION=4.1.1 16 | DOCKER_BUILDKIT=1 docker build -t fse-$GENSIM_VERSION --build-arg gensim=$GENSIM_VERSION . 17 | docker run --rm "fse-$GENSIM_VERSION" 18 | 19 | export GENSIM_VERSION=4.1.2 20 | DOCKER_BUILDKIT=1 docker build -t fse-$GENSIM_VERSION --build-arg gensim=$GENSIM_VERSION . 21 | docker run --rm "fse-$GENSIM_VERSION" -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled source # 2 | ################### 3 | *.com 4 | *.class 5 | *.dll 6 | *.exe 7 | *.o 8 | *.so 9 | *.pyc 10 | *.pyo 11 | *.pyd 12 | 13 | # Packages # 14 | ############ 15 | # it's better to unpack these files and commit the raw source 16 | # git has its own built in compression methods 17 | *.7z 18 | *.dmg 19 | *.gz 20 | *.iso 21 | *.jar 22 | *.rar 23 | *.tar 24 | *.zip 25 | 26 | # Logs and databases # 27 | ###################### 28 | *.log 29 | *.sql 30 | *.sqlite 31 | *.pkl 32 | *.bak 33 | *.npy 34 | *.npz 35 | *.code-workspace 36 | 37 | # OS generated files # 38 | ###################### 39 | .DS_Store? 40 | .DS_Store 41 | ehthumbs.db 42 | Icon? 43 | Thumbs.db 44 | *.icloud 45 | 46 | # Folders # 47 | ########### 48 | legacy 49 | latex 50 | draft 51 | fse.egg-info/ 52 | 53 | # Other # 54 | ######### 55 | .ipynb_checkpoints/ 56 | .settings/ 57 | .vscode/ 58 | .eggs 59 | fse*.egg-info 60 | *.pptx 61 | *.doc 62 | *.docx 63 | *.dict 64 | .coverage 65 | *.bak 66 | /build/ 67 | /dist/ 68 | *.prof 69 | *.lprof 70 | *.bin 71 | *.old 72 | *.model 73 | *_out.txt 74 | vectors 75 | *.vectors 76 | *.joblib 77 | 78 | test_emb* -------------------------------------------------------------------------------- /fse/models/average_inner.pxd: -------------------------------------------------------------------------------- 1 | # cython: boundscheck=False 2 | # cython: wraparound=False 3 | # cython: cdivision=True 4 | # cython: embedsignature=True 5 | # coding: utf-8 6 | 7 | # Author: Oliver Borchers 8 | # Copyright (C) Oliver Borchers 9 | 10 | cimport numpy as np 11 | 12 | cdef extern from "voidptr.h": 13 | void* PyCObject_AsVoidPtr(object obj) 14 | 15 | ctypedef np.float32_t REAL_t 16 | ctypedef np.uint32_t uINT_t 17 | 18 | # BLAS routine signatures 19 | ctypedef void (*saxpy_ptr) (const int *N, const float *alpha, const float *X, const int *incX, float *Y, const int *incY) nogil 20 | ctypedef void (*sscal_ptr) (const int *N, const float *alpha, const float *X, const int *incX) nogil 21 | 22 | cdef saxpy_ptr saxpy 23 | cdef sscal_ptr sscal 24 | 25 | DEF MAX_WORDS = 10000 26 | DEF MAX_NGRAMS = 40 27 | 28 | cdef struct BaseSentenceVecsConfig: 29 | int size, workers 30 | 31 | # Vectors 32 | REAL_t *mem 33 | REAL_t *word_vectors 34 | REAL_t *word_weights 35 | REAL_t *sentence_vectors 36 | 37 | uINT_t word_indices[MAX_WORDS] 38 | uINT_t sent_adresses[MAX_WORDS] 39 | uINT_t sentence_boundary[MAX_WORDS + 1] 40 | 41 | cdef struct FTSentenceVecsConfig: 42 | int size, workers, min_n, max_n, bucket 43 | 44 | REAL_t oov_weight 45 | 46 | # Vectors 47 | REAL_t *mem 48 | REAL_t *word_vectors # Note: these will be the vocab vectors, not wv.vectors 49 | REAL_t *ngram_vectors 50 | REAL_t *word_weights 51 | 52 | REAL_t *sentence_vectors 53 | 54 | # REAL_t *work memory for summation? 55 | uINT_t word_indices[MAX_WORDS] 56 | uINT_t sent_adresses[MAX_WORDS] 57 | uINT_t sentence_boundary[MAX_WORDS + 1] 58 | 59 | # For storing the oov items 60 | uINT_t subwords_idx_len[MAX_WORDS] 61 | uINT_t *subwords_idx 62 | 63 | cdef init_base_s2v_config(BaseSentenceVecsConfig *c, model, target, memory) 64 | 65 | cdef init_ft_s2v_config(FTSentenceVecsConfig *c, model, target, memory) -------------------------------------------------------------------------------- /test/test_vectors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | 7 | """Automated tests for checking the average model.""" 8 | 9 | import logging 10 | import unittest 11 | from pathlib import Path 12 | from unittest.mock import patch 13 | 14 | from fse.vectors import FTVectors, Vectors 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | TEST_DATA = Path(__file__).parent / "test_data" 19 | 20 | 21 | class TestVectors(unittest.TestCase): 22 | def test_from_pretrained(self): 23 | """Test the pretrained vectors.""" 24 | vectors = Vectors.from_pretrained("glove-wiki-gigaword-50") 25 | self.assertEqual(vectors.vector_size, 50) 26 | self.assertEqual(vectors.vectors.shape, (400000, 50)) 27 | 28 | vectors = Vectors.from_pretrained("glove-wiki-gigaword-50", mmap="r") 29 | self.assertEqual(vectors.vector_size, 50) 30 | self.assertEqual(vectors.vectors.shape, (400000, 50)) 31 | 32 | def test_missing_model(self): 33 | """Tests a missing model.""" 34 | 35 | with self.assertRaises(ValueError): 36 | Vectors.from_pretrained("unittest") 37 | 38 | with patch("fse.vectors.snapshot_download") as mock: 39 | mock.side_effect = RuntimeError 40 | with self.assertRaises(RuntimeError): 41 | Vectors.from_pretrained("unittest") 42 | 43 | 44 | class TestFTVectors(unittest.TestCase): 45 | def test_from_pretrained(self): 46 | """Test the pretrained vectors.""" 47 | with patch("fse.vectors.snapshot_download") as mock, patch( 48 | "fse.vectors.FastTextKeyedVectors.load" 49 | ): 50 | mock.return_value = TEST_DATA.as_posix() 51 | FTVectors.from_pretrained("ft") 52 | 53 | def test_missing_model(self): 54 | """Tests a missing model.""" 55 | 56 | with self.assertRaises(ValueError): 57 | FTVectors.from_pretrained("unittest") 58 | 59 | with patch("fse.vectors.snapshot_download") as mock: 60 | mock.side_effect = RuntimeError 61 | with self.assertRaises(RuntimeError): 62 | FTVectors.from_pretrained("unittest") 63 | 64 | 65 | if __name__ == "__main__": 66 | logging.basicConfig( 67 | format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG 68 | ) 69 | unittest.main() 70 | -------------------------------------------------------------------------------- /fse/vectors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | # Licensed under GNU General Public License v3.0 7 | 8 | """Class to obtain BaseKeyedVector from.""" 9 | 10 | from pathlib import Path 11 | 12 | from gensim.models.fasttext import FastTextKeyedVectors 13 | from gensim.models.keyedvectors import KeyedVectors 14 | from huggingface_hub import snapshot_download 15 | from requests import HTTPError 16 | 17 | _SUFFIX: str = ".model" 18 | 19 | 20 | class Vectors(KeyedVectors): 21 | """Class to instantiates vectors from pretrained models.""" 22 | 23 | @classmethod 24 | def from_pretrained(cls, model: str, mmap: str = None): 25 | """Method to load vectors from a pre-trained model. 26 | 27 | Parameters 28 | ---------- 29 | model : :str: of the model name to load from the bug. For example: "glove-wiki-gigaword-50" 30 | mmap : :str: If to load the vectors in mmap mode. 31 | 32 | Returns 33 | ------- 34 | Vectors 35 | An object of pretrained vectors. 36 | """ 37 | try: 38 | path = Path(snapshot_download(repo_id=f"fse/{model}")) 39 | except HTTPError as err: 40 | if err.response.status_code == 404: 41 | raise ValueError(f"model {model} does not exist") 42 | raise 43 | 44 | assert path.exists(), "something went wrong. the file wasn't downloaded." 45 | 46 | return super(Vectors, cls).load( 47 | (path / (model + _SUFFIX)).as_posix(), mmap=mmap 48 | ) 49 | 50 | 51 | class FTVectors(FastTextKeyedVectors): 52 | """Class to instantiates FT vectors from pretrained models.""" 53 | 54 | @classmethod 55 | def from_pretrained(cls, model: str, mmap: str = None): 56 | """Method to load vectors from a pre-trained model. 57 | 58 | Parameters 59 | ---------- 60 | model : :str: of the model name to load from the bug. For example: "glove-wiki-gigaword-50" 61 | mmap : :str: If to load the vectors in mmap mode. 62 | 63 | Returns 64 | ------- 65 | Vectors 66 | An object of pretrained vectors. 67 | """ 68 | try: 69 | path = Path(snapshot_download(repo_id=f"fse/{model}")) 70 | except HTTPError as err: 71 | if err.response.status_code == 404: 72 | raise ValueError(f"model {model} does not exist") 73 | raise 74 | 75 | assert path.exists(), "something went wrong. the file wasn't downloaded." 76 | 77 | return super(FTVectors, cls).load( 78 | (path / (model + _SUFFIX)).as_posix(), mmap=mmap 79 | ) 80 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import unittest 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from numpy.testing import assert_allclose, assert_raises 7 | 8 | from fse.models.utils import compute_principal_components, remove_principal_components 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | TEST_DATA = Path(__file__).parent / "test_data" 13 | 14 | 15 | class TestUtils(unittest.TestCase): 16 | def test_compute_components(self): 17 | m = np.random.uniform(size=(500, 10)).astype(np.float32) 18 | out = compute_principal_components(vectors=m) 19 | self.assertEqual(2, len(out)) 20 | self.assertEqual(1, len(out[1])) 21 | self.assertEqual(np.float32, out[1].dtype) 22 | 23 | m = np.random.uniform(size=(500, 10)) 24 | out = compute_principal_components(vectors=m, components=5) 25 | self.assertEqual(2, len(out)) 26 | self.assertEqual(5, len(out[1])) 27 | 28 | def test_compute_large_components(self): 29 | m = np.random.uniform(size=(int(2e6), 100)).astype(np.float32) 30 | out = compute_principal_components(vectors=m, cache_size_gb=0.2) 31 | self.assertEqual(2, len(out)) 32 | self.assertEqual(1, len(out[1])) 33 | self.assertEqual(np.float32, out[1].dtype) 34 | 35 | def test_remove_components_inplace(self): 36 | m = np.ones((500, 10), dtype=np.float32) 37 | c = np.copy(m) 38 | out = compute_principal_components(vectors=m) 39 | remove_principal_components(m, svd_res=out) 40 | assert_allclose(m, 0.0, atol=1e-5) 41 | with assert_raises(AssertionError): 42 | assert_allclose(m, c) 43 | 44 | def test_remove_components(self): 45 | m = np.ones((500, 10), dtype=np.float32) 46 | c = np.copy(m) 47 | out = compute_principal_components(vectors=m) 48 | res = remove_principal_components(m, svd_res=out, inplace=False) 49 | assert_allclose(res, 0.0, atol=1e-5) 50 | assert_allclose(m, c) 51 | 52 | def test_remove_weighted_components_inplace(self): 53 | m = np.ones((500, 10), dtype=np.float32) 54 | c = np.copy(m) 55 | out = compute_principal_components(vectors=m) 56 | remove_principal_components(m, svd_res=out, weights=np.array([0.5])) 57 | assert_allclose(m, 0.75, atol=1e-5) 58 | with assert_raises(AssertionError): 59 | assert_allclose(m, c) 60 | 61 | def test_remove_weighted_components(self): 62 | m = np.ones((500, 10), dtype=np.float32) 63 | c = np.copy(m) 64 | out = compute_principal_components(vectors=m) 65 | res = remove_principal_components( 66 | m, svd_res=out, weights=np.array([0.5]), inplace=False 67 | ) 68 | assert_allclose(res, 0.75, atol=1e-5) 69 | assert_allclose(m, c) 70 | 71 | def test_madvise(self): 72 | from sys import platform 73 | 74 | from fse.models.utils import set_madvise_for_mmap 75 | 76 | if platform in ["linux", "linux2", "darwin", "aix"]: 77 | p = TEST_DATA / "test_vectors" 78 | madvise = set_madvise_for_mmap(True) 79 | shape = (500, 10) 80 | mat = np.random.normal(size=shape) 81 | memvecs = np.memmap(p, dtype=np.float32, mode="w+", shape=shape) 82 | memvecs[:] = mat[:] 83 | del memvecs 84 | 85 | mat = np.memmap(p, dtype=np.float32, mode="r", shape=shape) 86 | 87 | self.assertEqual( 88 | madvise(mat.ctypes.data, mat.size * mat.dtype.itemsize, 1), 0 89 | ) 90 | p.unlink() 91 | 92 | 93 | if __name__ == "__main__": 94 | logging.basicConfig( 95 | format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG 96 | ) 97 | unittest.main() 98 | -------------------------------------------------------------------------------- /fse/models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | 7 | from typing import Tuple 8 | from sklearn.decomposition import TruncatedSVD 9 | 10 | from numpy import finfo, ndarray, float32 as REAL, ones, dtype 11 | from numpy.random import choice 12 | 13 | from time import time 14 | 15 | import logging 16 | 17 | from sys import platform 18 | 19 | import ctypes 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | EPS = finfo(REAL).eps 24 | 25 | 26 | def set_madvise_for_mmap(return_madvise: bool = False) -> object: 27 | """Method used to set madvise parameters. 28 | This problem adresses the memmap issue raised in https://github.com/numpy/numpy/issues/13172 29 | The issue is not applicable for windows 30 | 31 | Parameters 32 | ---------- 33 | return_madvise : bool 34 | Returns the madvise object for unittests, se test_utils.py 35 | 36 | Returns 37 | ------- 38 | object 39 | madvise object 40 | 41 | """ 42 | 43 | if platform in ["linux", "linux2", "darwin", "aix"]: 44 | if platform == "darwin": 45 | # Path different for Macos 46 | madvise = ctypes.CDLL("libc.dylib").madvise 47 | if platform in ["linux", "linux2", "aix"]: 48 | madvise = ctypes.CDLL("libc.so.6").madvise 49 | madvise.argtypes = [ctypes.c_void_p, ctypes.c_size_t, ctypes.c_int] 50 | madvise.restype = ctypes.c_int 51 | 52 | if return_madvise: 53 | return madvise 54 | 55 | 56 | def compute_principal_components( 57 | vectors: ndarray, components: int = 1, cache_size_gb: float = 1.0 58 | ) -> Tuple[ndarray, ndarray]: 59 | """Method used to compute the first singular vectors of a given (sub)matrix 60 | 61 | Parameters 62 | ---------- 63 | vectors : ndarray 64 | (Sentence) vectors to compute the truncated SVD on 65 | components : int, optional 66 | Number of singular values/vectors to compute 67 | cache_size_gb : float, optional 68 | Cache size for computing the principal components in GB 69 | 70 | Returns 71 | ------- 72 | ndarray, ndarray 73 | Singular values and singular vectors 74 | """ 75 | start = time() 76 | num_vectors = vectors.shape[0] 77 | svd = TruncatedSVD( 78 | n_components=components, n_iter=7, random_state=42, algorithm="randomized" 79 | ) 80 | 81 | sample_size = int( 82 | 1024 ** 3 * cache_size_gb / (vectors.shape[1] * dtype(REAL).itemsize) 83 | ) 84 | 85 | if sample_size > num_vectors: 86 | svd.fit(vectors) 87 | else: 88 | logger.info(f"sampling {sample_size} vectors to compute principal components") 89 | sample_indices = choice(range(num_vectors), replace=False, size=int(1e6)) 90 | svd.fit(vectors[sample_indices, :]) 91 | 92 | elapsed = time() 93 | logger.info( 94 | f"computing {components} principal components took {int(elapsed-start)}s" 95 | ) 96 | return svd.singular_values_.astype(REAL), svd.components_.astype(REAL) 97 | 98 | 99 | def remove_principal_components( 100 | vectors: ndarray, 101 | svd_res: Tuple[ndarray, ndarray], 102 | weights: ndarray = None, 103 | inplace: bool = True, 104 | ) -> ndarray: 105 | """Method used to remove the first singular vectors of a given matrix 106 | 107 | Parameters 108 | ---------- 109 | vectors : ndarray 110 | (Sentence) vectors to remove components fromm 111 | svd_res : (ndarray, ndarray) 112 | Tuple consisting of the singular values and components to remove from the vectors 113 | weights : ndarray, optional 114 | Weights to be used to weigh the components which are removed from the vectors 115 | inplace : bool, optional 116 | If true, removes the components from the vectors inplace (memory efficient) 117 | 118 | Returns 119 | ------- 120 | ndarray, ndarray 121 | Singular values and singular vectors 122 | """ 123 | components = svd_res[1].astype(REAL) 124 | 125 | start = time() 126 | if weights is None: 127 | w_comp = components * ones(len(components), dtype=REAL)[:, None] 128 | else: 129 | w_comp = components * (weights[:, None].astype(REAL)) 130 | 131 | output = None 132 | if len(components) == 1: 133 | if not inplace: 134 | output = vectors - vectors.dot(w_comp.transpose()) * w_comp 135 | else: 136 | vectors -= vectors.dot(w_comp.transpose()) * w_comp 137 | else: 138 | if not inplace: 139 | output = vectors - vectors.dot(w_comp.transpose()).dot(w_comp) 140 | else: 141 | vectors -= vectors.dot(w_comp.transpose()).dot(w_comp) 142 | elapsed = time() 143 | 144 | logger.info( 145 | f"removing {len(components)} principal components took {int(elapsed-start)}s" 146 | ) 147 | if not inplace: 148 | return output 149 | -------------------------------------------------------------------------------- /test/test_usif.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import unittest 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from gensim.models import Word2Vec 7 | 8 | from fse.inputs import IndexedLineDocument 9 | from fse.models.usif import uSIF 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | CORPUS = Path(__file__).parent / "test_data" / "test_sentences.txt" 14 | DIM = 50 15 | W2V = Word2Vec(min_count=1, vector_size=DIM) 16 | with open(CORPUS, "r") as file: 17 | SENTENCES = [l.split() for _, l in enumerate(file)] 18 | W2V.build_vocab(SENTENCES) 19 | 20 | 21 | class TestuSIFFunctions(unittest.TestCase): 22 | def setUp(self): 23 | self.sentences = IndexedLineDocument(CORPUS) 24 | self.model = uSIF(W2V, lang_freq="en") 25 | 26 | def test_parameter_sanity(self): 27 | with self.assertRaises(ValueError): 28 | m = uSIF(W2V, length=0) 29 | m._check_parameter_sanity() 30 | with self.assertRaises(ValueError): 31 | m = uSIF(W2V, components=-1, length=11) 32 | m._check_parameter_sanity() 33 | 34 | def test_pre_train_calls(self): 35 | kwargs = {"average_length": 10} 36 | self.model._pre_train_calls(**kwargs) 37 | self.assertEqual(10, self.model.length) 38 | 39 | def test_post_train_calls(self): 40 | self.model.sv.vectors = np.ones((200, 10), dtype=np.float32) 41 | self.model._post_train_calls() 42 | self.assertTrue(np.allclose(self.model.sv.vectors, 0, atol=1e-5)) 43 | 44 | def test_post_train_calls_no_removal(self): 45 | self.model.components = 0 46 | self.model.sv.vectors = np.ones((200, 10), dtype=np.float32) 47 | self.model._post_train_calls() 48 | self.assertTrue(np.allclose(self.model.sv.vectors, 1, atol=1e-5)) 49 | 50 | def test_post_inference_calls(self): 51 | self.model.sv.vectors = np.ones((200, 10), dtype=np.float32) 52 | self.model._post_train_calls() 53 | 54 | output = np.ones((200, 10), dtype=np.float32) 55 | self.model._post_inference_calls(output=output) 56 | self.assertTrue(np.allclose(output, 0, atol=1e-5)) 57 | 58 | def test_post_inference_calls_no_svd(self): 59 | self.model.sv.vectors = np.ones((200, 10), dtype=np.float32) 60 | self.model.svd_res = None 61 | with self.assertRaises(RuntimeError): 62 | self.model._post_inference_calls(output=None) 63 | 64 | def test_post_inference_calls_no_removal(self): 65 | self.model.components = 0 66 | self.model.sv.vectors = np.ones((200, 10), dtype=np.float32) 67 | self.model._post_train_calls() 68 | self.model._post_inference_calls(output=None) 69 | self.assertTrue(np.allclose(self.model.sv.vectors, 1, atol=1e-5)) 70 | 71 | def test_dtype_sanity_word_weights(self): 72 | self.model.word_weights = np.ones_like(self.model.word_weights, dtype=int) 73 | with self.assertRaises(TypeError): 74 | self.model._check_dtype_santiy() 75 | 76 | def test_dtype_sanity_svd_vals(self): 77 | self.model.svd_res = ( 78 | np.ones_like(self.model.word_weights, dtype=int), 79 | np.array(0, dtype=np.float32), 80 | ) 81 | with self.assertRaises(TypeError): 82 | self.model._check_dtype_santiy() 83 | 84 | def test_dtype_sanity_svd_vecs(self): 85 | self.model.svd_res = ( 86 | np.array(0, dtype=np.float32), 87 | np.ones_like(self.model.word_weights, dtype=int), 88 | ) 89 | with self.assertRaises(TypeError): 90 | self.model._check_dtype_santiy() 91 | 92 | def test_compute_usif_weights(self): 93 | w = "Good" 94 | pw = 1.916650481770269e-08 95 | idx = self.model.wv.key_to_index[w] 96 | self.model.length = 11 97 | a = 0.17831555484795414 98 | usif = a / ((a / 2) + pw) 99 | self.model._compute_usif_weights() 100 | self.assertTrue(np.allclose(self.model.word_weights[idx], usif)) 101 | 102 | def test_train(self): 103 | output = self.model.train(self.sentences) 104 | self.assertEqual((100, 1450), output) 105 | self.assertTrue(np.isfinite(self.model.sv.vectors).all()) 106 | 107 | def test_broken_vocab(self): 108 | w2v = Word2Vec(min_count=1, vector_size=DIM) 109 | 110 | with open(CORPUS, "r") as file: 111 | w2v.build_vocab([l.split() for l in file]) 112 | for k in w2v.wv.key_to_index: 113 | w2v.wv.set_vecattr(k, "count", -1) 114 | 115 | model = uSIF(w2v) 116 | 117 | with self.assertRaises(ValueError): 118 | model.train(self.sentences) 119 | 120 | def test_zero_div_error(self): 121 | """From issue: #47.""" 122 | 123 | model = uSIF(W2V, length=12, components=1) 124 | model._compute_usif_weights() 125 | 126 | 127 | if __name__ == "__main__": 128 | logging.basicConfig( 129 | format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG 130 | ) 131 | unittest.main() 132 | -------------------------------------------------------------------------------- /test/test_sif.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import unittest 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from gensim.models import Word2Vec 7 | 8 | from fse.inputs import IndexedLineDocument 9 | from fse.models.sif import SIF 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | TEST_DATA = Path(__file__).parent / "test_data" 14 | CORPUS = TEST_DATA / "test_sentences.txt" 15 | DIM = 50 16 | W2V = Word2Vec(min_count=1, vector_size=DIM) 17 | with open(CORPUS, "r") as file: 18 | SENTENCES = [l.split() for _, l in enumerate(file)] 19 | W2V.build_vocab(SENTENCES) 20 | 21 | 22 | class TestSIFFunctions(unittest.TestCase): 23 | def setUp(self): 24 | self.sentences = IndexedLineDocument(CORPUS) 25 | self.model = SIF(W2V, lang_freq="en") 26 | 27 | def test_parameter_sanity(self): 28 | with self.assertRaises(ValueError): 29 | m = SIF(W2V, alpha=-1) 30 | m._check_parameter_sanity() 31 | with self.assertRaises(ValueError): 32 | m = SIF(W2V, components=-1) 33 | m._check_parameter_sanity() 34 | with self.assertRaises(ValueError): 35 | m = SIF(W2V) 36 | m.word_weights = np.ones_like(m.word_weights) + 2 37 | m._check_parameter_sanity() 38 | 39 | def test_pre_train_calls(self): 40 | self.model._pre_train_calls() 41 | 42 | def test_post_train_calls(self): 43 | self.model.sv.vectors = np.ones((200, 10), dtype=np.float32) 44 | self.model._post_train_calls() 45 | self.assertTrue(np.allclose(self.model.sv.vectors, 0, atol=1e-5)) 46 | 47 | def test_post_train_calls_no_removal(self): 48 | self.model.components = 0 49 | self.model.sv.vectors = np.ones((200, 10), dtype=np.float32) 50 | self.model._post_train_calls() 51 | self.assertTrue(np.allclose(self.model.sv.vectors, 1, atol=1e-5)) 52 | 53 | def test_post_inference_calls(self): 54 | self.model.sv.vectors = np.ones((200, 10), dtype=np.float32) 55 | self.model._post_train_calls() 56 | 57 | output = np.ones((200, 10), dtype=np.float32) 58 | self.model._post_inference_calls(output=output) 59 | self.assertTrue(np.allclose(output, 0, atol=1e-5)) 60 | 61 | def test_post_inference_calls_no_svd(self): 62 | self.model.sv.vectors = np.ones((200, 10), dtype=np.float32) 63 | self.model.svd_res = None 64 | with self.assertRaises(RuntimeError): 65 | self.model._post_inference_calls(output=None) 66 | 67 | def test_post_inference_calls_no_removal(self): 68 | self.model.components = 0 69 | self.model.sv.vectors = np.ones((200, 10), dtype=np.float32) 70 | self.model._post_train_calls() 71 | self.model._post_inference_calls(output=None) 72 | self.assertTrue(np.allclose(self.model.sv.vectors, 1, atol=1e-5)) 73 | 74 | def test_dtype_sanity_word_weights(self): 75 | self.model.word_weights = np.ones_like(self.model.word_weights, dtype=int) 76 | with self.assertRaises(TypeError): 77 | self.model._check_dtype_santiy() 78 | 79 | def test_dtype_sanity_svd_vals(self): 80 | self.model.svd_res = ( 81 | np.ones_like(self.model.word_weights, dtype=int), 82 | np.array(0, dtype=np.float32), 83 | ) 84 | with self.assertRaises(TypeError): 85 | self.model._check_dtype_santiy() 86 | 87 | def test_dtype_sanity_svd_vecs(self): 88 | self.model.svd_res = ( 89 | np.array(0, dtype=np.float32), 90 | np.ones_like(self.model.word_weights, dtype=int), 91 | ) 92 | with self.assertRaises(TypeError): 93 | self.model._check_dtype_santiy() 94 | 95 | def test_compute_sif_weights(self): 96 | w = "Good" 97 | pw = 1.916650481770269e-08 98 | alpha = self.model.alpha 99 | sif = alpha / (alpha + pw) 100 | 101 | idx = self.model.wv.key_to_index[w] 102 | self.model._compute_sif_weights() 103 | self.assertTrue(np.allclose(self.model.word_weights[idx], sif)) 104 | 105 | def test_train(self): 106 | output = self.model.train(self.sentences) 107 | self.assertEqual((100, 1450), output) 108 | self.assertTrue(np.isfinite(self.model.sv.vectors).all()) 109 | self.assertEqual(2, len(self.model.svd_res)) 110 | 111 | def test_save_issue(self): 112 | model = SIF(W2V) 113 | model.train(self.sentences) 114 | 115 | p = TEST_DATA / "test_emb.model" 116 | model.save(str(p)) 117 | model = SIF.load(str(p)) 118 | p.unlink() 119 | 120 | self.assertEqual(2, len(model.svd_res)) 121 | model.sv.similar_by_sentence("test sentence".split(), model=model) 122 | 123 | def test_broken_vocab(self): 124 | w2v = Word2Vec(min_count=1, vector_size=DIM) 125 | with open(CORPUS, "r") as file: 126 | w2v.build_vocab([l.split() for l in file]) 127 | 128 | for k in w2v.wv.key_to_index: 129 | w2v.wv.set_vecattr(k, "count", -1) 130 | 131 | model = SIF(w2v) 132 | with self.assertRaises(ValueError): 133 | model.train(self.sentences) 134 | 135 | 136 | if __name__ == "__main__": 137 | logging.basicConfig( 138 | format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG 139 | ) 140 | unittest.main() 141 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # For License information, see corresponding LICENSE file. 6 | 7 | """Template setup.py Read more on 8 | https://docs.python.org/3.7/distutils/setupscript.html.""" 9 | 10 | import distutils 11 | import itertools 12 | import os 13 | import platform 14 | import shutil 15 | 16 | from setuptools import Extension, find_packages, setup 17 | from setuptools.command.build_ext import build_ext 18 | 19 | NAME = "fse" 20 | VERSION = "1.0.0" 21 | DESCRIPTION = "Fast Sentence Embeddings for Gensim" 22 | AUTHOR = "Oliver Borchers" 23 | AUTHOR_EMAIL = "o.borchers@oxolo.com" 24 | URL = "https://github.com/oborchers/Fast_Sentence_Embeddings" 25 | LICENSE = "GPL-3.0" 26 | REQUIRES_PYTHON = ">=3.6" 27 | NUMPY_STR = "numpy >= 1.11.3" 28 | CYTHON_STR = "Cython==0.29.23" 29 | 30 | INSTALL_REQUIRES = [ 31 | NUMPY_STR, 32 | "scipy >= 0.18.1", 33 | "smart_open >= 1.5.0", 34 | "scikit-learn >= 0.19.1", 35 | "gensim>=4", 36 | "wordfreq >= 2.2.1", 37 | "huggingface-hub", 38 | "psutil", 39 | "dataclasses; python_version < '3.7'", 40 | ] 41 | SETUP_REQUIRES = [NUMPY_STR] 42 | 43 | c_extensions = { 44 | "fse.models.average_inner": "fse/models/average_inner.c", 45 | } 46 | cpp_extensions = {} 47 | 48 | 49 | def need_cython(): 50 | """Return True if we need Cython to translate any of the extensions. 51 | 52 | If the extensions have already been translated to C/C++, then we don"t need to 53 | install Cython and perform the translation. 54 | """ 55 | expected = list(c_extensions.values()) + list(cpp_extensions.values()) 56 | return any([not os.path.isfile(f) for f in expected]) 57 | 58 | 59 | def make_c_ext(use_cython=False): 60 | for module, source in c_extensions.items(): 61 | if use_cython: 62 | source = source.replace(".c", ".pyx") 63 | extra_args = [] 64 | # extra_args.extend(["-g", "-O0"]) # uncomment if optimization limiting crash info 65 | yield Extension( 66 | module, 67 | sources=[source], 68 | language="c", 69 | extra_compile_args=extra_args, 70 | ) 71 | 72 | 73 | def make_cpp_ext(use_cython=False): 74 | extra_args = [] 75 | system = platform.system() 76 | 77 | if system == "Linux": 78 | extra_args.append("-std=c++11") 79 | elif system == "Darwin": 80 | extra_args.extend(["-stdlib=libc++", "-std=c++11"]) 81 | # extra_args.extend(["-g", "-O0"]) # uncomment if 82 | # optimization limiting crash info 83 | for module, source in cpp_extensions.items(): 84 | if use_cython: 85 | source = source.replace(".cpp", ".pyx") 86 | yield Extension( 87 | module, 88 | sources=[source], 89 | language="c++", 90 | extra_compile_args=extra_args, 91 | extra_link_args=extra_args, 92 | ) 93 | 94 | 95 | # 96 | # We use use_cython=False here for two reasons: 97 | # 98 | # 1. Cython may not be available at this stage 99 | # 2. The actual translation from Cython to C/C++ happens inside CustomBuildExt 100 | # 101 | ext_modules = list( 102 | itertools.chain(make_c_ext(use_cython=False), make_cpp_ext(use_cython=False)) 103 | ) 104 | 105 | 106 | class CustomBuildExt(build_ext): 107 | """Custom build_ext action with bootstrapping. 108 | 109 | We need this in order to use numpy and Cython in this script without importing them 110 | at module level, because they may not be available yet. 111 | """ 112 | 113 | # 114 | # http://stackoverflow.com/questions/19919905/how-to-bootstrap-numpy-installation-in-setup-py 115 | # 116 | def finalize_options(self): 117 | build_ext.finalize_options(self) 118 | # Prevent numpy from thinking it is still in its setup process: 119 | # https://docs.python.org/2/library/__builtin__.html#module-__builtin__ 120 | __builtins__.__NUMPY_SETUP__ = False 121 | 122 | import numpy 123 | 124 | self.include_dirs.append(numpy.get_include()) 125 | 126 | if need_cython(): 127 | import Cython.Build 128 | 129 | Cython.Build.cythonize(list(make_c_ext(use_cython=True))) 130 | Cython.Build.cythonize(list(make_cpp_ext(use_cython=True))) 131 | 132 | 133 | class CleanExt(distutils.cmd.Command): 134 | description = "Remove C sources, C++ sources and binaries for gensim extensions" 135 | user_options = [] 136 | 137 | def initialize_options(self): 138 | pass 139 | 140 | def finalize_options(self): 141 | pass 142 | 143 | def run(self): 144 | for root, dirs, files in os.walk("gensim"): 145 | files = [ 146 | os.path.join(root, f) 147 | for f in files 148 | if os.path.splitext(f)[1] in (".c", ".cpp", ".so") 149 | ] 150 | for f in files: 151 | self.announce("removing %s" % f, level=distutils.log.INFO) 152 | os.unlink(f) 153 | 154 | if os.path.isdir("build"): 155 | self.announce("recursively removing build", level=distutils.log.INFO) 156 | shutil.rmtree("build") 157 | 158 | 159 | cmdclass = {"build_ext": CustomBuildExt, "clean_ext": CleanExt} 160 | 161 | if need_cython(): 162 | INSTALL_REQUIRES.append(CYTHON_STR) 163 | SETUP_REQUIRES.append(CYTHON_STR) 164 | 165 | setup( 166 | name=NAME, 167 | version=VERSION, 168 | description=DESCRIPTION, 169 | author=AUTHOR, 170 | author_email=AUTHOR_EMAIL, 171 | packages=find_packages(), 172 | requires_python=REQUIRES_PYTHON, 173 | install_requires=INSTALL_REQUIRES, 174 | setup_requires=SETUP_REQUIRES, 175 | ext_modules=ext_modules, 176 | cmdclass=cmdclass, 177 | zip_safe=False, 178 | include_package_data=True, 179 | ) 180 | -------------------------------------------------------------------------------- /evaluation/readme.txt: -------------------------------------------------------------------------------- 1 | 2 | STS Benchmark: Main English dataset 3 | 4 | Semantic Textual Similarity 2012-2017 Dataset 5 | 6 | http://ixa2.si.ehu.eus/stswiki 7 | 8 | 9 | STS Benchmark comprises a selection of the English datasets used in 10 | the STS tasks organized by us in the context of SemEval between 2012 11 | and 2017. 12 | 13 | In order to provide a standard benchmark to compare among systems, we 14 | organized it into train, development and test. The development part 15 | can be used to develop and tune hyperparameters of the systems, and 16 | the test part should be only used once for the final system. 17 | 18 | The benchmark comprises 8628 sentence pairs. This is the breakdown 19 | according to genres and train-dev-test splits: 20 | 21 | train dev test total 22 | ----------------------------- 23 | news 3299 500 500 4299 24 | caption 2000 625 525 3250 25 | forum 450 375 254 1079 26 | ----------------------------- 27 | total 5749 1500 1379 8628 28 | 29 | For reference, this is the breakdown according to the original names 30 | and task years of the datasets: 31 | 32 | genre file years train dev test 33 | ------------------------------------------------ 34 | news MSRpar 2012 1000 250 250 35 | news headlines 2013-16 1999 250 250 36 | news deft-news 2014 300 0 0 37 | captions MSRvid 2012 1000 250 250 38 | captions images 2014-15 1000 250 250 39 | captions track5.en-en 2017 0 125 125 40 | forum deft-forum 2014 450 0 0 41 | forum answers-forums 2015 0 375 0 42 | forum answer-answer 2016 0 0 254 43 | 44 | In addition to the standard benchmark, we also include other datasets 45 | (see readme.txt in "companion" directory). 46 | 47 | 48 | Introduction 49 | ------------ 50 | 51 | Given two sentences of text, s1 and s2, the systems need to compute 52 | how similar s1 and s2 are, returning a similarity score between 0 and 53 | 5. The dataset comprises naturally occurring pairs of sentences drawn 54 | from several domains and genres, annotated by crowdsourcing. See 55 | papers by Agirre et al. (2012; 2013; 2014; 2015; 2016; 2017). 56 | 57 | Format 58 | ------ 59 | 60 | Each file is encoded in utf-8 (a superset of ASCII), and has the 61 | following tab separated fields: 62 | 63 | genre filename year score sentence1 sentence2 64 | 65 | optionally there might be some license-related fields after sentence2. 66 | 67 | NOTE: Given that some sentence pairs have been reused here and 68 | elsewhere, systems should NOT use the following datasets to develop or 69 | train their systems (see below for more details on datasets): 70 | 71 | - Any of the datasets in Semeval STS competitions, including Semeval 72 | 2014 task 1 (also known as SICK). 73 | - The test part of MSR-Paraphrase (development and train are fine). 74 | - The text of the videos in MSR-Video. 75 | 76 | 77 | Evaluation script 78 | ----------------- 79 | 80 | The official evaluation is the Pearson correlation coefficient. Given 81 | an output file comprising the system scores (one per line) in a file 82 | called sys.txt, you can use the evaluation script as follows: 83 | 84 | $ perl correlation.pl sts-dev.txt sys.txt 85 | 86 | 87 | Other 88 | ----- 89 | 90 | Please check http://ixa2.si.ehu.eus/stswiki 91 | 92 | We recommend that interested researchers join the (low traffic) 93 | mailing list: 94 | 95 | http://groups.google.com/group/STS-semeval 96 | 97 | Notse on datasets and licenses 98 | ------------------------------ 99 | 100 | If using this data in your research please cite (Agirre et al. 2017) 101 | and the STS website: http://ixa2.si.ehu.eus/stswiki. 102 | 103 | Please see LICENSE.txt 104 | 105 | 106 | Organizers of tasks by year 107 | --------------------------- 108 | 109 | 2012 Eneko Agirre, Daniel Cer, Mona Diab, Aitor Gonzalez-Agirre 110 | 111 | 2013 Eneko Agirre, Daniel Cer, Mona Diab, Aitor Gonzalez-Agirre, 112 | WeiWei Guo 113 | 114 | 2014 Eneko Agirre, Carmen Banea, Claire Cardie, Daniel Cer, Mona Diab, 115 | Aitor Gonzalez-Agirre, Weiwei Guo, Rada Mihalcea, German Rigau, 116 | Janyce Wiebe 117 | 118 | 2015 Eneko Agirre, Carmen Banea, Claire Cardie, Daniel Cer, Mona Diab, 119 | Aitor Gonzalez-Agirre, Weiwei Guo, Inigo Lopez-Gazpio, Montse 120 | Maritxalar, Rada Mihalcea, German Rigau, Larraitz Uria, Janyce 121 | Wiebe 122 | 123 | 2016 Eneko Agirre, Carmen Banea, Daniel Cer, Mona Diab, Aitor 124 | Gonzalez-Agirre, Rada Mihalcea, German Rigau, Janyce 125 | Wiebe 126 | 127 | 2017 Eneko Agirre, Daniel Cer, Mona Diab, Iñigo Lopez-Gazpio, Lucia 128 | Specia 129 | 130 | 131 | References 132 | ---------- 133 | 134 | Eneko Agirre, Daniel Cer, Mona Diab, Aitor Gonzalez-Agirre. Task 6: A 135 | Pilot on Semantic Textual Similarity. Procceedings of Semeval 2012 136 | 137 | Eneko Agirre, Daniel Cer, Mona Diab, Aitor Gonzalez-Agirre, WeiWei 138 | Guo. *SEM 2013 shared task: Semantic Textual 139 | Similarity. Procceedings of *SEM 2013 140 | 141 | Eneko Agirre, Carmen Banea, Claire Cardie, Daniel Cer, Mona Diab, 142 | Aitor Gonzalez-Agirre, Weiwei Guo, Rada Mihalcea, German Rigau, 143 | Janyce Wiebe. Task 10: Multilingual Semantic Textual 144 | Similarity. Proceedings of SemEval 2014. 145 | 146 | Eneko Agirre, Carmen Banea, Claire Cardie, Daniel Cer, Mona Diab, 147 | Aitor Gonzalez-Agirre, Weiwei Guo, Inigo Lopez-Gazpio, Montse 148 | Maritxalar, Rada Mihalcea, German Rigau, Larraitz Uria, Janyce 149 | Wiebe. Task 2: Semantic Textual Similarity, English, Spanish and 150 | Pilot on Interpretability. Proceedings of SemEval 2015. 151 | 152 | Eneko Agirre, Carmen Banea, Daniel Cer, Mona Diab, Aitor 153 | Gonzalez-Agirre, Rada Mihalcea, German Rigau, Janyce 154 | Wiebe. Semeval-2016 Task 1: Semantic Textual Similarity, 155 | Monolingual and Cross-Lingual Evaluation. Proceedings of SemEval 156 | 2016. 157 | 158 | Eneko Agirre, Daniel Cer, Mona Diab, Iñigo Lopez-Gazpio, Lucia 159 | Specia. Semeval-2017 Task 1: Semantic Textual Similarity 160 | Multilingual and Crosslingual Focused Evaluation. Proceedings of 161 | SemEval 2017. 162 | 163 | Clive Best, Erik van der Goot, Ken Blackler, Tefilo Garcia, and David 164 | Horby. 2005. Europe media monitor - system description. In EUR 165 | Report 22173-En, Ispra, Italy. 166 | 167 | Cyrus Rashtchian, Peter Young, Micah Hodosh, and Julia Hockenmaier. 168 | Collecting Image Annotations Using Amazon's Mechanical Turk. In 169 | Proceedings of the NAACL HLT 2010 Workshop on Creating Speech and 170 | Language Data with Amazon's Mechanical Turk. 171 | 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /evaluation/LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Notes on datasets and licenses 3 | ------------------------------ 4 | 5 | If using this data in your research please cite the following paper 6 | and the url of the STS website: http://ixa2.si.ehu.eus/stswiki: 7 | 8 | Eneko Agirre, Daniel Cer, Mona Diab, Iñigo Lopez-Gazpio, Lucia 9 | Specia. Semeval-2017 Task 1: Semantic Textual Similarity 10 | Multilingual and Crosslingual Focused Evaluation. Proceedings of 11 | SemEval 2017. 12 | 13 | The scores are released under a "Commons Attribution - Share Alike 4.0 14 | International License" http://creativecommons.org/licenses/by-sa/4.0/ 15 | 16 | The text of each dataset has a license of its own, as follows: 17 | 18 | - MSR-Paraphrase, Microsoft Research Paraphrase Corpus. In order to use 19 | MSRpar, researchers need to agree with the license terms from 20 | Microsoft Research: 21 | http://research.microsoft.com/en-us/downloads/607d14d9-20cd-47e3-85bc-a2f65cd28042/ 22 | 23 | - headlines: Mined from several news sources by European Media Monitor 24 | (Best et al. 2005). using the RSS feed. European Media Monitor (EMM) 25 | Real Time News Clusters are the top news stories for the last 4 26 | hours, updated every ten minutes. The article clustering is fully 27 | automatic. The selection and placement of stories are determined 28 | automatically by a computer program. This site is a joint project of 29 | DG-JRC and DG-COMM. The information on this site is subject to a 30 | disclaimer (see 31 | http://europa.eu/geninfo/legal_notices_en.htm). Please acknowledge 32 | EMM when (re)using this material. 33 | http://emm.newsbrief.eu/rss?type=rtn&language=en&duplicates=false 34 | 35 | - deft-news: A subset of news article data in the DEFT 36 | project. 37 | 38 | - MSR-Video, Microsoft Research Video Description Corpus. In order to 39 | use MSRvideo, researchers need to agree with the license terms from 40 | Microsoft Research: 41 | http://research.microsoft.com/en-us/downloads/38cf15fd-b8df-477e-a4e4-a4680caa75af/ 42 | 43 | - image: The Image Descriptions data set is a subset of 44 | the PASCAL VOC-2008 data set (Rashtchian et al., 2010) . PASCAL 45 | VOC-2008 data set consists of 1,000 images and has been used by a 46 | number of image description systems. The image captions of the data 47 | set are released under a CreativeCommons Attribution-ShareAlike 48 | license, the descriptions itself are free. 49 | 50 | - track5.en-en: This text is a subset of the Stanford Natural 51 | Language Inference (SNLI) corpus, by The Stanford NLP Group is 52 | licensed under a Creative Commons Attribution-ShareAlike 4.0 53 | International License. Based on a work at 54 | http://shannon.cs.illinois.edu/DenotationGraph/. 55 | https://creativecommons.org/licenses/by-sa/4.0/ 56 | 57 | - answers-answers: user content from stack-exchange. Check the license 58 | below in ======ANSWERS-ANSWERS====== 59 | 60 | - answers-forums: user content from stack-exchange. Check the license 61 | below in ======ANSWERS-FORUMS====== 62 | 63 | 64 | 65 | ======ANSWER-ANSWER====== 66 | 67 | Creative Commons Attribution-ShareAlike 3.0 Unported (CC BY-SA 3.0) 68 | http://creativecommons.org/licenses/by-sa/3.0/ 69 | 70 | Attribution Requirements: 71 | 72 | "* Visually display or otherwise indicate the source of the content 73 | as coming from the Stack Exchange Network. This requirement is 74 | satisfied with a discreet text blurb, or some other unobtrusive but 75 | clear visual indication. 76 | 77 | * Ensure that any Internet use of the content includes a hyperlink 78 | directly to the original question on the source site on the Network 79 | (e.g., http://stackoverflow.com/questions/12345) 80 | 81 | * Visually display or otherwise clearly indicate the author names for 82 | every question and answer used 83 | 84 | * Ensure that any Internet use of the content includes a hyperlink for 85 | each author name directly back to his or her user profile page on the 86 | source site on the Network (e.g., 87 | http://stackoverflow.com/users/12345/username), directly to the Stack 88 | Exchange domain, in standard HTML (i.e. not through a Tinyurl or other 89 | such indirect hyperlink, form of obfuscation or redirection), without 90 | any “nofollow” command or any other such means of avoiding detection by 91 | search engines, and visible even with JavaScript disabled." 92 | 93 | (https://archive.org/details/stackexchange) 94 | 95 | 96 | 97 | ======ANSWERS-FORUMS====== 98 | 99 | 100 | Stack Exchange Inc. generously made the data used to construct the STS 2015 answer-answer statement pairs available under a Creative Commons Attribution-ShareAlike (cc-by-sa) 3.0 license. 101 | 102 | The license is reproduced below from: https://archive.org/details/stackexchange 103 | 104 | The STS.input.answers-forums.txt file should be redistributed with this LICENSE text and the accompanying files in LICENSE.answers-forums.zip. The tsv files in the zip file contain the additional information that's needed to comply with the license. 105 | 106 | -- 107 | 108 | All user content contributed to the Stack Exchange network is cc-by-sa 3.0 licensed, intended to be shared and remixed. We even provide all our data as a convenient data dump. 109 | 110 | http://creativecommons.org/licenses/by-sa/3.0/ 111 | 112 | But our cc-by-sa 3.0 licensing, while intentionally permissive, does *require attribution*: 113 | 114 | "Attribution — You must attribute the work in the manner specified by the author or licensor (but not in any way that suggests that they endorse you or your use of the work)." 115 | 116 | Specifically the attribution requirements are as follows: 117 | 118 | 1. Visually display or otherwise indicate the source of the content as coming from the Stack Exchange Network. This requirement is satisfied with a discreet text blurb, or some other unobtrusive but clear visual indication. 119 | 120 | 2. Ensure that any Internet use of the content includes a hyperlink directly to the original question on the source site on the Network (e.g., http://stackoverflow.com/questions/12345) 121 | 122 | 3. Visually display or otherwise clearly indicate the author names for every question and answer so used. 123 | 124 | 4. Ensure that any Internet use of the content includes a hyperlink for each author name directly back to his or her user profile page on the source site on the Network (e.g., http://stackoverflow.com/users/12345/username), directly to the Stack Exchange domain, in standard HTML (i.e. not through a Tinyurl or other such indirect hyperlink, form of obfuscation or redirection), without any “nofollow” command or any other such means of avoiding detection by search engines, and visible even with JavaScript disabled. 125 | 126 | Our goal is to maintain the spirit of fair attribution. That means attribution to the website, and more importantly, to the individuals who so generously contributed their time to create that content in the first place! 127 | 128 | For more information, see the Stack Exchange Terms of Service: http://stackexchange.com/legal/terms-of-service 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /fse/models/sif.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | 7 | from fse.models.average import Average 8 | from fse.models.utils import compute_principal_components, remove_principal_components 9 | 10 | from gensim.models.keyedvectors import KeyedVectors 11 | 12 | from numpy import ndarray, float32 as REAL, zeros, isfinite 13 | 14 | import logging 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | class SIF(Average): 20 | def __init__( 21 | self, 22 | model: KeyedVectors, 23 | alpha: float = 1e-3, 24 | components: int = 1, 25 | cache_size_gb: float = 1.0, 26 | sv_mapfile_path: str = None, 27 | wv_mapfile_path: str = None, 28 | workers: int = 1, 29 | lang_freq: str = None, 30 | ): 31 | """Smooth-inverse frequency (SIF) weighted sentence embeddings model. Performs a weighted averaging operation over all 32 | words in a sentences. After training, the model removes a number of singular vectors. 33 | 34 | The implementation is based on Arora et al. (2017): A Simple but Tough-to-Beat Baseline for Sentence Embeddings. 35 | For more information, see and 36 | 37 | Parameters 38 | ---------- 39 | model : :class:`~gensim.models.keyedvectors.KeyedVectors` or :class:`~gensim.models.base_any2vec.BaseWordEmbeddingsModel` 40 | This object essentially contains the mapping between words and embeddings. To compute the sentence embeddings 41 | the wv.vocab and wv.vector elements are required. 42 | alpha : float, optional 43 | Alpha is the weighting factor used to downweigh each individual word. 44 | components : int, optional 45 | Corresponds to the number of singular vectors to remove from the sentence embeddings. 46 | cache_size_gb : float, optional 47 | Cache size for computing the singular vectors in GB. 48 | sv_mapfile_path : str, optional 49 | Optional path to store the sentence-vectors in for very large datasets. Used for memmap. 50 | wv_mapfile_path : str, optional 51 | Optional path to store the word-vectors in for very large datasets. Used for memmap. 52 | Use sv_mapfile_path and wv_mapfile_path to train disk-to-disk without needing much ram. 53 | workers : int, optional 54 | Number of working threads, used for multithreading. For most tasks (few words in a sentence) 55 | a value of 1 should be more than enough. 56 | lang_freq : str, optional 57 | Some pre-trained embeddings, i.e. "GoogleNews-vectors-negative300.bin", do not contain information about 58 | the frequency of a word. As the frequency is required for estimating the word weights, we induce 59 | frequencies into the wv.vocab.count based on :class:`~wordfreq` 60 | If no frequency information is available, you can choose the language to estimate the frequency. 61 | See https://github.com/LuminosoInsight/wordfreq 62 | 63 | """ 64 | 65 | self.alpha = float(alpha) 66 | self.components = int(components) 67 | self.cache_size_gb = float(cache_size_gb) 68 | self.svd_res = None 69 | 70 | if lang_freq is None: 71 | logger.info( 72 | "make sure you are using a model with valid word-frequency information. Otherwise use lang_freq argument." 73 | ) 74 | 75 | super(SIF, self).__init__( 76 | model=model, 77 | sv_mapfile_path=sv_mapfile_path, 78 | wv_mapfile_path=wv_mapfile_path, 79 | workers=workers, 80 | lang_freq=lang_freq, 81 | ) 82 | 83 | def _check_parameter_sanity(self): 84 | """ Check the sanity of all paramters """ 85 | if not all(self.word_weights <= 1.0) or not all(self.word_weights >= 0.0): 86 | raise ValueError("For SIF, all word weights must be 0 <= w_weight <= 1") 87 | if self.alpha <= 0.0: 88 | raise ValueError("Alpha must be greater than zero.") 89 | if self.components < 0.0: 90 | raise ValueError("Components must be greater or equal zero") 91 | 92 | def _pre_train_calls(self, **kwargs): 93 | """Function calls to perform before training """ 94 | self._compute_sif_weights() 95 | 96 | def _post_train_calls(self): 97 | """ Function calls to perform after training, such as computing eigenvectors """ 98 | if self.components > 0: 99 | self.svd_res = compute_principal_components( 100 | self.sv.vectors, 101 | components=self.components, 102 | cache_size_gb=self.cache_size_gb, 103 | ) 104 | remove_principal_components( 105 | self.sv.vectors, svd_res=self.svd_res, inplace=True 106 | ) 107 | else: 108 | self.svd_res = 0 109 | logger.info(f"no removal of principal components") 110 | 111 | def _post_inference_calls(self, output: ndarray, **kwargs): 112 | """ Function calls to perform after training & inference """ 113 | if self.svd_res is None: 114 | raise RuntimeError( 115 | "You must first train the model to obtain SVD components" 116 | ) 117 | elif self.components > 0: 118 | remove_principal_components(output, svd_res=self.svd_res, inplace=True) 119 | else: 120 | logger.info(f"no removal of principal components") 121 | 122 | def _check_dtype_santiy(self): 123 | """ Check the dtypes of all attributes """ 124 | if self.word_weights.dtype != REAL: 125 | raise TypeError(f"type of word_weights is wrong: {self.word_weights.dtype}") 126 | if self.svd_res is not None: 127 | if self.svd_res[0].dtype != REAL: 128 | raise TypeError(f"type of svd values is wrong: {self.svd_res[0].dtype}") 129 | if self.svd_res[1].dtype != REAL: 130 | raise TypeError( 131 | f"type of svd components is wrong: {self.svd_res[1].dtype}" 132 | ) 133 | 134 | def _compute_sif_weights(self): 135 | """ Precomputes the SIF weights for all words in the vocabulary """ 136 | logger.info(f"pre-computing SIF weights for {len(self.wv)} words") 137 | v = len(self.wv) 138 | corpus_size = 0 139 | 140 | pw = zeros(v, dtype=REAL) 141 | for word in self.wv.key_to_index: 142 | c = self.wv.get_vecattr(word, "count") 143 | if c < 0: 144 | raise ValueError("vocab count is negative") 145 | corpus_size += c 146 | pw[self.wv.key_to_index[word]] = c 147 | pw /= corpus_size 148 | 149 | self.word_weights = (self.alpha / (self.alpha + pw)).astype(REAL) 150 | 151 | if not all(isfinite(self.word_weights)) or any(self.word_weights < 0): 152 | raise RuntimeError( 153 | "Encountered nan values. " 154 | "This likely happens because the word frequency information is wrong/missing. " 155 | "Consider restarting using lang_freq argument to infer frequency. " 156 | ) 157 | -------------------------------------------------------------------------------- /test/test_data/test_sentences.txt: -------------------------------------------------------------------------------- 1 | Good stuff i just wish it lasted longer 2 | Hp makes qualilty stuff 3 | I like it 4 | Try it you will like it 5 | El producto es excelente me encanto , recomiendo al vendedor y el producto ampliamente , se cumplio con el tiempo estimado de envio 6 | Unfortunatly my first was the wrong battery 7 | The second order did the job and was received within 5 days and working as good as ever 8 | I did not get a return label to send the first battery back 9 | Could you please send one to me so i can get the credit back on my credit card 10 | Thank you for your help 11 | I bought the 32 gig touch for my 12 year old granddaughter and she loves it , especially the video camera feature 12 | It is amazing to me how adept the young people are in using this great tool 13 | She had a older itouch which is now being used by her 6 year old sister 14 | They both love the vast array of apps that are available from apple 15 | I recommend this as a great gift 16 | Apple products are amazing 17 | Mine just came in the mail today and i am about to set it up 18 | I went to lg 19 | Com and they have it for sale there for the msrp of $ 79 20 | Save yourself money and buy it direct from lg 21 | I feel like i just got screwed 22 | As advised by many reviews , i bought a size bigger and that is really a good piece of advice because the belt fits just right at the middle hole 23 | I must admit i had little hope for these cheap little buds 24 | But after popping them in for a test drive , i must admit that they are so comfy 25 | I keep forgetting they are in and have to keep checking 26 | The sound quality is pretty amazing for such a cheap pair of buds , clear , loud , not like some of the $ 20 - $ 30 ear buds i have tried that sound muffled and quiet , and i end up pitching 27 | I do not normally leave reviews on amazon , but these are great little buds that deserve a try 28 | I have just bought 5 more pairs , some as gifts , but mostly for me to hoard , as i will no doubt loose / break this pair eventually 29 | The only thing i can not say , is how they stay in with movement , i mainly use earbuds when i am working on a computer or when walking the dogs for a few hours 30 | So we will see how they hold up 31 | I have to say i was very disappointed with this trimmer after using it for the first time 32 | I bought it to replace the nose / ear trimming portion of my all - in - one groomer that finally gave out after 7 years 33 | The instructions for cleaning the trimmer are more confusing than they need to be and the actual cleaning is tedious 34 | I was more than a little miffed that using the device on my nose often resulted in quick little tugs as the blade yanked a hair out rather than cutting it 35 | And after all that , i had to go back and use it again right away because it left huge conspicuous hairs untouched after the first use 36 | All in all it took about 5 minutes , whereas i would expect such a procedure to take maybe 30 - 45 seconds 37 | The blade this replaced was one of 5 uses for a similarly cheap groomer , and it worked infinitely better 38 | I am going to have to look and see if they still make that model and replace this junker 39 | I was really skeptical at first , but the phone was even better than i expected 40 | I absolutely love it 41 | I had no issues with it whatsoever and will continue doing business with them 42 | Really awesome 43 | I love the color , the ease of usage , hte size matterof fact everything about this nano except that i droped it on the ceramic floor and broke the glass 44 | It still works 45 | That was a bummer 46 | Looks good and fits snug 47 | I am happy with it 48 | I use it at the gym 49 | I love the shade of green 50 | I was originally turned onto this via starbucks as an iced tea so now i usually use to make iced tea and was so happy to find it in the k - cups 51 | A basic phone which does what it is suppose to do , the thing i like most is the distinctive ring 52 | Easy to set up and use 53 | I really like this phone 54 | It is easy to use , and minimal directions 55 | There is alot of bells and whistles , but nothing that can not be programmed in a few seconds 56 | I have not had the experience others have had on this 57 | I love this thing 58 | It is like having a coach with you 59 | You know when you are done with your run , it is going to show you the efforts of your run 60 | I live out in the sticks , but it never has had an issue connecting with the satellite and the web site shows me all of the pertinent info on my run for the day 61 | Best thing i have used to track performance when running 62 | I have always been a big fan of hp products 63 | This cartridge works perfectly and gives a nice clear and clean print 64 | Provides a great cheap alternative to the apple tv or an expensive smart tv 65 | I was so happy with mine that i have since purchased more as gifts for family members 66 | Easy to set up 67 | User interface is intuitive and easy to navigate 68 | Poised to only get better as google adds new features and supported appsvery happy with purchase and would recommend to others 69 | I think amazon should take this letter in case people are scammers first , if amazon does not seem to me that are accomplices then , buy a bold 9700 phone to reach venezuela just switch it on and the phone does not work the screen goes blank all , i am very upset seems that mocked me as a customer lost my money and here amazon does nothing or draws attention to the store 70 | By the way do not buy a new phone repowered it really disrespectful 71 | It does not have the standard rear - facing camera for normal picture taking and videos 72 | It has a face - time camera only 73 | This makes it hard to take pictures with the thing 74 | I broke some keys on my laptop , so rather than spend $ 150 to have it repaired , i bought this instead and my computer is good as new 75 | I am a novice with computer repairs and this took me about five minutes to replace 76 | Pretty good for the price 77 | Separate ink cartridge for color really saves you money 78 | Very difficult to make this wireless printer work with my laptop 79 | Just settled for wired use 80 | Love the phone , good speed , all that 81 | Only thing i find awkward about is , so that when i try to change the volume from the controls on the left of the phone , i often accidentally press the power button on the right of the phone 82 | Which is annoying 83 | I have to then press the power button again and log in to continue what i was doing 84 | The product came within the time suggested 85 | It works great 86 | Worth the money i paid for it 87 | Good buy for sure 88 | This headset actually works well , but it is so uncomfortable to wear that i do not use it 89 | I found success with the panasonic kx - tca60 hands - free headset with comfort fit headband and the headset buddy adapter 01 - ph25 - pc35 phone headset to pc adapter 90 | That combination not only sounds great , but it so comfortable i forget to take it off 91 | Fast shipping on sharp good ceramic knife set 92 | Have held sharpness for a while of use and will look for ceramic sharpening options 93 | I bought this as a gift at christmas 94 | I have heard nothing but praise for it 95 | You can steam anything you want onto your tv from your computer 96 | Makes any tv a smart tv for less than $ 40 97 | I received the battery as ordered 98 | Replaced the old battery with the new 99 | It is not holding a charge as long as the old battery did when it was new 100 | I am not sure if it is a tracfone problem or the battery 101 | -------------------------------------------------------------------------------- /test/test_inputs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | 7 | 8 | """Automated tests for checking the input methods.""" 9 | 10 | import logging 11 | import unittest 12 | from pathlib import Path 13 | 14 | import numpy as np 15 | 16 | from fse.inputs import ( 17 | BaseIndexedList, 18 | CIndexedList, 19 | CSplitCIndexedList, 20 | CSplitIndexedList, 21 | IndexedLineDocument, 22 | IndexedList, 23 | SplitCIndexedList, 24 | SplitIndexedList, 25 | ) 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class TestBaseIndexedList(unittest.TestCase): 31 | def setUp(self): 32 | self.list_a = ["the dog is good", "it's nice and comfy"] 33 | self.list_b = ["lorem ipsum dolor", "si amet"] 34 | self.list_c = [s.split() for s in self.list_a] 35 | self.set_a = set(["hello there", "its a set"]) 36 | self.arr_a = np.array(self.list_a) 37 | self.l = BaseIndexedList(self.list_a) 38 | self.ll = BaseIndexedList(self.list_a, self.list_b, self.list_c) 39 | 40 | def test_init(self): 41 | _ = BaseIndexedList(self.list_a) 42 | 43 | def test_init_mult_arg(self): 44 | self.assertEqual(6, len(self.ll.items)) 45 | 46 | def test_init_ndarray(self): 47 | _ = BaseIndexedList(self.arr_a) 48 | 49 | def test__check_list_type(self): 50 | with self.assertRaises(TypeError): 51 | self.l._check_list_type(1) 52 | with self.assertRaises(TypeError): 53 | self.l._check_list_type("Hello") 54 | 55 | def test__check_str_type(self): 56 | self.assertEqual(1, self.l._check_str_type("Hello")) 57 | with self.assertRaises(TypeError): 58 | self.l._check_str_type(1) 59 | with self.assertRaises(TypeError): 60 | self.l._check_str_type([]) 61 | 62 | def test__len(self): 63 | self.assertEqual(2, len(self.l)) 64 | 65 | def test__str(self): 66 | self.assertEqual("['the dog is good', \"it's nice and comfy\"]", str(self.l)) 67 | 68 | def test__getitem(self): 69 | with self.assertRaises(NotImplementedError): 70 | self.l[0] 71 | 72 | def test__delitem(self): 73 | self.ll.__delitem__(0) 74 | self.assertEqual(5, len(self.ll)) 75 | 76 | def test__setitem(self): 77 | self.ll.__setitem__(0, "is it me?") 78 | self.assertEqual("is it me?", self.ll.items[0]) 79 | 80 | def test_append(self): 81 | self.ll.append("is it me?") 82 | self.assertEqual("is it me?", self.ll.items[-1]) 83 | 84 | def test_extend(self): 85 | self.ll.extend(self.list_a) 86 | self.assertEqual(8, len(self.ll)) 87 | 88 | self.ll.extend(self.set_a) 89 | self.assertEqual(10, len(self.ll)) 90 | 91 | def test_extend_ndarr(self): 92 | l = BaseIndexedList(np.array([str(i) for i in [1, 2, 3, 4]])) 93 | l.extend(np.array([str(i) for i in [1, 2, 3, 4]])) 94 | self.assertEqual(8, len(l)) 95 | 96 | 97 | class TestIndexedList(unittest.TestCase): 98 | def setUp(self): 99 | self.list_a = ["the dog is good", "it's nice and comfy"] 100 | self.list_b = [s.split() for s in self.list_a] 101 | self.il = IndexedList(self.list_a, self.list_b) 102 | 103 | def test_init(self): 104 | _ = IndexedList(self.list_a) 105 | 106 | def test_getitem(self): 107 | self.assertEqual(("the dog is good", 0), self.il[0]) 108 | 109 | def test_split(self): 110 | l = SplitIndexedList(self.list_a) 111 | self.assertEqual("the dog is good".split(), l[0][0]) 112 | 113 | 114 | class TestCIndexedList(unittest.TestCase): 115 | def setUp(self): 116 | self.list_a = ["The Dog is good", "it's nice and comfy"] 117 | self.il = CIndexedList(self.list_a, custom_index=[1, 1]) 118 | 119 | def test_cust_index(self): 120 | self.assertEqual(1, self.il[0][1]) 121 | 122 | def test_wrong_len(self): 123 | with self.assertRaises(RuntimeError): 124 | CIndexedList(self.list_a, custom_index=[1]) 125 | 126 | def test_mutable_funcs(self): 127 | with self.assertRaises(NotImplementedError): 128 | self.il.__delitem__(0) 129 | with self.assertRaises(NotImplementedError): 130 | self.il.__setitem__(0, "the") 131 | 132 | with self.assertRaises(NotImplementedError): 133 | self.il.insert(0, "the") 134 | with self.assertRaises(NotImplementedError): 135 | self.il.append("the") 136 | with self.assertRaises(NotImplementedError): 137 | self.il.extend(["the", "dog"]) 138 | 139 | 140 | class TestCSplitIndexedList(unittest.TestCase): 141 | def setUp(self): 142 | self.list_a = ["The Dog is good", "it's nice and comfy"] 143 | self.il = CSplitIndexedList(self.list_a, custom_split=self.split_func) 144 | 145 | def split_func(self, input): 146 | return input.lower().split() 147 | 148 | def test_getitem(self): 149 | self.assertEqual("the dog is good".split(), self.il[0][0]) 150 | 151 | 152 | class TestSplitCIndexedList(unittest.TestCase): 153 | def setUp(self): 154 | self.list_a = ["The Dog is good", "it's nice and comfy"] 155 | self.il = SplitCIndexedList(self.list_a, custom_index=[1, 1]) 156 | 157 | def test_getitem(self): 158 | self.assertEqual(("The Dog is good".split(), 1), self.il[0]) 159 | 160 | def test_mutable_funcs(self): 161 | with self.assertRaises(NotImplementedError): 162 | self.il.__delitem__(0) 163 | with self.assertRaises(NotImplementedError): 164 | self.il.__setitem__(0, "the") 165 | 166 | with self.assertRaises(NotImplementedError): 167 | self.il.insert(0, "the") 168 | with self.assertRaises(NotImplementedError): 169 | self.il.append("the") 170 | with self.assertRaises(NotImplementedError): 171 | self.il.extend(["the", "dog"]) 172 | 173 | 174 | class TestCSplitCIndexedList(unittest.TestCase): 175 | def setUp(self): 176 | self.list_a = ["The Dog is good", "it's nice and comfy"] 177 | self.il = CSplitCIndexedList( 178 | self.list_a, custom_split=self.split_func, custom_index=[1, 1] 179 | ) 180 | 181 | def split_func(self, input): 182 | return input.lower().split() 183 | 184 | def test_getitem(self): 185 | self.assertEqual(("the dog is good".split(), 1), self.il[0]) 186 | 187 | def test_mutable_funcs(self): 188 | with self.assertRaises(NotImplementedError): 189 | self.il.__delitem__(0) 190 | with self.assertRaises(NotImplementedError): 191 | self.il.__setitem__(0, "the") 192 | 193 | with self.assertRaises(NotImplementedError): 194 | self.il.insert(0, "the") 195 | with self.assertRaises(NotImplementedError): 196 | self.il.append("the") 197 | with self.assertRaises(NotImplementedError): 198 | self.il.extend(["the", "dog"]) 199 | 200 | 201 | class TestIndexedLineDocument(unittest.TestCase): 202 | def setUp(self): 203 | self.p = Path(__file__).parent / "test_data" / "test_sentences.txt" 204 | self.doc = IndexedLineDocument(self.p) 205 | 206 | def test_getitem(self): 207 | self.assertEqual("Good stuff i just wish it lasted longer", self.doc[0]) 208 | self.assertEqual("Save yourself money and buy it direct from lg", self.doc[19]) 209 | self.assertEqual( 210 | "I am not sure if it is a tracfone problem or the battery", self.doc[-1] 211 | ) 212 | 213 | def test_yield(self): 214 | first = ("Good stuff i just wish it lasted longer".split(), 0) 215 | last = ("I am not sure if it is a tracfone problem or the battery".split(), 99) 216 | for i, obj in enumerate(self.doc): 217 | if i == 0: 218 | self.assertEqual(first, obj) 219 | if i == 99: 220 | self.assertEqual(last, obj) 221 | 222 | 223 | if __name__ == "__main__": 224 | logging.basicConfig( 225 | format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG 226 | ) 227 | unittest.main() 228 | -------------------------------------------------------------------------------- /fse/models/usif.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | 7 | import logging 8 | 9 | from gensim.models.keyedvectors import KeyedVectors 10 | from numpy import float32 as REAL 11 | from numpy import isfinite, ndarray, zeros 12 | 13 | from fse.models.average import Average 14 | from fse.models.utils import ( 15 | EPS, 16 | compute_principal_components, 17 | remove_principal_components, 18 | ) 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class uSIF(Average): 24 | def __init__( 25 | self, 26 | model: KeyedVectors, 27 | length: int = None, 28 | components: int = 5, 29 | cache_size_gb: float = 1.0, 30 | sv_mapfile_path: str = None, 31 | wv_mapfile_path: str = None, 32 | workers: int = 1, 33 | lang_freq: str = None, 34 | ): 35 | """Unsupervised smooth-inverse frequency (uSIF) weighted sentence embeddings 36 | model. Performs a weighted averaging operation over all words in a sentences. 37 | After training, the model removes a number of weighted singular vectors. 38 | 39 | The implementation is based on Ethayarajh (2018): Unsupervised Random Walk Sentence Embeddings: A Strong but Simple Baseline. 40 | For more information, see and 41 | 42 | Parameters 43 | ---------- 44 | model : :class:`~gensim.models.keyedvectors.KeyedVectors` or :class:`~gensim.models.base_any2vec.BaseWordEmbeddingsModel` 45 | This object essentially contains the mapping between words and embeddings. To compute the sentence embeddings 46 | the wv.vocab and wv.vector elements are required. 47 | length : int, optional 48 | Corresponds to the average number of words in a sentence in the training corpus. 49 | If length is None, then the model takes the average number of words from 50 | :meth: `~fse.models.base_s2v.BaseSentence2VecModel.scan_sentences` 51 | Is equivalent to n in the paper. 52 | components : int, optional 53 | Corresponds to the number of singular vectors to remove from the sentence embeddings. 54 | Is equivalent to m in the paper. 55 | cache_size_gb : float, optional 56 | Cache size for computing the singular vectors in GB. 57 | sv_mapfile_path : str, optional 58 | Optional path to store the sentence-vectors in for very large datasets. Used for memmap. 59 | wv_mapfile_path : str, optional 60 | Optional path to store the word-vectors in for very large datasets. Used for memmap. 61 | Use sv_mapfile_path and wv_mapfile_path to train disk-to-disk without needing much ram. 62 | workers : int, optional 63 | Number of working threads, used for multithreading. For most tasks (few words in a sentence) 64 | a value of 1 should be more than enough. 65 | lang_freq : str, optional 66 | Some pre-trained embeddings, i.e. "GoogleNews-vectors-negative300.bin", do not contain information about 67 | the frequency of a word. As the frequency is required for estimating the word weights, we induce 68 | frequencies into the wv.vocab.count based on :class:`~wordfreq` 69 | If no frequency information is available, you can choose the language to estimate the frequency. 70 | See https://github.com/LuminosoInsight/wordfreq 71 | """ 72 | 73 | self.length = length 74 | self.components = int(components) 75 | self.cache_size_gb = float(cache_size_gb) 76 | self.svd_res = None 77 | self.svd_weights = None 78 | 79 | if lang_freq is None: 80 | logger.info( 81 | "make sure you are using a model with valid word-frequency information. Otherwise use lang_freq argument." 82 | ) 83 | 84 | super(uSIF, self).__init__( 85 | model=model, 86 | sv_mapfile_path=sv_mapfile_path, 87 | wv_mapfile_path=wv_mapfile_path, 88 | workers=workers, 89 | lang_freq=lang_freq, 90 | ) 91 | 92 | def _check_parameter_sanity(self): 93 | """Check the sanity of all paramters.""" 94 | if self.length <= 0.0: 95 | raise ValueError("Length must be greater than zero.") 96 | if self.components < 0.0: 97 | raise ValueError("Components must be greater or equal zero") 98 | 99 | def _pre_train_calls(self, **kwargs): 100 | """Function calls to perform before training.""" 101 | self.length = kwargs["average_length"] if self.length is None else self.length 102 | self._compute_usif_weights() 103 | 104 | def _post_train_calls(self): 105 | """Function calls to perform after training, such as computing eigenvectors.""" 106 | if self.components > 0: 107 | self.svd_res = compute_principal_components( 108 | self.sv.vectors, 109 | components=self.components, 110 | cache_size_gb=self.cache_size_gb, 111 | ) 112 | self.svd_weights = (self.svd_res[0] ** 2) / ( 113 | self.svd_res[0] ** 2 114 | ).sum().astype(REAL) 115 | remove_principal_components( 116 | self.sv.vectors, 117 | svd_res=self.svd_res, 118 | weights=self.svd_weights, 119 | inplace=True, 120 | ) 121 | else: 122 | self.svd_res = 0 123 | logger.info(f"no removal of principal components") 124 | 125 | def _post_inference_calls(self, output: ndarray, **kwargs): 126 | """Function calls to perform after training & inference.""" 127 | if self.svd_res is None: 128 | raise RuntimeError( 129 | "You must first train the model to obtain SVD components" 130 | ) 131 | elif self.components > 0: 132 | remove_principal_components( 133 | output, svd_res=self.svd_res, weights=self.svd_weights, inplace=True 134 | ) 135 | else: 136 | logger.info(f"no removal of principal components") 137 | 138 | def _check_dtype_santiy(self): 139 | """Check the dtypes of all attributes.""" 140 | if self.word_weights.dtype != REAL: 141 | raise TypeError(f"type of word_weights is wrong: {self.word_weights.dtype}") 142 | if self.svd_res is not None: 143 | if self.svd_res[0].dtype != REAL: 144 | raise TypeError(f"type of svd values is wrong: {self.svd_res[0].dtype}") 145 | if self.svd_res[1].dtype != REAL: 146 | raise TypeError( 147 | f"type of svd components is wrong: {self.svd_res[1].dtype}" 148 | ) 149 | if self.svd_weights.dtype != REAL: 150 | raise TypeError( 151 | f"type of svd weights is wrong: {self.svd_weights.dtype}" 152 | ) 153 | 154 | def _compute_usif_weights(self): 155 | """Precomputes the uSIF weights.""" 156 | logger.info(f"pre-computing uSIF weights for {len(self.wv)} words") 157 | v = len(self.wv) 158 | corpus_size = 0 159 | 160 | pw = zeros(v, dtype=REAL) 161 | for word in self.wv.key_to_index: 162 | c = self.wv.get_vecattr(word, "count") 163 | if c < 0: 164 | raise ValueError("vocab count is negative") 165 | corpus_size += c 166 | pw[self.wv.key_to_index[word]] = c 167 | pw /= corpus_size 168 | 169 | threshold = 1 - (1 - (1 / v)) ** self.length 170 | alpha = sum(pw > threshold) / v 171 | z = v / 2 172 | a = (1 - alpha) / ((alpha * z) + EPS) 173 | 174 | self.word_weights = (a / ((a / 2) + pw)).astype(REAL) 175 | 176 | if not all(isfinite(self.word_weights)): 177 | raise RuntimeError( 178 | "Encountered nan values. " 179 | "This likely happens because the word frequency information is wrong/missing. " 180 | "Consider restarting using lang_freq argument to infer frequency. " 181 | ) 182 | -------------------------------------------------------------------------------- /notebooks/Speed Comparision.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Speed comparision" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from fse import Vectors, Average\n", 17 | "from fse.models.average import train_average_np\n", 18 | "from fse.models.average_inner import train_average_cy\n", 19 | "\n", 20 | "from fse.models.average import MAX_WORDS_IN_BATCH\n", 21 | "\n", 22 | "from fse import IndexedList\n", 23 | "\n", 24 | "import numpy as np\n", 25 | "\n", 26 | "import gensim.downloader as api\n", 27 | "data = api.load(\"quora-duplicate-questions\")\n", 28 | "\n", 29 | "sentences = []\n", 30 | "batch_size = 0\n", 31 | "for d in data:\n", 32 | " strings = d[\"question1\"].split()\n", 33 | " if len(strings) + batch_size < MAX_WORDS_IN_BATCH:\n", 34 | " sentences.append(strings)\n", 35 | " batch_size += len(strings)\n", 36 | "sentences = IndexedList(sentences)" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "# Test W2V Model" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "import gensim.downloader as api\n", 53 | "\n", 54 | "w2v = Vectors.from_pretrained(\"glove-wiki-gigaword-100\")\n", 55 | "ft = api.load(\"fasttext-wiki-news-subwords-300\")" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "To test if the fast version is available, you need to import the variable FAST_VERSION from fse.models.average. \n", 63 | "1 : The cython version is available\n", 64 | "-1 : The cython version is not available.\n", 65 | "\n", 66 | "If the cython compiliation fails, you will be notified." 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": 3, 72 | "metadata": {}, 73 | "outputs": [ 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "1" 78 | ] 79 | }, 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "from fse.models.average import FAST_VERSION\n", 87 | "FAST_VERSION" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 4, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "119 µs ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" 100 | ] 101 | } 102 | ], 103 | "source": [ 104 | "%%timeit\n", 105 | "w2v_avg = Average(w2v)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 5, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "1.18 s ± 37.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" 118 | ] 119 | } 120 | ], 121 | "source": [ 122 | "%%timeit\n", 123 | "w2v_avg = Average(w2v, lang_freq=\"en\")" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "metadata": {}, 129 | "source": [ 130 | "The slowest part during the init is the induction of frequencies for words, as some pre-trained embeddings do not come with frequencies for words. This is only necessary for the SIF and uSIF Model, not for the Average model." 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": 6, 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "w2v_avg = Average(w2v)\n", 140 | "statistics = w2v_avg.scan_sentences(sentences)\n", 141 | "w2v_avg.prep.prepare_vectors(sv=w2v_avg.sv, total_sentences=statistics[\"max_index\"], update=False)\n", 142 | "memory = w2v_avg._get_thread_working_mem()" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 7, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "35.7 ms ± 1.48 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "%%timeit\n", 160 | "train_average_np(model=w2v_avg, indexed_sentences=sentences, target=w2v_avg.sv.vectors, memory=memory)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 8, 166 | "metadata": {}, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | "1.93 ms ± 26.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" 173 | ] 174 | } 175 | ], 176 | "source": [ 177 | "%%timeit\n", 178 | "train_average_cy(model=w2v_avg, indexed_sentences=sentences, target=w2v_avg.sv.vectors, memory=memory)" 179 | ] 180 | }, 181 | { 182 | "cell_type": "markdown", 183 | "metadata": {}, 184 | "source": [ 185 | "For 90 sentences, the Cython version is about 8-15 faster than the numpy version when using a Word2Vec type model." 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 9, 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "data": { 195 | "text/plain": [ 196 | "True" 197 | ] 198 | }, 199 | "execution_count": 9, 200 | "metadata": {}, 201 | "output_type": "execute_result" 202 | } 203 | ], 204 | "source": [ 205 | "out_w2v_np = np.zeros_like(w2v_avg.sv.vectors)\n", 206 | "out_w2v_cy = np.zeros_like(w2v_avg.sv.vectors)\n", 207 | "train_average_np(model=w2v_avg, indexed_sentences=sentences, target=out_w2v_np, memory=w2v_avg._get_thread_working_mem())\n", 208 | "train_average_cy(model=w2v_avg, indexed_sentences=sentences, target=out_w2v_cy, memory=w2v_avg._get_thread_working_mem())\n", 209 | "\n", 210 | "np.allclose(out_w2v_np, out_w2v_cy)" 211 | ] 212 | }, 213 | { 214 | "cell_type": "markdown", 215 | "metadata": {}, 216 | "source": [ 217 | "# Test FastTextModel" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": 10, 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "ft_avg = Average(ft)\n", 227 | "statistics = ft_avg.scan_sentences(sentences)\n", 228 | "ft_avg.prep.prepare_vectors(sv=ft_avg.sv, total_sentences=statistics[\"max_index\"], update=False)\n", 229 | "memory = ft_avg._get_thread_working_mem()" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 11, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "36.8 ms ± 554 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "%%timeit\n", 247 | "train_average_np(model=ft_avg, indexed_sentences=sentences, target=ft_avg.sv.vectors, memory=memory)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 12, 253 | "metadata": {}, 254 | "outputs": [ 255 | { 256 | "name": "stdout", 257 | "output_type": "stream", 258 | "text": [ 259 | "2.54 ms ± 44.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" 260 | ] 261 | } 262 | ], 263 | "source": [ 264 | "%%timeit\n", 265 | "train_average_cy(model=ft_avg, indexed_sentences=sentences, target=ft_avg.sv.vectors, memory=memory)" 266 | ] 267 | }, 268 | { 269 | "cell_type": "markdown", 270 | "metadata": {}, 271 | "source": [ 272 | "With a FastText type model, the cython routine is about 5-10 times faster." 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 13, 278 | "metadata": {}, 279 | "outputs": [ 280 | { 281 | "data": { 282 | "text/plain": [ 283 | "True" 284 | ] 285 | }, 286 | "execution_count": 13, 287 | "metadata": {}, 288 | "output_type": "execute_result" 289 | } 290 | ], 291 | "source": [ 292 | "out_ft_np = np.zeros_like(ft_avg.sv.vectors)\n", 293 | "out_ft_cy = np.zeros_like(ft_avg.sv.vectors)\n", 294 | "train_average_np(model=ft_avg, indexed_sentences=sentences, target=out_ft_np, memory=ft_avg._get_thread_working_mem())\n", 295 | "train_average_cy(model=ft_avg, indexed_sentences=sentences, target=out_ft_cy, memory=ft_avg._get_thread_working_mem())\n", 296 | "\n", 297 | "np.allclose(out_ft_np, out_ft_cy)" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [] 306 | } 307 | ], 308 | "metadata": { 309 | "kernelspec": { 310 | "display_name": "Python 3", 311 | "language": "python", 312 | "name": "python3" 313 | }, 314 | "language_info": { 315 | "codemirror_mode": { 316 | "name": "ipython", 317 | "version": 3 318 | }, 319 | "file_extension": ".py", 320 | "mimetype": "text/x-python", 321 | "name": "python", 322 | "nbconvert_exporter": "python", 323 | "pygments_lexer": "ipython3", 324 | "version": "3.8.5" 325 | } 326 | }, 327 | "nbformat": 4, 328 | "nbformat_minor": 4 329 | } 330 | -------------------------------------------------------------------------------- /test/test_sentencevectors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | 7 | 8 | """Automated tests for checking the sentence vectors.""" 9 | 10 | import logging 11 | import unittest 12 | from pathlib import Path 13 | 14 | import numpy as np 15 | from gensim.models import Word2Vec 16 | 17 | from fse.inputs import IndexedLineDocument, IndexedList 18 | from fse.models.average import Average 19 | from fse.models.sentencevectors import SentenceVectors, _l2_norm 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | TEST_DATA = Path(__file__).parent / "test_data" 24 | CORPUS = TEST_DATA / "test_sentences.txt" 25 | DIM = 5 26 | W2V = Word2Vec(min_count=1, vector_size=DIM, seed=42) 27 | with open(CORPUS, "r") as file: 28 | SENTENCES = [l.split() for _, l in enumerate(file)] 29 | W2V.build_vocab(SENTENCES) 30 | 31 | rng = np.random.default_rng(12345) 32 | W2V.wv.vectors = rng.uniform(size=W2V.wv.vectors.shape).astype(np.float32) 33 | 34 | 35 | class TestSentenceVectorsFunctions(unittest.TestCase): 36 | def setUp(self): 37 | self.sv = SentenceVectors(2) 38 | self.sv.vectors = np.arange(10).reshape(5, 2) 39 | 40 | def test_getitem(self): 41 | self.assertTrue(([0, 1] == self.sv[0]).all()) 42 | self.assertTrue(([[0, 1], [4, 5]] == self.sv[[0, 2]]).all()) 43 | 44 | def test_isin(self): 45 | self.assertIn(0, self.sv) 46 | self.assertNotIn(5, self.sv) 47 | 48 | def test_init_sims_wo_replace(self): 49 | self.sv.init_sims() 50 | self.assertIsNotNone(self.sv.vectors_norm) 51 | self.assertFalse((self.sv.vectors == self.sv.vectors_norm).all()) 52 | 53 | v1 = self.sv.vectors[0] 54 | v1 = v1 / np.sqrt(np.sum(v1 ** 2)) 55 | 56 | v2 = self.sv.vectors[1] 57 | v2 = v2 / np.sqrt(np.sum(v2 ** 2)) 58 | 59 | self.assertTrue(np.allclose(v1, self.sv.vectors_norm[0])) 60 | self.assertTrue(np.allclose(v2, self.sv.vectors_norm[1])) 61 | self.assertTrue(np.allclose(v2, self.sv.get_vector(1, True))) 62 | 63 | def test_get_vector(self): 64 | self.assertTrue(([0, 1] == self.sv.get_vector(0)).all()) 65 | self.assertTrue(([2, 3] == self.sv.get_vector(1)).all()) 66 | 67 | def test_init_sims_w_replace(self): 68 | self.sv.init_sims(True) 69 | self.assertTrue(np.allclose(self.sv.vectors[0], self.sv.vectors_norm[0])) 70 | 71 | def test_init_sims_w_mapfile(self): 72 | p = TEST_DATA / "test_vectors" 73 | self.sv.mapfile_path = str(p.absolute()) 74 | self.sv.init_sims() 75 | p = TEST_DATA / "test_vectors.vectors_norm" 76 | self.assertTrue(p.exists()) 77 | p.unlink() 78 | 79 | def test_save_load(self): 80 | p = TEST_DATA / "test_vectors.vectors" 81 | self.sv.save(str(p.absolute())) 82 | self.assertTrue(p.exists()) 83 | sv2 = SentenceVectors.load(str(p.absolute())) 84 | self.assertTrue((self.sv.vectors == sv2.vectors).all()) 85 | p.unlink() 86 | 87 | def test_save_load_with_memmap(self): 88 | p = TEST_DATA / "test_vectors" 89 | p_target = TEST_DATA / "test_vectors.vectors" 90 | p_not_exists = TEST_DATA / "test_vectors.vectors.npy" 91 | 92 | sv = SentenceVectors(2, mapfile_path=str(p)) 93 | 94 | shape = (1000, 1000) 95 | sv.vectors = np.ones(shape, dtype=np.float32) 96 | 97 | memvecs = np.memmap(p_target, dtype=np.float32, mode="w+", shape=shape) 98 | memvecs[:] = sv.vectors[:] 99 | del memvecs 100 | 101 | self.assertTrue(p_target.exists()) 102 | sv.save(str(p.absolute())) 103 | self.assertTrue(p.exists()) 104 | self.assertFalse(p_not_exists.exists()) 105 | 106 | sv = SentenceVectors.load(str(p.absolute())) 107 | self.assertEqual(shape, sv.vectors.shape) 108 | 109 | for t in [p, p_target]: 110 | t.unlink() 111 | 112 | def test_len(self): 113 | self.assertEqual(5, len(self.sv)) 114 | 115 | def test_similarity(self): 116 | v1 = self.sv.vectors[0] 117 | v1 = v1 / np.sqrt(np.sum(v1 ** 2)) 118 | 119 | v2 = self.sv.vectors[1] 120 | v2 = v2 / np.sqrt(np.sum(v2 ** 2)) 121 | 122 | self.assertTrue(np.allclose(v1.dot(v2), self.sv.similarity(0, 1))) 123 | self.assertTrue(np.allclose(1 - v1.dot(v2), self.sv.distance(0, 1))) 124 | 125 | def test_most_similar(self): 126 | sent_ind = IndexedList(SENTENCES) 127 | sentences = IndexedLineDocument(CORPUS) 128 | m = Average(W2V) 129 | m.train(sentences) 130 | o = m.sv.most_similar(positive=0) 131 | self.assertEqual(50, o[0][0]) 132 | self.assertEqual(58, o[1][0]) 133 | o = m.sv.most_similar(positive=0, indexable=sentences) 134 | self.assertEqual( 135 | "A basic phone which does what it is suppose to do , the thing i like most is the distinctive ring", 136 | o[0][0], 137 | ) 138 | 139 | o = m.sv.most_similar(positive=0, indexable=sent_ind) 140 | self.assertEqual( 141 | "A basic phone which does what it is suppose to do , the thing i like most is the distinctive ring".split(), 142 | o[0][0][0], 143 | ) 144 | 145 | def test_most_similar_vec(self): 146 | sentences = IndexedLineDocument(CORPUS) 147 | m = Average(W2V) 148 | m.train(sentences) 149 | m.sv.init_sims() 150 | v = m.sv.get_vector(0, use_norm=True) 151 | o = m.sv.most_similar(positive=v) 152 | self.assertEqual(0, o[0][0]) 153 | self.assertEqual(50, o[1][0]) 154 | self.assertEqual(58, o[2][0]) 155 | 156 | def test_most_similar_vectors(self): 157 | sentences = IndexedLineDocument(CORPUS) 158 | m = Average(W2V) 159 | m.train(sentences) 160 | m.sv.init_sims() 161 | v = m.sv[[0, 1]] 162 | o = m.sv.most_similar(positive=v) 163 | self.assertEqual(10, o[0][0]) 164 | self.assertEqual(11, o[1][0]) 165 | 166 | def test_most_similar_wrong_indexable(self): 167 | def indexable(self): 168 | pass 169 | 170 | sentences = IndexedLineDocument(CORPUS) 171 | m = Average(W2V) 172 | m.train(sentences) 173 | with self.assertRaises(RuntimeError): 174 | m.sv.most_similar(positive=0, indexable=indexable) 175 | 176 | def test_most_similar_topn(self): 177 | sentences = IndexedLineDocument(CORPUS) 178 | m = Average(W2V) 179 | m.train(sentences) 180 | o = m.sv.most_similar(positive=0, topn=20) 181 | self.assertEqual(20, len(o)) 182 | 183 | def test_most_similar_restrict_size(self): 184 | sentences = IndexedLineDocument(CORPUS) 185 | m = Average(W2V) 186 | m.train(sentences) 187 | o = m.sv.most_similar(positive=20, topn=20, restrict_size=5) 188 | self.assertEqual(5, len(o)) 189 | 190 | def test_most_similar_restrict_size_tuple(self): 191 | sentences = IndexedLineDocument(CORPUS) 192 | m = Average(W2V) 193 | m.train(sentences) 194 | o = m.sv.most_similar(positive=20, topn=20, restrict_size=(5, 25)) 195 | self.assertEqual(19, len(o)) 196 | self.assertEqual(22, o[0][0]) 197 | 198 | o = m.sv.most_similar(positive=1, topn=20, restrict_size=(5, 25)) 199 | self.assertEqual(20, len(o)) 200 | self.assertEqual(11, o[0][0]) 201 | 202 | o = m.sv.most_similar( 203 | positive=1, topn=20, restrict_size=(5, 25), indexable=sentences 204 | ) 205 | self.assertEqual(20, len(o)) 206 | self.assertEqual(11, o[0][1]) 207 | 208 | def test_similar_by_word(self): 209 | sentences = IndexedLineDocument(CORPUS) 210 | m = Average(W2V) 211 | m.train(sentences) 212 | o = m.sv.similar_by_word(word="the", wv=m.wv) 213 | self.assertEqual(5, o[0][0]) 214 | o = m.sv.similar_by_word(word="the", wv=m.wv, indexable=sentences) 215 | self.assertEqual(5, o[0][1]) 216 | 217 | def test_similar_by_vector(self): 218 | sentences = IndexedLineDocument(CORPUS) 219 | m = Average(W2V) 220 | m.train(sentences) 221 | o = m.sv.similar_by_vector(m.wv["the"]) 222 | self.assertEqual(5, o[0][0]) 223 | 224 | def test_similar_by_sentence(self): 225 | sentences = IndexedLineDocument(CORPUS) 226 | m = Average(W2V) 227 | m.train(sentences) 228 | o = m.sv.similar_by_sentence(sentence=["the", "product", "is", "good"], model=m) 229 | self.assertEqual(26, o[0][0]) 230 | 231 | def test_similar_by_sentence_wrong_model(self): 232 | sentences = IndexedLineDocument(CORPUS) 233 | m = Average(W2V) 234 | m.train(sentences) 235 | with self.assertRaises(RuntimeError): 236 | m.sv.similar_by_sentence( 237 | sentence=["the", "product", "is", "good"], model=W2V 238 | ) 239 | 240 | def test_l2_norm(self): 241 | out = np.random.normal(size=(200, 50)).astype(np.float32) 242 | result = _l2_norm(out, False) 243 | lens = np.sqrt(np.sum((result ** 2), axis=-1)) 244 | self.assertTrue(np.allclose(1, lens, atol=1e-6)) 245 | 246 | out = np.random.normal(size=(200, 50)).astype(np.float32) 247 | out = _l2_norm(out, True) 248 | lens = np.sqrt(np.sum((out ** 2), axis=-1)) 249 | self.assertTrue(np.allclose(1, lens, atol=1e-6)) 250 | 251 | 252 | if __name__ == "__main__": 253 | logging.basicConfig( 254 | format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG 255 | ) 256 | unittest.main() 257 | -------------------------------------------------------------------------------- /test/test_average.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | 7 | """Automated tests for checking the average model.""" 8 | 9 | import logging 10 | import unittest 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | from gensim.models import FastText, Word2Vec 15 | 16 | from fse.models.average import Average, train_average_np 17 | from fse.models.base_s2v import EPS 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | TEST_DATA = Path(__file__).parent / "test_data" 23 | CORPUS = TEST_DATA / "test_sentences.txt" 24 | 25 | 26 | DIM = 5 27 | W2V = Word2Vec(min_count=1, vector_size=DIM) 28 | with open(CORPUS, "r") as file: 29 | SENTENCES = [l.split() for _, l in enumerate(file)] 30 | W2V.build_vocab(SENTENCES) 31 | W2V.wv.vectors[:,] = np.arange( 32 | len(W2V.wv.vectors), dtype=np.float32 33 | )[:, None] 34 | 35 | 36 | class TestAverageFunctions(unittest.TestCase): 37 | def setUp(self): 38 | self.sentences = [ 39 | ["They", "admit"], 40 | ["So", "Apple", "bought", "buds"], 41 | ["go", "12345"], 42 | ["pull", "12345678910111213"], 43 | ] 44 | self.sentences = [(s, i) for i, s in enumerate(self.sentences)] 45 | self.model = Average(W2V) 46 | self.model.prep.prepare_vectors( 47 | sv=self.model.sv, total_sentences=len(self.sentences), update=False 48 | ) 49 | self.model._pre_train_calls() 50 | 51 | def test_cython(self): 52 | from fse.models.average_inner import ( 53 | FAST_VERSION, 54 | MAX_NGRAMS_IN_BATCH, 55 | MAX_WORDS_IN_BATCH, 56 | ) 57 | 58 | self.assertTrue(FAST_VERSION) 59 | self.assertEqual(10000, MAX_WORDS_IN_BATCH) 60 | self.assertEqual(40, MAX_NGRAMS_IN_BATCH) 61 | 62 | def test_average_train_np_w2v(self): 63 | self.model.sv.vectors = np.zeros_like(self.model.sv.vectors, dtype=np.float32) 64 | mem = self.model._get_thread_working_mem() 65 | output = train_average_np( 66 | self.model, self.sentences, self.model.sv.vectors, mem 67 | ) 68 | self.assertEqual((4, 7), output) 69 | self.assertTrue((179 == self.model.sv[0]).all()) 70 | self.assertTrue((140.75 == self.model.sv[1]).all()) 71 | self.assertTrue((self.model.wv.key_to_index["go"] == self.model.sv[2]).all()) 72 | 73 | def test_average_train_cy_w2v(self): 74 | self.model.sv.vectors = np.zeros_like(self.model.sv.vectors, dtype=np.float32) 75 | mem = self.model._get_thread_working_mem() 76 | 77 | from fse.models.average_inner import train_average_cy 78 | 79 | output = train_average_cy( 80 | self.model, self.sentences, self.model.sv.vectors, mem 81 | ) 82 | self.assertEqual((4, 7), output) 83 | self.assertTrue((179 == self.model.sv[0]).all()) 84 | self.assertTrue((140.75 == self.model.sv[1]).all()) 85 | self.assertTrue((self.model.wv.key_to_index["go"] == self.model.sv[2]).all()) 86 | 87 | def test_average_train_np_ft(self): 88 | ft = FastText(min_count=1, vector_size=DIM) 89 | ft.build_vocab(SENTENCES) 90 | m = Average(ft) 91 | m.prep.prepare_vectors( 92 | sv=m.sv, total_sentences=len(self.sentences), update=False 93 | ) 94 | m._pre_train_calls() 95 | m.wv.vectors = m.wv.vectors_vocab = np.ones_like(m.wv.vectors, dtype=np.float32) 96 | m.wv.vectors_ngrams = np.full_like(m.wv.vectors_ngrams, 2, dtype=np.float32) 97 | mem = m._get_thread_working_mem() 98 | output = train_average_np(m, self.sentences, m.sv.vectors, mem) 99 | self.assertEqual((4, 10), output) 100 | self.assertTrue((1.0 == m.sv[0]).all()) 101 | self.assertTrue((1.5 == m.sv[2]).all()) 102 | self.assertTrue((2 == m.sv[3]).all()) 103 | # "go" -> [1,1...] 104 | # oov: "12345" -> (14 hashes * 2) / 14 = 2 105 | # (2 + 1) / 2 = 1.5 106 | 107 | def test_average_train_cy_ft(self): 108 | ft = FastText(min_count=1, vector_size=DIM) 109 | ft.build_vocab(SENTENCES) 110 | m = Average(ft) 111 | m.prep.prepare_vectors( 112 | sv=m.sv, total_sentences=len(self.sentences), update=False 113 | ) 114 | m._pre_train_calls() 115 | m.wv.vectors = m.wv.vectors_vocab = np.ones_like(m.wv.vectors, dtype=np.float32) 116 | m.wv.vectors_ngrams = np.full_like(m.wv.vectors_ngrams, 2, dtype=np.float32) 117 | mem = m._get_thread_working_mem() 118 | 119 | from fse.models.average_inner import train_average_cy 120 | 121 | output = train_average_cy(m, self.sentences, m.sv.vectors, mem) 122 | self.assertEqual((4, 10), output) 123 | self.assertTrue((1.0 + EPS == m.sv[0]).all()) 124 | self.assertTrue(np.allclose(1.5, m.sv[2])) 125 | self.assertTrue(np.allclose(2, m.sv[3])) 126 | 127 | def test_cy_equal_np_w2v(self): 128 | m1 = Average(W2V) 129 | m1.prep.prepare_vectors( 130 | sv=m1.sv, total_sentences=len(self.sentences), update=False 131 | ) 132 | m1._pre_train_calls() 133 | mem1 = m1._get_thread_working_mem() 134 | o1 = train_average_np(m1, self.sentences, m1.sv.vectors, mem1) 135 | 136 | m2 = Average(W2V) 137 | m2.prep.prepare_vectors( 138 | sv=m2.sv, total_sentences=len(self.sentences), update=False 139 | ) 140 | m2._pre_train_calls() 141 | mem2 = m2._get_thread_working_mem() 142 | 143 | from fse.models.average_inner import train_average_cy 144 | 145 | o2 = train_average_cy(m2, self.sentences, m2.sv.vectors, mem2) 146 | 147 | self.assertEqual(o1, o2) 148 | self.assertTrue((m1.sv.vectors == m2.sv.vectors).all()) 149 | 150 | def test_cy_equal_np_w2v_random(self): 151 | w2v = Word2Vec(min_count=1, vector_size=DIM) 152 | # Random initialization 153 | w2v.build_vocab(SENTENCES) 154 | 155 | m1 = Average(w2v) 156 | m1.prep.prepare_vectors( 157 | sv=m1.sv, total_sentences=len(self.sentences), update=False 158 | ) 159 | m1._pre_train_calls() 160 | mem1 = m1._get_thread_working_mem() 161 | train_average_np(m1, self.sentences, m1.sv.vectors, mem1) 162 | 163 | m2 = Average(w2v) 164 | m2.prep.prepare_vectors( 165 | sv=m2.sv, total_sentences=len(self.sentences), update=False 166 | ) 167 | m2._pre_train_calls() 168 | mem2 = m2._get_thread_working_mem() 169 | 170 | from fse.models.average_inner import train_average_cy 171 | 172 | train_average_cy(m2, self.sentences, m2.sv.vectors, mem2) 173 | 174 | self.assertTrue(np.allclose(m1.sv.vectors, m2.sv.vectors, atol=1e-6)) 175 | 176 | def test_cy_equal_np_ft_random(self): 177 | ft = FastText(vector_size=20, min_count=1) 178 | ft.build_vocab(SENTENCES) 179 | 180 | m1 = Average(ft) 181 | m1.prep.prepare_vectors( 182 | sv=m1.sv, total_sentences=len(self.sentences), update=False 183 | ) 184 | m1._pre_train_calls() 185 | 186 | from fse.models.average_inner import MAX_NGRAMS_IN_BATCH 187 | 188 | m1.batch_ngrams = MAX_NGRAMS_IN_BATCH 189 | mem1 = m1._get_thread_working_mem() 190 | o1 = train_average_np(m1, self.sentences[:2], m1.sv.vectors, mem1) 191 | 192 | m2 = Average(ft) 193 | m2.prep.prepare_vectors( 194 | sv=m2.sv, total_sentences=len(self.sentences), update=False 195 | ) 196 | m2._pre_train_calls() 197 | mem2 = m2._get_thread_working_mem() 198 | 199 | from fse.models.average_inner import train_average_cy 200 | 201 | o2 = train_average_cy(m2, self.sentences[:2], m2.sv.vectors, mem2) 202 | 203 | self.assertEqual(o1, o2) 204 | self.assertTrue(np.allclose(m1.sv.vectors, m2.sv.vectors, atol=1e-6)) 205 | 206 | def test_do_train_job(self): 207 | self.model.prep.prepare_vectors( 208 | sv=self.model.sv, total_sentences=len(SENTENCES), update=True 209 | ) 210 | mem = self.model._get_thread_working_mem() 211 | self.assertEqual( 212 | (100, 1450), 213 | self.model._do_train_job( 214 | [(s, i) for i, s in enumerate(SENTENCES)], 215 | target=self.model.sv.vectors, 216 | memory=mem, 217 | ), 218 | ) 219 | self.assertEqual((104, DIM), self.model.sv.vectors.shape) 220 | 221 | def test_train(self): 222 | self.assertEqual( 223 | (100, 1450), self.model.train([(s, i) for i, s in enumerate(SENTENCES)]) 224 | ) 225 | 226 | def test_train_single_from_disk(self): 227 | p = TEST_DATA / "test_vecs" 228 | p_res = TEST_DATA / "test_vecs.vectors" 229 | p_target = TEST_DATA / "test_vecs_wv.vectors" 230 | 231 | se1 = Average(W2V) 232 | se2 = Average( 233 | W2V, sv_mapfile_path=str(p.absolute()), wv_mapfile_path=str(p.absolute()) 234 | ) 235 | se1.train([(s, i) for i, s in enumerate(SENTENCES)]) 236 | se2.train([(s, i) for i, s in enumerate(SENTENCES)]) 237 | 238 | self.assertTrue(p_target.exists()) 239 | self.assertTrue((se1.wv.vectors == se2.wv.vectors).all()) 240 | self.assertFalse(se2.wv.vectors.flags.writeable) 241 | 242 | self.assertTrue((se1.sv.vectors == se2.sv.vectors).all()) 243 | p_res.unlink() 244 | p_target.unlink() 245 | 246 | def test_train_multi_from_disk(self): 247 | p = TEST_DATA / "test_vecs" 248 | p_res = TEST_DATA / "test_vecs.vectors" 249 | p_target = TEST_DATA / "test_vecs_wv.vectors" 250 | 251 | se1 = Average(W2V, workers=2) 252 | se2 = Average( 253 | W2V, 254 | workers=2, 255 | sv_mapfile_path=str(p.absolute()), 256 | wv_mapfile_path=str(p.absolute()), 257 | ) 258 | se1.train([(s, i) for i, s in enumerate(SENTENCES)]) 259 | se2.train([(s, i) for i, s in enumerate(SENTENCES)]) 260 | 261 | self.assertTrue(p_target.exists()) 262 | self.assertTrue((se1.wv.vectors == se2.wv.vectors).all()) 263 | self.assertFalse(se2.wv.vectors.flags.writeable) 264 | 265 | self.assertTrue((se1.sv.vectors == se2.sv.vectors).all()) 266 | p_res.unlink() 267 | p_target.unlink() 268 | 269 | def test_check_parameter_sanity(self): 270 | se = Average(W2V) 271 | se.word_weights = np.full(20, 2.0, dtype=np.float32) 272 | with self.assertRaises(ValueError): 273 | se._check_parameter_sanity() 274 | 275 | 276 | if __name__ == "__main__": 277 | logging.basicConfig( 278 | format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG 279 | ) 280 | unittest.main() 281 | -------------------------------------------------------------------------------- /fse/models/average.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | 7 | """This module implements the base class to compute average representations for sentences, using highly optimized C routines, 8 | data streaming and Pythonic interfaces. 9 | 10 | The implementation is based on Iyyer et al. (2015): Deep Unordered Composition Rivals Syntactic Methods for Text Classification. 11 | For more information, see . 12 | 13 | The training algorithms is based on the Gensim implementation of Word2Vec, FastText, and Doc2Vec. 14 | For more information, see: :class:`~gensim.models.word2vec.Word2Vec`, :class:`~gensim.models.fasttext.FastText`, or 15 | :class:`~gensim.models.doc2vec.Doc2Vec`. 16 | 17 | Initialize and train a :class:`~fse.models.sentence2vec.Sentence2Vec` model 18 | 19 | .. sourcecode:: pycon 20 | 21 | >>> from gensim.models.word2vec import Word2Vec 22 | >>> sentences = [["cat", "say", "meow"], ["dog", "say", "woof"]] 23 | >>> model = Word2Vec(sentences, min_count=1, vector_size=20) 24 | 25 | >>> from fse.models.average import Average 26 | >>> avg = Average(model) 27 | >>> avg.train([(s, i) for i, s in enumerate(sentences)]) 28 | >>> avg.sv.vectors.shape 29 | (2, 20) 30 | 31 | """ 32 | 33 | from __future__ import division 34 | 35 | from fse.models.base_s2v import BaseSentence2VecModel 36 | 37 | from gensim.models.keyedvectors import KeyedVectors 38 | from gensim.models.fasttext import ft_ngram_hashes 39 | 40 | from numpy import ( 41 | ndarray, 42 | float32 as REAL, 43 | sum as np_sum, 44 | multiply as np_mult, 45 | zeros, 46 | max as np_max, 47 | ) 48 | 49 | from typing import List, Tuple 50 | 51 | import logging 52 | 53 | logger = logging.getLogger(__name__) 54 | 55 | 56 | def train_average_np( 57 | model: BaseSentence2VecModel, 58 | indexed_sentences: List[tuple], 59 | target: ndarray, 60 | memory: ndarray, 61 | ) -> Tuple[int, int]: 62 | """Training on a sequence of sentences and update the target ndarray. 63 | 64 | Called internally from :meth:`~fse.models.average.Average._do_train_job`. 65 | 66 | Warnings 67 | -------- 68 | This is the non-optimized, pure Python version. If you have a C compiler, 69 | fse will use an optimized code path from :mod:`fse.models.average_inner` instead. 70 | 71 | Parameters 72 | ---------- 73 | model : :class:`~fse.models.base_s2v.BaseSentence2VecModel` 74 | The BaseSentence2VecModel model instance. 75 | indexed_sentences : iterable of tuple 76 | The sentences used to train the model. 77 | target : ndarray 78 | The target ndarray. We use the index from indexed_sentences 79 | to write into the corresponding row of target. 80 | memory : ndarray 81 | Private memory for each working thread 82 | 83 | Returns 84 | ------- 85 | int, int 86 | Number of effective sentences (non-zero) and effective words in the vocabulary used 87 | during training the sentence embedding. 88 | 89 | """ 90 | size = model.wv.vector_size 91 | 92 | w_vectors = model.wv.vectors 93 | w_weights = model.word_weights 94 | 95 | s_vectors = target 96 | 97 | is_ft = model.is_ft 98 | 99 | mem = memory[0] 100 | 101 | if is_ft: 102 | # NOTE: For Fasttext: Use wv.vectors_vocab 103 | # Using the wv.vectors from fasttext had horrible effects on the sts results 104 | # I suspect this is because the wv.vectors are based on the averages of 105 | # wv.vectors_vocab + wv.vectors_ngrams, which will all point into very 106 | # similar directions. 107 | max_ngrams = model.batch_ngrams 108 | w_vectors = model.wv.vectors_vocab 109 | ngram_vectors = model.wv.vectors_ngrams 110 | min_n = model.wv.min_n 111 | max_n = model.wv.max_n 112 | bucket = model.wv.bucket 113 | oov_weight = np_max(w_weights) 114 | 115 | eff_sentences, eff_words = 0, 0 116 | 117 | if not is_ft: 118 | for obj in indexed_sentences: 119 | mem.fill(0.0) 120 | sent = obj[0] 121 | sent_adr = obj[1] 122 | 123 | word_indices = [ 124 | model.wv.key_to_index[word] 125 | for word in sent 126 | if word in model.wv.key_to_index 127 | ] 128 | eff_sentences += 1 129 | if not len(word_indices): 130 | continue 131 | eff_words += len(word_indices) 132 | 133 | mem += np_sum( 134 | np_mult(w_vectors[word_indices], w_weights[word_indices][:, None]), 135 | axis=0, 136 | ) 137 | mem *= 1 / len(word_indices) 138 | s_vectors[sent_adr] = mem.astype(REAL) 139 | else: 140 | for obj in indexed_sentences: 141 | mem.fill(0.0) 142 | sent = obj[0] 143 | sent_adr = obj[1] 144 | 145 | if not len(sent): 146 | continue 147 | mem = zeros(size, dtype=REAL) 148 | 149 | eff_sentences += 1 150 | eff_words += len(sent) # Counts everything in the sentence 151 | 152 | for word in sent: 153 | if word in model.wv.key_to_index: 154 | word_index = model.wv.key_to_index[word] 155 | mem += w_vectors[word_index] * w_weights[word_index] 156 | else: 157 | ngram_hashes = ft_ngram_hashes(word, min_n, max_n, bucket)[ 158 | :max_ngrams 159 | ] 160 | if len(ngram_hashes) == 0: 161 | continue 162 | mem += oov_weight * ( 163 | np_sum(ngram_vectors[ngram_hashes], axis=0) / len(ngram_hashes) 164 | ) 165 | # Implicit addition of zero if oov does not contain any ngrams 166 | s_vectors[sent_adr] = mem / len(sent) 167 | 168 | return eff_sentences, eff_words 169 | 170 | 171 | try: 172 | from fse.models.average_inner import train_average_cy 173 | from fse.models.average_inner import ( 174 | FAST_VERSION, 175 | MAX_WORDS_IN_BATCH, 176 | MAX_NGRAMS_IN_BATCH, 177 | ) 178 | 179 | train_average = train_average_cy 180 | except ImportError: 181 | FAST_VERSION = -1 182 | MAX_WORDS_IN_BATCH = 10000 183 | MAX_NGRAMS_IN_BATCH = 40 184 | train_average = train_average_np 185 | 186 | 187 | class Average(BaseSentence2VecModel): 188 | """Train, use and evaluate averaged sentence vectors. 189 | 190 | The model can be stored/loaded via its :meth:`~fse.models.average.Average.save` and 191 | :meth:`~fse.models.average.Average.load` methods. 192 | 193 | Some important attributes are the following: 194 | 195 | Attributes 196 | ---------- 197 | wv : :class:`~gensim.models.keyedvectors.KeyedVectors` 198 | This object essentially contains the mapping between words and embeddings. After training, it can be used 199 | directly to query those embeddings in various ways. See the module level docstring for examples. 200 | 201 | sv : :class:`~fse.models.sentencevectors.SentenceVectors` 202 | This object contains the sentence vectors inferred from the training data. There will be one such vector 203 | for each unique docusentence supplied during training. They may be individually accessed using the index. 204 | 205 | prep : :class:`~fse.models.base_s2v.BaseSentence2VecPreparer` 206 | The prep object is used to transform and initialize the sv.vectors. Aditionally, it can be used 207 | to move the vectors to disk for training with memmap. 208 | 209 | """ 210 | 211 | def __init__( 212 | self, 213 | model: KeyedVectors, 214 | sv_mapfile_path: str = None, 215 | wv_mapfile_path: str = None, 216 | workers: int = 1, 217 | lang_freq: str = None, 218 | **kwargs 219 | ): 220 | """Average (unweighted) sentence embeddings model. Performs a simple averaging operation over all 221 | words in a sentences without further transformation. 222 | 223 | The implementation is based on Iyyer et al. (2015): Deep Unordered Composition Rivals Syntactic Methods for Text Classification. 224 | For more information, see . 225 | 226 | Parameters 227 | ---------- 228 | model : :class:`~gensim.models.keyedvectors.KeyedVectors` or :class:`~gensim.models.base_any2vec.BaseWordEmbeddingsModel` 229 | This object essentially contains the mapping between words and embeddings. To compute the sentence embeddings 230 | the wv.vocab and wv.vector elements are required. 231 | sv_mapfile_path : str, optional 232 | Optional path to store the sentence-vectors in for very large datasets. Used for memmap. 233 | wv_mapfile_path : str, optional 234 | Optional path to store the word-vectors in for very large datasets. Used for memmap. 235 | Use sv_mapfile_path and wv_mapfile_path to train disk-to-disk without needing much ram. 236 | workers : int, optional 237 | Number of working threads, used for multithreading. For most tasks (few words in a sentence) 238 | a value of 1 should be more than enough. 239 | lang_freq : str, optional 240 | Some pre-trained embeddings, i.e. "GoogleNews-vectors-negative300.bin", do not contain information about 241 | the frequency of a word. As the frequency is required for estimating the word weights, we induce 242 | frequencies into the wv.vocab.count based on :class:`~wordfreq` 243 | If no frequency information is available, you can choose the language to estimate the frequency. 244 | See https://github.com/LuminosoInsight/wordfreq 245 | 246 | """ 247 | 248 | super(Average, self).__init__( 249 | model=model, 250 | sv_mapfile_path=sv_mapfile_path, 251 | wv_mapfile_path=wv_mapfile_path, 252 | workers=workers, 253 | lang_freq=lang_freq, 254 | batch_words=MAX_WORDS_IN_BATCH, 255 | batch_ngrams=MAX_NGRAMS_IN_BATCH, 256 | fast_version=FAST_VERSION, 257 | ) 258 | 259 | def _do_train_job( 260 | self, data_iterable: List[tuple], target: ndarray, memory: ndarray 261 | ) -> Tuple[int, int]: 262 | """ Internal routine which is called on training and performs averaging for all entries in the iterable """ 263 | eff_sentences, eff_words = train_average( 264 | model=self, indexed_sentences=data_iterable, target=target, memory=memory 265 | ) 266 | return eff_sentences, eff_words 267 | 268 | def _check_parameter_sanity(self, **kwargs): 269 | """ Check the sanity of all child paramters """ 270 | if not all(self.word_weights == 1.0): 271 | raise ValueError("All word weights must equal one for averaging") 272 | 273 | def _pre_train_calls(self, **kwargs): 274 | """Function calls to perform before training """ 275 | pass 276 | 277 | def _post_train_calls(self, **kwargs): 278 | """ Function calls to perform after training, such as computing eigenvectors """ 279 | pass 280 | 281 | def _post_inference_calls(self, **kwargs): 282 | """Function calls to perform after training & inference 283 | Examples include the removal of components 284 | """ 285 | pass 286 | 287 | def _check_dtype_santiy(self, **kwargs): 288 | """ Check the dtypes of all child attributes""" 289 | pass 290 | -------------------------------------------------------------------------------- /fse/models/average_inner.pyx: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env cython 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: cdivision=True 5 | # cython: embedsignature=True 6 | # coding: utf-8 7 | 8 | # Author: Oliver Borchers 9 | # Copyright (C) Oliver Borchers 10 | 11 | """Optimized cython functions for computing sentence embeddings""" 12 | 13 | import cython 14 | import numpy as np 15 | 16 | cimport numpy as np 17 | 18 | from gensim.models.fasttext import compute_ngrams_bytes, ft_hash_bytes 19 | 20 | from libc.string cimport memset 21 | from libc.stdio cimport printf 22 | 23 | import scipy.linalg.blas as fblas 24 | 25 | cdef saxpy_ptr saxpy=PyCObject_AsVoidPtr(fblas.saxpy._cpointer) # y += alpha * x 26 | cdef sscal_ptr sscal=PyCObject_AsVoidPtr(fblas.sscal._cpointer) # x = alpha * x 27 | 28 | cdef int ONE = 1 29 | cdef int ZERO = 0 30 | 31 | cdef REAL_t ONEF = 1.0 32 | cdef REAL_t ZEROF = 0.0 33 | 34 | DEF MAX_WORDS = 10000 35 | DEF MAX_NGRAMS = 40 36 | 37 | cdef init_base_s2v_config(BaseSentenceVecsConfig *c, model, target, memory): 38 | """Load BaseAny2Vec parameters into a BaseSentenceVecsConfig struct. 39 | 40 | Parameters 41 | ---------- 42 | c : FTSentenceVecsConfig * 43 | A pointer to the struct to initialize. 44 | model : fse.models.base_s2v.BaseSentence2VecModel 45 | The model to load. 46 | target : np.ndarray 47 | The target array to write the averages to. 48 | memory : np.ndarray 49 | Private working memory for each worker. 50 | Consists of 2 nd.arrays. 51 | 52 | """ 53 | c[0].workers = model.workers 54 | c[0].size = model.sv.vector_size 55 | 56 | c[0].mem = (np.PyArray_DATA(memory[0])) 57 | 58 | c[0].word_vectors = (np.PyArray_DATA(model.wv.vectors)) 59 | c[0].word_weights = (np.PyArray_DATA(model.word_weights)) 60 | 61 | c[0].sentence_vectors = (np.PyArray_DATA(target)) 62 | 63 | cdef init_ft_s2v_config(FTSentenceVecsConfig *c, model, target, memory): 64 | """Load Fasttext parameters into a FTSentenceVecsConfig struct. 65 | 66 | Parameters 67 | ---------- 68 | c : FTSentenceVecsConfig * 69 | A pointer to the struct to initialize. 70 | model : fse.models.base_s2v.BaseSentence2VecModel 71 | The model to load. 72 | target : np.ndarray 73 | The target array to write the averages to. 74 | memory : np.ndarray 75 | Private working memory for each worker. 76 | Consists of 2 nd.arrays. 77 | 78 | """ 79 | 80 | c[0].workers = model.workers 81 | c[0].size = model.sv.vector_size 82 | c[0].min_n = model.wv.min_n 83 | c[0].max_n = model.wv.max_n 84 | c[0].bucket = model.wv.bucket 85 | 86 | c[0].oov_weight = np.max(model.word_weights) 87 | 88 | c[0].mem = (np.PyArray_DATA(memory[0])) 89 | 90 | memory[1].fill(ZERO) # Reset the ngram storage before filling the struct 91 | c[0].subwords_idx = (np.PyArray_DATA(memory[1])) 92 | 93 | c[0].word_vectors = (np.PyArray_DATA(model.wv.vectors_vocab)) 94 | c[0].ngram_vectors = (np.PyArray_DATA(model.wv.vectors_ngrams)) 95 | c[0].word_weights = (np.PyArray_DATA(model.word_weights)) 96 | 97 | c[0].sentence_vectors = (np.PyArray_DATA(target)) 98 | 99 | cdef object populate_base_s2v_config(BaseSentenceVecsConfig *c, wv, indexed_sentences): 100 | """Prepare C structures for BaseAny2VecModel so we can go "full C" and release the Python GIL. 101 | 102 | We create indices over the sentences. We also perform some calculations for 103 | each token/ngram and store the result up front to save time. 104 | 105 | Parameters 106 | ---------- 107 | c : BaseSentenceVecsConfig* 108 | A pointer to the struct that will contain the populated indices. 109 | wv : obj 110 | The word vector object 111 | indexed_sentences : iterable of tuple 112 | The sentences to read 113 | 114 | Returns 115 | ------- 116 | eff_words : int 117 | The number of in-vocabulary tokens. 118 | eff_sents : int 119 | The number of non-empty sentences. 120 | 121 | """ 122 | 123 | cdef uINT_t eff_words = ZERO # Effective words encountered in a sentence 124 | cdef uINT_t eff_sents = ZERO # Effective sentences encountered 125 | 126 | c.sentence_boundary[0] = ZERO 127 | 128 | for obj in indexed_sentences: 129 | if not obj[0]: 130 | continue 131 | for token in obj[0]: 132 | word = token if token in wv.key_to_index else None 133 | if word is None: 134 | continue 135 | c.word_indices[eff_words] = wv.key_to_index[token] 136 | c.sent_adresses[eff_words] = obj[1] 137 | 138 | eff_words += ONE 139 | if eff_words == MAX_WORDS: 140 | break 141 | eff_sents += 1 142 | c.sentence_boundary[eff_sents] = eff_words 143 | 144 | if eff_words == MAX_WORDS: 145 | break 146 | 147 | return eff_sents, eff_words 148 | 149 | cdef object populate_ft_s2v_config(FTSentenceVecsConfig *c, wv, indexed_sentences): 150 | """Prepare C structures for FastText so we can go "full C" and release the Python GIL. 151 | 152 | We create indices over the sentences. We also perform some calculations for 153 | each token/ngram and store the result up front to save time. 154 | 155 | Parameters 156 | ---------- 157 | c : FTSentenceVecsConfig* 158 | A pointer to the struct that will contain the populated indices. 159 | wv : obj 160 | The word vector object 161 | indexed_sentences : iterable of tuples 162 | The sentences to read 163 | 164 | Returns 165 | ------- 166 | eff_words : int 167 | The number of in-vocabulary tokens. 168 | eff_sents : int 169 | The number of non-empty sentences. 170 | 171 | """ 172 | 173 | cdef uINT_t eff_words = ZERO # Effective words encountered in a sentence 174 | cdef uINT_t eff_sents = ZERO # Effective sentences encountered 175 | 176 | c.sentence_boundary[0] = ZERO 177 | 178 | for obj in indexed_sentences: 179 | if not obj[0]: 180 | continue 181 | for token in obj[0]: 182 | c.sent_adresses[eff_words] = obj[1] 183 | if token in wv.key_to_index: 184 | # In Vocabulary 185 | c.word_indices[eff_words] = wv.key_to_index[token] 186 | c.subwords_idx_len[eff_words] = ZERO 187 | else: 188 | # OOV words --> write ngram indices to memory 189 | c.word_indices[eff_words] = ZERO 190 | 191 | encoded_ngrams = compute_ngrams_bytes(token, c.min_n, c.max_n) 192 | hashes = [ft_hash_bytes(n) % c.bucket for n in encoded_ngrams] 193 | 194 | c.subwords_idx_len[eff_words] = min(len(encoded_ngrams), MAX_NGRAMS) 195 | for i, h in enumerate(hashes[:MAX_NGRAMS]): 196 | c.subwords_idx[(eff_words * MAX_NGRAMS) + i] = h 197 | 198 | eff_words += ONE 199 | 200 | if eff_words == MAX_WORDS: 201 | break 202 | 203 | eff_sents += 1 204 | c.sentence_boundary[eff_sents] = eff_words 205 | 206 | if eff_words == MAX_WORDS: 207 | break 208 | 209 | return eff_sents, eff_words 210 | 211 | cdef void compute_base_sentence_averages(BaseSentenceVecsConfig *c, uINT_t num_sentences) nogil: 212 | """Perform optimized sentence-level averaging for BaseAny2Vec model. 213 | 214 | Parameters 215 | ---------- 216 | c : BaseSentenceVecsConfig * 217 | A pointer to a fully initialized and populated struct. 218 | num_sentences : uINT_t 219 | The number of sentences used to train the model. 220 | 221 | Notes 222 | ----- 223 | This routine does not provide oov support. 224 | 225 | """ 226 | cdef: 227 | int size = c.size 228 | 229 | uINT_t sent_idx, sent_start, sent_end, sent_row 230 | 231 | uINT_t i, word_idx, word_row 232 | 233 | REAL_t sent_len, inv_count 234 | 235 | for sent_idx in range(num_sentences): 236 | memset(c.mem, 0, size * cython.sizeof(REAL_t)) 237 | 238 | sent_start = c.sentence_boundary[sent_idx] 239 | sent_end = c.sentence_boundary[sent_idx + 1] 240 | sent_len = ZEROF 241 | 242 | for i in range(sent_start, sent_end): 243 | sent_len += ONEF 244 | sent_row = c.sent_adresses[i] * size 245 | word_row = c.word_indices[i] * size 246 | word_idx = c.word_indices[i] 247 | 248 | saxpy(&size, &c.word_weights[word_idx], &c.word_vectors[word_row], &ONE, c.mem, &ONE) 249 | 250 | if sent_len > ZEROF: 251 | inv_count = ONEF / sent_len 252 | # If we perform the a*x on memory, the computation is compatible with many-to-one mappings 253 | # because it doesn't rescale the overall result 254 | saxpy(&size, &inv_count, c.mem, &ONE, &c.sentence_vectors[sent_row], &ONE) 255 | 256 | cdef void compute_ft_sentence_averages(FTSentenceVecsConfig *c, uINT_t num_sentences) nogil: 257 | """Perform optimized sentence-level averaging for FastText model. 258 | 259 | Parameters 260 | ---------- 261 | c : FTSentenceVecsConfig * 262 | A pointer to a fully initialized and populated struct. 263 | num_sentences : uINT_t 264 | The number of sentences used to train the model. 265 | 266 | Notes 267 | ----- 268 | This routine DOES provide oov support. 269 | 270 | """ 271 | cdef: 272 | int size = c.size 273 | 274 | uINT_t sent_idx, sent_start, sent_end, sent_row 275 | 276 | uINT_t ngram_row, ngrams 277 | 278 | uINT_t i, j, word_idx, word_row 279 | 280 | REAL_t sent_len 281 | REAL_t inv_count, inv_ngram 282 | REAL_t oov_weight = c.oov_weight 283 | 284 | 285 | for sent_idx in range(num_sentences): 286 | memset(c.mem, 0, size * cython.sizeof(REAL_t)) 287 | sent_start = c.sentence_boundary[sent_idx] 288 | sent_end = c.sentence_boundary[sent_idx + 1] 289 | sent_len = ZEROF 290 | 291 | for i in range(sent_start, sent_end): 292 | sent_len += ONEF 293 | sent_row = c.sent_adresses[i] * size 294 | 295 | word_idx = c.word_indices[i] 296 | ngrams = c.subwords_idx_len[i] 297 | 298 | if ngrams == 0: 299 | word_row = c.word_indices[i] * size 300 | saxpy(&size, &c.word_weights[word_idx], &c.word_vectors[word_row], &ONE, c.mem, &ONE) 301 | else: 302 | inv_ngram = (ONEF / ngrams) * c.oov_weight 303 | for j in range(ngrams): 304 | ngram_row = c.subwords_idx[(i * MAX_NGRAMS)+j] * size 305 | saxpy(&size, &inv_ngram, &c.ngram_vectors[ngram_row], &ONE, c.mem, &ONE) 306 | 307 | if sent_len > ZEROF: 308 | inv_count = ONEF / sent_len 309 | saxpy(&size, &inv_count, c.mem, &ONE, &c.sentence_vectors[sent_row], &ONE) 310 | 311 | def train_average_cy(model, indexed_sentences, target, memory): 312 | """Training on a sequence of sentences and update the target ndarray. 313 | 314 | Called internally from :meth:`~fse.models.average.Average._do_train_job`. 315 | 316 | Parameters 317 | ---------- 318 | model : :class:`~fse.models.base_s2v.BaseSentence2VecModel` 319 | The BaseSentence2VecModel model instance. 320 | indexed_sentences : iterable of tuple 321 | The sentences used to train the model. 322 | target : ndarray 323 | The target ndarray. We use the index from indexed_sentences 324 | to write into the corresponding row of target. 325 | memory : ndarray 326 | Private memory for each working thread. 327 | 328 | Returns 329 | ------- 330 | int, int 331 | Number of effective sentences (non-zero) and effective words in the vocabulary used 332 | during training the sentence embedding. 333 | """ 334 | 335 | cdef uINT_t eff_sentences = 0 336 | cdef uINT_t eff_words = 0 337 | cdef BaseSentenceVecsConfig w2v 338 | cdef FTSentenceVecsConfig ft 339 | 340 | if not model.is_ft: 341 | init_base_s2v_config(&w2v, model, target, memory) 342 | 343 | eff_sentences, eff_words = populate_base_s2v_config(&w2v, model.wv, indexed_sentences) 344 | 345 | with nogil: 346 | compute_base_sentence_averages(&w2v, eff_sentences) 347 | else: 348 | init_ft_s2v_config(&ft, model, target, memory) 349 | 350 | eff_sentences, eff_words = populate_ft_s2v_config(&ft, model.wv, indexed_sentences) 351 | 352 | with nogil: 353 | compute_ft_sentence_averages(&ft, eff_sentences) 354 | 355 | return eff_sentences, eff_words 356 | 357 | def init(): 358 | return 1 359 | 360 | MAX_WORDS_IN_BATCH = MAX_WORDS 361 | MAX_NGRAMS_IN_BATCH = MAX_NGRAMS 362 | FAST_VERSION = init() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Build Status 3 | Coverage Status 4 | Downloads 5 | Language grade: Python 6 | Code style: black 7 | License: GPL3 8 |

9 |

10 | fse 11 |

12 | 13 | Fast Sentence Embeddings 14 | ================================== 15 | 16 | Fast Sentence Embeddings is a Python library that serves as an addition to Gensim. This library is intended to compute *sentence vectors* for large collections of sentences or documents with as little hassle as possible: 17 | 18 | ``` 19 | from fse import Vectors, Average, IndexedList 20 | 21 | vecs = Vectors.from_pretrained("glove-wiki-gigaword-50") 22 | model = Average(vecs) 23 | 24 | sentences = [["cat", "say", "meow"], ["dog", "say", "woof"]] 25 | 26 | model.train(IndexedList(sentences)) 27 | 28 | model.sv.similarity(0,1) 29 | ``` 30 | 31 | If you want to support fse, take a quick [survey](https://forms.gle/8uSU323fWUVtVwcAA) to improve it. 32 | 33 | Audience 34 | ------------ 35 | 36 | This package builds upon Gensim and is intenteded to compute sentence/paragraph vectors for large databases. Use this package if: 37 | - (Sentence) Transformers are too slow 38 | - Your dataset is too large for existing solutions (spacy) 39 | - Using GPUs is not an option. 40 | 41 | The average (online) inference time for a well optimized (and batched) sentence-transformer is around 1ms-10ms per sentence. If that is not enough and you are willing to sacrifice a bit in terms of quality, this is your package. 42 | 43 | Features 44 | ------------ 45 | 46 | Find the corresponding blog post(s) here (code may be outdated): 47 | 48 | - [Visualizing 100,000 Amazon Products](https://towardsdatascience.com/vis-amz-83dea6fcb059) 49 | - [Sentence Embeddings. Fast, please!](https://towardsdatascience.com/fse-2b1ffa791cf9) 50 | 51 | **fse** implements three algorithms for sentence embeddings. You can choose 52 | between *unweighted sentence averages*, *smooth inverse frequency averages*, and *unsupervised smooth inverse frequency averages*. 53 | 54 | Key features of **fse** are: 55 | 56 | **[X]** Up to 500.000 sentences / second (1) 57 | 58 | **[X]** Provides HUB access to various pre-trained models for convenience 59 | 60 | **[X]** Supports Average, SIF, and uSIF Embeddings 61 | 62 | **[X]** Full support for Gensims Word2Vec and all other compatible classes 63 | 64 | **[X]** Full support for Gensims FastText with out-of-vocabulary words 65 | 66 | **[X]** Induction of word frequencies for pre-trained embeddings 67 | 68 | **[X]** Incredibly fast Cython core routines 69 | 70 | **[X]** Dedicated input file formats for easy usage (including disk streaming) 71 | 72 | **[X]** Ram-to-disk training for large corpora 73 | 74 | **[X]** Disk-to-disk training for even larger corpora 75 | 76 | **[X]** Many fail-safe checks for easy usage 77 | 78 | **[X]** Simple interface for developing your own models 79 | 80 | **[X]** Extensive documentation of all functions 81 | 82 | **[X]** Optimized Input Classes 83 | 84 | (1) May vary significantly from system to system (i.e. by using swap memory) and processing. 85 | I regularly observe 300k-500k sentences/s for preprocessed data on my Macbook (2016). 86 | Visit **Tutorial.ipynb** for an example. 87 | 88 | 89 | Installation 90 | ------------ 91 | 92 | This software depends on NumPy, Scipy, Scikit-learn, Gensim, and Wordfreq. 93 | You must have them installed prior to installing fse. 94 | 95 | As with gensim, it is also recommended you install a BLAS library before installing fse. 96 | 97 | The simple way to install **fse** is: 98 | 99 | pip install -U fse 100 | 101 | In case you want to build from source, just run: 102 | 103 | python setup.py install 104 | 105 | If building the Cython extension fails (you will be notified), try: 106 | 107 | pip install -U git+https://github.com/oborchers/Fast_Sentence_Embeddings 108 | 109 | Usage 110 | ------------- 111 | 112 | Using pre-trained models with **fse** is easy. You can just use them from the hub and download them accordingly. 113 | They will be stored locally so you can re-use them later. 114 | 115 | ``` 116 | from fse import Vectors, Average, IndexedList 117 | vecs = Vectors.from_pretrained("glove-wiki-gigaword-50") 118 | model = Average(vecs) 119 | 120 | sentences = [["cat", "say", "meow"], ["dog", "say", "woof"]] 121 | 122 | model.train(IndexedList(sentences)) 123 | 124 | model.sv.similarity(0,1) 125 | ``` 126 | 127 | If your vectors are large and you don't have a lot of RAM, you can supply the `mmap` argument as follows to read the vectors from disk instead of loading them into RAM: 128 | 129 | ``` 130 | Vectors.from_pretrained("glove-wiki-gigaword-50", mmap="r") 131 | ``` 132 | 133 | To check which vectors are on the hub, please check: https://huggingface.co/fse. For example, you will find: 134 | - glove-twitter-25 135 | - glove-twitter-50 136 | - glove-twitter-100 137 | - glove-twitter-200 138 | - glove-wiki-gigaword-100 139 | - glove-wiki-gigaword-300 140 | - word2vec-google-news-300 141 | - paragram-25 142 | - paranmt-300 143 | - paragram-300-sl999 144 | - paragram-300-ws353 145 | - fasttext-wiki-news-subwords-300 146 | - fasttext-crawl-subwords-300 (Use with `FTVectors`) 147 | 148 | In order to use **fse** with a custom model you must first estimate a Gensim model which contains a 149 | gensim.models.keyedvectors.BaseKeyedVectors class, for example *Word2Vec* or *Fasttext*. Then you can proceed to compute sentence embeddings for a corpus as follows: 150 | 151 | ``` 152 | from gensim.models import FastText 153 | sentences = [["cat", "say", "meow"], ["dog", "say", "woof"]] 154 | ft = FastText(sentences, min_count=1, vector_size=10) 155 | 156 | from fse import Average, IndexedList 157 | model = Average(ft) 158 | model.train(IndexedList(sentences)) 159 | 160 | model.sv.similarity(0,1) 161 | ``` 162 | 163 | fse offers multi-thread support out of the box. However, for most applications a *single thread will most likely be sufficient*. 164 | 165 | Additional Information 166 | ------------- 167 | 168 | Within the folder nootebooks you can find the following guides: 169 | 170 | **Tutorial.ipynb** offers a detailed walk-through of some of the most important functions fse has to offer. 171 | 172 | **STS-Benchmarks.ipynb** contains an example of how to use the library with pre-trained models to 173 | replicate the STS Benchmark results [4] reported in the papers. 174 | 175 | **Speed Comparision.ipynb** compares the speed between the numpy and the cython routines. 176 | 177 | In order to use the **fse** model, you first need some pre-trained gensim 178 | word embedding model, which is then used by **fse** to compute the sentence embeddings. 179 | 180 | After computing sentence embeddings, you can use them in supervised or 181 | unsupervised NLP applications, as they serve as a formidable baseline. 182 | 183 | The models presented are based on 184 | - Deep-averaging embeddings [1] 185 | - Smooth inverse frequency embeddings [2] 186 | - Unsupervised smooth inverse frequency embeddings [3] 187 | 188 | Credits to Radim Řehůřek and all contributors for the **awesome** library 189 | and code that [Gensim](https://github.com/RaRe-Technologies/gensim) provides. A whole lot of the code found in this lib is based on Gensim. 190 | 191 | To install **fse** on Colab, check out: https://colab.research.google.com/drive/1qq9GBgEosG7YSRn7r6e02T9snJb04OEi 192 | 193 | Results 194 | ------------ 195 | 196 | Model | Vectors | params | [STS Benchmark](http://ixa2.si.ehu.es/stswiki/index.php/STSbenchmark#Results) 197 | :---: | :---: | :---: | :---: 198 | `CBOW` | `paranmt-300` | | 79.82 199 | `uSIF` | `paranmt-300` | length=11 | 79.00 200 | `SIF-10` | `paranmt-300` | components=10 | 76.72 201 | `SIF-10` | `paragram-300-sl999` | components=10 | 74.21 202 | `SIF-10` | `paragram-300-ws353` | components=10 | 74.03 203 | `SIF-10` | `fasttext-crawl-subwords-300` | components=10 | 73.38 204 | `uSIF` | `paragram-300-sl999` | length=11 | 73.04 205 | `SIF-10` | `fasttext-wiki-news-subwords-300` | components=10 | 72.29 206 | `uSIF` | `paragram-300-ws353` | length=11 | 71.84 207 | `SIF-10` | `glove-twitter-200` | components=10 | 71.62 208 | `SIF-10` | `glove-wiki-gigaword-300` | components=10 | 71.35 209 | `SIF-10` | `word2vec-google-news-300` | components=10 | 71.12 210 | `SIF-10` | `glove-wiki-gigaword-200` | components=10 | 70.62 211 | `SIF-10` | `glove-twitter-100` | components=10 | 69.65 212 | `uSIF` | `fasttext-crawl-subwords-300` | length=11 | 69.40 213 | `uSIF` | `fasttext-wiki-news-subwords-300` | length=11 | 68.63 214 | `SIF-10` | `glove-wiki-gigaword-100` | components=10 | 68.34 215 | `uSIF` | `glove-wiki-gigaword-300` | length=11 | 67.60 216 | `uSIF` | `glove-wiki-gigaword-200` | length=11 | 67.11 217 | `uSIF` | `word2vec-google-news-300` | length=11 | 66.99 218 | `uSIF` | `glove-twitter-200` | length=11 | 66.67 219 | `SIF-10` | `glove-twitter-50` | components=10 | 65.52 220 | `uSIF` | `glove-wiki-gigaword-100` | length=11 | 65.33 221 | `uSIF` | `paragram-25` | length=11 | 64.22 222 | `uSIF` | `glove-twitter-100` | length=11 | 64.13 223 | `SIF-10` | `glove-wiki-gigaword-50` | components=10 | 64.11 224 | `uSIF` | `glove-wiki-gigaword-50` | length=11 | 62.06 225 | `CBOW` | `word2vec-google-news-300` | | 61.54 226 | `uSIF` | `glove-twitter-50` | length=11 | 60.41 227 | `SIF-10` | `paragram-25` | components=10 | 59.07 228 | `uSIF` | `glove-twitter-25` | length=11 | 55.06 229 | `CBOW` | `paragram-300-ws353` | | 54.72 230 | `SIF-10` | `glove-twitter-25` | components=10 | 54.16 231 | `CBOW` | `paragram-300-sl999` | | 51.46 232 | `CBOW` | `fasttext-crawl-subwords-300` | | 48.49 233 | `CBOW` | `glove-wiki-gigaword-300` | | 44.46 234 | `CBOW` | `glove-wiki-gigaword-200` | | 42.40 235 | `CBOW` | `paragram-25` | | 40.13 236 | `CBOW` | `glove-wiki-gigaword-100` | | 38.12 237 | `CBOW` | `glove-wiki-gigaword-50` | | 37.47 238 | `CBOW` | `glove-twitter-200` | | 34.94 239 | `CBOW` | `glove-twitter-100` | | 33.81 240 | `CBOW` | `glove-twitter-50` | | 30.78 241 | `CBOW` | `glove-twitter-25` | | 26.15 242 | `CBOW` | `fasttext-wiki-news-subwords-300` | | 26.08 243 | 244 | Changelog 245 | ------------- 246 | 247 | 1.0.0: 248 | - Added support for gensim>=4. This library is no longer compatible with gensim<4. For migration, see the [README](https://github.com/RaRe-Technologies/gensim/wiki/Migrating-from-Gensim-3.x-to-4). 249 | - `size` argument is now `vector_size` 250 | 251 | 0.2.0: 252 | - Added `Vectors` and `FTVectors` class and hub support by `from_pretrained` 253 | - Extended benchmark 254 | - Fixed zero division bug for uSIF 255 | - Moved tests out of the main folder 256 | - Moved sts out of the main folder 257 | 258 | 0.1.17: 259 | - Fixed dependency issue where you cannot install fse properly 260 | - Updated readme 261 | - Updated travis python versions (3.6, 3.9) 262 | 263 | 0.1.15 from 0.1.11: 264 | - Fixed major FT Ngram computation bug 265 | - Rewrote the input class. Turns out NamedTuple was pretty slow. 266 | - Added further unittests 267 | - Added documentation 268 | - Major speed improvements 269 | - Fixed division by zero for empty sentences 270 | - Fixed overflow when infer method is used with too many sentences 271 | - Fixed similar_by_sentence bug 272 | 273 | Literature 274 | ------------- 275 | 276 | 1. Iyyer M, Manjunatha V, Boyd-Graber J, Daumé III H (2015) Deep Unordered 277 | Composition Rivals Syntactic Methods for Text Classification. Proc. 53rd Annu. 278 | Meet. Assoc. Comput. Linguist. 7th Int. Jt. Conf. Nat. Lang. Process., 1681–1691. 279 | 280 | 2. Arora S, Liang Y, Ma T (2017) A Simple but Tough-to-Beat Baseline for Sentence 281 | Embeddings. Int. Conf. Learn. Represent. (Toulon, France), 1–16. 282 | 283 | 3. Ethayarajh K (2018) Unsupervised Random Walk Sentence Embeddings: A Strong but Simple Baseline. 284 | Proceedings of the 3rd Workshop on Representation Learning for NLP. (Toulon, France), 91–100. 285 | 286 | 4. Eneko Agirre, Daniel Cer, Mona Diab, Iñigo Lopez-Gazpio, Lucia Specia. Semeval-2017 Task 1: Semantic Textual Similarity Multilingual and Crosslingual Focused Evaluation. Proceedings of SemEval 2017. 287 | 288 | 289 | Copyright 290 | ------------- 291 | 292 | **Disclaimer**: I am working full time. Unfortunately, I have yet to find time to add all the features I'd like to see. Especially the API needs some overhaul and we need support for gensim 4.0.0. 293 | 294 | I am looking for active contributors to keep this package alive. Please feel free to ping me at if you are interested. 295 | 296 | Author: Oliver Borchers 297 | 298 | Copyright (C) 2022 Oliver Borchers 299 | 300 | Citation 301 | ------------- 302 | 303 | If you found this software useful, please cite it in your publication. 304 | 305 | @misc{Borchers2019, 306 | author = {Borchers, Oliver}, 307 | title = {Fast sentence embeddings}, 308 | year = {2019}, 309 | publisher = {GitHub}, 310 | journal = {GitHub Repository}, 311 | howpublished = {\url{https://github.com/oborchers/Fast_Sentence_Embeddings}}, 312 | } 313 | -------------------------------------------------------------------------------- /fse/inputs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | 7 | from pathlib import Path 8 | from typing import List, MutableSequence, Union 9 | 10 | from gensim.utils import any2unicode 11 | from numpy import concatenate, ndarray 12 | from smart_open import open 13 | 14 | 15 | class BaseIndexedList(MutableSequence): 16 | def __init__(self, *args: List[Union[list, set, ndarray]]): 17 | """Base object to be used for feeding in-memory stored lists of sentences to the 18 | training routine. 19 | 20 | Parameters 21 | ---------- 22 | args : lists, sets, ndarray 23 | Arguments to be merged into a single contianer. Can be single or multiple list/set/ndarray objects. 24 | """ 25 | 26 | self.items = list() 27 | 28 | if len(args) == 1: 29 | self._check_list_type(args[0]) 30 | self.items = args[0] 31 | else: 32 | for arg in args: 33 | self.extend(arg) 34 | 35 | super().__init__() 36 | 37 | def _check_list_type(self, obj: object): 38 | """Checks input validity.""" 39 | if isinstance(obj, (list, set, ndarray)): 40 | return 1 41 | else: 42 | raise TypeError(f"Arg must be list/set type. Got {type(obj)}") 43 | 44 | def _check_str_type(self, obj: object): 45 | """Checks input validity.""" 46 | if isinstance(obj, str): 47 | return 1 48 | else: 49 | raise TypeError(f"Arg must be str type. Got {type(obj)}") 50 | 51 | def __len__(self): 52 | """List length. 53 | 54 | Returns 55 | ------- 56 | int 57 | Length of the IndexedList 58 | """ 59 | return len(self.items) 60 | 61 | def __str__(self): 62 | """Human readable representation of the object's state, used for debugging. 63 | 64 | Returns 65 | ------- 66 | str 67 | Human readable representation of the object's state (words and tags). 68 | """ 69 | return str(self.items) 70 | 71 | def __getitem__(self, i: int) -> tuple: 72 | """Getitem method. 73 | 74 | Returns 75 | ------- 76 | tuple ([str], int) 77 | Returns the core object, a tuple, for every sentence embedding model. 78 | """ 79 | raise NotImplementedError() 80 | 81 | def __delitem__(self, i: int): 82 | """Delete an item.""" 83 | del self.items[i] 84 | 85 | def __setitem__(self, i: int, item: str): 86 | """Sets an item.""" 87 | self._check_str_type(item) 88 | self.items[i] = item 89 | 90 | def insert(self, i: int, item: str): 91 | """Inserts an item at a position.""" 92 | self._check_str_type(item) 93 | self.items.insert(i, item) 94 | 95 | def append(self, item: str): 96 | """Appends item at last position.""" 97 | self._check_str_type(item) 98 | self.insert(len(self.items), item) 99 | 100 | def extend(self, arg: Union[list, set, ndarray]): 101 | """Extens list.""" 102 | self._check_list_type(arg) 103 | 104 | if not isinstance(arg, ndarray): 105 | self.items += arg 106 | else: 107 | self.items = concatenate([self.items, arg], axis=0) 108 | 109 | 110 | class IndexedList(BaseIndexedList): 111 | def __init__(self, *args: Union[list, set, ndarray]): 112 | """Quasi-list to be used for feeding in-memory stored lists of sentences to the 113 | training routine. 114 | 115 | Parameters 116 | ---------- 117 | args : lists, sets, ndarray 118 | Arguments to be merged into a single contianer. Can be single or multiple list/set objects. 119 | """ 120 | super(IndexedList, self).__init__(*args) 121 | 122 | def __getitem__(self, i: int) -> tuple: 123 | """Getitem method. 124 | 125 | Returns 126 | ------- 127 | tuple 128 | Returns the core object, tuple, for every sentence embedding model. 129 | """ 130 | return (self.items.__getitem__(i), i) 131 | 132 | 133 | class CIndexedList(BaseIndexedList): 134 | def __init__( 135 | self, *args: Union[list, set, ndarray], custom_index: Union[list, ndarray] 136 | ): 137 | """Quasi-list with custom indices to be used for feeding in-memory stored lists 138 | of sentences to the training routine. 139 | 140 | Parameters 141 | ---------- 142 | args : lists, sets, ndarray 143 | Arguments to be merged into a single contianer. Can be single or multiple list/set objects. 144 | custom_index : list, ndarray 145 | Custom index to support many to one mappings. 146 | """ 147 | self.custom_index = custom_index 148 | 149 | super(CIndexedList, self).__init__(*args) 150 | 151 | if len(self.items) != len(self.custom_index): 152 | raise RuntimeError( 153 | f"Size of custom_index {len(custom_index)} does not match items {len(self.items)}" 154 | ) 155 | 156 | def __getitem__(self, i: int) -> tuple: 157 | """Getitem method. 158 | 159 | Returns 160 | ------- 161 | tuple 162 | Returns the core object, tuple, for every sentence embedding model. 163 | """ 164 | return (self.items.__getitem__(i), self.custom_index[i]) 165 | 166 | def __delitem__(self, i: int): 167 | raise NotImplementedError("Method currently not supported") 168 | 169 | def __setitem__(self, i: int, item: str): 170 | raise NotImplementedError("Method currently not supported") 171 | 172 | def insert(self, i: int, item: str): 173 | raise NotImplementedError("Method currently not supported") 174 | 175 | def append(self, item: str): 176 | raise NotImplementedError("Method currently not supported") 177 | 178 | def extend(self, arg: Union[list, set, ndarray]): 179 | raise NotImplementedError("Method currently not supported") 180 | 181 | 182 | class SplitIndexedList(BaseIndexedList): 183 | def __init__(self, *args: Union[list, set, ndarray]): 184 | """Quasi-list with string splitting to be used for feeding in-memory stored 185 | lists of sentences to the training routine. 186 | 187 | Parameters 188 | ---------- 189 | args : lists, sets, ndarray 190 | Arguments to be merged into a single contianer. Can be single or multiple list/set objects. 191 | """ 192 | super(SplitIndexedList, self).__init__(*args) 193 | 194 | def __getitem__(self, i: int) -> tuple: 195 | """Getitem method. 196 | 197 | Returns 198 | ------- 199 | tuple 200 | Returns the core object, tuple, for every sentence embedding model. 201 | """ 202 | return (self.items.__getitem__(i).split(), i) 203 | 204 | 205 | class SplitCIndexedList(BaseIndexedList): 206 | def __init__( 207 | self, *args: Union[list, set, ndarray], custom_index: Union[list, ndarray] 208 | ): 209 | """Quasi-list with custom indices and string splitting to be used for feeding 210 | in-memory stored lists of sentences to the training routine. 211 | 212 | Parameters 213 | ---------- 214 | args : lists, sets, ndarray 215 | Arguments to be merged into a single contianer. Can be single or multiple list/set objects. 216 | custom_index : list, ndarray 217 | Custom index to support many to one mappings. 218 | """ 219 | self.custom_index = custom_index 220 | 221 | super(SplitCIndexedList, self).__init__(*args) 222 | 223 | if len(self.items) != len(self.custom_index): 224 | raise RuntimeError( 225 | f"Size of custom_index {len(custom_index)} does not match items {len(self.items)}" 226 | ) 227 | 228 | def __getitem__(self, i: int) -> tuple: 229 | """Getitem method. 230 | 231 | Returns 232 | ------- 233 | tuple 234 | Returns the core object, tuple, for every sentence embedding model. 235 | """ 236 | return (self.items.__getitem__(i).split(), self.custom_index[i]) 237 | 238 | def __delitem__(self, i: int): 239 | raise NotImplementedError("Method currently not supported") 240 | 241 | def __setitem__(self, i: int, item: str): 242 | raise NotImplementedError("Method currently not supported") 243 | 244 | def insert(self, i: int, item: str): 245 | raise NotImplementedError("Method currently not supported") 246 | 247 | def append(self, item: str): 248 | raise NotImplementedError("Method currently not supported") 249 | 250 | def extend(self, arg: Union[list, set, ndarray]): 251 | raise NotImplementedError("Method currently not supported") 252 | 253 | 254 | class CSplitIndexedList(BaseIndexedList): 255 | def __init__(self, *args: Union[list, set, ndarray], custom_split: callable): 256 | """Quasi-list with custom string splitting to be used for feeding in-memory 257 | stored lists of sentences to the training routine. 258 | 259 | Parameters 260 | ---------- 261 | args : lists, sets, ndarray 262 | Arguments to be merged into a single contianer. Can be single or multiple list/set objects. 263 | custom_split : callable 264 | Split function to be used to convert strings into list of str. 265 | """ 266 | self.custom_split = custom_split 267 | super(CSplitIndexedList, self).__init__(*args) 268 | 269 | def __getitem__(self, i: int) -> tuple: 270 | """Getitem method. 271 | 272 | Returns 273 | ------- 274 | tuple 275 | Returns the core object, tuple, for every sentence embedding model. 276 | """ 277 | return (self.custom_split(self.items.__getitem__(i)), i) 278 | 279 | 280 | class CSplitCIndexedList(BaseIndexedList): 281 | def __init__( 282 | self, 283 | *args: Union[list, set, ndarray], 284 | custom_split: callable, 285 | custom_index: Union[list, ndarray], 286 | ): 287 | """Quasi-list with custom indices and ustom string splitting to be used for 288 | feeding in-memory stored lists of sentences to the training routine. 289 | 290 | Parameters 291 | ---------- 292 | args : lists, sets, ndarray 293 | Arguments to be merged into a single contianer. Can be single or multiple list/set objects. 294 | custom_split : callable 295 | Split function to be used to convert strings into list of str. 296 | custom_index : list, ndarray 297 | Custom index to support many to one mappings. 298 | """ 299 | self.custom_split = custom_split 300 | self.custom_index = custom_index 301 | 302 | super(CSplitCIndexedList, self).__init__(*args) 303 | 304 | if len(self.items) != len(self.custom_index): 305 | raise RuntimeError( 306 | f"Size of custom_index {len(custom_index)} does not match items {len(self.items)}" 307 | ) 308 | 309 | def __getitem__(self, i: int) -> tuple: 310 | """Getitem method. 311 | 312 | Returns 313 | ------- 314 | tuple 315 | Returns the core object, tuple, for every sentence embedding model. 316 | """ 317 | return (self.custom_split(self.items.__getitem__(i)), self.custom_index[i]) 318 | 319 | def __delitem__(self, i: int): 320 | raise NotImplementedError("Method currently not supported") 321 | 322 | def __setitem__(self, i: int, item: str): 323 | raise NotImplementedError("Method currently not supported") 324 | 325 | def insert(self, i: int, item: str): 326 | raise NotImplementedError("Method currently not supported") 327 | 328 | def append(self, item: str): 329 | raise NotImplementedError("Method currently not supported") 330 | 331 | def extend(self, arg: Union[list, set, ndarray]): 332 | raise NotImplementedError("Method currently not supported") 333 | 334 | 335 | class IndexedLineDocument(object): 336 | def __init__(self, path, get_able=True): 337 | """Iterate over a file that contains sentences: one line = tuple([str], int). 338 | 339 | Words are expected to be already preprocessed and separated by whitespace. Sentence tags are constructed 340 | automatically from the sentence line number. 341 | 342 | Parameters 343 | ---------- 344 | path : str 345 | The path of the file to read and return lines from 346 | get_able : bool, optional 347 | Use to determine if the IndexedLineDocument is indexable. 348 | This functionality is required if you want to pass an indexable to 349 | :meth:`~fse.models.sentencevectors.SentenceVectors.most_similar`. 350 | 351 | """ 352 | self.path = Path(path) 353 | self.line_offset = list() 354 | self.get_able = bool(get_able) 355 | 356 | if self.get_able: 357 | self._build_offsets() 358 | 359 | def _build_offsets(self): 360 | """Builds an offset table to index the file.""" 361 | with open(self.path, "rb") as f: 362 | offset = f.tell() 363 | for line in f: 364 | self.line_offset.append(offset) 365 | offset += len(line) 366 | 367 | def __getitem__(self, i): 368 | """Returns the line indexed by i. Primarily used for. 369 | 370 | :meth:`~fse.models.sentencevectors.SentenceVectors.most_similar` 371 | 372 | Parameters 373 | ---------- 374 | i : int 375 | The line index used to index the file 376 | 377 | Returns 378 | ------- 379 | str 380 | line at the current index 381 | """ 382 | if not self.get_able: 383 | raise RuntimeError( 384 | "To index the lines, you must contruct with get_able=True" 385 | ) 386 | 387 | with open(self.path, "rb") as f: 388 | f.seek(self.line_offset[i]) 389 | output = f.readline() 390 | f.seek(0) 391 | return any2unicode(output).rstrip() 392 | 393 | def __iter__(self): 394 | """Iterate through the lines in the source. 395 | 396 | Yields 397 | ------ 398 | tuple : (list[str], int) 399 | Tuple of list of string and index 400 | """ 401 | with open(self.path, "rb") as f: 402 | for i, line in enumerate(f): 403 | yield (any2unicode(line).split(), i) 404 | -------------------------------------------------------------------------------- /fse/models/sentencevectors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | 7 | 8 | from __future__ import division 9 | 10 | from fse.inputs import IndexedList, IndexedLineDocument 11 | 12 | from fse.models.utils import set_madvise_for_mmap 13 | 14 | from gensim.models.keyedvectors import KeyedVectors 15 | 16 | from numpy import ( 17 | dot, 18 | float32 as REAL, 19 | memmap as np_memmap, 20 | array, 21 | zeros, 22 | vstack, 23 | sqrt, 24 | newaxis, 25 | integer, 26 | ndarray, 27 | ) 28 | 29 | from gensim import utils, matutils 30 | 31 | from typing import List, Tuple, Union 32 | 33 | from pathlib import Path 34 | 35 | import logging 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | class SentenceVectors(utils.SaveLoad): 41 | def __init__(self, vector_size: int, mapfile_path: str = None): 42 | 43 | set_madvise_for_mmap() 44 | 45 | self.vector_size = vector_size # Size of vectors 46 | self.vectors = zeros((0, vector_size), REAL) # Vectors for sentences 47 | self.vectors_norm = None 48 | 49 | # File for numpy memmap 50 | self.mapfile_path = Path(mapfile_path) if mapfile_path is not None else None 51 | self.mapfile_shape = None 52 | 53 | def __getitem__(self, entities: int) -> ndarray: 54 | """Get vector representation of `entities`. 55 | 56 | Parameters 57 | ---------- 58 | entities : {int, list of int} 59 | Index or sequence of entities. 60 | 61 | Returns 62 | ------- 63 | numpy.ndarray 64 | Vector representation for `entities` (1D if `entities` is int, otherwise - 2D). 65 | 66 | """ 67 | 68 | if isinstance( 69 | entities, 70 | ( 71 | int, 72 | integer, 73 | ), 74 | ): 75 | return self.get_vector(entities) 76 | 77 | return vstack([self.get_vector(e) for e in entities]) 78 | 79 | def __contains__(self, index: int) -> bool: 80 | if isinstance( 81 | index, 82 | ( 83 | int, 84 | integer, 85 | ), 86 | ): 87 | return index < len(self) 88 | else: 89 | raise KeyError(f"index {index} is not a valid index") 90 | 91 | def __len__(self) -> int: 92 | return len(self.vectors) 93 | 94 | def _load_all_vectors_from_disk(self, mapfile_path: Path): 95 | """ Reads all vectors from disk """ 96 | path = str(mapfile_path.absolute()) 97 | self.vectors = np_memmap( 98 | f"{path}.vectors", dtype=REAL, mode="r+", shape=self.mapfile_shape 99 | ) 100 | 101 | def save(self, *args, **kwargs): 102 | """Save object. 103 | 104 | Parameters 105 | ---------- 106 | fname : str 107 | Path to the output file. 108 | 109 | See Also 110 | -------- 111 | :meth:`~gensim.models.keyedvectors.Doc2VecKeyedVectors.load` 112 | Load object. 113 | 114 | """ 115 | self.mapfile_shape = self.vectors.shape 116 | ignore = ["vectors_norm"] 117 | # don't bother storing the cached normalized vectors 118 | if self.mapfile_path is not None: 119 | ignore.append("vectors") 120 | kwargs["ignore"] = kwargs.get("ignore", ignore) 121 | super(SentenceVectors, self).save(*args, **kwargs) 122 | 123 | @classmethod 124 | def load(cls, fname_or_handle, **kwargs): 125 | # TODO: Unittests 126 | sv = super(SentenceVectors, cls).load(fname_or_handle, **kwargs) 127 | path = sv.mapfile_path 128 | if path is not None: 129 | sv._load_all_vectors_from_disk(mapfile_path=path) 130 | set_madvise_for_mmap() 131 | return sv 132 | 133 | def get_vector(self, index: int, use_norm: bool = False) -> ndarray: 134 | """Get sentence representations in vector space, as a 1D numpy array. 135 | 136 | Parameters 137 | ---------- 138 | index : int 139 | Input index 140 | norm : bool, optional 141 | If True - resulting vector will be L2-normalized (unit euclidean length). 142 | 143 | Returns 144 | ------- 145 | numpy.ndarray 146 | Vector representation of index. 147 | 148 | Raises 149 | ------ 150 | KeyError 151 | If index out of bounds. 152 | 153 | """ 154 | if index in self: 155 | if use_norm: 156 | result = self.vectors_norm[index] 157 | else: 158 | result = self.vectors[index] 159 | 160 | result.setflags(write=False) 161 | return result 162 | else: 163 | raise KeyError("index {index} not found") 164 | 165 | def init_sims(self, replace: bool = False): 166 | """Precompute L2-normalized vectors. 167 | 168 | Parameters 169 | ---------- 170 | replace : bool, optional 171 | If True - forget the original vectors and only keep the normalized ones = saves lots of memory! 172 | """ 173 | if getattr(self, "vectors_norm", None) is None or replace: 174 | logger.info("precomputing L2-norms of sentence vectors") 175 | if not replace and self.mapfile_path is not None: 176 | self.vectors_norm = np_memmap( 177 | self.mapfile_path + ".vectors_norm", 178 | dtype=REAL, 179 | mode="w+", 180 | shape=self.vectors.shape, 181 | ) 182 | self.vectors_norm = _l2_norm(self.vectors, replace=replace) 183 | 184 | def similarity(self, d1: int, d2: int) -> float: 185 | """Compute cosine similarity between two sentences from the training set. 186 | 187 | Parameters 188 | ---------- 189 | d1 : int 190 | index of sentence 191 | d2 : int 192 | index of sentence 193 | 194 | Returns 195 | ------- 196 | float 197 | The cosine similarity between the vectors of the two sentences. 198 | 199 | """ 200 | return dot(matutils.unitvec(self[d1]), matutils.unitvec(self[d2])) 201 | 202 | def distance(self, d1: int, d2: int) -> float: 203 | """Compute cosine similarity between two sentences from the training set. 204 | 205 | Parameters 206 | ---------- 207 | d1 : int 208 | index of sentence 209 | d2 : int 210 | index of sentence 211 | 212 | Returns 213 | ------- 214 | float 215 | The cosine distance between the vectors of the two sentences. 216 | 217 | """ 218 | return 1 - self.similarity(d1, d2) 219 | 220 | def most_similar( 221 | self, 222 | positive: Union[int, ndarray] = None, 223 | negative: Union[int, ndarray] = None, 224 | indexable: Union[IndexedList, IndexedLineDocument] = None, 225 | topn: int = 10, 226 | restrict_size: Union[int, Tuple[int, int]] = None, 227 | ) -> List[Tuple[int, float]]: 228 | 229 | """Find the top-N most similar sentences. 230 | Positive sentences contribute positively towards the similarity, negative sentences negatively. 231 | 232 | This method computes cosine similarity between a simple mean of the projection 233 | weight vectors of the given sentences and the vectors for each sentence in the model. 234 | 235 | Parameters 236 | ---------- 237 | positive : list of int, optional 238 | List of indices that contribute positively. 239 | negative : list of int, optional 240 | List of indices that contribute negatively. 241 | indexable: list, IndexedList, IndexedLineDocument 242 | Provides an indexable object from where the most similar sentences are read 243 | topn : int or None, optional 244 | Number of top-N similar sentences to return, when `topn` is int. When `topn` is None, 245 | then similarities for all sentences are returned. 246 | restrict_size : int or Tuple(int,int), optional 247 | Optional integer which limits the range of vectors which 248 | are searched for most-similar values. For example, restrict_vocab=10000 would 249 | only check the first 10000 sentence vectors. 250 | restrict_vocab=(500, 1000) would search the sentence vectors with indices between 251 | 500 and 1000. 252 | 253 | Returns 254 | ------- 255 | list of (int, float) or list of (str, int, float) 256 | A sequence of (index, similarity) is returned. 257 | When an indexable is provided, returns (str, index, similarity) 258 | When `topn` is None, then similarities for all words are returned as a 259 | one-dimensional numpy array with the size of the vocabulary. 260 | 261 | """ 262 | if indexable is not None and not hasattr(indexable, "__getitem__"): 263 | raise RuntimeError("Indexable must provide __getitem__") 264 | if positive is None: 265 | positive = [] 266 | if negative is None: 267 | negative = [] 268 | 269 | self.init_sims() 270 | 271 | if isinstance(positive, (int, integer)) and not negative: 272 | positive = [positive] 273 | if isinstance(positive, (ndarray)) and not negative: 274 | if len(positive.shape) == 1: 275 | positive = [positive] 276 | 277 | positive = [ 278 | (sent, 1.0) if isinstance(sent, (int, integer, ndarray)) else sent 279 | for sent in positive 280 | ] 281 | negative = [ 282 | (sent, -1.0) if isinstance(sent, (int, integer, ndarray)) else sent 283 | for sent in negative 284 | ] 285 | 286 | all_sents, mean = set(), [] 287 | for sent, weight in positive + negative: 288 | if isinstance(sent, ndarray): 289 | mean.append(weight * sent) 290 | else: 291 | mean.append(weight * self.get_vector(index=sent, use_norm=True)) 292 | if sent in self: 293 | all_sents.add(sent) 294 | if not mean: 295 | raise ValueError("cannot compute similarity with no input") 296 | mean = matutils.unitvec(array(mean).mean(axis=0)).astype(REAL) 297 | 298 | if isinstance(restrict_size, (int, integer)): 299 | lo, hi = 0, restrict_size 300 | elif isinstance(restrict_size, Tuple): 301 | lo, hi = restrict_size 302 | else: 303 | lo, hi = 0, None 304 | 305 | limited = ( 306 | self.vectors_norm if restrict_size is None else self.vectors_norm[lo:hi] 307 | ) 308 | dists = dot(limited, mean) 309 | if not topn: 310 | return dists 311 | best = matutils.argsort(dists, topn=topn + len(all_sents), reverse=True) 312 | best_off = best + lo 313 | 314 | if indexable is not None: 315 | result = [ 316 | (indexable[off_idx], off_idx, float(dists[idx])) 317 | for off_idx, idx in zip(best_off, best) 318 | if off_idx not in all_sents 319 | ] 320 | else: 321 | result = [ 322 | (off_idx, float(dists[idx])) 323 | for off_idx, idx in zip(best_off, best) 324 | if off_idx not in all_sents 325 | ] 326 | return result[:topn] 327 | 328 | def similar_by_word( 329 | self, 330 | word: str, 331 | wv: KeyedVectors, 332 | indexable: Union[IndexedList, IndexedLineDocument] = None, 333 | topn: int = 10, 334 | restrict_size: Union[int, Tuple[int, int]] = None, 335 | ) -> List[Tuple[int, float]]: 336 | 337 | """Find the top-N most similar sentences to a given word. 338 | 339 | Parameters 340 | ---------- 341 | word : str 342 | Word 343 | wv : :class:`~gensim.models.keyedvectors.KeyedVectors` 344 | This object essentially contains the mapping between words and embeddings. 345 | indexable: list, IndexedList, IndexedLineDocument 346 | Provides an indexable object from where the most similar sentences are read 347 | topn : int or None, optional 348 | Number of top-N similar sentences to return, when `topn` is int. When `topn` is None, 349 | then similarities for all sentences are returned. 350 | restrict_size : int or Tuple(int,int), optional 351 | Optional integer which limits the range of vectors which 352 | are searched for most-similar values. For example, restrict_vocab=10000 would 353 | only check the first 10000 sentence vectors. 354 | restrict_vocab=(500, 1000) would search the sentence vectors with indices between 355 | 500 and 1000. 356 | 357 | Returns 358 | ------- 359 | list of (int, float) or list of (str, int, float) 360 | A sequence of (index, similarity) is returned. 361 | When an indexable is provided, returns (str, index, similarity) 362 | When `topn` is None, then similarities for all words are returned as a 363 | one-dimensional numpy array with the size of the vocabulary. 364 | 365 | """ 366 | return self.most_similar( 367 | positive=wv[word], 368 | indexable=indexable, 369 | topn=topn, 370 | restrict_size=restrict_size, 371 | ) 372 | 373 | def similar_by_sentence( 374 | self, 375 | sentence: List[str], 376 | model, 377 | indexable: Union[IndexedList, IndexedLineDocument] = None, 378 | topn: int = 10, 379 | restrict_size: Union[int, Tuple[int, int]] = None, 380 | ) -> List[Tuple[int, float]]: 381 | 382 | """Find the top-N most similar sentences to a given sentence. 383 | 384 | Parameters 385 | ---------- 386 | sentence : list of str 387 | Sentence as list of strings 388 | model : :class:`~fse.models.base_s2v.BaseSentence2VecModel` 389 | This object essentially provides the infer method used to transform . 390 | indexable: list, IndexedList, IndexedLineDocument 391 | Provides an indexable object from where the most similar sentences are read 392 | topn : int or None, optional 393 | Number of top-N similar sentences to return, when `topn` is int. When `topn` is None, 394 | then similarities for all sentences are returned. 395 | restrict_size : int or Tuple(int,int), optional 396 | Optional integer which limits the range of vectors which 397 | are searched for most-similar values. For example, restrict_vocab=10000 would 398 | only check the first 10000 sentence vectors. 399 | restrict_vocab=(500, 1000) would search the sentence vectors with indices between 400 | 500 and 1000. 401 | 402 | Returns 403 | ------- 404 | list of (int, float) or list of (str, int, float) 405 | A sequence of (index, similarity) is returned. 406 | When an indexable is provided, returns (str, index, similarity) 407 | When `topn` is None, then similarities for all words are returned as a 408 | one-dimensional numpy array with the size of the vocabulary. 409 | 410 | """ 411 | infer_op = getattr(model, "infer", None) 412 | if not callable(infer_op): 413 | raise RuntimeError( 414 | "Model does not have infer method. Make sure to pass a BaseSentence2VecModel" 415 | ) 416 | 417 | vector = model.infer([(sentence, 0)]) 418 | return self.most_similar( 419 | positive=vector, indexable=indexable, topn=topn, restrict_size=restrict_size 420 | ) 421 | 422 | def similar_by_vector( 423 | self, 424 | vector: ndarray, 425 | indexable: Union[IndexedList, IndexedLineDocument] = None, 426 | topn: int = 10, 427 | restrict_size: Union[int, Tuple[int, int]] = None, 428 | ) -> List[Tuple[int, float]]: 429 | 430 | """Find the top-N most similar sentences to a given vector. 431 | 432 | Parameters 433 | ---------- 434 | vector : ndarray 435 | Vectors 436 | indexable: list, IndexedList, IndexedLineDocument 437 | Provides an indexable object from where the most similar sentences are read 438 | topn : int or None, optional 439 | Number of top-N similar sentences to return, when `topn` is int. When `topn` is None, 440 | then similarities for all sentences are returned. 441 | restrict_size : int or Tuple(int,int), optional 442 | Optional integer which limits the range of vectors which 443 | are searched for most-similar values. For example, restrict_vocab=10000 would 444 | only check the first 10000 sentence vectors. 445 | restrict_vocab=(500, 1000) would search the sentence vectors with indices between 446 | 500 and 1000. 447 | 448 | Returns 449 | ------- 450 | list of (int, float) or list of (str, int, float) 451 | A sequence of (index, similarity) is returned. 452 | When an indexable is provided, returns (str, index, similarity) 453 | When `topn` is None, then similarities for all words are returned as a 454 | one-dimensional numpy array with the size of the vocabulary. 455 | 456 | """ 457 | return self.most_similar( 458 | positive=vector, indexable=indexable, topn=topn, restrict_size=restrict_size 459 | ) 460 | 461 | 462 | def _l2_norm(m, replace=False): 463 | """Return an L2-normalized version of a matrix. 464 | 465 | Parameters 466 | ---------- 467 | m : np.array 468 | The matrix to normalize. 469 | replace : boolean, optional 470 | If True, modifies the existing matrix. 471 | 472 | Returns 473 | ------- 474 | The normalized matrix. If replace=True, this will be the same as m. 475 | 476 | NOTE: This part is copied from Gensim and modified as the call 477 | m /= dist somtimes raises an exception and sometimes it does not. 478 | """ 479 | dist = sqrt((m ** 2).sum(-1))[..., newaxis] 480 | if replace: 481 | m = m / dist 482 | return m 483 | else: 484 | return (m / dist).astype(REAL) 485 | -------------------------------------------------------------------------------- /test/test_base_s2v.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # Author: Oliver Borchers 5 | # Copyright (C) Oliver Borchers 6 | 7 | """Automated tests for checking the base_s2v class.""" 8 | 9 | import logging 10 | import unittest 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | from gensim.models import FastText, Word2Vec 15 | from gensim.models.keyedvectors import KeyedVectors 16 | from wordfreq import get_frequency_dict 17 | 18 | from fse.models.base_s2v import EPS, BaseSentence2VecModel, BaseSentence2VecPreparer 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | TEST_DATA = Path(__file__).parent / "test_data" 23 | CORPUS = TEST_DATA / "test_sentences.txt" 24 | DIM = 5 25 | W2V = Word2Vec(min_count=1, vector_size=DIM) 26 | with open(CORPUS, "r") as file: 27 | SENTENCES = [l.split() for _, l in enumerate(file)] 28 | W2V.build_vocab(SENTENCES) 29 | 30 | 31 | class TestBaseSentence2VecModelFunctions(unittest.TestCase): 32 | def test_init_w_wrong_model(self): 33 | with self.assertRaises(RuntimeError): 34 | BaseSentence2VecModel(int) 35 | 36 | def test_init_w_empty_w2v_model(self): 37 | with self.assertRaises(RuntimeError): 38 | w2v = Word2Vec(min_count=1, vector_size=DIM) 39 | del w2v.wv.vectors 40 | BaseSentence2VecModel(w2v) 41 | 42 | def test_init_w_empty_vocab_model(self): 43 | with self.assertRaises(RuntimeError): 44 | w2v = Word2Vec(min_count=1, vector_size=DIM) 45 | del w2v.wv 46 | BaseSentence2VecModel(w2v) 47 | 48 | def test_init_w_ft_model_wo_vecs(self): 49 | ft = FastText(SENTENCES, vector_size=5) 50 | with self.assertRaises(RuntimeError): 51 | ft.wv.vectors_vocab = None 52 | BaseSentence2VecModel(ft) 53 | with self.assertRaises(RuntimeError): 54 | ft.wv.vectors_ngrams = None 55 | BaseSentence2VecModel(ft) 56 | 57 | def test_init_w_empty_ft_model(self): 58 | ft = FastText(min_count=1, vector_size=DIM) 59 | ft.wv.vectors = np.zeros(10) 60 | ft.wv.vectors_ngrams = None 61 | with self.assertRaises(RuntimeError): 62 | BaseSentence2VecModel(ft) 63 | 64 | def test_init_w_incompatible_ft_model(self): 65 | ft = FastText(min_count=1, vector_size=DIM) 66 | with self.assertRaises(RuntimeError): 67 | BaseSentence2VecModel(ft) 68 | 69 | def test_include_model(self): 70 | se = BaseSentence2VecModel(W2V) 71 | self.assertTrue(isinstance(se.wv, KeyedVectors)) 72 | 73 | def test_model_w_language(self): 74 | se = BaseSentence2VecModel(W2V, lang_freq="en") 75 | freq = int((2 ** 31 - 1) * get_frequency_dict("en", wordlist="best")["help"]) 76 | self.assertEqual(freq, se.wv.get_vecattr("help", "count")) 77 | self.assertEqual(21, se.wv.get_vecattr("79", "count")) 78 | 79 | def test_model_w_wrong_language(self): 80 | with self.assertRaises(ValueError): 81 | BaseSentence2VecModel(W2V, lang_freq="test") 82 | 83 | def test_save_load(self): 84 | se = BaseSentence2VecModel(W2V) 85 | p = TEST_DATA / "test_emb.model" 86 | se.save(str(p.absolute())) 87 | self.assertTrue(p.exists()) 88 | se2 = BaseSentence2VecModel.load(str(p.absolute())) 89 | self.assertTrue((se.wv.vectors == se2.wv.vectors).all()) 90 | self.assertEqual(se.wv.index_to_key, se2.wv.index_to_key) 91 | self.assertEqual(se.workers, se2.workers) 92 | p.unlink() 93 | 94 | def test_save_load_with_memmap(self): 95 | ft = FastText(min_count=1, vector_size=5) 96 | ft.build_vocab(SENTENCES) 97 | shape = (1000, 1000) 98 | ft.wv.vectors = np.zeros(shape, np.float32) 99 | 100 | p = TEST_DATA / "test_emb" 101 | p_vecs = TEST_DATA / "test_emb_wv.vectors" 102 | p_ngrams = TEST_DATA / "test_emb_ngrams.vectors" 103 | p_vocab = TEST_DATA / "test_emb_vocab.vectors" 104 | 105 | p_not_exists = TEST_DATA / "test_emb.wv.vectors.npy" 106 | 107 | se = BaseSentence2VecModel(ft, wv_mapfile_path=str(p)) 108 | self.assertTrue(p_vecs.exists()) 109 | self.assertTrue(p_ngrams.exists()) 110 | self.assertTrue(p_vocab.exists()) 111 | 112 | se.save(str(p.absolute())) 113 | self.assertTrue(p.exists()) 114 | self.assertFalse(p_not_exists.exists()) 115 | 116 | se = BaseSentence2VecModel.load(str(p.absolute())) 117 | self.assertFalse(se.wv.vectors_vocab.flags.writeable) 118 | self.assertEqual(shape, se.wv.vectors.shape) 119 | self.assertEqual((2000000, 5), se.wv.vectors_ngrams.shape) 120 | 121 | for p in [p, p_vecs, p_ngrams, p_vocab]: 122 | p.unlink() 123 | 124 | def test_map_all_vectors_to_disk(self): 125 | ft = FastText(min_count=1, vector_size=5) 126 | ft.build_vocab(SENTENCES) 127 | 128 | p = TEST_DATA / "test_emb" 129 | p_vecs = TEST_DATA / "test_emb_wv.vectors" 130 | p_ngrams = TEST_DATA / "test_emb_ngrams.vectors" 131 | p_vocab = TEST_DATA / "test_emb_vocab.vectors" 132 | 133 | BaseSentence2VecModel(ft, wv_mapfile_path=str(p)) 134 | 135 | self.assertTrue(p_vecs.exists()) 136 | self.assertTrue(p_ngrams.exists()) 137 | self.assertTrue(p_vocab.exists()) 138 | 139 | for p in [p_vecs, p_ngrams, p_vocab]: 140 | p.unlink() 141 | 142 | def test_input_check(self): 143 | se = BaseSentence2VecModel(W2V) 144 | 145 | class BadIterator: 146 | def __init__(self): 147 | pass 148 | 149 | with self.assertRaises(TypeError): 150 | se._check_input_data_sanity(data_iterable=None) 151 | with self.assertRaises(TypeError): 152 | se._check_input_data_sanity(data_iterable="Hello there!") 153 | with self.assertRaises(TypeError): 154 | se._check_input_data_sanity(data_iterable=BadIterator()) 155 | 156 | def test_scan_w_list(self): 157 | se = BaseSentence2VecModel(W2V) 158 | with self.assertRaises(TypeError): 159 | se.scan_sentences(SENTENCES) 160 | 161 | def test_str_rep(self): 162 | output = str(BaseSentence2VecModel(W2V)) 163 | self.assertEqual( 164 | "BaseSentence2VecModel based on KeyedVectors, vector_size=0", output 165 | ) 166 | 167 | def test_scan_w_ituple(self): 168 | se = BaseSentence2VecModel(W2V) 169 | id_sent = [(s, i) for i, s in enumerate(SENTENCES)] 170 | stats = se.scan_sentences(id_sent, progress_per=0) 171 | 172 | self.assertEqual(100, stats["total_sentences"]) 173 | self.assertEqual(1417, stats["total_words"]) 174 | self.assertEqual(14, stats["average_length"]) 175 | self.assertEqual(3, stats["empty_sentences"]) 176 | self.assertEqual(100, stats["max_index"]) 177 | 178 | def test_scan_w_wrong_tuple(self): 179 | se = BaseSentence2VecModel(W2V) 180 | id_sent = [(s, str(i)) for i, s in enumerate(SENTENCES)] 181 | with self.assertRaises(TypeError): 182 | se.scan_sentences(id_sent) 183 | 184 | def test_scan_w_empty(self): 185 | se = BaseSentence2VecModel(W2V) 186 | for i in [5, 10, 15]: 187 | SENTENCES[i] = [] 188 | self.assertEqual( 189 | 3, 190 | se.scan_sentences([(s, i) for i, s in enumerate(SENTENCES)])[ 191 | "empty_sentences" 192 | ], 193 | ) 194 | 195 | def test_scan_w_wrong_input(self): 196 | se = BaseSentence2VecModel(W2V) 197 | sentences = ["the dog hit the car", "he was very fast"] 198 | 199 | with self.assertRaises(TypeError): 200 | se.scan_sentences(sentences) 201 | with self.assertRaises(TypeError): 202 | se.scan_sentences([(s, i) for i, s in enumerate(sentences)]) 203 | with self.assertRaises(TypeError): 204 | se.scan_sentences([list(range(10) for _ in range(2))]) 205 | 206 | with self.assertRaises(RuntimeError): 207 | se.scan_sentences([(s, i + 1) for i, s in enumerate(SENTENCES)]) 208 | with self.assertRaises(ValueError): 209 | se.scan_sentences([(s, i - 1) for i, s in enumerate(SENTENCES)]) 210 | 211 | def test_scan_w_many_to_one_input(self): 212 | se = BaseSentence2VecModel(W2V) 213 | output = se.scan_sentences([(s, 0) for i, s in enumerate(SENTENCES)])[ 214 | "max_index" 215 | ] 216 | self.assertEqual(1, output) 217 | 218 | def test_estimate_memory(self): 219 | ft = FastText(min_count=1, vector_size=5) 220 | ft.build_vocab(SENTENCES) 221 | se = BaseSentence2VecModel(ft) 222 | self.assertEqual(2040025124, se.estimate_memory(int(1e8))["Total"]) 223 | 224 | def test_train(self): 225 | se = BaseSentence2VecModel(W2V) 226 | with self.assertRaises(NotImplementedError): 227 | se.train([(s, i) for i, s in enumerate(SENTENCES)]) 228 | 229 | def test_log_end(self): 230 | se = BaseSentence2VecModel(W2V) 231 | se._log_train_end(eff_sentences=2000, eff_words=4000, overall_time=10) 232 | 233 | def test_child_requirements(self): 234 | se = BaseSentence2VecModel(W2V) 235 | 236 | with self.assertRaises(NotImplementedError): 237 | se._do_train_job(None, None, None) 238 | with self.assertRaises(NotImplementedError): 239 | se._pre_train_calls() 240 | with self.assertRaises(NotImplementedError): 241 | se._post_train_calls() 242 | with self.assertRaises(NotImplementedError): 243 | se._check_parameter_sanity() 244 | with self.assertRaises(NotImplementedError): 245 | se._check_dtype_santiy() 246 | with self.assertRaises(NotImplementedError): 247 | se._post_inference_calls() 248 | 249 | def test_check_pre_train_san_no_wv(self): 250 | ft = FastText(min_count=1, vector_size=5) 251 | ft.build_vocab(SENTENCES) 252 | se = BaseSentence2VecModel(ft) 253 | se.wv = None 254 | with self.assertRaises(RuntimeError): 255 | se._check_pre_training_sanity(1, 1, 1) 256 | 257 | def test_check_pre_train_san_no_wv_len(self): 258 | ft = FastText(min_count=1, vector_size=5) 259 | ft.build_vocab(SENTENCES) 260 | se = BaseSentence2VecModel(ft) 261 | se.wv.vectors = [] 262 | with self.assertRaises(RuntimeError): 263 | se._check_pre_training_sanity(1, 1, 1) 264 | 265 | def test_check_pre_train_san_no_ngrams_vectors(self): 266 | ft = FastText(min_count=1, vector_size=5) 267 | ft.build_vocab(SENTENCES) 268 | se = BaseSentence2VecModel(ft) 269 | se.wv.vectors_ngrams = [] 270 | with self.assertRaises(RuntimeError): 271 | se._check_pre_training_sanity(1, 1, 1) 272 | se.wv.vectors_ngrams = [1] 273 | se.wv.vectors_vocab = [] 274 | with self.assertRaises(RuntimeError): 275 | se._check_pre_training_sanity(1, 1, 1) 276 | 277 | def test_check_pre_train_san_no_sv_vecs(self): 278 | ft = FastText(min_count=1, vector_size=5) 279 | ft.build_vocab(SENTENCES) 280 | se = BaseSentence2VecModel(ft) 281 | se.sv.vectors = None 282 | with self.assertRaises(RuntimeError): 283 | se._check_pre_training_sanity(1, 1, 1) 284 | 285 | def test_check_pre_train_san_no_word_weights(self): 286 | ft = FastText(min_count=1, vector_size=5) 287 | ft.build_vocab(SENTENCES) 288 | se = BaseSentence2VecModel(ft) 289 | se.word_weights = None 290 | with self.assertRaises(RuntimeError): 291 | se._check_pre_training_sanity(1, 1, 1) 292 | 293 | def test_check_pre_train_san_incos_len(self): 294 | ft = FastText(min_count=1, vector_size=5) 295 | ft.build_vocab(SENTENCES) 296 | se = BaseSentence2VecModel(ft) 297 | se.word_weights = np.ones(20) 298 | with self.assertRaises(RuntimeError): 299 | se._check_pre_training_sanity(1, 1, 1) 300 | 301 | def test_check_pre_train_dtypes(self): 302 | ft = FastText(min_count=1, vector_size=5) 303 | ft.build_vocab(SENTENCES) 304 | se = BaseSentence2VecModel(ft) 305 | 306 | se.wv.vectors = np.zeros((len(se.wv), 20), dtype=np.float64) 307 | with self.assertRaises(TypeError): 308 | se._check_pre_training_sanity(1, 1, 1) 309 | se.wv.vectors = np.zeros((len(se.wv), 20), dtype=np.float32) 310 | 311 | se.wv.vectors_ngrams = np.ones(len(se.wv), dtype=np.float16) 312 | with self.assertRaises(TypeError): 313 | se._check_pre_training_sanity(1, 1, 1) 314 | se.wv.vectors_ngrams = np.ones(len(se.wv), dtype=np.float32) 315 | 316 | se.wv.vectors_vocab = np.ones(len(se.wv), dtype=np.float16) 317 | with self.assertRaises(TypeError): 318 | se._check_pre_training_sanity(1, 1, 1) 319 | se.wv.vectors_vocab = np.ones(len(se.wv), dtype=np.float32) 320 | 321 | se.sv.vectors = np.zeros((len(se.wv), 20), dtype=int) 322 | with self.assertRaises(TypeError): 323 | se._check_pre_training_sanity(1, 1, 1) 324 | se.sv.vectors = np.zeros((len(se.wv), 20), dtype=np.float32) 325 | 326 | se.word_weights = np.ones(len(se.wv), dtype=bool) 327 | with self.assertRaises(TypeError): 328 | se._check_pre_training_sanity(1, 1, 1) 329 | se.word_weights = np.ones(len(se.wv), dtype=np.float32) 330 | 331 | def test_check_pre_train_statistics(self): 332 | ft = FastText(min_count=1, vector_size=5) 333 | ft.build_vocab(SENTENCES) 334 | se = BaseSentence2VecModel(ft) 335 | 336 | for v in se.wv.key_to_index: 337 | se.wv.set_vecattr(v, "count", 1) 338 | 339 | # Just throws multiple warnings warning 340 | se._check_pre_training_sanity(1, 1, 1) 341 | 342 | with self.assertRaises(ValueError): 343 | se._check_pre_training_sanity(0, 1, 1) 344 | with self.assertRaises(ValueError): 345 | se._check_pre_training_sanity(1, 0, 1) 346 | with self.assertRaises(ValueError): 347 | se._check_pre_training_sanity(1, 1, 0) 348 | 349 | def test_post_training_sanity(self): 350 | w2v = Word2Vec() 351 | w2v.build_vocab(SENTENCES) 352 | se = BaseSentence2VecModel(w2v) 353 | se.prep.prepare_vectors(se.sv, 20) 354 | with self.assertRaises(ValueError): 355 | se._check_post_training_sanity(0, 1) 356 | with self.assertRaises(ValueError): 357 | se._check_post_training_sanity(1, 0) 358 | 359 | def test_move_ndarray_to_disk_w2v(self): 360 | se = BaseSentence2VecModel(W2V) 361 | p = TEST_DATA / "test_vecs" 362 | p_target = TEST_DATA / "test_vecs_wv.vectors" 363 | se.wv.vectors[0, 1] = 10 364 | vecs = se.wv.vectors.copy() 365 | output = se._move_ndarray_to_disk( 366 | se.wv.vectors, name="wv", mapfile_path=str(p.absolute()) 367 | ) 368 | self.assertTrue(p_target.exists()) 369 | self.assertFalse(output.flags.writeable) 370 | self.assertTrue((vecs == output).all()) 371 | p_target.unlink() 372 | 373 | def test_move_w2v_vectors_to_disk_from_init(self): 374 | p = TEST_DATA / "test_vecs" 375 | se = BaseSentence2VecModel(W2V, wv_mapfile_path=str(p.absolute())) 376 | p_target = TEST_DATA / "test_vecs_wv.vectors" 377 | self.assertTrue(p_target.exists()) 378 | self.assertFalse(se.wv.vectors.flags.writeable) 379 | p_target.unlink() 380 | 381 | def test_move_ft_vectors_to_disk_from_init(self): 382 | ft = FastText(min_count=1, vector_size=DIM) 383 | ft.build_vocab(SENTENCES) 384 | 385 | p = TEST_DATA / "test_vecs" 386 | p_target_wv = TEST_DATA / "test_vecs_wv.vectors" 387 | p_target_ngram = TEST_DATA / "test_vecs_ngrams.vectors" 388 | p_target_vocab = TEST_DATA / "test_vecs_vocab.vectors" 389 | 390 | se = BaseSentence2VecModel(ft, wv_mapfile_path=str(p.absolute())) 391 | 392 | self.assertTrue(p_target_wv.exists()) 393 | self.assertFalse(se.wv.vectors.flags.writeable) 394 | 395 | self.assertTrue(p_target_ngram.exists()) 396 | self.assertFalse(se.wv.vectors_ngrams.flags.writeable) 397 | 398 | p_target_wv.unlink() 399 | p_target_ngram.unlink() 400 | p_target_vocab.unlink() 401 | 402 | def test_train_manager(self): 403 | se = BaseSentence2VecModel(W2V, workers=2) 404 | 405 | def temp_train_job(data_iterable, target, memory): 406 | v1 = v2 = sum(1 for _ in data_iterable) 407 | return v1 * 2, v2 * 3 408 | 409 | se._do_train_job = temp_train_job 410 | job_output = se._train_manager( 411 | data_iterable=[(s, i) for i, s in enumerate(SENTENCES)], 412 | total_sentences=len(SENTENCES), 413 | report_delay=0.01, 414 | ) 415 | self.assertEqual((100, 200, 300), job_output) 416 | 417 | def test_infer_method(self): 418 | se = BaseSentence2VecModel(W2V) 419 | 420 | def temp_train_job(data_iterable, target, memory): 421 | for i in data_iterable: 422 | target += 1 423 | return target 424 | 425 | def pass_method(**kwargs): 426 | pass 427 | 428 | se._post_inference_calls = pass_method 429 | se._do_train_job = temp_train_job 430 | output = se.infer([(s, i) for i, s in enumerate(SENTENCES)]) 431 | self.assertTrue((100 == output).all()) 432 | 433 | def test_infer_method_cy_overflow(self): 434 | se = BaseSentence2VecModel(W2V) 435 | 436 | from fse.models.average_inner import MAX_WORDS_IN_BATCH, train_average_cy 437 | 438 | def _do_train_job(data_iterable, target, memory): 439 | eff_sentences, eff_words = train_average_cy( 440 | model=se, indexed_sentences=data_iterable, target=target, memory=memory 441 | ) 442 | return eff_sentences, eff_words 443 | 444 | def pass_method(**kwargs): 445 | pass 446 | 447 | se._post_inference_calls = pass_method 448 | se._do_train_job = _do_train_job 449 | tmp = [] 450 | for i in range(20): 451 | tmp.extend(SENTENCES) 452 | bs = 0 453 | for i, s in enumerate(tmp): 454 | if bs >= MAX_WORDS_IN_BATCH: 455 | break 456 | bs += len(s) 457 | sents = [(s, i) for i, s in enumerate(tmp)] 458 | output = se.infer(sents) 459 | output = output[i:] 460 | self.assertTrue((0 != output).all()) 461 | 462 | def test_infer_many_to_one(self): 463 | se = BaseSentence2VecModel(W2V) 464 | 465 | def temp_train_job(data_iterable, target, memory): 466 | for i in data_iterable: 467 | target += 1 468 | return target 469 | 470 | def pass_method(**kwargs): 471 | pass 472 | 473 | se._post_inference_calls = pass_method 474 | se._do_train_job = temp_train_job 475 | output = se.infer([(s, 0) for i, s in enumerate(SENTENCES)]) 476 | self.assertTrue((100 == output).all()) 477 | self.assertEqual((1, 5), output.shape) 478 | 479 | def test_infer_use_norm(self): 480 | se = BaseSentence2VecModel(W2V) 481 | 482 | def temp_train_job(data_iterable, target, memory): 483 | for i in data_iterable: 484 | target += 1 485 | return target 486 | 487 | def pass_method(**kwargs): 488 | pass 489 | 490 | se._post_inference_calls = pass_method 491 | se._do_train_job = temp_train_job 492 | output = se.infer([(s, i) for i, s in enumerate(SENTENCES)], use_norm=True) 493 | 494 | self.assertTrue(np.allclose(1.0, np.sqrt(np.sum(output[0] ** 2)))) 495 | 496 | 497 | class TestBaseSentence2VecPreparerFunctions(unittest.TestCase): 498 | def test_reset_vectors(self): 499 | se = BaseSentence2VecModel(W2V) 500 | trainables = BaseSentence2VecPreparer() 501 | trainables.reset_vectors(se.sv, 20) 502 | self.assertEqual((20, DIM), se.sv.vectors.shape) 503 | self.assertEqual(np.float32, se.sv.vectors.dtype) 504 | self.assertTrue((EPS == se.sv.vectors).all()) 505 | self.assertIsNone(se.sv.vectors_norm) 506 | 507 | def test_reset_vectors_memmap(self): 508 | p = TEST_DATA / "test_vectors" 509 | p_target = TEST_DATA / "test_vectors.vectors" 510 | se = BaseSentence2VecModel(W2V, sv_mapfile_path=str(p.absolute())) 511 | trainables = BaseSentence2VecPreparer() 512 | trainables.reset_vectors(se.sv, 20) 513 | self.assertTrue(p_target.exists()) 514 | self.assertEqual((20, DIM), se.sv.vectors.shape) 515 | self.assertEqual(np.float32, se.sv.vectors.dtype) 516 | self.assertTrue((EPS == se.sv.vectors).all()) 517 | self.assertIsNone(se.sv.vectors_norm) 518 | p_target.unlink() 519 | 520 | def test_update_vectors(self): 521 | se = BaseSentence2VecModel(W2V) 522 | trainables = BaseSentence2VecPreparer() 523 | trainables.reset_vectors(se.sv, 20) 524 | se.sv.vectors[:] = 1.0 525 | trainables.update_vectors(se.sv, 10) 526 | self.assertEqual((30, DIM), se.sv.vectors.shape) 527 | self.assertEqual(np.float32, se.sv.vectors.dtype) 528 | self.assertTrue((np.ones((20, DIM)) == se.sv.vectors[:20]).all()) 529 | self.assertTrue((EPS == se.sv.vectors[20:]).all()) 530 | self.assertIsNone(se.sv.vectors_norm) 531 | 532 | def test_update_vectors_memmap(self): 533 | p = TEST_DATA / "test_vectors" 534 | p_target = TEST_DATA / "test_vectors.vectors" 535 | se = BaseSentence2VecModel(W2V, sv_mapfile_path=str(p.absolute())) 536 | trainables = BaseSentence2VecPreparer() 537 | trainables.reset_vectors(se.sv, 20) 538 | se.sv.vectors[:] = 1.0 539 | trainables.update_vectors(se.sv, 10) 540 | self.assertTrue(p_target.exists()) 541 | self.assertEqual((30, DIM), se.sv.vectors.shape) 542 | self.assertEqual(np.float32, se.sv.vectors.dtype) 543 | self.assertTrue((np.ones((20, DIM)) == se.sv.vectors[:20]).all()) 544 | self.assertTrue((EPS == se.sv.vectors[20:]).all()) 545 | self.assertIsNone(se.sv.vectors_norm) 546 | p_target.unlink() 547 | 548 | def test_prepare_vectors(self): 549 | se = BaseSentence2VecModel(W2V) 550 | trainables = BaseSentence2VecPreparer() 551 | trainables.prepare_vectors(se.sv, 20, update=False) 552 | self.assertEqual((20, DIM), se.sv.vectors.shape) 553 | trainables.prepare_vectors(se.sv, 40, update=True) 554 | self.assertEqual((60, DIM), se.sv.vectors.shape) 555 | 556 | 557 | if __name__ == "__main__": 558 | logging.basicConfig( 559 | format="%(asctime)s : %(levelname)s : %(message)s", level=logging.DEBUG 560 | ) 561 | unittest.main() 562 | --------------------------------------------------------------------------------