├── .github └── workflows │ └── wheels.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── README.rst ├── build-wheels.sh ├── pyproject.toml ├── python └── fasttext_module │ └── fasttext │ ├── FastText.py │ ├── __init__.py │ └── pybind │ └── fasttext_pybind.cc ├── setup.py └── src ├── args.cc ├── args.h ├── densematrix.cc ├── densematrix.h ├── dictionary.cc ├── dictionary.h ├── fasttext.cc ├── fasttext.h ├── loss.cc ├── loss.h ├── matrix.cc ├── matrix.h ├── model.cc ├── model.h ├── productquantizer.cc ├── productquantizer.h ├── quantmatrix.cc ├── quantmatrix.h ├── real.h ├── utils.h ├── vector.cc └── vector.h /.github/workflows/wheels.yml: -------------------------------------------------------------------------------- 1 | name: Wheel build 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | tags: 7 | - "v*.*.*" 8 | pull_request: 9 | branches: ["main"] 10 | 11 | jobs: 12 | sdist: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.x" 22 | 23 | - name: Install Python dependencies 24 | run: python -m pip install -U pip setuptools wheel 25 | 26 | - name: Build docs and sdist 27 | run: make sdist 28 | env: { STATIC_DEPS: false } 29 | 30 | - name: Upload sdist 31 | uses: actions/upload-artifact@v4 32 | with: 33 | name: sdist 34 | path: dist/*.tar.gz 35 | 36 | build_wheels: 37 | name: Build wheels on ${{ matrix.os }} 38 | strategy: 39 | fail-fast: false 40 | matrix: 41 | include: 42 | - os: ubuntu-latest 43 | arch: auto i686 aarch64 armv7l 44 | skip: cp36-* cp37-* cp38-* pp37-* pp38-* pp39-* 45 | - os: windows-latest 46 | arch: AMD64 ARM64 47 | skip: cp36-* cp37-* cp38-* pp37-* pp38-* pp39-* 48 | - os: macos-13 49 | arch: x86_64 arm64 universal2 50 | skip: cp36-* cp37-* cp38-* pp* 51 | - os: macos-14 52 | arch: x86_64 arm64 universal2 53 | skip: cp36-* cp37-* cp38-* pp* 54 | 55 | runs-on: ${{ matrix.os }} 56 | 57 | steps: 58 | - uses: actions/checkout@v4 59 | 60 | - uses: actions/setup-python@v5 61 | with: 62 | python-version: '3.12' 63 | 64 | - name: Install cibuildwheel 65 | run: python -m pip install cibuildwheel==2.22.0 66 | 67 | - name: Set up QEMU 68 | if: runner.os == 'Linux' 69 | uses: docker/setup-qemu-action@v3 70 | with: 71 | platforms: all 72 | 73 | - name: Build wheels 74 | run: python -m cibuildwheel --output-dir wheelhouse 75 | env: 76 | CIBW_ARCHS_LINUX: ${{ matrix.arch }} 77 | CIBW_SKIP: ${{ matrix.skip }} 78 | CIBW_ENABLE: cpython-freethreading 79 | 80 | - uses: actions/upload-artifact@v4 81 | with: 82 | name: cibw-wheels-${{ matrix.os }}-${{ strategy.job-index }} 83 | path: ./wheelhouse/*.whl 84 | 85 | release: 86 | name: Release 87 | runs-on: ubuntu-latest 88 | if: startsWith(github.ref, 'refs/tags/') 89 | needs: [ sdist, build_wheels ] 90 | steps: 91 | - uses: actions/download-artifact@v4 92 | with: 93 | path: artifacts 94 | - uses: actions/setup-python@v4 95 | with: 96 | python-version: '3.12' 97 | - name: Display structure of downloaded files 98 | run: | 99 | ls -R 100 | mkdir dist 101 | mv artifacts/sdist/*.tar.gz dist 102 | mv artifacts/*/*.whl dist 103 | - name: Publish to PyPi 104 | env: 105 | TWINE_USERNAME: __token__ 106 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 107 | run: | 108 | pip install --upgrade wheel pip setuptools twine 109 | twine upload --skip-existing dist/* 110 | - name: Release 111 | uses: softprops/action-gh-release@v2 112 | with: 113 | files: dist/* 114 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .*.swp 2 | *.o 3 | *.bin 4 | *.vec 5 | *.bc 6 | .DS_Store 7 | data 8 | fasttext 9 | result 10 | website/node_modules/ 11 | package-lock.json 12 | node_modules/ 13 | 14 | build/ 15 | dist/ 16 | wheelhouse/ 17 | python/fasttext_module/fasttext_predict.egg-info/ 18 | python/fasttext_module/fasttext_pybind.*.so 19 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | 78 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to fastText 2 | We want to make contributing to this project as easy and transparent as possible. 3 | 4 | ## Issues 5 | We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. 6 | 7 | ### Reproducing issues 8 | Please make sure that the issue you mention is not a result of one of the existing third-party libraries. For example, please do not post an issue if you encountered an error within a third-party Python library. We can only help you with errors which can be directly reproduced either with our C++ code or the corresponding Python bindings. If you do find an error, please post detailed steps to reproduce it. If we can't reproduce your error, we can't help you fix it. 9 | 10 | ## Pull Requests 11 | Please post an Issue before submitting a pull request. This might save you some time as it is possible we can't support your contribution, albeit we try our best to accomodate your (planned) work and highly appreciate your time. Generally, it is best to have a pull request emerge from an issue rather than the other way around. 12 | 13 | To create a pull request: 14 | 15 | 1. Fork the repo and create your branch from `master`. 16 | 2. If you've added code that should be tested, add tests. 17 | 3. If you've changed APIs, update the documentation. 18 | 4. Ensure the test suite passes. 19 | 5. Make sure your code lints. 20 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 21 | 22 | ## Tests 23 | First, you will need to make sure you have the required data. For that, please have a look at the fetch_test_data.sh script under tests. Next run the tests using the runtests.py script passing a path to the directory containing the datasets. 24 | 25 | ## Contributor License Agreement ("CLA") 26 | In order to accept your pull request, we need you to submit a CLA. You only need 27 | to do this once to work on any of Facebook's open source projects. 28 | 29 | Complete your CLA here: 30 | 31 | ## License 32 | By contributing to fastText, you agree that your contributions will be licensed under its MIT license. 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016-present, Facebook, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | 3 | recursive-include python *.md *.rst 4 | recursive-include src *.h 5 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2016-present, Facebook, Inc. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | CXX = c++ 10 | CXXFLAGS = -std=c++11 -march=native 11 | OBJS = args.o matrix.o dictionary.o loss.o productquantizer.o densematrix.o quantmatrix.o vector.o model.o fasttext.o 12 | INCLUDES = -I. 13 | 14 | args.o: src/args.cc src/args.h 15 | $(CXX) $(CXXFLAGS) -c src/args.cc 16 | 17 | matrix.o: src/matrix.cc src/matrix.h 18 | $(CXX) $(CXXFLAGS) -c src/matrix.cc 19 | 20 | dictionary.o: src/dictionary.cc src/dictionary.h src/args.h 21 | $(CXX) $(CXXFLAGS) -c src/dictionary.cc 22 | 23 | loss.o: src/loss.cc src/loss.h src/matrix.h src/real.h 24 | $(CXX) $(CXXFLAGS) -c src/loss.cc 25 | 26 | productquantizer.o: src/productquantizer.cc src/productquantizer.h src/utils.h 27 | $(CXX) $(CXXFLAGS) -c src/productquantizer.cc 28 | 29 | densematrix.o: src/densematrix.cc src/densematrix.h src/utils.h src/matrix.h 30 | $(CXX) $(CXXFLAGS) -c src/densematrix.cc 31 | 32 | quantmatrix.o: src/quantmatrix.cc src/quantmatrix.h src/utils.h src/matrix.h 33 | $(CXX) $(CXXFLAGS) -c src/quantmatrix.cc 34 | 35 | vector.o: src/vector.cc src/vector.h src/utils.h 36 | $(CXX) $(CXXFLAGS) -c src/vector.cc 37 | 38 | model.o: src/model.cc src/model.h src/args.h 39 | $(CXX) $(CXXFLAGS) -c src/model.cc 40 | 41 | fasttext.o: src/fasttext.cc src/*.h 42 | $(CXX) $(CXXFLAGS) -c src/fasttext.cc 43 | 44 | clean: 45 | rm -rf *.o *.gcno *.gcda 46 | 47 | ####### 48 | # from https://github.com/lxml/lxml/blob/2cd510258d03887dfad69e77edc47f8bf28773ae/Makefile 49 | 50 | FTFVERSION:=$(shell python setup.py -V' ) 51 | PYTHON_BUILD_VERSION ?= * 52 | 53 | .PHONY: sdist build require-cython wheel_manylinux wheel 54 | 55 | dist/fasttext-predict-$(FTFVERSION).tar.gz: 56 | pip install build 57 | python -m build --sdist 58 | 59 | sdist: dist/fasttext-predict-$(FTFVERSION).tar.gz 60 | 61 | qemu-user-static: 62 | docker run --rm --privileged multiarch/qemu-user-static --reset -p yes 63 | 64 | wheel_%: dist/fasttext-predict-$(FTFVERSION).tar.gz qemu-user-static 65 | time docker run --rm -t \ 66 | -v $(shell pwd):/io \ 67 | -e AR=gcc-ar \ 68 | -e NM=gcc-nm \ 69 | -e RANLIB=gcc-ranlib \ 70 | -e PYTHON_BUILD_VERSION="$(PYTHON_BUILD_VERSION)" \ 71 | -e WHEELHOUSE=$(subst wheel_,wheelhouse/,$@) \ 72 | quay.io/pypa/$(subst wheel_,,$@) \ 73 | bash /io/build-wheels.sh /io/$< 74 | 75 | wheel: 76 | pip install build 77 | python -m build --wheel 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fasttext-predict 2 | 3 | Python package for [fasttext](https://github.com/facebookresearch/fastText): 4 | 5 | * keep only the `predict` method, all other features are removed 6 | * standalone package without external dependency (numpy is not a dependency) 7 | * wheels for various architectures using GitHub workflows. The script is inspired by lxml build scripts. 8 | 9 | ## Usage 10 | 11 | ```sh 12 | pip install fasttext-predict 13 | ``` 14 | 15 | ```python 16 | import fasttext 17 | model = fasttext.load_model('lid.176.ftz') 18 | result = model.predict('Fondant au chocolat et tarte aux myrtilles') 19 | ``` 20 | 21 | See https://fasttext.cc/docs/en/language-identification.html 22 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | fasttext-predict 2 | ================ 3 | 4 | Python package for 5 | `fasttext `__: 6 | 7 | - keep only the ``predict`` method, all other features are removed 8 | - standalone package without external dependency (numpy is not a 9 | dependency) 10 | - wheels for various architectures using GitHub workflows. The script 11 | is inspired by lxml build scripts. 12 | 13 | Usage 14 | ----- 15 | 16 | .. code:: sh 17 | 18 | pip install fasttext-predict 19 | 20 | .. code:: python 21 | 22 | import fasttext 23 | model = fasttext.load_model('lid.176.ftz') 24 | result = model.predict('Fondant au chocolat et tarte aux myrtilles') 25 | 26 | See https://fasttext.cc/docs/en/language-identification.html 27 | -------------------------------------------------------------------------------- /build-wheels.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # Called inside the manylinux image 4 | # 5 | # based on https://github.com/lxml/lxml/blob/b224e0f69dde58425d1077e07d193d19d3f803a9/tools/manylinux/build-wheels.sh 6 | echo "Started $0 $@" 7 | 8 | set -e -x 9 | [ -n "$WHEELHOUSE" ] || WHEELHOUSE=wheelhouse 10 | SDIST=$1 11 | PACKAGE=$(basename ${SDIST%-*}) 12 | SDIST_PREFIX='fasttext_predict' 13 | [ -z "$PYTHON_BUILD_VERSION" ] && PYTHON_BUILD_VERSION="*" 14 | 15 | build_wheel() { 16 | echo "===[ build_wheels $1 $2 ]===" 17 | pybin="$1" 18 | source="$2" 19 | [ -n "$source" ] || source=/io 20 | 21 | ${pybin}/pip install --upgrade pip 22 | 23 | case $( uname -m ) in 24 | x86_64|i686|amd64) CFLAGS="$CFLAGS -march=core2";; 25 | aarch64) CFLAGS="$CFLAGS -march=armv8-a -mtune=cortex-a72";; 26 | esac 27 | 28 | rm -rf /io/build 29 | env STATIC_DEPS=true \ 30 | RUN_TESTS=true \ 31 | LDFLAGS="$LDFLAGS -fPIC" \ 32 | CFLAGS="$CFLAGS -fPIC" \ 33 | ACLOCAL_PATH=/usr/share/aclocal/ \ 34 | ${pybin}/pip \ 35 | wheel \ 36 | -v \ 37 | "$source" \ 38 | -w /io/$WHEELHOUSE 39 | } 40 | 41 | prepare_system() { 42 | echo "===[ prepare_system ]===" 43 | rm -fr /opt/python/cp27-* 44 | rm -fr /opt/python/cp34-* 45 | rm -fr /opt/python/cp35-* 46 | rm -fr /opt/python/cp36-* 47 | echo "Python versions found: $(cd /opt/python && echo cp* | sed -e 's|[^ ]*-||g')" 48 | ${CC:-gcc} --version 49 | } 50 | 51 | build_wheels() { 52 | echo "===[ build_wheels ]===" 53 | # Compile wheels for all python versions 54 | test -e "$SDIST" && source="$SDIST" || source= 55 | for PYBIN in /opt/python/${PYTHON_BUILD_VERSION}/bin; do 56 | echo "Starting build with $($PYBIN/python -V)" 57 | build_wheel "$PYBIN" "$source" 58 | done 59 | } 60 | 61 | repair_wheels() { 62 | echo "===[ repair_wheels ]===" 63 | # Bundle external shared libraries into the wheels 64 | for whl in /io/$WHEELHOUSE/${SDIST_PREFIX}-*.whl; do 65 | OPT="--strip" 66 | if [[ "$whl" == *x86_64.whl && "$whl" == *manylinux_2_24_x86_64* ]]; then 67 | OPT="$OPT --plat manylinux_2_34_x86_64" 68 | fi 69 | if [[ "$whl" == *x86_64.whl && "$whl" == *manylinux1_x86_64* ]]; then 70 | OPT="$OPT --plat manylinux1_x86_64" 71 | fi 72 | auditwheel show $whl 73 | auditwheel repair $whl $OPT -w /io/$WHEELHOUSE || exit 1 74 | done 75 | } 76 | 77 | show_wheels() { 78 | echo "===[ show_wheels ]===" 79 | ls -l /io/$WHEELHOUSE/${SDIST_PREFIX}-*.whl 80 | } 81 | 82 | prepare_system 83 | build_wheels 84 | repair_wheels 85 | show_wheels 86 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # based on https://github.com/pybind/python_example/blob/master/pyproject.toml 2 | [build-system] 3 | requires = [ 4 | "setuptools>=42", 5 | "pybind11>=2.10.0", 6 | "wheel", 7 | ] 8 | build-backend = "setuptools.build_meta" 9 | -------------------------------------------------------------------------------- /python/fasttext_module/fasttext/FastText.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | from __future__ import unicode_literals 11 | 12 | import fasttext_pybind as fasttext 13 | 14 | 15 | class _FastText(object): 16 | """ 17 | This class defines the API to inspect models and should not be used to 18 | create objects. It will be returned by functions such as load_model or 19 | train. 20 | 21 | In general this API assumes to be given only unicode for Python2 and the 22 | Python3 equvalent called str for any string-like arguments. All unicode 23 | strings are then encoded as UTF-8 and fed to the fastText C++ API. 24 | """ 25 | 26 | def __init__(self, model_path=None, args=None): 27 | self.f = fasttext.fasttext() 28 | if model_path is not None: 29 | self.f.loadModel(model_path) 30 | if args is not None: 31 | raise RuntimeError('args argument is not supported') 32 | 33 | def predict(self, text, k=1, threshold=0.0, on_unicode_error='strict'): 34 | """ 35 | Given a string, get a list of labels and a list of 36 | corresponding probabilities. k controls the number 37 | of returned labels. A choice of 5, will return the 5 38 | most probable labels. By default this returns only 39 | the most likely label and probability. threshold filters 40 | the returned labels by a threshold on probability. A 41 | choice of 0.5 will return labels with at least 0.5 42 | probability. k and threshold will be applied together to 43 | determine the returned labels. 44 | 45 | This function assumes to be given 46 | a single line of text. We split words on whitespace (space, 47 | newline, tab, vertical tab) and the control characters carriage 48 | return, formfeed and the null character. 49 | 50 | If the model is not supervised, this function will throw a ValueError. 51 | 52 | If given a list of strings, it will return a list of results as usually 53 | received for a single line of text. 54 | """ 55 | 56 | def check(entry): 57 | if entry.find('\n') != -1: 58 | raise ValueError( 59 | "predict processes one line at a time (remove \'\\n\')" 60 | ) 61 | entry += "\n" 62 | return entry 63 | 64 | if type(text) == list: 65 | text = [check(entry) for entry in text] 66 | all_labels, all_probs = self.f.multilinePredict( 67 | text, k, threshold, on_unicode_error) 68 | 69 | return all_labels, all_probs 70 | else: 71 | text = check(text) 72 | predictions = self.f.predict(text, k, threshold, on_unicode_error) 73 | if predictions: 74 | probs, labels = zip(*predictions) 75 | else: 76 | probs, labels = ([], ()) 77 | return (labels, probs) 78 | 79 | 80 | def load_model(path): 81 | """Load a model given a filepath and return a model object.""" 82 | return _FastText(model_path=path) 83 | -------------------------------------------------------------------------------- /python/fasttext_module/fasttext/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | from __future__ import unicode_literals 11 | 12 | from .FastText import load_model 13 | -------------------------------------------------------------------------------- /python/fasttext_module/fasttext/pybind/fasttext_pybind.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2017-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | using namespace pybind11::literals; 21 | namespace py = pybind11; 22 | 23 | py::str castToPythonString(const std::string& s, const char* onUnicodeError) { 24 | PyObject* handle = PyUnicode_DecodeUTF8(s.data(), s.length(), onUnicodeError); 25 | if (!handle) { 26 | throw py::error_already_set(); 27 | } 28 | 29 | // py::str's constructor from a PyObject assumes the string has been encoded 30 | // for python 2 and not encoded for python 3 : 31 | // https://github.com/pybind/pybind11/blob/ccbe68b084806dece5863437a7dc93de20bd9b15/include/pybind11/pytypes.h#L930 32 | #if PY_MAJOR_VERSION < 3 33 | PyObject* handle_encoded = 34 | PyUnicode_AsEncodedString(handle, "utf-8", onUnicodeError); 35 | Py_DECREF(handle); 36 | handle = handle_encoded; 37 | #endif 38 | 39 | py::str handle_str = py::str(handle); 40 | Py_DECREF(handle); 41 | return handle_str; 42 | } 43 | 44 | std::vector> castToPythonString( 45 | const std::vector>& predictions, 46 | const char* onUnicodeError) { 47 | std::vector> transformedPredictions; 48 | 49 | for (const auto& prediction : predictions) { 50 | transformedPredictions.emplace_back( 51 | prediction.first, 52 | castToPythonString(prediction.second, onUnicodeError)); 53 | } 54 | 55 | return transformedPredictions; 56 | } 57 | 58 | PYBIND11_MODULE(fasttext_pybind, m) { 59 | py::class_(m, "fasttext") 60 | .def(py::init<>()) 61 | .def( 62 | "loadModel", 63 | [](fasttext::FastText& m, std::string s) { m.loadModel(s); }) 64 | .def( 65 | "predict", 66 | // NOTE: text needs to end in a newline 67 | // to exactly mimic the behavior of the cli 68 | [](fasttext::FastText& m, 69 | const std::string text, 70 | int32_t k, 71 | fasttext::real threshold, 72 | const char* onUnicodeError) { 73 | std::stringstream ioss(text); 74 | std::vector> predictions; 75 | m.predictLine(ioss, predictions, k, threshold); 76 | 77 | return castToPythonString(predictions, onUnicodeError); 78 | }) 79 | .def( 80 | "multilinePredict", 81 | // NOTE: text needs to end in a newline 82 | // to exactly mimic the behavior of the cli 83 | [](fasttext::FastText& m, 84 | const std::vector& lines, 85 | int32_t k, 86 | fasttext::real threshold, 87 | const char* onUnicodeError) { 88 | std::vector> allLabels; 89 | std::vector> predictions; 90 | 91 | for (const std::string& text : lines) { 92 | std::stringstream ioss(text); 93 | m.predictLine(ioss, predictions, k, threshold); 94 | std::vector probabilities; 95 | std::vector labels; 96 | 97 | for (const auto& prediction : predictions) { 98 | probabilities.push_back(prediction.first); 99 | labels.push_back( 100 | castToPythonString(prediction.second, onUnicodeError)); 101 | } 102 | allLabels.push_back(labels); 103 | } 104 | 105 | return allLabels; 106 | }); 107 | } 108 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) 2017-present, Facebook, Inc. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the MIT license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # 9 | 10 | import sys 11 | import os 12 | import sysconfig 13 | import io 14 | 15 | from pybind11.setup_helpers import Pybind11Extension, ParallelCompile 16 | from setuptools import setup 17 | 18 | ParallelCompile().install() 19 | 20 | __version__ = '0.9.2.4' 21 | FASTTEXT_SRC = "src" 22 | 23 | WIN = sys.platform.startswith("win32") and "mingw" not in sysconfig.get_platform() 24 | 25 | fasttext_src_files = map(str, os.listdir(FASTTEXT_SRC)) 26 | fasttext_src_cc = list(filter(lambda x: x.endswith('.cc'), fasttext_src_files)) 27 | 28 | fasttext_src_cc = list( 29 | map(lambda x: str(os.path.join(FASTTEXT_SRC, x)), fasttext_src_cc) 30 | ) 31 | 32 | extra_compile_args = [] 33 | if WIN: 34 | extra_compile_args.append('/DVERSION_INFO=\\"%s\\"' % __version__) 35 | else: 36 | extra_compile_args.append('-DVERSION_INFO="%s"' % __version__) 37 | extra_compile_args.extend(["-O3", "-flto"]) 38 | 39 | 40 | def _get_readme(): 41 | """ 42 | Use pandoc to generate rst from md. 43 | pandoc --from=markdown --to=rst --output=python/README.rst python/README.md 44 | """ 45 | with io.open("README.rst", encoding='utf-8') as fid: 46 | return fid.read() 47 | 48 | 49 | setup( 50 | name='fasttext-predict', 51 | version=__version__, 52 | author='Alexandre Flament', 53 | author_email='alex.andre@al-f.net', 54 | keywords=['fasttext', 'language detection', 'language identification'], 55 | description='fasttext with wheels and no external dependency, but only the predict method (<1MB)', 56 | long_description=_get_readme(), 57 | url='https://github.com/searxng/fasttext-predict/', 58 | license='MIT', 59 | classifiers=[ 60 | 'Development Status :: 3 - Alpha', 61 | 'Intended Audience :: Developers', 62 | 'Intended Audience :: Science/Research', 63 | 'License :: OSI Approved :: MIT License', 64 | 'Programming Language :: Python :: 3.9', 65 | 'Programming Language :: Python :: 3.10', 66 | 'Programming Language :: Python :: 3.11', 67 | 'Programming Language :: Python :: 3.12', 68 | 'Programming Language :: Python :: 3.13', 69 | 'Topic :: Software Development', 70 | 'Topic :: Scientific/Engineering', 71 | 'Operating System :: Microsoft :: Windows', 72 | 'Operating System :: POSIX', 73 | 'Operating System :: Unix', 74 | 'Operating System :: MacOS', 75 | ], 76 | packages=[ 77 | 'fasttext', 78 | ], 79 | package_dir={ 80 | '': 'python/fasttext_module' 81 | }, 82 | zip_safe=False, 83 | ext_modules=[ 84 | Pybind11Extension( 85 | "fasttext_pybind", 86 | ["python/fasttext_module/fasttext/pybind/fasttext_pybind.cc"] + fasttext_src_cc, 87 | include_dirs=[ 88 | FASTTEXT_SRC, 89 | ], 90 | cxx_std=17, 91 | extra_compile_args=extra_compile_args, 92 | ) 93 | ], 94 | ) 95 | -------------------------------------------------------------------------------- /src/args.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "args.h" 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | namespace fasttext { 19 | 20 | Args::Args() { 21 | lr = 0.05; 22 | dim = 100; 23 | ws = 5; 24 | epoch = 5; 25 | minCount = 5; 26 | minCountLabel = 0; 27 | neg = 5; 28 | wordNgrams = 1; 29 | loss = loss_name::ns; 30 | model = model_name::sg; 31 | bucket = 2000000; 32 | minn = 3; 33 | maxn = 6; 34 | thread = 12; 35 | lrUpdateRate = 100; 36 | t = 1e-4; 37 | label = "__label__"; 38 | verbose = 2; 39 | pretrainedVectors = ""; 40 | saveOutput = false; 41 | seed = 0; 42 | 43 | qout = false; 44 | retrain = false; 45 | qnorm = false; 46 | cutoff = 0; 47 | dsub = 2; 48 | } 49 | 50 | void Args::load(std::istream& in) { 51 | in.read((char*)&(dim), sizeof(int)); 52 | in.read((char*)&(ws), sizeof(int)); 53 | in.read((char*)&(epoch), sizeof(int)); 54 | in.read((char*)&(minCount), sizeof(int)); 55 | in.read((char*)&(neg), sizeof(int)); 56 | in.read((char*)&(wordNgrams), sizeof(int)); 57 | in.read((char*)&(loss), sizeof(loss_name)); 58 | in.read((char*)&(model), sizeof(model_name)); 59 | in.read((char*)&(bucket), sizeof(int)); 60 | in.read((char*)&(minn), sizeof(int)); 61 | in.read((char*)&(maxn), sizeof(int)); 62 | in.read((char*)&(lrUpdateRate), sizeof(int)); 63 | in.read((char*)&(t), sizeof(double)); 64 | } 65 | 66 | } // namespace fasttext 67 | -------------------------------------------------------------------------------- /src/args.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace fasttext { 18 | 19 | enum class model_name : int { cbow = 1, sg, sup }; 20 | enum class loss_name : int { hs = 1, ns, softmax, ova }; 21 | enum class metric_name : int { 22 | f1score = 1, 23 | f1scoreLabel, 24 | precisionAtRecall, 25 | precisionAtRecallLabel, 26 | recallAtPrecision, 27 | recallAtPrecisionLabel 28 | }; 29 | 30 | class Args { 31 | public: 32 | Args(); 33 | std::string input; 34 | std::string output; 35 | double lr; 36 | int lrUpdateRate; 37 | int dim; 38 | int ws; 39 | int epoch; 40 | int minCount; 41 | int minCountLabel; 42 | int neg; 43 | int wordNgrams; 44 | loss_name loss; 45 | model_name model; 46 | int bucket; 47 | int minn; 48 | int maxn; 49 | int thread; 50 | double t; 51 | std::string label; 52 | int verbose; 53 | std::string pretrainedVectors; 54 | bool saveOutput; 55 | int seed; 56 | 57 | bool qout; 58 | bool retrain; 59 | bool qnorm; 60 | size_t cutoff; 61 | size_t dsub; 62 | 63 | void load(std::istream&); 64 | 65 | }; 66 | } // namespace fasttext 67 | -------------------------------------------------------------------------------- /src/densematrix.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "densematrix.h" 10 | 11 | #include 12 | #include 13 | #include 14 | #include "utils.h" 15 | #include "vector.h" 16 | 17 | namespace fasttext { 18 | 19 | DenseMatrix::DenseMatrix() : DenseMatrix(0, 0) {} 20 | 21 | DenseMatrix::DenseMatrix(int64_t m, int64_t n) : Matrix(m, n), data_(m * n) {} 22 | 23 | DenseMatrix::DenseMatrix(DenseMatrix&& other) noexcept 24 | : Matrix(other.m_, other.n_), data_(std::move(other.data_)) {} 25 | 26 | DenseMatrix::DenseMatrix(int64_t m, int64_t n, real* dataPtr) 27 | : Matrix(m, n), data_(dataPtr, dataPtr + (m * n)) {} 28 | 29 | real DenseMatrix::dotRow(const Vector& vec, int64_t i) const { 30 | assert(i >= 0); 31 | assert(i < m_); 32 | assert(vec.size() == n_); 33 | real d = 0.0; 34 | for (int64_t j = 0; j < n_; j++) { 35 | d += at(i, j) * vec[j]; 36 | } 37 | if (std::isnan(d)) { 38 | throw EncounteredNaNError(); 39 | } 40 | return d; 41 | } 42 | 43 | void DenseMatrix::addRowToVector(Vector& x, int32_t i) const { 44 | assert(i >= 0); 45 | assert(i < this->size(0)); 46 | assert(x.size() == this->size(1)); 47 | for (int64_t j = 0; j < n_; j++) { 48 | x[j] += at(i, j); 49 | } 50 | } 51 | 52 | void DenseMatrix::addRowToVector(Vector& x, int32_t i, real a) const { 53 | assert(i >= 0); 54 | assert(i < this->size(0)); 55 | assert(x.size() == this->size(1)); 56 | for (int64_t j = 0; j < n_; j++) { 57 | x[j] += a * at(i, j); 58 | } 59 | } 60 | 61 | void DenseMatrix::load(std::istream& in) { 62 | in.read((char*)&m_, sizeof(int64_t)); 63 | in.read((char*)&n_, sizeof(int64_t)); 64 | data_ = std::vector(m_ * n_); 65 | in.read((char*)data_.data(), m_ * n_ * sizeof(real)); 66 | } 67 | 68 | } // namespace fasttext 69 | -------------------------------------------------------------------------------- /src/densematrix.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #include "matrix.h" 19 | #include "real.h" 20 | 21 | namespace fasttext { 22 | 23 | class Vector; 24 | 25 | class DenseMatrix : public Matrix { 26 | protected: 27 | std::vector data_; 28 | 29 | public: 30 | DenseMatrix(); 31 | explicit DenseMatrix(int64_t, int64_t); 32 | explicit DenseMatrix(int64_t m, int64_t n, real* dataPtr); 33 | DenseMatrix(const DenseMatrix&) = default; 34 | DenseMatrix(DenseMatrix&&) noexcept; 35 | DenseMatrix& operator=(const DenseMatrix&) = delete; 36 | DenseMatrix& operator=(DenseMatrix&&) = delete; 37 | virtual ~DenseMatrix() noexcept override = default; 38 | 39 | inline real* data() { 40 | return data_.data(); 41 | } 42 | inline const real* data() const { 43 | return data_.data(); 44 | } 45 | 46 | inline const real& at(int64_t i, int64_t j) const { 47 | assert(i * n_ + j < data_.size()); 48 | return data_[i * n_ + j]; 49 | }; 50 | inline real& at(int64_t i, int64_t j) { 51 | return data_[i * n_ + j]; 52 | }; 53 | 54 | inline int64_t rows() const { 55 | return m_; 56 | } 57 | inline int64_t cols() const { 58 | return n_; 59 | } 60 | 61 | real dotRow(const Vector&, int64_t) const override; 62 | void addRowToVector(Vector& x, int32_t i) const override; 63 | void addRowToVector(Vector& x, int32_t i, real a) const override; 64 | void load(std::istream&) override; 65 | 66 | class EncounteredNaNError : public std::runtime_error { 67 | public: 68 | EncounteredNaNError() : std::runtime_error("Encountered NaN.") {} 69 | }; 70 | }; 71 | } // namespace fasttext 72 | -------------------------------------------------------------------------------- /src/dictionary.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "dictionary.h" 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | namespace fasttext { 21 | 22 | const std::string Dictionary::EOS = ""; 23 | const std::string Dictionary::BOW = "<"; 24 | const std::string Dictionary::EOW = ">"; 25 | 26 | Dictionary::Dictionary(std::shared_ptr args) 27 | : args_(args), 28 | word2int_(MAX_VOCAB_SIZE, -1), 29 | size_(0), 30 | nwords_(0), 31 | nlabels_(0), 32 | ntokens_(0), 33 | pruneidx_size_(-1) {} 34 | 35 | Dictionary::Dictionary(std::shared_ptr args, std::istream& in) 36 | : args_(args), 37 | size_(0), 38 | nwords_(0), 39 | nlabels_(0), 40 | ntokens_(0), 41 | pruneidx_size_(-1) { 42 | load(in); 43 | } 44 | 45 | int32_t Dictionary::find(const std::string& w) const { 46 | return find(w, hash(w)); 47 | } 48 | 49 | int32_t Dictionary::find(const std::string& w, uint32_t h) const { 50 | int32_t word2intsize = word2int_.size(); 51 | int32_t id = h % word2intsize; 52 | while (word2int_[id] != -1 && words_[word2int_[id]].word != w) { 53 | id = (id + 1) % word2intsize; 54 | } 55 | return id; 56 | } 57 | 58 | int32_t Dictionary::nwords() const { 59 | return nwords_; 60 | } 61 | 62 | int32_t Dictionary::nlabels() const { 63 | return nlabels_; 64 | } 65 | 66 | int64_t Dictionary::ntokens() const { 67 | return ntokens_; 68 | } 69 | 70 | const std::vector& Dictionary::getSubwords(int32_t i) const { 71 | assert(i >= 0); 72 | assert(i < nwords_); 73 | return words_[i].subwords; 74 | } 75 | 76 | bool Dictionary::discard(int32_t id, real rand) const { 77 | assert(id >= 0); 78 | assert(id < nwords_); 79 | if (args_->model == model_name::sup) { 80 | return false; 81 | } 82 | return rand > pdiscard_[id]; 83 | } 84 | 85 | int32_t Dictionary::getId(const std::string& w, uint32_t h) const { 86 | int32_t id = find(w, h); 87 | return word2int_[id]; 88 | } 89 | 90 | entry_type Dictionary::getType(int32_t id) const { 91 | assert(id >= 0); 92 | assert(id < size_); 93 | return words_[id].type; 94 | } 95 | 96 | entry_type Dictionary::getType(const std::string& w) const { 97 | return (w.find(args_->label) == 0) ? entry_type::label : entry_type::word; 98 | } 99 | 100 | // The correct implementation of fnv should be: 101 | // h = h ^ uint32_t(uint8_t(str[i])); 102 | // Unfortunately, earlier version of fasttext used 103 | // h = h ^ uint32_t(str[i]); 104 | // which is undefined behavior (as char can be signed or unsigned). 105 | // Since all fasttext models that were already released were trained 106 | // using signed char, we fixed the hash function to make models 107 | // compatible whatever compiler is used. 108 | uint32_t Dictionary::hash(const std::string& str) const { 109 | uint32_t h = 2166136261; 110 | for (size_t i = 0; i < str.size(); i++) { 111 | h = h ^ uint32_t(int8_t(str[i])); 112 | h = h * 16777619; 113 | } 114 | return h; 115 | } 116 | 117 | void Dictionary::computeSubwords( 118 | const std::string& word, 119 | std::vector& ngrams, 120 | std::vector* substrings) const { 121 | for (size_t i = 0; i < word.size(); i++) { 122 | std::string ngram; 123 | if ((word[i] & 0xC0) == 0x80) { 124 | continue; 125 | } 126 | for (size_t j = i, n = 1; j < word.size() && n <= args_->maxn; n++) { 127 | ngram.push_back(word[j++]); 128 | while (j < word.size() && (word[j] & 0xC0) == 0x80) { 129 | ngram.push_back(word[j++]); 130 | } 131 | if (n >= args_->minn && !(n == 1 && (i == 0 || j == word.size()))) { 132 | int32_t h = hash(ngram) % args_->bucket; 133 | pushHash(ngrams, h); 134 | if (substrings) { 135 | substrings->push_back(ngram); 136 | } 137 | } 138 | } 139 | } 140 | } 141 | 142 | void Dictionary::initNgrams() { 143 | for (size_t i = 0; i < size_; i++) { 144 | std::string word = BOW + words_[i].word + EOW; 145 | words_[i].subwords.clear(); 146 | words_[i].subwords.push_back(i); 147 | if (words_[i].word != EOS) { 148 | computeSubwords(word, words_[i].subwords); 149 | } 150 | } 151 | } 152 | 153 | bool Dictionary::readWord(std::istream& in, std::string& word) const { 154 | int c; 155 | std::streambuf& sb = *in.rdbuf(); 156 | word.clear(); 157 | while ((c = sb.sbumpc()) != EOF) { 158 | if (c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == '\v' || 159 | c == '\f' || c == '\0') { 160 | if (word.empty()) { 161 | if (c == '\n') { 162 | word += EOS; 163 | return true; 164 | } 165 | continue; 166 | } else { 167 | if (c == '\n') 168 | sb.sungetc(); 169 | return true; 170 | } 171 | } 172 | word.push_back(c); 173 | } 174 | // trigger eofbit 175 | in.get(); 176 | return !word.empty(); 177 | } 178 | 179 | void Dictionary::initTableDiscard() { 180 | pdiscard_.resize(size_); 181 | for (size_t i = 0; i < size_; i++) { 182 | real f = real(words_[i].count) / real(ntokens_); 183 | pdiscard_[i] = std::sqrt(args_->t / f) + args_->t / f; 184 | } 185 | } 186 | 187 | std::vector Dictionary::getCounts(entry_type type) const { 188 | std::vector counts; 189 | for (auto& w : words_) { 190 | if (w.type == type) { 191 | counts.push_back(w.count); 192 | } 193 | } 194 | return counts; 195 | } 196 | 197 | void Dictionary::addWordNgrams( 198 | std::vector& line, 199 | const std::vector& hashes, 200 | int32_t n) const { 201 | for (int32_t i = 0; i < hashes.size(); i++) { 202 | uint64_t h = hashes[i]; 203 | for (int32_t j = i + 1; j < hashes.size() && j < i + n; j++) { 204 | h = h * 116049371 + hashes[j]; 205 | pushHash(line, h % args_->bucket); 206 | } 207 | } 208 | } 209 | 210 | void Dictionary::addSubwords( 211 | std::vector& line, 212 | const std::string& token, 213 | int32_t wid) const { 214 | if (wid < 0) { // out of vocab 215 | if (token != EOS) { 216 | computeSubwords(BOW + token + EOW, line); 217 | } 218 | } else { 219 | if (args_->maxn <= 0) { // in vocab w/o subwords 220 | line.push_back(wid); 221 | } else { // in vocab w/ subwords 222 | const std::vector& ngrams = getSubwords(wid); 223 | line.insert(line.end(), ngrams.cbegin(), ngrams.cend()); 224 | } 225 | } 226 | } 227 | 228 | void Dictionary::reset(std::istream& in) const { 229 | if (in.eof()) { 230 | in.clear(); 231 | in.seekg(std::streampos(0)); 232 | } 233 | } 234 | 235 | int32_t Dictionary::getLine( 236 | std::istream& in, 237 | std::vector& words, 238 | std::minstd_rand& rng) const { 239 | std::uniform_real_distribution<> uniform(0, 1); 240 | std::string token; 241 | int32_t ntokens = 0; 242 | 243 | reset(in); 244 | words.clear(); 245 | while (readWord(in, token)) { 246 | int32_t h = find(token); 247 | int32_t wid = word2int_[h]; 248 | if (wid < 0) { 249 | continue; 250 | } 251 | 252 | ntokens++; 253 | if (getType(wid) == entry_type::word && !discard(wid, uniform(rng))) { 254 | words.push_back(wid); 255 | } 256 | if (ntokens > MAX_LINE_SIZE || token == EOS) { 257 | break; 258 | } 259 | } 260 | return ntokens; 261 | } 262 | 263 | int32_t Dictionary::getLine( 264 | std::istream& in, 265 | std::vector& words, 266 | std::vector& labels) const { 267 | std::vector word_hashes; 268 | std::string token; 269 | int32_t ntokens = 0; 270 | 271 | reset(in); 272 | words.clear(); 273 | labels.clear(); 274 | while (readWord(in, token)) { 275 | uint32_t h = hash(token); 276 | int32_t wid = getId(token, h); 277 | entry_type type = wid < 0 ? getType(token) : getType(wid); 278 | 279 | ntokens++; 280 | if (type == entry_type::word) { 281 | addSubwords(words, token, wid); 282 | word_hashes.push_back(h); 283 | } else if (type == entry_type::label && wid >= 0) { 284 | labels.push_back(wid - nwords_); 285 | } 286 | if (token == EOS) { 287 | break; 288 | } 289 | } 290 | addWordNgrams(words, word_hashes, args_->wordNgrams); 291 | return ntokens; 292 | } 293 | 294 | void Dictionary::pushHash(std::vector& hashes, int32_t id) const { 295 | if (pruneidx_size_ == 0 || id < 0) { 296 | return; 297 | } 298 | if (pruneidx_size_ > 0) { 299 | if (pruneidx_.count(id)) { 300 | id = pruneidx_.at(id); 301 | } else { 302 | return; 303 | } 304 | } 305 | hashes.push_back(nwords_ + id); 306 | } 307 | 308 | std::string Dictionary::getLabel(int32_t lid) const { 309 | if (lid < 0 || lid >= nlabels_) { 310 | throw std::invalid_argument( 311 | "Label id is out of range [0, " + std::to_string(nlabels_) + "]"); 312 | } 313 | return words_[lid + nwords_].word; 314 | } 315 | 316 | void Dictionary::load(std::istream& in) { 317 | words_.clear(); 318 | in.read((char*)&size_, sizeof(int32_t)); 319 | in.read((char*)&nwords_, sizeof(int32_t)); 320 | in.read((char*)&nlabels_, sizeof(int32_t)); 321 | in.read((char*)&ntokens_, sizeof(int64_t)); 322 | in.read((char*)&pruneidx_size_, sizeof(int64_t)); 323 | for (int32_t i = 0; i < size_; i++) { 324 | char c; 325 | entry e; 326 | while ((c = in.get()) != 0) { 327 | e.word.push_back(c); 328 | } 329 | in.read((char*)&e.count, sizeof(int64_t)); 330 | in.read((char*)&e.type, sizeof(entry_type)); 331 | words_.push_back(e); 332 | } 333 | pruneidx_.clear(); 334 | for (int32_t i = 0; i < pruneidx_size_; i++) { 335 | int32_t first; 336 | int32_t second; 337 | in.read((char*)&first, sizeof(int32_t)); 338 | in.read((char*)&second, sizeof(int32_t)); 339 | pruneidx_[first] = second; 340 | } 341 | initTableDiscard(); 342 | initNgrams(); 343 | 344 | int32_t word2intsize = std::ceil(size_ / 0.7); 345 | word2int_.assign(word2intsize, -1); 346 | for (int32_t i = 0; i < size_; i++) { 347 | word2int_[find(words_[i].word)] = i; 348 | } 349 | } 350 | 351 | } // namespace fasttext 352 | -------------------------------------------------------------------------------- /src/dictionary.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | #include "args.h" 20 | #include "real.h" 21 | 22 | namespace fasttext { 23 | 24 | typedef int32_t id_type; 25 | enum class entry_type : int8_t { word = 0, label = 1 }; 26 | 27 | struct entry { 28 | std::string word; 29 | int64_t count; 30 | entry_type type; 31 | std::vector subwords; 32 | }; 33 | 34 | class Dictionary { 35 | protected: 36 | static const int32_t MAX_VOCAB_SIZE = 30000000; 37 | static const int32_t MAX_LINE_SIZE = 1024; 38 | 39 | int32_t find(const std::string&) const; 40 | int32_t find(const std::string&, uint32_t h) const; 41 | void initTableDiscard(); 42 | void initNgrams(); 43 | void reset(std::istream&) const; 44 | void pushHash(std::vector&, int32_t) const; 45 | void addSubwords(std::vector&, const std::string&, int32_t) const; 46 | 47 | std::shared_ptr args_; 48 | std::vector word2int_; 49 | std::vector words_; 50 | 51 | std::vector pdiscard_; 52 | int32_t size_; 53 | int32_t nwords_; 54 | int32_t nlabels_; 55 | int64_t ntokens_; 56 | 57 | int64_t pruneidx_size_; 58 | std::unordered_map pruneidx_; 59 | void addWordNgrams( 60 | std::vector& line, 61 | const std::vector& hashes, 62 | int32_t n) const; 63 | 64 | public: 65 | static const std::string EOS; 66 | static const std::string BOW; 67 | static const std::string EOW; 68 | 69 | explicit Dictionary(std::shared_ptr); 70 | explicit Dictionary(std::shared_ptr, std::istream&); 71 | int32_t nwords() const; 72 | int32_t nlabels() const; 73 | int64_t ntokens() const; 74 | int32_t getId(const std::string&, uint32_t h) const; 75 | entry_type getType(int32_t) const; 76 | entry_type getType(const std::string&) const; 77 | bool discard(int32_t, real) const; 78 | const std::vector& getSubwords(int32_t) const; 79 | void computeSubwords( 80 | const std::string&, 81 | std::vector&, 82 | std::vector* substrings = nullptr) const; 83 | uint32_t hash(const std::string& str) const; 84 | bool readWord(std::istream&, std::string&) const; 85 | std::string getLabel(int32_t) const; 86 | void load(std::istream&); 87 | std::vector getCounts(entry_type) const; 88 | int32_t getLine(std::istream&, std::vector&, std::vector&) 89 | const; 90 | int32_t getLine(std::istream&, std::vector&, std::minstd_rand&) 91 | const; 92 | void threshold(int64_t, int64_t); 93 | bool isPruned() { 94 | return pruneidx_size_ >= 0; 95 | } 96 | }; 97 | 98 | } // namespace fasttext 99 | -------------------------------------------------------------------------------- /src/fasttext.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "fasttext.h" 10 | #include "loss.h" 11 | #include "quantmatrix.h" 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | namespace fasttext { 23 | 24 | constexpr int32_t FASTTEXT_VERSION = 12; /* Version 1b */ 25 | constexpr int32_t FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314; 26 | 27 | 28 | std::shared_ptr FastText::createLoss(std::shared_ptr& output) { 29 | loss_name lossName = args_->loss; 30 | switch (lossName) { 31 | case loss_name::hs: 32 | return std::make_shared( 33 | output, getTargetCounts()); 34 | case loss_name::ns: 35 | return std::make_shared( 36 | output, args_->neg, getTargetCounts()); 37 | case loss_name::softmax: 38 | return std::make_shared(output); 39 | case loss_name::ova: 40 | return std::make_shared(output); 41 | default: 42 | throw std::runtime_error("Unknown loss"); 43 | } 44 | } 45 | 46 | bool FastText::checkModel(std::istream& in) { 47 | int32_t magic; 48 | in.read((char*)&(magic), sizeof(int32_t)); 49 | if (magic != FASTTEXT_FILEFORMAT_MAGIC_INT32) { 50 | return false; 51 | } 52 | in.read((char*)&(version), sizeof(int32_t)); 53 | if (version > FASTTEXT_VERSION) { 54 | return false; 55 | } 56 | return true; 57 | } 58 | 59 | void FastText::loadModel(const std::string& filename) { 60 | std::ifstream ifs(filename, std::ifstream::binary); 61 | if (!ifs.is_open()) { 62 | throw std::invalid_argument(filename + " cannot be opened for loading!"); 63 | } 64 | if (!checkModel(ifs)) { 65 | throw std::invalid_argument(filename + " has wrong file format!"); 66 | } 67 | loadModel(ifs); 68 | ifs.close(); 69 | } 70 | 71 | std::vector FastText::getTargetCounts() const { 72 | if (args_->model == model_name::sup) { 73 | return dict_->getCounts(entry_type::label); 74 | } else { 75 | return dict_->getCounts(entry_type::word); 76 | } 77 | } 78 | 79 | void FastText::buildModel() { 80 | auto loss = createLoss(output_); 81 | bool normalizeGradient = (args_->model == model_name::sup); 82 | model_ = std::make_shared(input_, output_, loss, normalizeGradient); 83 | } 84 | 85 | void FastText::loadModel(std::istream& in) { 86 | args_ = std::make_shared(); 87 | input_ = std::make_shared(); 88 | output_ = std::make_shared(); 89 | args_->load(in); 90 | if (version == 11 && args_->model == model_name::sup) { 91 | // backward compatibility: old supervised models do not use char ngrams. 92 | args_->maxn = 0; 93 | } 94 | dict_ = std::make_shared(args_, in); 95 | 96 | bool quant_input; 97 | in.read((char*)&quant_input, sizeof(bool)); 98 | if (quant_input) { 99 | quant_ = true; 100 | input_ = std::make_shared(); 101 | } 102 | input_->load(in); 103 | 104 | if (!quant_input && dict_->isPruned()) { 105 | throw std::invalid_argument( 106 | "Invalid model file.\n" 107 | "Please download the updated model from www.fasttext.cc.\n" 108 | "See issue #332 on Github for more information.\n"); 109 | } 110 | 111 | in.read((char*)&args_->qout, sizeof(bool)); 112 | if (quant_ && args_->qout) { 113 | output_ = std::make_shared(); 114 | } 115 | output_->load(in); 116 | 117 | buildModel(); 118 | } 119 | 120 | void FastText::predict( 121 | int32_t k, 122 | const std::vector& words, 123 | Predictions& predictions, 124 | real threshold) const { 125 | if (words.empty()) { 126 | return; 127 | } 128 | Model::State state(args_->dim, dict_->nlabels(), 0); 129 | if (args_->model != model_name::sup) { 130 | throw std::invalid_argument("Model needs to be supervised for prediction!"); 131 | } 132 | model_->predict(words, k, threshold, predictions, state); 133 | } 134 | 135 | bool FastText::predictLine( 136 | std::istream& in, 137 | std::vector>& predictions, 138 | int32_t k, 139 | real threshold) const { 140 | predictions.clear(); 141 | if (in.peek() == EOF) { 142 | return false; 143 | } 144 | 145 | std::vector words, labels; 146 | dict_->getLine(in, words, labels); 147 | Predictions linePredictions; 148 | predict(k, words, linePredictions, threshold); 149 | for (const auto& p : linePredictions) { 150 | predictions.push_back( 151 | std::make_pair(std::exp(p.first), dict_->getLabel(p.second))); 152 | } 153 | 154 | return true; 155 | } 156 | 157 | } // namespace fasttext 158 | -------------------------------------------------------------------------------- /src/fasttext.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "args.h" 16 | #include "densematrix.h" 17 | #include "dictionary.h" 18 | #include "matrix.h" 19 | #include "model.h" 20 | #include "real.h" 21 | #include "utils.h" 22 | #include "vector.h" 23 | 24 | namespace fasttext { 25 | 26 | class FastText { 27 | public: 28 | using TrainCallback = 29 | std::function; 30 | 31 | protected: 32 | std::shared_ptr args_; 33 | std::shared_ptr dict_; 34 | std::shared_ptr input_; 35 | std::shared_ptr output_; 36 | std::shared_ptr model_; 37 | std::atomic loss_{}; 38 | bool quant_; 39 | int32_t version; 40 | std::unique_ptr wordVectors_; 41 | std::exception_ptr trainException_; 42 | 43 | bool checkModel(std::istream&); 44 | std::vector getTargetCounts() const; 45 | std::shared_ptr createLoss(std::shared_ptr& output); 46 | void buildModel(); 47 | 48 | public: 49 | void loadModel(std::istream& in); 50 | 51 | void loadModel(const std::string& filename); 52 | 53 | void predict( 54 | int32_t k, 55 | const std::vector& words, 56 | Predictions& predictions, 57 | real threshold = 0.0) const; 58 | 59 | bool predictLine( 60 | std::istream& in, 61 | std::vector>& predictions, 62 | int32_t k, 63 | real threshold) const; 64 | }; 65 | } // namespace fasttext 66 | -------------------------------------------------------------------------------- /src/loss.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "loss.h" 10 | #include "utils.h" 11 | 12 | #include 13 | 14 | namespace fasttext { 15 | 16 | constexpr int64_t SIGMOID_TABLE_SIZE = 512; 17 | constexpr int64_t MAX_SIGMOID = 8; 18 | constexpr int64_t LOG_TABLE_SIZE = 512; 19 | 20 | bool comparePairs( 21 | const std::pair& l, 22 | const std::pair& r) { 23 | return l.first > r.first; 24 | } 25 | 26 | real std_log(real x) { 27 | return std::log(x + 1e-5); 28 | } 29 | 30 | Loss::Loss(std::shared_ptr& wo) : wo_(wo) { 31 | t_sigmoid_.reserve(SIGMOID_TABLE_SIZE + 1); 32 | for (int i = 0; i < SIGMOID_TABLE_SIZE + 1; i++) { 33 | real x = real(i * 2 * MAX_SIGMOID) / SIGMOID_TABLE_SIZE - MAX_SIGMOID; 34 | t_sigmoid_.push_back(1.0 / (1.0 + std::exp(-x))); 35 | } 36 | 37 | t_log_.reserve(LOG_TABLE_SIZE + 1); 38 | for (int i = 0; i < LOG_TABLE_SIZE + 1; i++) { 39 | real x = (real(i) + 1e-5) / LOG_TABLE_SIZE; 40 | t_log_.push_back(std::log(x)); 41 | } 42 | } 43 | 44 | real Loss::log(real x) const { 45 | if (x > 1.0) { 46 | return 0.0; 47 | } 48 | int64_t i = int64_t(x * LOG_TABLE_SIZE); 49 | return t_log_[i]; 50 | } 51 | 52 | real Loss::sigmoid(real x) const { 53 | if (x < -MAX_SIGMOID) { 54 | return 0.0; 55 | } else if (x > MAX_SIGMOID) { 56 | return 1.0; 57 | } else { 58 | int64_t i = 59 | int64_t((x + MAX_SIGMOID) * SIGMOID_TABLE_SIZE / MAX_SIGMOID / 2); 60 | return t_sigmoid_[i]; 61 | } 62 | } 63 | 64 | void Loss::predict( 65 | int32_t k, 66 | real threshold, 67 | Predictions& heap, 68 | Model::State& state) const { 69 | computeOutput(state); 70 | findKBest(k, threshold, heap, state.output); 71 | std::sort_heap(heap.begin(), heap.end(), comparePairs); 72 | } 73 | 74 | void Loss::findKBest( 75 | int32_t k, 76 | real threshold, 77 | Predictions& heap, 78 | const Vector& output) const { 79 | for (int32_t i = 0; i < output.size(); i++) { 80 | if (output[i] < threshold) { 81 | continue; 82 | } 83 | if (heap.size() == k && std_log(output[i]) < heap.front().first) { 84 | continue; 85 | } 86 | heap.push_back(std::make_pair(std_log(output[i]), i)); 87 | std::push_heap(heap.begin(), heap.end(), comparePairs); 88 | if (heap.size() > k) { 89 | std::pop_heap(heap.begin(), heap.end(), comparePairs); 90 | heap.pop_back(); 91 | } 92 | } 93 | } 94 | 95 | BinaryLogisticLoss::BinaryLogisticLoss(std::shared_ptr& wo) 96 | : Loss(wo) {} 97 | 98 | void BinaryLogisticLoss::computeOutput(Model::State& state) const { 99 | Vector& output = state.output; 100 | output.mul(*wo_, state.hidden); 101 | int32_t osz = output.size(); 102 | for (int32_t i = 0; i < osz; i++) { 103 | output[i] = sigmoid(output[i]); 104 | } 105 | } 106 | 107 | OneVsAllLoss::OneVsAllLoss(std::shared_ptr& wo) 108 | : BinaryLogisticLoss(wo) {} 109 | 110 | NegativeSamplingLoss::NegativeSamplingLoss( 111 | std::shared_ptr& wo, 112 | int neg, 113 | const std::vector& targetCounts) 114 | : BinaryLogisticLoss(wo), neg_(neg), negatives_(), uniform_() { 115 | real z = 0.0; 116 | for (size_t i = 0; i < targetCounts.size(); i++) { 117 | z += pow(targetCounts[i], 0.5); 118 | } 119 | for (size_t i = 0; i < targetCounts.size(); i++) { 120 | real c = pow(targetCounts[i], 0.5); 121 | for (size_t j = 0; j < c * NegativeSamplingLoss::NEGATIVE_TABLE_SIZE / z; 122 | j++) { 123 | negatives_.push_back(i); 124 | } 125 | } 126 | uniform_ = std::uniform_int_distribution(0, negatives_.size() - 1); 127 | } 128 | 129 | HierarchicalSoftmaxLoss::HierarchicalSoftmaxLoss( 130 | std::shared_ptr& wo, 131 | const std::vector& targetCounts) 132 | : BinaryLogisticLoss(wo), 133 | paths_(), 134 | codes_(), 135 | tree_(), 136 | osz_(targetCounts.size()) { 137 | buildTree(targetCounts); 138 | } 139 | 140 | void HierarchicalSoftmaxLoss::buildTree(const std::vector& counts) { 141 | tree_.resize(2 * osz_ - 1); 142 | for (int32_t i = 0; i < 2 * osz_ - 1; i++) { 143 | tree_[i].parent = -1; 144 | tree_[i].left = -1; 145 | tree_[i].right = -1; 146 | tree_[i].count = 1e15; 147 | tree_[i].binary = false; 148 | } 149 | for (int32_t i = 0; i < osz_; i++) { 150 | tree_[i].count = counts[i]; 151 | } 152 | int32_t leaf = osz_ - 1; 153 | int32_t node = osz_; 154 | for (int32_t i = osz_; i < 2 * osz_ - 1; i++) { 155 | int32_t mini[2] = {0}; 156 | for (int32_t j = 0; j < 2; j++) { 157 | if (leaf >= 0 && tree_[leaf].count < tree_[node].count) { 158 | mini[j] = leaf--; 159 | } else { 160 | mini[j] = node++; 161 | } 162 | } 163 | tree_[i].left = mini[0]; 164 | tree_[i].right = mini[1]; 165 | tree_[i].count = tree_[mini[0]].count + tree_[mini[1]].count; 166 | tree_[mini[0]].parent = i; 167 | tree_[mini[1]].parent = i; 168 | tree_[mini[1]].binary = true; 169 | } 170 | for (int32_t i = 0; i < osz_; i++) { 171 | std::vector path; 172 | std::vector code; 173 | int32_t j = i; 174 | while (tree_[j].parent != -1) { 175 | path.push_back(tree_[j].parent - osz_); 176 | code.push_back(tree_[j].binary); 177 | j = tree_[j].parent; 178 | } 179 | paths_.push_back(path); 180 | codes_.push_back(code); 181 | } 182 | } 183 | 184 | void HierarchicalSoftmaxLoss::predict( 185 | int32_t k, 186 | real threshold, 187 | Predictions& heap, 188 | Model::State& state) const { 189 | dfs(k, threshold, 2 * osz_ - 2, 0.0, heap, state.hidden); 190 | std::sort_heap(heap.begin(), heap.end(), comparePairs); 191 | } 192 | 193 | void HierarchicalSoftmaxLoss::dfs( 194 | int32_t k, 195 | real threshold, 196 | int32_t node, 197 | real score, 198 | Predictions& heap, 199 | const Vector& hidden) const { 200 | if (score < std_log(threshold)) { 201 | return; 202 | } 203 | if (heap.size() == k && score < heap.front().first) { 204 | return; 205 | } 206 | 207 | if (tree_[node].left == -1 && tree_[node].right == -1) { 208 | heap.push_back(std::make_pair(score, node)); 209 | std::push_heap(heap.begin(), heap.end(), comparePairs); 210 | if (heap.size() > k) { 211 | std::pop_heap(heap.begin(), heap.end(), comparePairs); 212 | heap.pop_back(); 213 | } 214 | return; 215 | } 216 | 217 | real f = wo_->dotRow(hidden, node - osz_); 218 | f = 1. / (1 + std::exp(-f)); 219 | 220 | dfs(k, threshold, tree_[node].left, score + std_log(1.0 - f), heap, hidden); 221 | dfs(k, threshold, tree_[node].right, score + std_log(f), heap, hidden); 222 | } 223 | 224 | SoftmaxLoss::SoftmaxLoss(std::shared_ptr& wo) : Loss(wo) {} 225 | 226 | void SoftmaxLoss::computeOutput(Model::State& state) const { 227 | Vector& output = state.output; 228 | output.mul(*wo_, state.hidden); 229 | real max = output[0], z = 0.0; 230 | int32_t osz = output.size(); 231 | for (int32_t i = 0; i < osz; i++) { 232 | max = std::max(output[i], max); 233 | } 234 | for (int32_t i = 0; i < osz; i++) { 235 | output[i] = exp(output[i] - max); 236 | z += output[i]; 237 | } 238 | for (int32_t i = 0; i < osz; i++) { 239 | output[i] /= z; 240 | } 241 | } 242 | 243 | } // namespace fasttext 244 | -------------------------------------------------------------------------------- /src/loss.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "matrix.h" 16 | #include "model.h" 17 | #include "real.h" 18 | #include "utils.h" 19 | #include "vector.h" 20 | 21 | namespace fasttext { 22 | 23 | class Loss { 24 | private: 25 | void findKBest( 26 | int32_t k, 27 | real threshold, 28 | Predictions& heap, 29 | const Vector& output) const; 30 | 31 | protected: 32 | std::vector t_sigmoid_; 33 | std::vector t_log_; 34 | std::shared_ptr& wo_; 35 | 36 | real log(real x) const; 37 | real sigmoid(real x) const; 38 | 39 | public: 40 | explicit Loss(std::shared_ptr& wo); 41 | virtual ~Loss() = default; 42 | 43 | virtual void computeOutput(Model::State& state) const = 0; 44 | 45 | virtual void predict( 46 | int32_t /*k*/, 47 | real /*threshold*/, 48 | Predictions& /*heap*/, 49 | Model::State& /*state*/) const; 50 | }; 51 | 52 | class BinaryLogisticLoss : public Loss { 53 | public: 54 | explicit BinaryLogisticLoss(std::shared_ptr& wo); 55 | virtual ~BinaryLogisticLoss() noexcept override = default; 56 | void computeOutput(Model::State& state) const override; 57 | }; 58 | 59 | class OneVsAllLoss : public BinaryLogisticLoss { 60 | public: 61 | explicit OneVsAllLoss(std::shared_ptr& wo); 62 | ~OneVsAllLoss() noexcept override = default; 63 | }; 64 | 65 | class NegativeSamplingLoss : public BinaryLogisticLoss { 66 | protected: 67 | static const int32_t NEGATIVE_TABLE_SIZE = 10000000; 68 | 69 | int neg_; 70 | std::vector negatives_; 71 | std::uniform_int_distribution uniform_; 72 | 73 | public: 74 | explicit NegativeSamplingLoss( 75 | std::shared_ptr& wo, 76 | int neg, 77 | const std::vector& targetCounts); 78 | ~NegativeSamplingLoss() noexcept override = default; 79 | 80 | }; 81 | 82 | class HierarchicalSoftmaxLoss : public BinaryLogisticLoss { 83 | protected: 84 | struct Node { 85 | int32_t parent; 86 | int32_t left; 87 | int32_t right; 88 | int64_t count; 89 | bool binary; 90 | }; 91 | 92 | std::vector> paths_; 93 | std::vector> codes_; 94 | std::vector tree_; 95 | int32_t osz_; 96 | void buildTree(const std::vector& counts); 97 | void dfs( 98 | int32_t k, 99 | real threshold, 100 | int32_t node, 101 | real score, 102 | Predictions& heap, 103 | const Vector& hidden) const; 104 | 105 | public: 106 | explicit HierarchicalSoftmaxLoss( 107 | std::shared_ptr& wo, 108 | const std::vector& counts); 109 | ~HierarchicalSoftmaxLoss() noexcept override = default; 110 | void predict( 111 | int32_t k, 112 | real threshold, 113 | Predictions& heap, 114 | Model::State& state) const override; 115 | }; 116 | 117 | class SoftmaxLoss : public Loss { 118 | public: 119 | explicit SoftmaxLoss(std::shared_ptr& wo); 120 | ~SoftmaxLoss() noexcept override = default; 121 | void computeOutput(Model::State& state) const override; 122 | }; 123 | 124 | } // namespace fasttext 125 | -------------------------------------------------------------------------------- /src/matrix.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "matrix.h" 10 | 11 | namespace fasttext { 12 | 13 | Matrix::Matrix() : m_(0), n_(0) {} 14 | 15 | Matrix::Matrix(int64_t m, int64_t n) : m_(m), n_(n) {} 16 | 17 | int64_t Matrix::size(int64_t dim) const { 18 | assert(dim == 0 || dim == 1); 19 | if (dim == 0) { 20 | return m_; 21 | } 22 | return n_; 23 | } 24 | 25 | } // namespace fasttext 26 | -------------------------------------------------------------------------------- /src/matrix.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include "real.h" 18 | 19 | namespace fasttext { 20 | 21 | class Vector; 22 | 23 | class Matrix { 24 | protected: 25 | int64_t m_; 26 | int64_t n_; 27 | 28 | public: 29 | Matrix(); 30 | explicit Matrix(int64_t, int64_t); 31 | virtual ~Matrix() = default; 32 | 33 | int64_t size(int64_t dim) const; 34 | 35 | virtual real dotRow(const Vector&, int64_t) const = 0; 36 | virtual void addRowToVector(Vector& x, int32_t i) const = 0; 37 | virtual void addRowToVector(Vector& x, int32_t i, real a) const = 0; 38 | virtual void load(std::istream&) = 0; 39 | }; 40 | 41 | } // namespace fasttext 42 | -------------------------------------------------------------------------------- /src/model.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "model.h" 10 | #include "loss.h" 11 | #include "utils.h" 12 | 13 | #include 14 | #include 15 | 16 | namespace fasttext { 17 | 18 | Model::State::State(int32_t hiddenSize, int32_t outputSize, int32_t seed) 19 | : lossValue_(0.0), 20 | nexamples_(0), 21 | hidden(hiddenSize), 22 | output(outputSize), 23 | grad(hiddenSize), 24 | rng(seed) {} 25 | 26 | Model::Model( 27 | std::shared_ptr wi, 28 | std::shared_ptr wo, 29 | std::shared_ptr loss, 30 | bool normalizeGradient) 31 | : wi_(wi), wo_(wo), loss_(loss), normalizeGradient_(normalizeGradient) {} 32 | 33 | void Model::computeHidden(const std::vector& input, State& state) 34 | const { 35 | Vector& hidden = state.hidden; 36 | hidden.zero(); 37 | for (auto it = input.cbegin(); it != input.cend(); ++it) { 38 | hidden.addRow(*wi_, *it); 39 | } 40 | hidden.mul(1.0 / input.size()); 41 | } 42 | 43 | void Model::predict( 44 | const std::vector& input, 45 | int32_t k, 46 | real threshold, 47 | Predictions& heap, 48 | State& state) const { 49 | if (k == Model::kUnlimitedPredictions) { 50 | k = wo_->size(0); // output size 51 | } else if (k <= 0) { 52 | throw std::invalid_argument("k needs to be 1 or higher!"); 53 | } 54 | heap.reserve(k + 1); 55 | computeHidden(input, state); 56 | 57 | loss_->predict(k, threshold, heap, state); 58 | } 59 | 60 | } // namespace fasttext 61 | -------------------------------------------------------------------------------- /src/model.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "matrix.h" 17 | #include "real.h" 18 | #include "utils.h" 19 | #include "vector.h" 20 | 21 | namespace fasttext { 22 | 23 | class Loss; 24 | 25 | class Model { 26 | protected: 27 | std::shared_ptr wi_; 28 | std::shared_ptr wo_; 29 | std::shared_ptr loss_; 30 | bool normalizeGradient_; 31 | 32 | public: 33 | Model( 34 | std::shared_ptr wi, 35 | std::shared_ptr wo, 36 | std::shared_ptr loss, 37 | bool normalizeGradient); 38 | Model(const Model& model) = delete; 39 | Model(Model&& model) = delete; 40 | Model& operator=(const Model& other) = delete; 41 | Model& operator=(Model&& other) = delete; 42 | 43 | class State { 44 | private: 45 | real lossValue_; 46 | int64_t nexamples_; 47 | 48 | public: 49 | Vector hidden; 50 | Vector output; 51 | Vector grad; 52 | std::minstd_rand rng; 53 | 54 | State(int32_t hiddenSize, int32_t outputSize, int32_t seed); 55 | }; 56 | 57 | void predict( 58 | const std::vector& input, 59 | int32_t k, 60 | real threshold, 61 | Predictions& heap, 62 | State& state) const; 63 | void computeHidden(const std::vector& input, State& state) const; 64 | 65 | static const int32_t kUnlimitedPredictions = -1; 66 | }; 67 | 68 | } // namespace fasttext 69 | -------------------------------------------------------------------------------- /src/productquantizer.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "productquantizer.h" 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace fasttext { 18 | 19 | ProductQuantizer::ProductQuantizer(int32_t dim, int32_t dsub) 20 | : dim_(dim), 21 | nsubq_(dim / dsub), 22 | dsub_(dsub), 23 | centroids_(dim * ksub_), 24 | rng(seed_) { 25 | lastdsub_ = dim_ % dsub; 26 | if (lastdsub_ == 0) { 27 | lastdsub_ = dsub_; 28 | } else { 29 | nsubq_++; 30 | } 31 | } 32 | 33 | const real* ProductQuantizer::get_centroids(int32_t m, uint8_t i) const { 34 | if (m == nsubq_ - 1) { 35 | return ¢roids_[m * ksub_ * dsub_ + i * lastdsub_]; 36 | } 37 | return ¢roids_[(m * ksub_ + i) * dsub_]; 38 | } 39 | 40 | real* ProductQuantizer::get_centroids(int32_t m, uint8_t i) { 41 | if (m == nsubq_ - 1) { 42 | return ¢roids_[m * ksub_ * dsub_ + i * lastdsub_]; 43 | } 44 | return ¢roids_[(m * ksub_ + i) * dsub_]; 45 | } 46 | 47 | real ProductQuantizer::mulcode( 48 | const Vector& x, 49 | const uint8_t* codes, 50 | int32_t t, 51 | real alpha) const { 52 | real res = 0.0; 53 | auto d = dsub_; 54 | const uint8_t* code = codes + nsubq_ * t; 55 | for (auto m = 0; m < nsubq_; m++) { 56 | const real* c = get_centroids(m, code[m]); 57 | if (m == nsubq_ - 1) { 58 | d = lastdsub_; 59 | } 60 | for (auto n = 0; n < d; n++) { 61 | res += x[m * dsub_ + n] * c[n]; 62 | } 63 | } 64 | return res * alpha; 65 | } 66 | 67 | void ProductQuantizer::addcode( 68 | Vector& x, 69 | const uint8_t* codes, 70 | int32_t t, 71 | real alpha) const { 72 | auto d = dsub_; 73 | const uint8_t* code = codes + nsubq_ * t; 74 | for (auto m = 0; m < nsubq_; m++) { 75 | const real* c = get_centroids(m, code[m]); 76 | if (m == nsubq_ - 1) { 77 | d = lastdsub_; 78 | } 79 | for (auto n = 0; n < d; n++) { 80 | x[m * dsub_ + n] += alpha * c[n]; 81 | } 82 | } 83 | } 84 | 85 | void ProductQuantizer::load(std::istream& in) { 86 | in.read((char*)&dim_, sizeof(dim_)); 87 | in.read((char*)&nsubq_, sizeof(nsubq_)); 88 | in.read((char*)&dsub_, sizeof(dsub_)); 89 | in.read((char*)&lastdsub_, sizeof(lastdsub_)); 90 | centroids_.resize(dim_ * ksub_); 91 | for (auto i = 0; i < centroids_.size(); i++) { 92 | in.read((char*)¢roids_[i], sizeof(real)); 93 | } 94 | } 95 | 96 | } // namespace fasttext 97 | -------------------------------------------------------------------------------- /src/productquantizer.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include "real.h" 18 | #include "vector.h" 19 | 20 | namespace fasttext { 21 | 22 | class ProductQuantizer { 23 | protected: 24 | const int32_t nbits_ = 8; 25 | const int32_t ksub_ = 1 << nbits_; 26 | const int32_t max_points_per_cluster_ = 256; 27 | const int32_t max_points_ = max_points_per_cluster_ * ksub_; 28 | const int32_t seed_ = 1234; 29 | const int32_t niter_ = 25; 30 | const real eps_ = 1e-7; 31 | 32 | int32_t dim_; 33 | int32_t nsubq_; 34 | int32_t dsub_; 35 | int32_t lastdsub_; 36 | 37 | std::vector centroids_; 38 | 39 | std::minstd_rand rng; 40 | 41 | public: 42 | ProductQuantizer() {} 43 | ProductQuantizer(int32_t, int32_t); 44 | 45 | real* get_centroids(int32_t, uint8_t); 46 | const real* get_centroids(int32_t, uint8_t) const; 47 | 48 | real mulcode(const Vector&, const uint8_t*, int32_t, real) const; 49 | void addcode(Vector&, const uint8_t*, int32_t, real) const; 50 | 51 | void load(std::istream&); 52 | }; 53 | 54 | } // namespace fasttext 55 | -------------------------------------------------------------------------------- /src/quantmatrix.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "quantmatrix.h" 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | namespace fasttext { 16 | 17 | QuantMatrix::QuantMatrix() : Matrix(), qnorm_(false), codesize_(0) {} 18 | 19 | real QuantMatrix::dotRow(const Vector& vec, int64_t i) const { 20 | assert(i >= 0); 21 | assert(i < m_); 22 | assert(vec.size() == n_); 23 | real norm = 1; 24 | if (qnorm_) { 25 | norm = npq_->get_centroids(0, norm_codes_[i])[0]; 26 | } 27 | return pq_->mulcode(vec, codes_.data(), i, norm); 28 | } 29 | 30 | void QuantMatrix::addRowToVector(Vector& x, int32_t i, real a) const { 31 | real norm = 1; 32 | if (qnorm_) { 33 | norm = npq_->get_centroids(0, norm_codes_[i])[0]; 34 | } 35 | pq_->addcode(x, codes_.data(), i, a * norm); 36 | } 37 | 38 | void QuantMatrix::addRowToVector(Vector& x, int32_t i) const { 39 | real norm = 1; 40 | if (qnorm_) { 41 | norm = npq_->get_centroids(0, norm_codes_[i])[0]; 42 | } 43 | pq_->addcode(x, codes_.data(), i, norm); 44 | } 45 | 46 | void QuantMatrix::load(std::istream& in) { 47 | in.read((char*)&qnorm_, sizeof(qnorm_)); 48 | in.read((char*)&m_, sizeof(m_)); 49 | in.read((char*)&n_, sizeof(n_)); 50 | in.read((char*)&codesize_, sizeof(codesize_)); 51 | codes_ = std::vector(codesize_); 52 | in.read((char*)codes_.data(), codesize_ * sizeof(uint8_t)); 53 | pq_ = std::unique_ptr(new ProductQuantizer()); 54 | pq_->load(in); 55 | if (qnorm_) { 56 | norm_codes_ = std::vector(m_); 57 | in.read((char*)norm_codes_.data(), m_ * sizeof(uint8_t)); 58 | npq_ = std::unique_ptr(new ProductQuantizer()); 59 | npq_->load(in); 60 | } 61 | } 62 | 63 | } // namespace fasttext 64 | -------------------------------------------------------------------------------- /src/quantmatrix.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | #include "real.h" 19 | 20 | #include "densematrix.h" 21 | #include "matrix.h" 22 | #include "vector.h" 23 | 24 | #include "productquantizer.h" 25 | 26 | namespace fasttext { 27 | 28 | class QuantMatrix : public Matrix { 29 | protected: 30 | std::unique_ptr pq_; 31 | std::unique_ptr npq_; 32 | 33 | std::vector codes_; 34 | std::vector norm_codes_; 35 | 36 | bool qnorm_; 37 | int32_t codesize_; 38 | 39 | public: 40 | QuantMatrix(); 41 | QuantMatrix(const QuantMatrix&) = delete; 42 | QuantMatrix(QuantMatrix&&) = delete; 43 | QuantMatrix& operator=(const QuantMatrix&) = delete; 44 | QuantMatrix& operator=(QuantMatrix&&) = delete; 45 | virtual ~QuantMatrix() noexcept override = default; 46 | 47 | real dotRow(const Vector&, int64_t) const override; 48 | void addRowToVector(Vector& x, int32_t i) const override; 49 | void addRowToVector(Vector& x, int32_t i, real a) const override; 50 | void load(std::istream&) override; 51 | }; 52 | 53 | } // namespace fasttext 54 | -------------------------------------------------------------------------------- /src/real.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | namespace fasttext { 12 | 13 | typedef float real; 14 | } 15 | -------------------------------------------------------------------------------- /src/utils.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include "real.h" 12 | 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | #if defined(__clang__) || defined(__GNUC__) 19 | #define FASTTEXT_DEPRECATED(msg) __attribute__((__deprecated__(msg))) 20 | #elif defined(_MSC_VER) 21 | #define FASTTEXT_DEPRECATED(msg) __declspec(deprecated(msg)) 22 | #else 23 | #define FASTTEXT_DEPRECATED(msg) 24 | #endif 25 | 26 | namespace fasttext { 27 | 28 | using Predictions = std::vector>; 29 | 30 | } // namespace fasttext 31 | -------------------------------------------------------------------------------- /src/vector.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #include "vector.h" 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include "matrix.h" 17 | 18 | namespace fasttext { 19 | 20 | Vector::Vector(int64_t m) : data_(m) {} 21 | 22 | void Vector::zero() { 23 | std::fill(data_.begin(), data_.end(), 0.0); 24 | } 25 | 26 | void Vector::mul(real a) { 27 | for (int64_t i = 0; i < size(); i++) { 28 | data_[i] *= a; 29 | } 30 | } 31 | 32 | void Vector::addRow(const Matrix& A, int64_t i, real a) { 33 | assert(i >= 0); 34 | assert(i < A.size(0)); 35 | assert(size() == A.size(1)); 36 | A.addRowToVector(*this, i, a); 37 | } 38 | 39 | void Vector::addRow(const Matrix& A, int64_t i) { 40 | assert(i >= 0); 41 | assert(i < A.size(0)); 42 | assert(size() == A.size(1)); 43 | A.addRowToVector(*this, i); 44 | } 45 | 46 | void Vector::mul(const Matrix& A, const Vector& vec) { 47 | assert(A.size(0) == size()); 48 | assert(A.size(1) == vec.size()); 49 | for (int64_t i = 0; i < size(); i++) { 50 | data_[i] = A.dotRow(vec, i); 51 | } 52 | } 53 | 54 | std::ostream& operator<<(std::ostream& os, const Vector& v) { 55 | os << std::setprecision(5); 56 | for (int64_t j = 0; j < v.size(); j++) { 57 | os << v[j] << ' '; 58 | } 59 | return os; 60 | } 61 | 62 | } // namespace fasttext 63 | -------------------------------------------------------------------------------- /src/vector.h: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (c) 2016-present, Facebook, Inc. 3 | * All rights reserved. 4 | * 5 | * This source code is licensed under the MIT license found in the 6 | * LICENSE file in the root directory of this source tree. 7 | */ 8 | 9 | #pragma once 10 | 11 | #include 12 | #include 13 | #include 14 | 15 | #include "real.h" 16 | 17 | namespace fasttext { 18 | 19 | class Matrix; 20 | 21 | class Vector { 22 | protected: 23 | std::vector data_; 24 | 25 | public: 26 | explicit Vector(int64_t); 27 | Vector(const Vector&) = default; 28 | Vector(Vector&&) noexcept = default; 29 | Vector& operator=(const Vector&) = default; 30 | Vector& operator=(Vector&&) = default; 31 | 32 | inline real* data() { 33 | return data_.data(); 34 | } 35 | inline const real* data() const { 36 | return data_.data(); 37 | } 38 | inline real& operator[](int64_t i) { 39 | return data_[i]; 40 | } 41 | inline const real& operator[](int64_t i) const { 42 | return data_[i]; 43 | } 44 | 45 | inline int64_t size() const { 46 | return data_.size(); 47 | } 48 | void zero(); 49 | void mul(real); 50 | void addRow(const Matrix&, int64_t); 51 | void addRow(const Matrix&, int64_t, real); 52 | void mul(const Matrix&, const Vector&); 53 | }; 54 | 55 | std::ostream& operator<<(std::ostream&, const Vector&); 56 | 57 | } // namespace fasttext 58 | --------------------------------------------------------------------------------