├── .editorconfig ├── .gitattributes ├── .gitignore ├── .gitmodules ├── .npmignore ├── .travis.yml ├── LICENSE ├── README.md ├── appveyor.yml ├── binding.gyp ├── index.d.ts ├── index.js ├── package.json ├── src ├── addon.cc ├── classifier.cc ├── classifier.h ├── loadModel.cc ├── loadModel.h ├── nnWorker.cc ├── nnWorker.h ├── node-argument.cc ├── node-argument.h ├── node-util.cc ├── node-util.h ├── predictWorker.cc ├── predictWorker.h ├── quantize.cc ├── quantize.h ├── query.cc ├── query.h ├── train.cc ├── train.h ├── wrapper.cc └── wrapper.h ├── tea.yaml └── test ├── data ├── cooking.preprocessed.txt ├── cooking.stackexchange.id ├── cooking.stackexchange.tar.gz ├── cooking.stackexchange.txt ├── cooking.train.txt └── cooking.valid.txt ├── models ├── lid.176.ftz ├── model_cooking.bin └── model_cooking.vec ├── specs ├── fastText.js ├── langid.js └── trainer.js └── start.js /.editorconfig: -------------------------------------------------------------------------------- 1 | # top-most EditorConfig file 2 | root = true 3 | 4 | # Unix-style newlines with a newline ending every file 5 | [*] 6 | charset = utf-8 7 | end_of_line = lf 8 | insert_final_newline = true 9 | trim_trailing_whitespace = true 10 | indent_style = space 11 | indent_size = 2 12 | 13 | # editorconfig-tools is unable to ignore longs strings or urls 14 | max_line_length = null 15 | 16 | # Tab indentation (no size specified) 17 | [Makefile] 18 | indent_style = tab 19 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | ## GITATTRIBUTES FOR WEB PROJECTS 2 | # 3 | # These settings are for any web project. 4 | # Ref: https://git.io/fxxEi 5 | # 6 | # Details per file setting: 7 | # text These files should be normalized (i.e. convert CRLF => LF). 8 | # binary These files are binary and should be left untouched. 9 | # 10 | # Note that binary is a macro for -text -diff. 11 | ###################################################################### 12 | 13 | ## AUTO-DETECT 14 | ## Handle line endings automatically for files detected as 15 | ## text and leave all files detected as binary untouched. 16 | ## This will handle all files NOT defined below. 17 | * text=auto eol=lf 18 | 19 | ## CSHARP 20 | *.cs text diff=csharp 21 | *.sln text eol=crlf 22 | *.csproj text eol=crlf 23 | 24 | ## SOURCE CODE 25 | *.bat text eol=crlf 26 | *.coffee text 27 | *.css text 28 | *.htm text diff=html 29 | *.html text diff=html 30 | *.inc text 31 | *.ini text 32 | *.js text 33 | *.json text 34 | *.jsx text 35 | *.less text 36 | *.od text 37 | *.onlydata text 38 | *.php text diff=php 39 | *.pl text 40 | *.py text diff=python 41 | *.rb text diff=ruby 42 | *.sass text 43 | *.scm text 44 | *.scss text 45 | *.sh text eol=lf 46 | *.sql text 47 | *.styl text 48 | *.tag text 49 | *.ts text 50 | *.tsx text 51 | *.vue text 52 | *.xml text 53 | *.xhtml text diff=html 54 | 55 | ## DOCKER 56 | *.dockerignore text 57 | Dockerfile text 58 | 59 | ## DOCUMENTATION 60 | *.ipynb text 61 | *.markdown text 62 | *.md text 63 | *.mdwn text 64 | *.mdown text 65 | *.mkd text 66 | *.mkdn text 67 | *.mdtxt text 68 | *.mdtext text 69 | *.txt text 70 | AUTHORS text 71 | CHANGELOG text 72 | CHANGES text 73 | CONTRIBUTING text 74 | COPYING text 75 | copyright text 76 | *COPYRIGHT* text 77 | INSTALL text 78 | license text 79 | LICENSE text 80 | NEWS text 81 | readme text 82 | *README* text 83 | TODO text 84 | 85 | ## TEMPLATES 86 | *.dot text 87 | *.ejs text 88 | *.haml text 89 | *.handlebars text 90 | *.hbs text 91 | *.hbt text 92 | *.jade text 93 | *.latte text 94 | *.mustache text 95 | *.njk text 96 | *.phtml text 97 | *.tmpl text 98 | *.tpl text 99 | *.twig text 100 | 101 | ## LINTERS 102 | .csslintrc text 103 | .eslintrc text 104 | .htmlhintrc text 105 | .jscsrc text 106 | .jshintrc text 107 | .jshintignore text 108 | .stylelintrc text 109 | 110 | ## CONFIGS 111 | *.bowerrc text 112 | *.cnf text 113 | *.conf text 114 | *.config text 115 | .babelrc text 116 | .browserslistrc text 117 | .editorconfig text 118 | .env text 119 | .gitattributes text 120 | .gitconfig text 121 | .gitignore text 122 | .htaccess text 123 | *.lock text 124 | *.npmignore text 125 | *.yaml text 126 | *.yml text 127 | browserslist text 128 | Makefile text 129 | makefile text 130 | 131 | ## HEROKU 132 | Procfile text 133 | .slugignore text 134 | 135 | ## GRAPHICS 136 | *.ai binary 137 | *.bmp binary 138 | *.eps binary 139 | *.gif binary 140 | *.ico binary 141 | *.jng binary 142 | *.jp2 binary 143 | *.jpg binary 144 | *.jpeg binary 145 | *.jpx binary 146 | *.jxr binary 147 | *.pdf binary 148 | *.png binary 149 | *.psb binary 150 | *.psd binary 151 | *.svg text 152 | *.svgz binary 153 | *.tif binary 154 | *.tiff binary 155 | *.wbmp binary 156 | *.webp binary 157 | 158 | ## AUDIO 159 | *.kar binary 160 | *.m4a binary 161 | *.mid binary 162 | *.midi binary 163 | *.mp3 binary 164 | *.ogg binary 165 | *.ra binary 166 | 167 | ## VIDEO 168 | *.3gpp binary 169 | *.3gp binary 170 | *.as binary 171 | *.asf binary 172 | *.asx binary 173 | *.fla binary 174 | *.flv binary 175 | *.m4v binary 176 | *.mng binary 177 | *.mov binary 178 | *.mp4 binary 179 | *.mpeg binary 180 | *.mpg binary 181 | *.ogv binary 182 | *.swc binary 183 | *.swf binary 184 | *.webm binary 185 | 186 | ## ARCHIVES 187 | *.7z binary 188 | *.gz binary 189 | *.jar binary 190 | *.rar binary 191 | *.tar binary 192 | *.zip binary 193 | 194 | ## FONTS 195 | *.ttf binary 196 | *.eot binary 197 | *.otf binary 198 | *.woff binary 199 | *.woff2 binary 200 | 201 | ## EXECUTABLES 202 | *.exe binary 203 | *.pyc binary 204 | 205 | # Project text 206 | .gitignore text 207 | *.gitattributes text 208 | *.md text 209 | 210 | # Project binary 211 | *.lock binary 212 | *.dll binary 213 | *.doc binary 214 | *.docx binary 215 | *.xls binary 216 | *.xlsx binary 217 | /dist/* binary 218 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | logs 3 | *.log 4 | npm-debug.log* 5 | yarn-debug.log* 6 | yarn-error.log* 7 | 8 | # Runtime data 9 | pids 10 | *.pid 11 | *.seed 12 | *.pid.lock 13 | 14 | # Directory for instrumented libs generated by jscoverage/JSCover 15 | lib-cov 16 | 17 | # Coverage directory used by tools like istanbul 18 | coverage 19 | 20 | # nyc test coverage 21 | .nyc_output 22 | 23 | # Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files) 24 | .grunt 25 | 26 | # Bower dependency directory (https://bower.io/) 27 | bower_components 28 | 29 | # node-waf configuration 30 | .lock-wscript 31 | 32 | # Compiled binary addons (http://nodejs.org/api/addons.html) 33 | build/Release 34 | 35 | # Dependency directories 36 | node_modules/ 37 | jspm_packages/ 38 | 39 | # Typescript v1 declaration files 40 | typings/ 41 | 42 | # Optional npm cache directory 43 | .npm 44 | 45 | # Optional eslint cache 46 | .eslintcache 47 | 48 | # Optional REPL history 49 | .node_repl_history 50 | 51 | # Output of 'npm pack' 52 | *.tgz 53 | 54 | # Yarn Integrity file 55 | .yarn-integrity 56 | 57 | # dotenv environment variables file 58 | .env 59 | 60 | .vscode 61 | build* 62 | data 63 | pretrained-vectors 64 | fasttext.node 65 | test*.js 66 | package-lock.json 67 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "fastText"] 2 | path = fastText 3 | url = https://github.com/facebookresearch/fastText.git 4 | -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | fasttext.node 2 | test.js 3 | examples 4 | .vscode 5 | build 6 | docs 7 | test 8 | lib -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: node_js 2 | 3 | compiler: gcc 4 | sudo: false 5 | 6 | os: 7 | - linux 8 | - osx 9 | 10 | env: 11 | global: 12 | - secure: JyNJe+dEpj5feig+xwuopgJZGfih0gzOHj+uh9iLO0LgIILVwUzeUNkMZpaT/ydsLBLTdSTKary3mts2oE878lQImi1va0hB4QKKtf+P0kNmeYaEDc8U1dLOv/0d0veMQD3ig8JzajcU5s6znhPOi/PnNTY3aGSnvE3UNr/9N5sz0Gxw/nfnanK3VcRbDqkDw/f3lUj30YLc0xR4e1HXNS+bqSLRg8pgYHtwModjsX5Y1t3BQOsB4fZu/Joo3Hz5jr6uFHECcVTwW5LPtfNq+96iTdEOk+dU03uZCyr++vBqX0/8acOsY4heFeCKpQDSPxQ/tTBtEcRTRFME4FSCxc0oh9VorgkAjszSTL6Ff64rxEfRX4rqkWgqQaaeA1e8sgJENydFNvjYt1IoSTzdYvof7EPk3kk7rbZmTdaygqt2F2KDzjFC5z4jHYk04cQSm5jIrIVQfaBJ83SFm/cH6Q+totQ48OAJxFz4VfGG5m8Tl1gBjwhAwl3+9D2m2SdebxNjCAJ8My+JlHx1D6DTR+/ir/Wd9tppZc/L2jRhY/qNz5fq+OF+yzM2CQNBjMOOrYoSbEz+rCyaXg/EYAFn5KnA0D6dRKvdiCnbhrE9BlfzcTUVUi9rAXZ03Bl++NU9//ECe30sU/vjpHPiQs280eJkYphoknjpWuZp8FEu0c0= 13 | 14 | node_js: 15 | # support latest Nodejs LTS version 16 | - '10' 17 | 18 | addons: 19 | apt: 20 | sources: 21 | - ubuntu-toolchain-r-test 22 | packages: 23 | - gcc-4.8 24 | - g++-4.8 25 | 26 | sudo: false 27 | 28 | before_install: 29 | - if [ "$TRAVIS_NODE_VERSION" = "0.8" ]; then npm install -g npm@2.7.3; fi; 30 | - if [ $TRAVIS_OS_NAME == "linux" ]; then 31 | export CC="gcc-4.8"; 32 | export CXX="g++-4.8"; 33 | export LINK="gcc-4.8"; 34 | export LINKXX="g++-4.8"; 35 | fi 36 | - nvm --version 37 | - node --version 38 | - npm --version 39 | - gcc --version 40 | - g++ --version 41 | 42 | before_script: 43 | # figure out if we should publish 44 | - echo $TRAVIS_BRANCH 45 | - echo `git describe --tags --always HEAD` 46 | - PUBLISH_BINARY=false 47 | - COMMIT_MESSAGE=$(git show -s --format=%B $TRAVIS_COMMIT | tr -d '\n') 48 | - if [[ $TRAVIS_BRANCH == `git describe --tags --always HEAD` || ${COMMIT_MESSAGE} =~ "[publish binary]" ]]; then PUBLISH_BINARY=true; fi; 49 | - echo "Publishing binaries? ->" $PUBLISH_BINARY 50 | 51 | install: 52 | - npm install --build-from-source 53 | 54 | script: 55 | - npm test 56 | - if [[ $PUBLISH_BINARY == true ]]; then node-pre-gyp rebuild package && node-pre-gyp-github publish --release; fi; 57 | 58 | #cache: 59 | # directories: 60 | # - $HOME/.node-gyp 61 | # - $HOME/.npm 62 | # - node_modules 63 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Nhữ Bảo Vũ 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # node-fasttext 2 | 3 | Nodejs binding for fasttext representation and classification. 4 | 5 | [![MIT License](https://img.shields.io/badge/license-MIT_License-green.svg?style=flat-square)](./LICENSE) 6 | [![npm version](https://img.shields.io/npm/v/fasttext.svg?style=flat)](https://www.npmjs.com/package/fasttext) 7 | [![downloads](https://img.shields.io/npm/dm/fasttext.svg)](https://www.npmjs.com/package/fasttext) 8 | [![Travis](https://travis-ci.org/vunb/node-fasttext.svg?branch=master)](https://travis-ci.org/vunb/node-fasttext) 9 | [![Appveyor](https://ci.appveyor.com/api/projects/status/9gd460vxd6jbel14/branch/master?svg=true)](https://ci.appveyor.com/project/vunb/node-fasttext/branch/master) 10 | 11 | > This is a link to the Facebook [fastText](https://github.com/facebookresearch/fastText). A Library for efficient text classification and representation learning. 12 | 13 | * FASTTEXT_VERSION = 12; 14 | * FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314; 15 | 16 | # Installation 17 | 18 | Using npm: 19 | 20 | > npm install fasttext --save 21 | 22 | # fastText Classifier 23 | 24 | According to [fasttext.cc](https://fasttext.cc/docs/en/supervised-tutorial.html). We have a simple classifier for executing prediction models about `cooking` from stackexchange questions: 25 | 26 | ```js 27 | const path = require('path'); 28 | const fastText = require('fasttext'); 29 | 30 | const model = path.resolve(__dirname, './model_cooking.bin'); 31 | const classifier = new fastText.Classifier(model); 32 | 33 | classifier.predict('Why not put knives in the dishwasher?', 5) 34 | .then((res) => { 35 | if (res.length > 0) { 36 | let tag = res[0].label; // __label__knives 37 | let confidence = res[0].value // 0.8787146210670471 38 | console.log('classify', tag, confidence, res); 39 | } else { 40 | console.log('No matches'); 41 | } 42 | }); 43 | ``` 44 | 45 | The model haved trained before with the followings params: 46 | 47 | ```js 48 | const path = require('path'); 49 | const fastText = require('fasttext'); 50 | 51 | let data = path.resolve(path.join(__dirname, '../data/cooking.train.txt')); 52 | let model = path.resolve(path.join(__dirname, '../data/cooking.model')); 53 | 54 | let classifier = new fastText.Classifier(); 55 | let options = { 56 | input: data, 57 | output: model, 58 | loss: "softmax", 59 | dim: 200, 60 | bucket: 2000000 61 | } 62 | 63 | classifier.train('supervised', options) 64 | .then((res) => { 65 | console.log('model info after training:', res) 66 | // Input <<<<< C:\projects\node-fasttext\test\data\cooking.train.txt 67 | // Output >>>>> C:\projects\node-fasttext\test\data\cooking.model.bin 68 | // Output >>>>> C:\projects\node-fasttext\test\data\cooking.model.vec 69 | }); 70 | ``` 71 | 72 | Or you can train directly from the command line with fasttext builded from official source: 73 | 74 | ```bash 75 | # Training 76 | ~/fastText/data$ ./fasttext supervised -input cooking.train -output model_cooking -lr 1.0 -epoch 25 -wordNgrams 2 -bucket 200000 -dim 50 -loss hs 77 | Read 0M words 78 | Number of words: 8952 79 | Number of labels: 735 80 | Progress: 100.0% words/sec/thread: 1687554 lr: 0.000000 loss: 5.247591 eta: 0h0m 4m 81 | 82 | # Testing 83 | ~/fastText/data$ ./fasttext test model_cooking.bin cooking.valid 84 | N 3000 85 | P@1 0.587 86 | R@1 0.254 87 | Number of examples: 3000 88 | ``` 89 | 90 | # Nearest neighbor 91 | 92 | Simple class for searching nearest neighbors: 93 | 94 | ```js 95 | const path = require('path'); 96 | const fastText = require('fasttext'); 97 | 98 | const model = path.resolve(__dirname, './skipgram.bin'); 99 | const query = new fastText.Query(model); 100 | 101 | query.nn('word', 5, (err, res) => { 102 | if (err) { 103 | console.error(err); 104 | } else if (res.length > 0) { 105 | let tag = res[0].label; // letter 106 | let confidence = res[0].value // 0.99992 107 | console.log('Nearest neighbor', tag, confidence, res); 108 | } else { 109 | console.log('No matches'); 110 | } 111 | }); 112 | ``` 113 | 114 | # Build from source 115 | 116 | See [Installation Prerequisites](https://github.com/nodejs/node-gyp#installation). 117 | 118 | ```bash 119 | # install dependencies and tools 120 | npm install 121 | 122 | # build node-fasttext from source 123 | npm run build 124 | 125 | # run unit-test 126 | npm test 127 | ``` 128 | 129 | # Contributing 130 | 131 | Pull requests and stars are highly welcome. 132 | 133 | For bugs and feature requests, please [create an issue](https://github.com/vunb/node-fasttext/issues/new). 134 | -------------------------------------------------------------------------------- /appveyor.yml: -------------------------------------------------------------------------------- 1 | os: unstable 2 | environment: 3 | NODE_PRE_GYP_GITHUB_TOKEN: 4 | secure: ex3daTUjJEaSvxAKyUiPOiLjEoYhmisgN9YjzbdYvMhdu3YyZo56JyEcTPjKIhQz 5 | 6 | matrix: 7 | # support latest Nodejs LTS version 8 | - nodejs_version: "10" 9 | 10 | platform: 11 | - x64 12 | - x86 13 | 14 | # Install scripts. (runs after repo cloning) 15 | install: 16 | # Get the latest stable version of Node.js or io.js 17 | - ps: Install-Product node $env:nodejs_version $env:platform 18 | - set PATH=%APPDATA%\npm;%PATH% 19 | # install submodules 20 | - git submodule update --init 21 | # install modules 22 | - npm install -g node-gyp node-pre-gyp node-pre-gyp-github 23 | - npm install --build-from-source 24 | # Check if we're building the latest tag, if so 25 | # then we publish the binaries if tests pass. 26 | - ps: > 27 | if ($env:APPVEYOR_REPO_COMMIT_MESSAGE.ToLower().Contains('[publish binary]') -OR $(git describe --tags --always HEAD) -eq $env:APPVEYOR_REPO_BRANCH) { 28 | $env:publish_binary = "true"; 29 | } 30 | if ($env:publish_binary -eq "true") { 31 | "We're publishing a binary!" | Write-Host 32 | } else { 33 | "We're not publishing a binary" | Write-Host 34 | } 35 | true; 36 | 37 | # Post-install test scripts. 38 | test_script: 39 | # Output useful info for debugging. 40 | - node --version 41 | - npm --version 42 | # run tests 43 | - npm test 44 | # publish binary 45 | - ps: if ($env:publish_binary -eq "true") { node-pre-gyp configure clean build package; node-pre-gyp-github publish --release } 46 | 47 | # Don't actually build. 48 | build: off 49 | -------------------------------------------------------------------------------- /binding.gyp: -------------------------------------------------------------------------------- 1 | { 2 | "targets": [ 3 | { 4 | "target_name": "fasttext", 5 | "sources": [ 6 | "fastText/src/args.h", 7 | "fastText/src/args.cc", 8 | "fastText/src/dictionary.h", 9 | "fastText/src/dictionary.cc", 10 | "fastText/src/fasttext.h", 11 | "fastText/src/fasttext.cc", 12 | "fastText/src/matrix.h", 13 | "fastText/src/matrix.cc", 14 | "fastText/src/model.h", 15 | "fastText/src/model.cc", 16 | "fastText/src/productquantizer.h", 17 | "fastText/src/productquantizer.cc", 18 | "fastText/src/qmatrix.h", 19 | "fastText/src/qmatrix.cc", 20 | "fastText/src/real.h", 21 | "fastText/src/utils.h", 22 | "fastText/src/utils.cc", 23 | "fastText/src/vector.h", 24 | "fastText/src/vector.cc", 25 | "src/node-util.cc", 26 | "src/node-argument.cc", 27 | "src/loadModel.cc", 28 | "src/train.cc", 29 | "src/quantize.cc", 30 | "src/predictWorker.cc", 31 | "src/nnWorker.cc", 32 | "src/wrapper.cc", 33 | "src/classifier.cc", 34 | "src/query.cc", 35 | "src/addon.cc" 36 | ], 37 | "defines": [ 38 | "NAPI_VERSION=<(napi_build_version)", 39 | ], 40 | "include_dirs": [ 41 | "; 4 | predict(sentence: string, k: number, callback?: DoneCallback): Promise>; 5 | train(command: 'supervised' | 'skipgram' | 'cbow' | 'quantize', options: Options, callback?: DoneCallback): Promise; 6 | quantize(options: Options, callback?: DoneCallback); 7 | } 8 | 9 | export declare class Query { 10 | constructor(modelFilename: string); 11 | nn(word: string, neighbors: number): Promise>; 12 | } 13 | 14 | export interface Options { 15 | [key: string]: any; 16 | // The following arguments are mandatory 17 | input: string; // training file path 18 | output: string; // output file path 19 | 20 | // The following arguments are optional 21 | verbose: number; // verbosity level [2] 22 | 23 | // The following arguments for the dictionary are optional 24 | minCount: number; // minimal number of word occurrences [5] 25 | minCountLabel: number; // minimal number of label occurrences [0] 26 | wordNgrams: number; // max length of word ngram [1] 27 | bucket: number; // number of buckets [2000000] 28 | minn: number; // min length of char ngram [3] 29 | maxn: number; // max length of char ngram [6] 30 | t: number; // sampling threshold [0.0001] 31 | label: string; // labels prefix [__label__] 32 | 33 | // The following arguments for training are optional 34 | lr: number; // learning rate [0.05] 35 | lrUpdateRate: number; // change the rate of updates for the learning rate [100] 36 | dim: number; // size of word vectors [100] 37 | ws: number; // size of the context window [5] 38 | epoch: number; // number of epochs [5] 39 | neg: number; // number of negatives sampled [5] 40 | loss: 'softmax' | 'hs' | 'ls' | string; // loss function {ns, hs, softmax} [ns] 41 | thread: number; // number of threads [12] 42 | pretrainedVectors: string; // pretrained word vectors for supervised learning [] 43 | saveOutput: boolean; // whether output params should be saved [0] 44 | 45 | // The following arguments for quantization are optional 46 | cutoff: number; // number of words and ngrams to retain [0] 47 | retrain: boolean; // finetune embeddings if a cutoff is applied [0] 48 | qnorm: boolean; // quantizing the norm separately [0] 49 | qout: boolean; // quantizing the classifier [0] 50 | dsub: number; // size of each sub-vector [2] 51 | } 52 | 53 | export interface DoneCallback { 54 | (error: any, result: any): void 55 | } 56 | -------------------------------------------------------------------------------- /index.js: -------------------------------------------------------------------------------- 1 | var binary = require('node-pre-gyp'); 2 | var path = require('path') 3 | 4 | var binaryPath = binary.find(path.resolve(path.join(__dirname, './package.json'))); 5 | var FastText = require(binaryPath); 6 | 7 | module.exports = FastText; -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "fasttext", 3 | "version": "1.0.0", 4 | "description": "Nodejs binding for Fasttext representation and classification", 5 | "main": "index.js", 6 | "types": "index.d.ts", 7 | "engines": { 8 | "node": ">=8.0.0" 9 | }, 10 | "scripts": { 11 | "config": "node-pre-gyp configure", 12 | "build": "node-pre-gyp rebuild", 13 | "install": "node-pre-gyp install --fallback-to-build", 14 | "test": "tape test/start.js | tap-spec", 15 | "publish-binary": "git commit --allow-empty -m \"[publish binary]\"" 16 | }, 17 | "repository": { 18 | "type": "git", 19 | "url": "git+https://github.com/vunb/node-fasttext.git" 20 | }, 21 | "keywords": [ 22 | "vntk", 23 | "fasttext", 24 | "node-fasttext" 25 | ], 26 | "author": "vunb", 27 | "license": "MIT", 28 | "bugs": { 29 | "url": "https://github.com/vunb/node-fasttext/issues" 30 | }, 31 | "homepage": "https://github.com/vunb/node-fasttext#readme", 32 | "binary": { 33 | "module_name": "fasttext", 34 | "module_path": "./lib/binding/{node_napi_label}", 35 | "remote_path": "{version}", 36 | "package_name": "{module_name}-{platform}-{arch}-{node_napi_label}.tar.gz", 37 | "host": "https://github.com/vunb/node-fasttext/releases/download/", 38 | "napi_versions": [ 39 | 1, 40 | 3 41 | ] 42 | }, 43 | "dependencies": { 44 | "node-addon-api": "^1.6.3", 45 | "node-pre-gyp": "^0.13.0" 46 | }, 47 | "devDependencies": { 48 | "node-gyp": "^4.0.0", 49 | "node-pre-gyp-github": "^1.4.3", 50 | "tap-spec": "^5.0.0", 51 | "tape": "^4.11.0" 52 | }, 53 | "files": [ 54 | "src", 55 | "fastText/src", 56 | "index.js", 57 | "index.d.ts", 58 | "binding.gyp" 59 | ] 60 | } 61 | -------------------------------------------------------------------------------- /src/addon.cc: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include "classifier.h" 4 | #include "query.h" 5 | 6 | Napi::Object Initialize(Napi::Env env, Napi::Object exports) 7 | { 8 | FasttextClassifier::Init(env, exports); 9 | FasttextQuery::Init(env, exports); 10 | return exports; 11 | } 12 | 13 | NODE_API_MODULE(addon, Initialize) 14 | -------------------------------------------------------------------------------- /src/classifier.cc: -------------------------------------------------------------------------------- 1 | #include "classifier.h" 2 | #include "loadModel.h" 3 | #include "predictWorker.h" 4 | #include "train.h" 5 | #include "quantize.h" 6 | #include 7 | 8 | Napi::FunctionReference FasttextClassifier::constructor; 9 | 10 | Napi::Object FasttextClassifier::Init(Napi::Env env, Napi::Object exports) 11 | { 12 | Napi::HandleScope scope(env); 13 | Napi::Function func = DefineClass(env, "FasttextClassifier", 14 | {InstanceMethod("loadModel", &FasttextClassifier::LoadModel), 15 | InstanceMethod("predict", &FasttextClassifier::Predict), 16 | InstanceMethod("train", &FasttextClassifier::Train), 17 | InstanceMethod("quantize", &FasttextClassifier::Quantize)}); 18 | 19 | constructor = Napi::Persistent(func); 20 | constructor.SuppressDestruct(); 21 | 22 | exports.Set("Classifier", func); 23 | return exports; 24 | } 25 | 26 | FasttextClassifier::FasttextClassifier(const Napi::CallbackInfo &info) : Napi::ObjectWrap(info) 27 | { 28 | Napi::Env env = info.Env(); 29 | Napi::HandleScope scope(env); 30 | std::string modelFileName = ""; 31 | 32 | if (info.Length() > 0 && info[0].IsString()) 33 | { 34 | modelFileName = info[0].As().Utf8Value(); 35 | } 36 | 37 | this->wrapper_ = new Wrapper(modelFileName); 38 | } 39 | 40 | Napi::Value FasttextClassifier::LoadModel(const Napi::CallbackInfo &info) 41 | { 42 | Napi::Env env = info.Env(); 43 | Napi::HandleScope scope(env); 44 | Napi::Function callback; 45 | 46 | if (info.Length() < 1) 47 | { 48 | Napi::TypeError::New(env, "Path to model file is missing!").ThrowAsJavaScriptException(); 49 | } 50 | else if (!info[0].IsString()) 51 | { 52 | Napi::TypeError::New(env, "Model file path must be a string!").ThrowAsJavaScriptException(); 53 | } 54 | 55 | if (info.Length() > 1 && info[1].IsFunction()) 56 | { 57 | callback = info[1].As(); 58 | } 59 | else 60 | { 61 | callback = Napi::Function::New(env, EmptyCallback); 62 | } 63 | 64 | Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); 65 | Napi::String filename = info[0].As(); 66 | 67 | std::cout << "Preparing load model from: " << filename.Utf8Value() << std::endl; 68 | 69 | LoadModelWorker *worker = new LoadModelWorker(filename, this->wrapper_, deferred, callback); 70 | worker->Queue(); 71 | 72 | return worker->deferred_.Promise(); 73 | } 74 | 75 | Napi::Value FasttextClassifier::Predict(const Napi::CallbackInfo &info) 76 | { 77 | Napi::Env env = info.Env(); 78 | Napi::HandleScope scope(env); 79 | Napi::Function callback = Napi::Function::New(env, EmptyCallback); 80 | ; 81 | int32_t k = 1; 82 | 83 | if (info.Length() < 1 || !info[0].IsString()) 84 | { 85 | Napi::TypeError::New(env, "sentence must be a string").ThrowAsJavaScriptException(); 86 | } 87 | 88 | if (info.Length() > 1 && info[1].IsNumber()) 89 | { 90 | k = info[1].As().Int32Value(); 91 | } 92 | 93 | if (info.Length() > 2 && info[2].IsFunction()) 94 | { 95 | callback = info[2].As(); 96 | } 97 | 98 | Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); 99 | Napi::String sentence = info[0].As(); 100 | 101 | PredictWorker *worker = new PredictWorker(sentence, k, this->wrapper_, deferred, callback); 102 | worker->Queue(); 103 | 104 | return worker->deferred_.Promise(); 105 | } 106 | 107 | Napi::Value FasttextClassifier::Train(const Napi::CallbackInfo &info) 108 | { 109 | Napi::Env env = info.Env(); 110 | Napi::HandleScope scope(env); 111 | Napi::Function callback = Napi::Function::New(env, EmptyCallback); 112 | ; 113 | int32_t k = 1; 114 | 115 | if (info.Length() < 2) 116 | { 117 | Napi::TypeError::New(env, "requires at least 2 parameters").ThrowAsJavaScriptException(); 118 | } 119 | else if (!info[0].IsString()) 120 | { 121 | Napi::TypeError::New(env, "command must be a string").ThrowAsJavaScriptException(); 122 | } 123 | else if (!info[1].IsObject()) 124 | { 125 | Napi::TypeError::New(env, "options must be an object").ThrowAsJavaScriptException(); 126 | } 127 | 128 | if (info.Length() > 2 && info[2].IsFunction()) 129 | { 130 | callback = info[2].As(); 131 | } 132 | 133 | std::string command = info[0].As().Utf8Value(); 134 | 135 | if (!(command == "cbow" || command == "quantize" || command == "skipgram" || command == "supervised")) 136 | { 137 | Napi::TypeError::New(env, "Permitted command types are ['cbow', 'quantize', 'skipgram', 'supervised").ThrowAsJavaScriptException(); 138 | } 139 | 140 | Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); 141 | NodeArgument::NodeArgument nodeArg; 142 | NodeArgument::CArgument c_argument; 143 | 144 | try 145 | { 146 | Napi::Object confObj = info[1].As(); 147 | c_argument = nodeArg.NapiObjectToCArgument(env, confObj); 148 | } 149 | catch (std::string errorMessage) 150 | { 151 | Napi::TypeError::New(env, errorMessage.c_str()).ThrowAsJavaScriptException(); 152 | } 153 | 154 | int count = c_argument.argc; 155 | char **argument = c_argument.argv; 156 | 157 | std::vector args; 158 | args.push_back("-command"); 159 | args.push_back(command.c_str()); 160 | 161 | for (int j = 0; j < count; j++) 162 | { 163 | args.push_back(argument[j]); 164 | } 165 | 166 | if (command == "quantize") 167 | { 168 | QuantizeWorker *worker = new QuantizeWorker(args, this->wrapper_, deferred, callback); 169 | worker->Queue(); 170 | return worker->deferred_.Promise(); 171 | } 172 | else 173 | { 174 | TrainWorker *worker = new TrainWorker(args, this->wrapper_, deferred, callback); 175 | worker->Queue(); 176 | return worker->deferred_.Promise(); 177 | } 178 | } 179 | 180 | Napi::Value FasttextClassifier::Quantize(const Napi::CallbackInfo &info) 181 | { 182 | Napi::Env env = info.Env(); 183 | Napi::HandleScope scope(env); 184 | Napi::Function callback = Napi::Function::New(env, EmptyCallback); 185 | ; 186 | int32_t k = 1; 187 | 188 | if (info.Length() < 1 || !info[0].IsObject()) 189 | { 190 | Napi::TypeError::New(env, "options must be an object").ThrowAsJavaScriptException(); 191 | } 192 | 193 | if (info.Length() > 1 && info[1].IsFunction()) 194 | { 195 | callback = info[1].As(); 196 | } 197 | 198 | Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); 199 | NodeArgument::NodeArgument nodeArg; 200 | NodeArgument::CArgument c_argument; 201 | 202 | try 203 | { 204 | Napi::Object confObj = info[1].As(); 205 | c_argument = nodeArg.NapiObjectToCArgument(env, confObj); 206 | } 207 | catch (std::string errorMessage) 208 | { 209 | Napi::TypeError::New(env, errorMessage.c_str()).ThrowAsJavaScriptException(); 210 | } 211 | 212 | int count = c_argument.argc; 213 | char **argument = c_argument.argv; 214 | 215 | std::vector args; 216 | args.push_back("-command"); 217 | args.push_back("quantize"); 218 | 219 | for (int j = 0; j < count; j++) 220 | { 221 | args.push_back(argument[j]); 222 | } 223 | 224 | QuantizeWorker *worker = new QuantizeWorker(args, this->wrapper_, deferred, callback); 225 | worker->Queue(); 226 | 227 | return worker->deferred_.Promise(); 228 | } 229 | -------------------------------------------------------------------------------- /src/classifier.h: -------------------------------------------------------------------------------- 1 | #ifndef FASTTEXT_CLASSIFIER_H 2 | #define FASTTEXT_CLASSIFIER_H 3 | 4 | #include 5 | #include "node-util.h" 6 | #include "node-argument.h" 7 | #include "wrapper.h" 8 | 9 | class FasttextClassifier : public Napi::ObjectWrap 10 | { 11 | public: 12 | static Napi::Object Init(Napi::Env env, Napi::Object exports); 13 | FasttextClassifier(const Napi::CallbackInfo &info); 14 | 15 | /** 16 | * Destructor 17 | */ 18 | ~FasttextClassifier() 19 | { 20 | delete wrapper_; 21 | } 22 | 23 | private: 24 | static Napi::FunctionReference constructor; 25 | 26 | Napi::Value LoadModel(const Napi::CallbackInfo &info); 27 | Napi::Value Predict(const Napi::CallbackInfo &info); 28 | Napi::Value Train(const Napi::CallbackInfo &info); 29 | Napi::Value Quantize(const Napi::CallbackInfo &info); 30 | 31 | Wrapper *wrapper_; 32 | }; 33 | 34 | #endif 35 | -------------------------------------------------------------------------------- /src/loadModel.cc: -------------------------------------------------------------------------------- 1 | #include "loadModel.h" 2 | #include "node-argument.h" 3 | 4 | void LoadModelWorker::Execute() 5 | { 6 | try 7 | { 8 | result_ = wrapper_->loadModel(filename); 9 | } 10 | catch (std::string errorMessage) 11 | { 12 | SetError(errorMessage.c_str()); 13 | } 14 | } 15 | 16 | void LoadModelWorker::OnOK() 17 | { 18 | NodeArgument::NodeArgument nodeArg; 19 | Napi::Object result = nodeArg.mapToNapiObject(Env(), result_); 20 | 21 | deferred_.Resolve(result); 22 | 23 | // Call empty function 24 | if (!Callback().IsEmpty()) 25 | { 26 | Callback().Call({Env().Null(), result}); 27 | } 28 | } 29 | 30 | void LoadModelWorker::OnError() 31 | { 32 | Napi::HandleScope scope(Env()); 33 | deferred_.Reject(Napi::String::New(Env(), "Can't load model file!")); 34 | 35 | // Call empty function 36 | Callback().Call({}); 37 | } 38 | -------------------------------------------------------------------------------- /src/loadModel.h: -------------------------------------------------------------------------------- 1 | #ifndef LOADMODEL_H 2 | #define LOADMODEL_H 3 | 4 | #include 5 | #include "wrapper.h" 6 | 7 | class LoadModelWorker : public Napi::AsyncWorker 8 | { 9 | public: 10 | LoadModelWorker(std::string filename, Wrapper *wrapper, Napi::Promise::Deferred deferred, Napi::Function &callback) 11 | : Napi::AsyncWorker(callback), 12 | filename(filename), 13 | result_(), 14 | wrapper_(wrapper), 15 | deferred_(deferred){}; 16 | 17 | ~LoadModelWorker(){}; 18 | 19 | Napi::Promise::Deferred deferred_; 20 | void Execute(); 21 | void OnOK(); 22 | void OnError(); 23 | 24 | private: 25 | std::string filename; 26 | std::map result_; 27 | Wrapper *wrapper_; 28 | }; 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /src/nnWorker.cc: -------------------------------------------------------------------------------- 1 | #include "nnWorker.h" 2 | #include "node-util.h" 3 | 4 | void NnWorker::Execute() 5 | { 6 | try 7 | { 8 | wrapper_->loadModel(); 9 | wrapper_->precomputeWordVectors(); 10 | result_ = wrapper_->nn(query_, k_); 11 | } 12 | catch (std::string errorMessage) 13 | { 14 | SetError(errorMessage.c_str()); 15 | } 16 | catch (const char *str) 17 | { 18 | SetError(str); 19 | } 20 | catch (const std::exception &e) 21 | { 22 | SetError(e.what()); 23 | } 24 | } 25 | 26 | void NnWorker::OnError(const Napi::Error &e) 27 | { 28 | Napi::HandleScope scope(Env()); 29 | Napi::String error = Napi::String::New(Env(), e.Message()); 30 | deferred_.Reject(error); 31 | 32 | // Call empty function 33 | Callback().Call({error}); 34 | } 35 | 36 | void NnWorker::OnOK() 37 | { 38 | Napi::Env env = Env(); 39 | Napi::HandleScope scope(env); 40 | Napi::Array result = Napi::Array::New(env, result_.size()); 41 | 42 | for (unsigned int i = 0; i < result_.size(); i++) 43 | { 44 | Napi::Object obj = Napi::Object::New(env); 45 | 46 | obj.Set(Napi::String::New(env, "label"), Napi::String::New(env, result_[i].label)); 47 | obj.Set(Napi::String::New(env, "value"), Napi::Number::New(env, result_[i].value)); 48 | 49 | result.Set(i, obj); 50 | } 51 | 52 | deferred_.Resolve(result); 53 | 54 | // Call empty function 55 | if (!Callback().IsEmpty()) 56 | { 57 | Callback().Call({env.Null(), result}); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/nnWorker.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef NN_WORKER_H 3 | #define NN_WORKER_H 4 | 5 | #include 6 | #include "wrapper.h" 7 | 8 | class NnWorker : public Napi::AsyncWorker 9 | { 10 | public: 11 | NnWorker( 12 | std::string query, 13 | int32_t k, 14 | Wrapper *wrapper, 15 | Napi::Promise::Deferred deferred, 16 | Napi::Function &callback) 17 | : Napi::AsyncWorker(callback), 18 | deferred_(deferred), 19 | query_(query), 20 | k_(k), 21 | wrapper_(wrapper), 22 | result_(){}; 23 | 24 | ~NnWorker(){}; 25 | 26 | Napi::Promise::Deferred deferred_; 27 | 28 | void Execute(); 29 | void OnOK(); 30 | void OnError(const Napi::Error &e); 31 | 32 | private: 33 | std::string query_; 34 | int32_t k_; 35 | Wrapper *wrapper_; 36 | std::vector result_; 37 | }; 38 | 39 | #endif 40 | -------------------------------------------------------------------------------- /src/node-argument.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Convert node object arugment to standar argv C or C++ argument 3 | * 4 | * Author: Yusuf Syaifudin 5 | * Date: December 6, 2016 10:57 AM 6 | * 7 | */ 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include // for std::find 14 | #include // for std::begin, std::end 15 | 16 | #include "node-argument.h" 17 | 18 | namespace NodeArgument 19 | { 20 | 21 | bool NodeArgument::isOnlyDouble(const char *str) 22 | { 23 | char *endptr = 0; 24 | strtod(str, &endptr); 25 | if (*endptr != '\0' || endptr == str) 26 | { 27 | return false; 28 | } 29 | return true; 30 | } 31 | 32 | /** 33 | * Concenate string 34 | */ 35 | char *NodeArgument::concat(const char *s1, const char *s2) 36 | { 37 | char *result = (char *)malloc(strlen(s1) + strlen(s2) + 1); //+1 for the zero-terminator 38 | //in real code you would check for errors in malloc here 39 | strcpy(result, s1); 40 | strcat(result, s2); 41 | return result; 42 | } 43 | 44 | /** 45 | * Add string to parameter argument 46 | */ 47 | int NodeArgument::AddStringArgument(char ***strings, size_t *count, const char *newStr) 48 | { 49 | char *copy; 50 | char **p; 51 | 52 | if (strings == NULL || newStr == NULL || 53 | (copy = (char *)malloc(strlen(newStr) + 1)) == NULL) 54 | { 55 | 56 | return 0; 57 | } 58 | 59 | strcpy(copy, newStr); 60 | 61 | if ((p = (char **)realloc(*strings, (*count + 1) * sizeof(char *))) == NULL) 62 | { 63 | free(copy); 64 | return 0; 65 | } 66 | 67 | *strings = p; 68 | 69 | (*strings)[(*count)++] = copy; 70 | 71 | return 1; 72 | } 73 | 74 | /** 75 | * Print all argument value 76 | */ 77 | void NodeArgument::PrintArguments(char **strings, size_t count) 78 | { 79 | printf("BEGIN\n"); 80 | if (strings != NULL) 81 | { 82 | while (count--) 83 | { 84 | printf(" %s\n", *strings++); 85 | } 86 | } 87 | printf("END\n"); 88 | } 89 | 90 | CArgument NodeArgument::NapiObjectToCArgument(Napi::Env env, Napi::Object obj) 91 | { 92 | Napi::HandleScope scope(env); 93 | Napi::Array props = obj.GetPropertyNames(); 94 | 95 | char **arguments = NULL; 96 | size_t count = 0; 97 | 98 | uint32_t indexLen = 0; 99 | if (!props.IsEmpty()) 100 | { 101 | indexLen = props.Length(); 102 | 103 | // for validation 104 | std::string permitted_command[] = { 105 | "input", "test", "output", "lr", "lrUpdateRate", 106 | "dim", "ws", "epoch", "minCount", "minCountLabel", "neg", 107 | "wordNgrams", "loss", "bucket", "minn", "maxn", 108 | "thread", "t", "label", "verbose", "pretrainedVectors", 109 | "cutoff", "dsub", "qnorm", "qout", "retrain"}; 110 | 111 | for (uint32_t i = 0; i < indexLen; ++i) 112 | { 113 | Napi::String key = props.Get(i).As(); 114 | 115 | std::string keyValue = key.Utf8Value(); 116 | char *theKey = (char *)keyValue.c_str(); 117 | 118 | bool exists = std::find(std::begin(permitted_command), 119 | std::end(permitted_command), keyValue) != std::end(permitted_command); 120 | 121 | if (!exists) 122 | { 123 | throw "Unknown argument: " + keyValue; 124 | } 125 | 126 | Napi::Value value = obj.Get(keyValue); 127 | NodeArgument::AddStringArgument(&arguments, &count, NodeArgument::concat("-", theKey)); 128 | 129 | if (!value.IsBoolean()) 130 | { 131 | std::string valueValue = value.ToString().Utf8Value(); 132 | // std::cout << "OKKKKK!!!" << keyValue << ": " << valueValue << std::endl; 133 | char *theValue = (char *)valueValue.c_str(); 134 | NodeArgument::AddStringArgument(&arguments, &count, theValue); 135 | } 136 | } 137 | } 138 | 139 | CArgument response = {count, arguments}; 140 | return response; 141 | } 142 | 143 | Napi::Object NodeArgument::mapToNapiObject(Napi::Env env, std::map obj) 144 | { 145 | Napi::Object result = Napi::Object::New(env); 146 | 147 | for (auto const &iterator : obj) 148 | { 149 | Napi::Value value; 150 | 151 | if (isOnlyDouble(iterator.second.c_str())) 152 | { 153 | value = Napi::Number::New(env, atof(iterator.second.c_str())); 154 | } 155 | else 156 | { 157 | value = Napi::String::New(env, iterator.second.c_str()); 158 | } 159 | result.Set(Napi::String::New(env, iterator.first.c_str()), value); 160 | } 161 | return result; 162 | } 163 | 164 | } // namespace NodeArgument 165 | -------------------------------------------------------------------------------- /src/node-argument.h: -------------------------------------------------------------------------------- 1 | /** 2 | * This is header file to convert node object arugment to standar argv C or C++ argument 3 | * 4 | * Author: Yusuf Syaifudin 5 | * Date: December 6, 2016 10:57 AM 6 | * 7 | */ 8 | 9 | #ifndef NODEARGUMENT_NODEARGUMENT_H 10 | #define NODEARGUMENT_NODEARGUMENT_H 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include // for std::find 20 | #include // for std::begin, std::end 21 | 22 | namespace NodeArgument 23 | { 24 | struct CArgument 25 | { 26 | size_t argc; 27 | char **argv; 28 | }; 29 | 30 | class NodeArgument 31 | { 32 | 33 | public: 34 | char *concat(const char *s1, const char *s2); 35 | int AddStringArgument(char ***strings, size_t *count, const char *newStr); 36 | void PrintArguments(char **strings, size_t count); 37 | bool isOnlyDouble(const char *str); 38 | CArgument NapiObjectToCArgument(Napi::Env env, Napi::Object obj); 39 | Napi::Object mapToNapiObject(Napi::Env env, std::map obj); 40 | }; 41 | } // namespace NodeArgument 42 | 43 | #endif 44 | -------------------------------------------------------------------------------- /src/node-util.cc: -------------------------------------------------------------------------------- 1 | #include "node-util.h" 2 | 3 | Napi::Value EmptyCallback(const Napi::CallbackInfo &info) 4 | { 5 | Napi::Env env = info.Env(); 6 | Napi::HandleScope scope(env); 7 | 8 | return env.Undefined(); 9 | } 10 | -------------------------------------------------------------------------------- /src/node-util.h: -------------------------------------------------------------------------------- 1 | #ifndef NODE_UTIL_H 2 | #define NODE_UTIL_H 3 | 4 | #include 5 | 6 | Napi::Value EmptyCallback(const Napi::CallbackInfo &info); 7 | 8 | #endif 9 | -------------------------------------------------------------------------------- /src/predictWorker.cc: -------------------------------------------------------------------------------- 1 | #include "predictWorker.h" 2 | 3 | void PredictWorker::Execute() 4 | { 5 | try 6 | { 7 | wrapper_->loadModel(); 8 | result_ = wrapper_->predict(sentence_, k_); 9 | } 10 | catch (std::string errorMessage) 11 | { 12 | SetError(errorMessage.c_str()); 13 | } 14 | catch (const char *str) 15 | { 16 | SetError(str); 17 | } 18 | catch (const std::exception &e) 19 | { 20 | SetError(e.what()); 21 | } 22 | } 23 | 24 | void PredictWorker::OnError(const Napi::Error &e) 25 | { 26 | Napi::HandleScope scope(Env()); 27 | Napi::String error = Napi::String::New(Env(), e.Message()); 28 | deferred_.Reject(error); 29 | 30 | // Call empty function 31 | Callback().Call({error}); 32 | } 33 | 34 | void PredictWorker::OnOK() 35 | { 36 | Napi::Env env = Env(); 37 | Napi::HandleScope scope(env); 38 | Napi::Array result = Napi::Array::New(env, result_.size()); 39 | 40 | for (unsigned int i = 0; i < result_.size(); i++) 41 | { 42 | Napi::Object obj = Napi::Object::New(env); 43 | 44 | obj.Set(Napi::String::New(env, "label"), Napi::String::New(env, result_[i].label)); 45 | obj.Set(Napi::String::New(env, "value"), Napi::Number::New(env, result_[i].value)); 46 | 47 | result.Set(i, obj); 48 | } 49 | 50 | deferred_.Resolve(result); 51 | 52 | // Call empty function 53 | if (!Callback().IsEmpty()) 54 | { 55 | Callback().Call({env.Null(), result}); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /src/predictWorker.h: -------------------------------------------------------------------------------- 1 | 2 | #ifndef PREDICT_WORKER_H 3 | #define PREDICT_WORKER_H 4 | 5 | #include 6 | #include "wrapper.h" 7 | 8 | class PredictWorker : public Napi::AsyncWorker 9 | { 10 | public: 11 | PredictWorker( 12 | std::string sentence, 13 | int32_t k, 14 | Wrapper *wrapper, 15 | Napi::Promise::Deferred deferred, 16 | Napi::Function &callback) 17 | : Napi::AsyncWorker(callback), 18 | deferred_(deferred), 19 | sentence_(sentence), 20 | wrapper_(wrapper), 21 | result_(), 22 | k_(k){}; 23 | 24 | ~PredictWorker(){}; 25 | 26 | Napi::Promise::Deferred deferred_; 27 | 28 | void Execute(); 29 | void OnOK(); 30 | void OnError(const Napi::Error &e); 31 | 32 | private: 33 | std::string sentence_; 34 | Wrapper *wrapper_; 35 | std::vector result_; 36 | int32_t k_; 37 | }; 38 | 39 | #endif 40 | -------------------------------------------------------------------------------- /src/quantize.cc: -------------------------------------------------------------------------------- 1 | #include "node-argument.h" 2 | #include "quantize.h" 3 | 4 | void QuantizeWorker::Execute() 5 | { 6 | try 7 | { 8 | result_ = wrapper_->quantize(args_); 9 | } 10 | catch (const std::string errorMessage) 11 | { 12 | SetError(errorMessage.c_str()); 13 | } 14 | catch (const char *str) 15 | { 16 | SetError(str); 17 | } 18 | catch (const std::exception &e) 19 | { 20 | SetError(e.what()); 21 | } 22 | } 23 | 24 | void QuantizeWorker::OnError(const Napi::Error &e) 25 | { 26 | Napi::HandleScope scope(Env()); 27 | Napi::String error = Napi::String::New(Env(), e.Message()); 28 | deferred_.Reject(error); 29 | 30 | // Call empty function 31 | Callback().Call({error}); 32 | } 33 | 34 | void QuantizeWorker::OnOK() 35 | { 36 | NodeArgument::NodeArgument nodeArg; 37 | Napi::Object result = nodeArg.mapToNapiObject(Env(), result_); 38 | 39 | deferred_.Resolve(result); 40 | 41 | // Call empty function 42 | if (!Callback().IsEmpty()) 43 | { 44 | Callback().Call({Env().Null(), result}); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/quantize.h: -------------------------------------------------------------------------------- 1 | #ifndef QUANTIZE_H 2 | #define QUANTIZE_H 3 | 4 | #include 5 | #include "wrapper.h" 6 | 7 | class QuantizeWorker : public Napi::AsyncWorker 8 | { 9 | public: 10 | QuantizeWorker( 11 | const std::vector args, 12 | Wrapper *wrapper, 13 | Napi::Promise::Deferred deferred, 14 | Napi::Function &callback) 15 | : Napi::AsyncWorker(callback), 16 | deferred_(deferred), 17 | args_(args), 18 | wrapper_(wrapper), 19 | result_(){}; 20 | 21 | ~QuantizeWorker(){}; 22 | 23 | Napi::Promise::Deferred deferred_; 24 | 25 | void Execute(); 26 | void OnOK(); 27 | void OnError(const Napi::Error &e); 28 | 29 | private: 30 | const std::vector args_; 31 | Wrapper *wrapper_; 32 | std::map result_; 33 | }; 34 | 35 | #endif 36 | -------------------------------------------------------------------------------- /src/query.cc: -------------------------------------------------------------------------------- 1 | #include "query.h" 2 | 3 | Napi::FunctionReference FasttextQuery::constructor; 4 | 5 | Napi::Object FasttextQuery::Init(Napi::Env env, Napi::Object exports) 6 | { 7 | Napi::HandleScope scope(env); 8 | Napi::Function func = DefineClass(env, "FasttextQuery", 9 | {InstanceMethod("nn", &FasttextQuery::Nn)}); 10 | 11 | constructor = Napi::Persistent(func); 12 | constructor.SuppressDestruct(); 13 | 14 | exports.Set("Query", func); 15 | return exports; 16 | } 17 | 18 | FasttextQuery::FasttextQuery(const Napi::CallbackInfo &info) : Napi::ObjectWrap(info) 19 | { 20 | Napi::Env env = info.Env(); 21 | Napi::HandleScope scope(env); 22 | 23 | if (info.Length() == 0 || !info[0].IsString()) 24 | { 25 | Napi::TypeError::New(env, "Path to model file is missing!").ThrowAsJavaScriptException(); 26 | } 27 | 28 | std::string modelFileName = info[0].As().Utf8Value(); 29 | this->wrapper_ = new Wrapper(modelFileName); 30 | } 31 | 32 | Napi::Value FasttextQuery::Nn(const Napi::CallbackInfo &info) 33 | { 34 | Napi::Env env = info.Env(); 35 | Napi::HandleScope scope(env); 36 | Napi::Function callback = Napi::Function::New(env, EmptyCallback); 37 | int32_t k = 10; 38 | 39 | if (info.Length() == 0 || !info[0].IsString()) 40 | { 41 | Napi::TypeError::New(env, "query must be a string").ThrowAsJavaScriptException(); 42 | } 43 | 44 | if (info.Length() > 1) 45 | { 46 | if (info[1].IsNumber()) 47 | { 48 | k = info[1].As().Int32Value(); 49 | } 50 | else if (info[1].IsFunction()) 51 | { 52 | callback = info[1].As(); 53 | } 54 | } 55 | 56 | if (info.Length() > 2 && info[2].IsFunction()) 57 | { 58 | callback = info[1].As(); 59 | } 60 | 61 | Napi::Promise::Deferred deferred = Napi::Promise::Deferred::New(info.Env()); 62 | Napi::String query = info[0].As(); 63 | 64 | NnWorker *worker = new NnWorker(query, k, this->wrapper_, deferred, callback); 65 | worker->Queue(); 66 | 67 | return worker->deferred_.Promise(); 68 | } 69 | -------------------------------------------------------------------------------- /src/query.h: -------------------------------------------------------------------------------- 1 | #ifndef FASTTEXT_QUERY_H 2 | #define FASTTEXT_QUERY_H 3 | 4 | #include 5 | #include "wrapper.h" 6 | #include "nnWorker.h" 7 | #include "node-util.h" 8 | 9 | class FasttextQuery : public Napi::ObjectWrap 10 | { 11 | public: 12 | static Napi::Object Init(Napi::Env env, Napi::Object exports); 13 | FasttextQuery(const Napi::CallbackInfo &info); 14 | 15 | /** 16 | * Destructor 17 | */ 18 | ~FasttextQuery() 19 | { 20 | delete wrapper_; 21 | } 22 | 23 | private: 24 | static Napi::FunctionReference constructor; 25 | Napi::Value Nn(const Napi::CallbackInfo &info); 26 | 27 | Wrapper *wrapper_; 28 | }; 29 | 30 | #endif 31 | -------------------------------------------------------------------------------- /src/train.cc: -------------------------------------------------------------------------------- 1 | #include "train.h" 2 | #include "node-argument.h" 3 | 4 | void TrainWorker::Execute() 5 | { 6 | try 7 | { 8 | result_ = wrapper_->train(args_); 9 | } 10 | catch (std::string errorMessage) 11 | { 12 | SetError(errorMessage.c_str()); 13 | } 14 | catch (const char *str) 15 | { 16 | SetError(str); 17 | } 18 | catch (const std::exception &e) 19 | { 20 | SetError(e.what()); 21 | } 22 | } 23 | 24 | void TrainWorker::OnError(const Napi::Error &e) 25 | { 26 | Napi::HandleScope scope(Env()); 27 | Napi::String error = Napi::String::New(Env(), e.Message()); 28 | deferred_.Reject(error); 29 | 30 | // Call empty function 31 | Callback().Call({error}); 32 | } 33 | 34 | void TrainWorker::OnOK() 35 | { 36 | Napi::Env env = Env(); 37 | Napi::HandleScope scope(env); 38 | 39 | NodeArgument::NodeArgument nodeArg; 40 | Napi::Object result = nodeArg.mapToNapiObject(env, result_); 41 | deferred_.Resolve(result); 42 | 43 | // Call empty function 44 | if (!Callback().IsEmpty()) 45 | { 46 | Callback().Call({env.Null(), result}); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/train.h: -------------------------------------------------------------------------------- 1 | #ifndef TRAIN_WORKER_H 2 | #define TRAIN_WORKER_H 3 | 4 | #include 5 | #include "wrapper.h" 6 | 7 | class TrainWorker : public Napi::AsyncWorker 8 | { 9 | public: 10 | TrainWorker(const std::vector args, Wrapper *wrapper, Napi::Promise::Deferred deferred, Napi::Function &callback) 11 | : Napi::AsyncWorker(callback), 12 | args_(args), 13 | wrapper_(wrapper), 14 | deferred_(deferred), 15 | result_(){}; 16 | 17 | ~TrainWorker(){}; 18 | 19 | Napi::Promise::Deferred deferred_; 20 | 21 | void Execute(); 22 | void OnOK(); 23 | void OnError(const Napi::Error &e); 24 | 25 | private: 26 | const std::vector args_; 27 | Wrapper *wrapper_; 28 | std::map result_; 29 | }; 30 | 31 | #endif 32 | -------------------------------------------------------------------------------- /src/wrapper.cc: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include "wrapper.h" 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | using fasttext::entry_type; 17 | using fasttext::model_name; 18 | 19 | constexpr int32_t FASTTEXT_VERSION = 12; /* Version 1b */ 20 | constexpr int32_t FASTTEXT_FILEFORMAT_MAGIC_INT32 = 793712314; 21 | 22 | Wrapper::Wrapper(std::string modelFilename) 23 | : quant_(false), 24 | modelFilename_(modelFilename), 25 | isLoaded_(false), 26 | isPrecomputed_(false) {} 27 | 28 | bool Wrapper::fileExist(const std::string &filename) 29 | { 30 | if (FILE *file = fopen(filename.c_str(), "r")) 31 | { 32 | fclose(file); 33 | return true; 34 | } 35 | else 36 | { 37 | return false; 38 | } 39 | } 40 | 41 | void Wrapper::getVector(Vector &vec, const std::string &word) 42 | { 43 | const std::vector &ngrams = dict_->getSubwords(word); 44 | vec.zero(); 45 | for (auto it = ngrams.begin(); it != ngrams.end(); ++it) 46 | { 47 | vec.addRow(*input_, *it); 48 | } 49 | if (ngrams.size() > 0) 50 | { 51 | vec.mul(1.0 / ngrams.size()); 52 | } 53 | } 54 | 55 | bool Wrapper::checkModel(std::istream &in) 56 | { 57 | int32_t magic; 58 | int32_t version; 59 | in.read((char *)&(magic), sizeof(int32_t)); 60 | if (magic != FASTTEXT_FILEFORMAT_MAGIC_INT32) 61 | { 62 | return false; 63 | } 64 | in.read((char *)&(version), sizeof(int32_t)); 65 | if (version != FASTTEXT_VERSION) 66 | { 67 | return false; 68 | } 69 | return true; 70 | } 71 | 72 | void Wrapper::signModel(std::ostream &out) 73 | { 74 | const int32_t magic = FASTTEXT_FILEFORMAT_MAGIC_INT32; 75 | const int32_t version = FASTTEXT_VERSION; 76 | out.write((char *)&(magic), sizeof(int32_t)); 77 | out.write((char *)&(version), sizeof(int32_t)); 78 | } 79 | 80 | std::map Wrapper::loadModel() 81 | { 82 | return loadModel(this->modelFilename_); 83 | } 84 | 85 | std::map Wrapper::loadModel(std::string filename) 86 | { 87 | if (isLoaded_) 88 | { 89 | return getModelInfo(); 90 | } 91 | mtx_.lock(); 92 | if (isLoaded_) 93 | { 94 | mtx_.unlock(); 95 | return getModelInfo(); 96 | } 97 | std::ifstream ifs(filename, std::ifstream::binary); 98 | if (!ifs.is_open()) 99 | { 100 | throw "Model file cannot be opened: " + filename; 101 | } 102 | if (!checkModel(ifs)) 103 | { 104 | throw "Model file has wrong file format!"; 105 | } 106 | std::map info = loadModel(ifs); 107 | ifs.close(); 108 | isLoaded_ = true; 109 | mtx_.unlock(); 110 | 111 | return info; 112 | } 113 | 114 | std::map Wrapper::loadModel(std::istream &in) 115 | { 116 | args_ = std::make_shared(); 117 | dict_ = std::make_shared(args_); 118 | input_ = std::make_shared(); 119 | output_ = std::make_shared(); 120 | qinput_ = std::make_shared(); 121 | qoutput_ = std::make_shared(); 122 | 123 | args_->load(in); 124 | dict_->load(in); 125 | 126 | bool quant_input; 127 | in.read((char *)&quant_input, sizeof(bool)); 128 | if (quant_input) 129 | { 130 | quant_ = true; 131 | qinput_->load(in); 132 | } 133 | else 134 | { 135 | input_->load(in); 136 | } 137 | 138 | in.read((char *)&args_->qout, sizeof(bool)); 139 | if (quant_ && args_->qout) 140 | { 141 | qoutput_->load(in); 142 | } 143 | else 144 | { 145 | output_->load(in); 146 | } 147 | 148 | model_ = std::make_shared(input_, output_, args_, 0); 149 | model_->quant_ = quant_; 150 | model_->setQuantizePointer(qinput_, qoutput_, args_->qout); 151 | 152 | if (args_->model == model_name::sup) 153 | { 154 | model_->setTargetCounts(dict_->getCounts(entry_type::label)); 155 | } 156 | else 157 | { 158 | model_->setTargetCounts(dict_->getCounts(entry_type::word)); 159 | } 160 | 161 | return getModelInfo(); 162 | } 163 | 164 | std::map Wrapper::getModelInfo() 165 | { 166 | 167 | std::map response; 168 | // dictionary 169 | response["word_count"] = std::to_string(dict_->nwords()); 170 | response["label_count"] = std::to_string(dict_->nlabels()); 171 | response["token_count"] = std::to_string(dict_->ntokens()); 172 | // arguments 173 | response["lr"] = std::to_string(args_->lr); 174 | response["dim"] = std::to_string(args_->dim); 175 | response["ws"] = std::to_string(args_->ws); 176 | response["epoch"] = std::to_string(args_->epoch); 177 | response["minCount"] = std::to_string(args_->minCount); 178 | response["minCountLabel"] = std::to_string(args_->minCountLabel); 179 | response["neg"] = std::to_string(args_->neg); 180 | response["wordNgrams"] = std::to_string(args_->wordNgrams); 181 | 182 | std::string loss_name = ""; 183 | if (args_->loss == fasttext::loss_name::hs) 184 | { 185 | loss_name = "hs"; 186 | } 187 | else if (args_->loss == fasttext::loss_name::ns) 188 | { 189 | loss_name = "ns"; 190 | } 191 | else if (args_->loss == fasttext::loss_name::softmax) 192 | { 193 | loss_name = "softmax"; 194 | } 195 | 196 | std::string model_name = ""; 197 | if (args_->model == fasttext::model_name::cbow) 198 | { 199 | model_name = "cbow"; 200 | } 201 | else if (args_->model == fasttext::model_name::sup) 202 | { 203 | model_name = "supervised"; 204 | } 205 | else if (args_->model == fasttext::model_name::sg) 206 | { 207 | model_name = "skipgram"; 208 | } 209 | 210 | response["loss"] = loss_name; 211 | response["model"] = model_name; 212 | response["bucket"] = std::to_string(args_->bucket); 213 | response["minn"] = std::to_string(args_->minn); 214 | response["maxn"] = std::to_string(args_->maxn); 215 | response["thread"] = std::to_string(args_->thread); 216 | response["lrUpdateRate"] = std::to_string(args_->lrUpdateRate); 217 | response["t"] = std::to_string(args_->t); 218 | response["label"] = args_->label; 219 | response["verbose"] = std::to_string(args_->verbose); 220 | response["pretrainedVectors"] = args_->pretrainedVectors; 221 | 222 | // `-quantize` arguments 223 | response["cutoff"] = std::to_string(args_->cutoff); 224 | response["dsub"] = std::to_string(args_->dsub); 225 | response["qnorm"] = std::to_string(args_->qnorm); 226 | response["qout"] = std::to_string(args_->qout); 227 | response["retrain"] = std::to_string(args_->retrain); 228 | 229 | return response; 230 | } 231 | 232 | void Wrapper::precomputeWordVectors() 233 | { 234 | if (isPrecomputed_) 235 | { 236 | return; 237 | } 238 | precomputeMtx_.lock(); 239 | if (isPrecomputed_) 240 | { 241 | precomputeMtx_.unlock(); 242 | return; 243 | } 244 | Matrix wordVectors(dict_->nwords(), args_->dim); 245 | wordVectors_ = wordVectors; 246 | Vector vec(args_->dim); 247 | wordVectors_.zero(); 248 | for (int32_t i = 0; i < dict_->nwords(); i++) 249 | { 250 | std::string word = dict_->getWord(i); 251 | getVector(vec, word); 252 | real norm = vec.norm(); 253 | wordVectors_.addRow(vec, i, 1.0 / norm); 254 | } 255 | isPrecomputed_ = true; 256 | precomputeMtx_.unlock(); 257 | } 258 | 259 | std::vector Wrapper::findNN(const Vector &queryVec, int32_t k, 260 | const std::set &banSet) 261 | { 262 | 263 | real queryNorm = queryVec.norm(); 264 | if (std::abs(queryNorm) < 1e-8) 265 | { 266 | queryNorm = 1; 267 | } 268 | std::priority_queue> heap; 269 | Vector vec(args_->dim); 270 | for (int32_t i = 0; i < dict_->nwords(); i++) 271 | { 272 | std::string word = dict_->getWord(i); 273 | real dp = wordVectors_.dotRow(queryVec, i); 274 | heap.push(std::make_pair(dp / queryNorm, word)); 275 | } 276 | 277 | PredictResult response; 278 | std::vector arr; 279 | int32_t i = 0; 280 | while (i < k && heap.size() > 0) 281 | { 282 | auto it = banSet.find(heap.top().second); 283 | if (it == banSet.end()) 284 | { 285 | response = {heap.top().second, exp(heap.top().first)}; 286 | arr.push_back(response); 287 | i++; 288 | } 289 | heap.pop(); 290 | } 291 | return arr; 292 | } 293 | 294 | std::vector Wrapper::nn(std::string query, int32_t k) 295 | { 296 | Vector queryVec(args_->dim); 297 | std::set banSet; 298 | banSet.clear(); 299 | banSet.insert(query); 300 | getVector(queryVec, query); 301 | return findNN(queryVec, k, banSet); 302 | } 303 | 304 | std::vector Wrapper::predict(std::string sentence, int32_t k) 305 | { 306 | 307 | std::vector arr; 308 | std::vector words, labels; 309 | std::istringstream in(sentence); 310 | 311 | dict_->getLine(in, words, labels, model_->rng); 312 | 313 | // std::cerr << "Got line!" << std::endl; 314 | 315 | if (words.empty()) 316 | { 317 | return arr; 318 | } 319 | 320 | Vector hidden(args_->dim); 321 | Vector output(dict_->nlabels()); 322 | std::vector> modelPredictions; 323 | model_->predict(words, k, modelPredictions, hidden, output); 324 | 325 | PredictResult response; 326 | 327 | for (auto it = modelPredictions.cbegin(); it != modelPredictions.cend(); it++) 328 | { 329 | response = {dict_->getLabel(it->second), exp(it->first)}; 330 | arr.push_back(response); 331 | } 332 | 333 | return arr; 334 | } 335 | 336 | std::map Wrapper::train(const std::vector args) 337 | { 338 | std::shared_ptr a = std::make_shared(); 339 | a->parseArgs(args); 340 | 341 | std::string inputFilename = a->input; 342 | if (!fileExist(inputFilename)) 343 | { 344 | throw "Input file is not exist."; 345 | } 346 | 347 | std::cout << "Input <<<<< " << a->input << std::endl; 348 | std::cout << "Output >>>>> " << a->output + ".bin" << std::endl; 349 | 350 | fastText_.train(a); 351 | fastText_.saveModel(); 352 | fastText_.saveVectors(); 353 | return loadModel(a->output + ".bin"); 354 | } 355 | 356 | std::map Wrapper::quantize(const std::vector args) 357 | { 358 | std::shared_ptr a = std::make_shared(); 359 | a->parseArgs(args); 360 | 361 | if (!fileExist(a->input)) 362 | { 363 | throw "Input file is not exist."; 364 | } 365 | 366 | std::cout << "Input: " << a->input << std::endl; 367 | std::cout << "Model: " << a->output + ".bin" << std::endl; 368 | std::cout << "Quantized: " << a->output + ".ftz" << std::endl; 369 | 370 | // parseArgs checks if a->output is given. 371 | fastText_.loadModel(a->output + ".bin"); 372 | fastText_.quantize(a); 373 | fastText_.saveModel(); 374 | return loadModel(a->output + ".ftz"); 375 | } 376 | -------------------------------------------------------------------------------- /src/wrapper.h: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef WRAPPER_H 4 | #define WRAPPER_H 5 | 6 | // #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "../fastText/src/fasttext.h" 15 | 16 | using fasttext::Args; 17 | using fasttext::Dictionary; 18 | using fasttext::FastText; 19 | using fasttext::Matrix; 20 | using fasttext::Model; 21 | using fasttext::QMatrix; 22 | using fasttext::real; 23 | using fasttext::Vector; 24 | 25 | struct PredictResult 26 | { 27 | std::string label; 28 | double value; 29 | }; 30 | 31 | class Wrapper 32 | { 33 | private: 34 | std::shared_ptr args_; 35 | std::shared_ptr dict_; 36 | 37 | std::shared_ptr input_; 38 | std::shared_ptr output_; 39 | 40 | std::shared_ptr qinput_; 41 | std::shared_ptr qoutput_; 42 | 43 | std::shared_ptr model_; 44 | Matrix wordVectors_; 45 | FastText fastText_; 46 | 47 | // std::atomic tokenCount; 48 | // clock_t start; 49 | 50 | void signModel(std::ostream &); 51 | bool checkModel(std::istream &); 52 | 53 | std::vector findNN(const Vector &, int32_t, 54 | const std::set &); 55 | 56 | std::map loadModel(std::istream &); 57 | 58 | bool quant_; 59 | std::string modelFilename_; 60 | std::mutex mtx_; 61 | std::mutex precomputeMtx_; 62 | 63 | bool isLoaded_; 64 | bool isPrecomputed_; 65 | 66 | bool isModelLoaded() { return isLoaded_; } 67 | bool fileExist(const std::string &filename); 68 | std::map getModelInfo(); 69 | 70 | public: 71 | Wrapper(std::string modelFilename); 72 | 73 | void getVector(Vector &, const std::string &); 74 | 75 | std::vector predict(std::string sentence, int32_t k); 76 | std::vector nn(std::string query, int32_t k); 77 | std::map train(const std::vector args); 78 | std::map quantize(const std::vector args); 79 | 80 | void precomputeWordVectors(); 81 | std::map loadModel(); 82 | std::map loadModel(std::string filename); 83 | }; 84 | 85 | #endif 86 | -------------------------------------------------------------------------------- /tea.yaml: -------------------------------------------------------------------------------- 1 | # https://tea.xyz/what-is-this-file 2 | --- 3 | version: 1.0.0 4 | codeOwners: 5 | - '0x452244cFD2293a8a9270bCc725eFc6924663B1B6' 6 | quorum: 1 7 | -------------------------------------------------------------------------------- /test/data/cooking.stackexchange.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vunb/node-fasttext/215e3b7536669dbfa9c7da6c5bfe8b3d37b66ed1/test/data/cooking.stackexchange.tar.gz -------------------------------------------------------------------------------- /test/models/lid.176.ftz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vunb/node-fasttext/215e3b7536669dbfa9c7da6c5bfe8b3d37b66ed1/test/models/lid.176.ftz -------------------------------------------------------------------------------- /test/models/model_cooking.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vunb/node-fasttext/215e3b7536669dbfa9c7da6c5bfe8b3d37b66ed1/test/models/model_cooking.bin -------------------------------------------------------------------------------- /test/specs/fastText.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | const test = require('tape'); 3 | const path = require('path'); 4 | const fastText = require('../../index'); 5 | 6 | test('fastText classifier', function (t) { 7 | t.plan(3) 8 | 9 | let model_path = path.resolve(path.join(__dirname, '../models/model_cooking.bin')) 10 | console.log('File model path: ' + model_path) 11 | let classifier = new fastText.Classifier(); 12 | 13 | classifier.loadModel(model_path) 14 | .then((info) => { 15 | console.log('load model success!!!', info); 16 | t.equal(info.model, 'supervised', 'load supervised model'); 17 | return classifier.predict('Why not put knives in the dishwasher?', 5, 1); 18 | }) 19 | .then((res) => { 20 | t.equal(res.length, 5, 'number of classifications output') 21 | t.equal(res[0].label, '__label__knives', 'output is __label__knives'); 22 | }) 23 | .catch((err) => { 24 | console.log('Result: ', err); 25 | t.fail(err); 26 | }) 27 | }) -------------------------------------------------------------------------------- /test/specs/langid.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | const test = require('tape'); 3 | const path = require('path'); 4 | const fastText = require('../../index'); 5 | 6 | test('fastText language identification', function (t) { 7 | t.plan(3) 8 | 9 | let model_path = path.resolve(path.join(__dirname, '../models/lid.176.ftz')) 10 | console.log('File model path: ' + model_path) 11 | let classifier = new fastText.Classifier(model_path); 12 | 13 | classifier.predict('sử dụng vntk với fastext rất tuyệt?', 5) 14 | .then((res) => { 15 | t.equal(res.length, 5, 'number of classifications output') 16 | t.equal(res[0].label, '__label__vi', 'output is __label__vi'); 17 | t.true(res[0].value > 0.99, 'confidence is 99%'); 18 | console.log('Result: ', res); 19 | }) 20 | .catch((err) => { 21 | t.fail(err); 22 | }); 23 | }) 24 | -------------------------------------------------------------------------------- /test/specs/trainer.js: -------------------------------------------------------------------------------- 1 | 'use strict'; 2 | const test = require('tape'); 3 | const path = require('path'); 4 | const fastText = require('../../index'); 5 | 6 | test('fastText trainer', function (t) { 7 | t.plan(1) 8 | 9 | let data = path.resolve(path.join(__dirname, '../data/cooking.train.txt')); 10 | let model = path.resolve(path.join(__dirname, '../data/cooking.model')); 11 | let classifier = new fastText.Classifier(); 12 | let options = { 13 | input: data, 14 | output: model, 15 | loss: "softmax", 16 | dim: 200, 17 | bucket: 2000000 18 | } 19 | 20 | classifier.train('supervised', options) 21 | .then((res) => { 22 | console.log('model info after training:', res) 23 | t.equal(res.dim, 200, 'dim') 24 | }); 25 | }) 26 | 27 | test('fastText quantize', function (t) { 28 | t.plan(1) 29 | let input = path.resolve(path.join(__dirname, '../data/cooking.train.txt')); 30 | let output = path.resolve(path.join(__dirname, '../data/cooking.model')); 31 | let classifier = new fastText.Classifier(); 32 | let options = { 33 | input, 34 | output, 35 | epoch: 1, 36 | qnorm: true, 37 | qout: true, 38 | retrain: true, 39 | cutoff: 1000, 40 | }; 41 | 42 | classifier.train('quantize', options) 43 | .then((res) => { console.log(res) }) 44 | .catch((e) => { console.error(e) }); 45 | 46 | t.ok(true); 47 | }) 48 | -------------------------------------------------------------------------------- /test/start.js: -------------------------------------------------------------------------------- 1 | var path = require('path'); 2 | var dir = '../test/specs/'; 3 | 4 | [ 5 | 'fastText', 6 | 'langid', 7 | 'trainer', 8 | ].forEach((script) => { 9 | require(path.join(dir, script)); 10 | }); --------------------------------------------------------------------------------