├── test ├── support │ ├── unsupervised.txt │ └── supervised.txt ├── test_helper.rb ├── vectorizer_test.rb └── classifier_test.rb ├── lib ├── fasttext │ ├── version.rb │ ├── model.rb │ ├── vectorizer.rb │ └── classifier.rb └── fasttext.rb ├── Gemfile ├── .gitmodules ├── .gitignore ├── .github └── workflows │ └── build.yml ├── Rakefile ├── ext └── fasttext │ ├── extconf.rb │ └── ext.cpp ├── fasttext.gemspec ├── CHANGELOG.md ├── LICENSE.txt └── README.md /test/support/unsupervised.txt: -------------------------------------------------------------------------------- 1 | this is a test 2 | -------------------------------------------------------------------------------- /lib/fasttext/version.rb: -------------------------------------------------------------------------------- 1 | module FastText 2 | VERSION = "0.4.1" 3 | end 4 | -------------------------------------------------------------------------------- /test/test_helper.rb: -------------------------------------------------------------------------------- 1 | require "bundler/setup" 2 | Bundler.require(:default) 3 | require "minitest/autorun" 4 | -------------------------------------------------------------------------------- /Gemfile: -------------------------------------------------------------------------------- 1 | source "https://rubygems.org" 2 | 3 | gemspec 4 | 5 | gem "rake" 6 | gem "rake-compiler" 7 | gem "minitest" 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "vendor/fastText"] 2 | path = vendor/fastText 3 | url = https://github.com/ankane/fastText-fork.git 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.bundle/ 2 | /.yardoc 3 | /_yardoc/ 4 | /coverage/ 5 | /doc/ 6 | /pkg/ 7 | /spec/reports/ 8 | /tmp/ 9 | *.lock 10 | *.bin 11 | *.bundle 12 | Makefile 13 | -------------------------------------------------------------------------------- /test/support/supervised.txt: -------------------------------------------------------------------------------- 1 | __label__ham This is the first document 2 | __label__spam Hello, this is the second document 3 | __label__spam Hello, and this is the third one 4 | __label__ham Is this the first document? 5 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | on: [push, pull_request] 3 | jobs: 4 | build: 5 | strategy: 6 | fail-fast: false 7 | matrix: 8 | os: [ubuntu-latest, macos-latest, windows-latest] 9 | runs-on: ${{ matrix.os }} 10 | steps: 11 | - uses: actions/checkout@v5 12 | with: 13 | submodules: true 14 | - uses: ruby/setup-ruby@v1 15 | with: 16 | ruby-version: 3.4 17 | bundler-cache: true 18 | - run: bundle exec rake compile 19 | - run: bundle exec rake test 20 | -------------------------------------------------------------------------------- /Rakefile: -------------------------------------------------------------------------------- 1 | require "bundler/gem_tasks" 2 | require "rake/testtask" 3 | require "rake/extensiontask" 4 | 5 | Rake::TestTask.new do |t| 6 | t.pattern = "test/**/*_test.rb" 7 | end 8 | 9 | task default: :test 10 | 11 | Rake::ExtensionTask.new("fasttext") do |ext| 12 | ext.name = "ext" 13 | ext.lib_dir = "lib/fasttext" 14 | end 15 | 16 | task :check_license do 17 | raise "Missing vendor license" unless File.exist?("vendor/fastText/LICENSE") 18 | end 19 | 20 | task :remove_ext do 21 | path = "lib/fasttext/ext.bundle" 22 | File.unlink(path) if File.exist?(path) 23 | end 24 | 25 | Rake::Task["build"].enhance [:check_license, :remove_ext] 26 | -------------------------------------------------------------------------------- /ext/fasttext/extconf.rb: -------------------------------------------------------------------------------- 1 | require "mkmf-rice" 2 | 3 | # -march=native not supported with ARM Mac 4 | default_optflags = RbConfig::CONFIG["host_os"] =~ /darwin/i && RbConfig::CONFIG["host_cpu"] =~ /arm|aarch64/i ? "" : "-march=native" 5 | # -pthread and -O3 set by default 6 | $CXXFLAGS << " -std=c++17 $(optflags) -funroll-loops " << with_config("optflags", default_optflags) 7 | 8 | ext = File.expand_path(".", __dir__) 9 | fasttext = File.expand_path("../../vendor/fastText/src", __dir__) 10 | 11 | $srcs = Dir["{#{ext},#{fasttext}}/*.{cc,cpp}"] 12 | $INCFLAGS << " -I#{fasttext}" 13 | $VPATH << fasttext 14 | 15 | create_makefile("fasttext/ext") 16 | -------------------------------------------------------------------------------- /fasttext.gemspec: -------------------------------------------------------------------------------- 1 | require_relative "lib/fasttext/version" 2 | 3 | Gem::Specification.new do |spec| 4 | spec.name = "fasttext" 5 | spec.version = FastText::VERSION 6 | spec.summary = "Efficient text classification and representation learning for Ruby" 7 | spec.homepage = "https://github.com/ankane/fastText-ruby" 8 | spec.license = "MIT" 9 | 10 | spec.author = "Andrew Kane" 11 | spec.email = "andrew@ankane.org" 12 | 13 | spec.files = Dir["*.{md,txt}", "{lib,ext}/**/*", "vendor/fastText/{LICENSE,README.md}", "vendor/fastText/src/**/*"] 14 | spec.require_path = "lib" 15 | spec.extensions = ["ext/fasttext/extconf.rb"] 16 | 17 | spec.required_ruby_version = ">= 3.1" 18 | 19 | spec.add_dependency "rice", ">= 4.7" 20 | end 21 | -------------------------------------------------------------------------------- /test/vectorizer_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class VectorizerTest < Minitest::Test 4 | def test_works 5 | x = [ 6 | "this is a test" 7 | ] 8 | 9 | model = FastText::Vectorizer.new 10 | model.fit(x) 11 | 12 | assert model.nearest_neighbors("asparagus") 13 | assert model.analogies("berlin", "germany", "france") 14 | end 15 | 16 | def test_input_file 17 | model = FastText::Vectorizer.new 18 | model.fit("test/support/unsupervised.txt") 19 | 20 | assert model.nearest_neighbors("asparagus") 21 | assert model.analogies("berlin", "germany", "france") 22 | end 23 | 24 | def test_train_unsupervised 25 | model = FastText.train_unsupervised(input: "test/support/unsupervised.txt") 26 | 27 | assert model.nearest_neighbors("asparagus") 28 | assert model.analogies("berlin", "germany", "france") 29 | end 30 | end 31 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## 0.4.1 (2025-10-26) 2 | 3 | - Fixed error with Rice 4.7 4 | 5 | ## 0.4.0 (2024-10-22) 6 | 7 | - Dropped support for Ruby < 3.1 8 | 9 | ## 0.3.0 (2023-07-24) 10 | 11 | - Fixed installation error on Windows 12 | - Dropped support for Ruby < 3 13 | 14 | ## 0.2.4 (2022-01-16) 15 | 16 | - Improved ARM detection 17 | 18 | ## 0.2.3 (2021-11-15) 19 | 20 | - Fixed installation error with ARM Mac 21 | 22 | ## 0.2.2 (2021-10-16) 23 | 24 | - Fixed `file cannot be opened` errors 25 | 26 | ## 0.2.1 (2021-05-23) 27 | 28 | - Improved performance 29 | 30 | ## 0.2.0 (2021-05-17) 31 | 32 | - Updated to Rice 4 33 | - Dropped support for Ruby < 2.6 34 | 35 | ## 0.1.3 (2020-04-28) 36 | 37 | - Updated fastText to 0.9.2 38 | - Added support for autotune 39 | - Added `--with-optflags` option 40 | 41 | ## 0.1.2 (2020-01-10) 42 | 43 | - Fixed installation error with Ruby 2.7 44 | 45 | ## 0.1.1 (2019-10-26) 46 | 47 | - Fixed installation 48 | - Reduced gem size 49 | - Added support for multiple documents to `predict` method 50 | 51 | ## 0.1.0 (2019-10-26) 52 | 53 | - First release 54 | -------------------------------------------------------------------------------- /lib/fasttext.rb: -------------------------------------------------------------------------------- 1 | # ext 2 | require "fasttext/ext" 3 | 4 | # stdlib 5 | require "tempfile" 6 | 7 | # modules 8 | require_relative "fasttext/model" 9 | require_relative "fasttext/classifier" 10 | require_relative "fasttext/vectorizer" 11 | require_relative "fasttext/version" 12 | 13 | module FastText 14 | class Error < StandardError; end 15 | 16 | class << self 17 | def load_model(path) 18 | m = Ext::Model.new 19 | m.load_model(path) 20 | model = 21 | if m.supervised? 22 | FastText::Classifier.new 23 | else 24 | FastText::Vectorizer.new 25 | end 26 | model.instance_variable_set("@m", m) 27 | model 28 | end 29 | 30 | def train_supervised(**options) 31 | input = options.delete(:input) 32 | model = FastText::Classifier.new(**options) 33 | model.fit(input) 34 | model 35 | end 36 | 37 | def train_unsupervised(**options) 38 | input = options.delete(:input) 39 | model = FastText::Vectorizer.new(**options) 40 | model.fit(input) 41 | model 42 | end 43 | end 44 | end 45 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016-present, Facebook, Inc. 4 | Copyright (c) 2019-2024 Andrew Kane 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /lib/fasttext/model.rb: -------------------------------------------------------------------------------- 1 | module FastText 2 | class Model 3 | def initialize(**options) 4 | @options = options 5 | end 6 | 7 | def dimension 8 | m.dimension 9 | end 10 | 11 | def quantized? 12 | m.quantized? 13 | end 14 | 15 | def save_model(path) 16 | m.save_model(path) 17 | end 18 | 19 | def words(include_freq: false) 20 | words, freqs = m.words 21 | if include_freq 22 | words.zip(freqs).to_h 23 | else 24 | words 25 | end 26 | end 27 | 28 | def word_vector(word) 29 | m.word_vector(word) 30 | end 31 | 32 | def sentence_vector(text) 33 | m.sentence_vector(prep_text(text)) 34 | end 35 | 36 | def word_id(word) 37 | m.word_id(word) 38 | end 39 | 40 | def subword_id(subword) 41 | m.subword_id(subword) 42 | end 43 | 44 | def subwords(word) 45 | m.subwords(word) 46 | end 47 | 48 | private 49 | 50 | # text must end in newline for prediction to match Python and CLI 51 | def prep_text(text) 52 | text = "#{text}\n" unless text.end_with?("\n") 53 | text 54 | end 55 | 56 | def m 57 | @m || (raise Error, "Not fit") 58 | end 59 | 60 | def build_args(default_options) 61 | a = Ext::Args.new 62 | opts = @options.dup 63 | default_options.each do |k, v| 64 | a.send("#{k}=", opts.delete(k) || v) 65 | end 66 | raise ArgumentError, "Unknown argument: #{opts.keys.first}" if opts.any? 67 | a 68 | end 69 | end 70 | end 71 | -------------------------------------------------------------------------------- /lib/fasttext/vectorizer.rb: -------------------------------------------------------------------------------- 1 | module FastText 2 | class Vectorizer < Model 3 | DEFAULT_OPTIONS = { 4 | lr: 0.5, 5 | lr_update_rate: 100, 6 | dim: 100, 7 | ws: 5, 8 | epoch: 5, 9 | min_count: 1, 10 | min_count_label: 0, 11 | neg: 5, 12 | word_ngrams: 1, 13 | loss: "ns", 14 | model: "skipgram", 15 | bucket: 2000000, 16 | minn: 3, 17 | maxn: 6, 18 | thread: 3, 19 | t: 0.0001, 20 | verbose: 2, 21 | pretrained_vectors: "", 22 | save_output: false, 23 | seed: 0, 24 | autotune_validation_file: "", 25 | autotune_metric: "f1", 26 | autotune_predictions: 1, 27 | autotune_duration: 60 * 5, 28 | autotune_model_size: "" 29 | } 30 | 31 | def fit(x) 32 | @m ||= Ext::Model.new 33 | a = build_args(DEFAULT_OPTIONS) 34 | a.input, _ref = input_path(x) 35 | m.train(a) 36 | end 37 | 38 | def nearest_neighbors(word, k: 10) 39 | m.nearest_neighbors(word, k).map(&:reverse).to_h 40 | end 41 | 42 | def analogies(word_a, word_b, word_c, k: 10) 43 | m.analogies(k, word_a, word_b, word_c).map(&:reverse).to_h 44 | end 45 | 46 | private 47 | 48 | # separate example by newlines 49 | # https://github.com/facebookresearch/fastText/issues/518 50 | def input_path(x) 51 | if x.is_a?(String) 52 | [x, nil] 53 | else 54 | tempfile = Tempfile.new("fasttext") 55 | x.each do |xi| 56 | tempfile.write(xi.gsub("\n", " ")) # replace newlines in document 57 | tempfile.write("\n") 58 | end 59 | tempfile.close 60 | [tempfile.path, tempfile] 61 | end 62 | end 63 | end 64 | end 65 | -------------------------------------------------------------------------------- /lib/fasttext/classifier.rb: -------------------------------------------------------------------------------- 1 | module FastText 2 | class Classifier < Model 3 | DEFAULT_OPTIONS = { 4 | lr: 0.1, 5 | lr_update_rate: 100, 6 | dim: 100, 7 | ws: 5, 8 | epoch: 5, 9 | min_count: 1, 10 | min_count_label: 0, 11 | neg: 5, 12 | word_ngrams: 1, 13 | loss: "softmax", 14 | model: "supervised", 15 | bucket: 2000000, 16 | minn: 0, 17 | maxn: 0, 18 | thread: 3, 19 | t: 0.0001, 20 | label_prefix: "__label__", 21 | verbose: 2, 22 | pretrained_vectors: "", 23 | save_output: false, 24 | seed: 0, 25 | autotune_validation_file: "", 26 | autotune_metric: "f1", 27 | autotune_predictions: 1, 28 | autotune_duration: 60 * 5, 29 | autotune_model_size: "" 30 | } 31 | 32 | def fit(x, y = nil, autotune_set: nil) 33 | input, _ref = input_path(x, y) 34 | @m ||= Ext::Model.new 35 | a = build_args(DEFAULT_OPTIONS) 36 | a.input = input 37 | a.model = "supervised" 38 | if autotune_set 39 | x, y = autotune_set 40 | a.autotune_validation_file, _autotune_ref = input_path(x, y) 41 | end 42 | m.train(a) 43 | end 44 | 45 | def predict(text, k: 1, threshold: 0.0) 46 | multiple = text.is_a?(Array) 47 | text = [text] unless multiple 48 | 49 | # TODO predict multiple in C++ for performance 50 | result = 51 | text.map do |t| 52 | m.predict(prep_text(t), k, threshold).to_h do |v| 53 | [remove_prefix(v[1]), v[0]] 54 | end 55 | end 56 | 57 | multiple ? result : result.first 58 | end 59 | 60 | def test(x, y = nil, k: 1) 61 | input, _ref = input_path(x, y) 62 | res = m.test(input, k) 63 | { 64 | examples: res[0], 65 | precision: res[1], 66 | recall: res[2] 67 | } 68 | end 69 | 70 | # TODO support options 71 | def quantize 72 | a = Ext::Args.new 73 | m.quantize(a) 74 | end 75 | 76 | def labels(include_freq: false) 77 | labels, freqs = m.labels 78 | labels.map! { |v| remove_prefix(v) } 79 | if include_freq 80 | labels.zip(freqs).to_h 81 | else 82 | labels 83 | end 84 | end 85 | 86 | private 87 | 88 | def input_path(x, y) 89 | if x.is_a?(String) 90 | raise ArgumentError, "Cannot pass y with file" if y 91 | [x, nil] 92 | else 93 | tempfile = Tempfile.new("fasttext") 94 | x.zip(y) do |xi, yi| 95 | parts = Array(yi).map { |label| "__label__" + label } 96 | parts << xi.gsub("\n", " ") # replace newlines in document 97 | tempfile.write(parts.join(" ")) 98 | tempfile.write("\n") 99 | end 100 | tempfile.close 101 | [tempfile.path, tempfile] 102 | end 103 | end 104 | 105 | def remove_prefix(label) 106 | label.sub(label_prefix, "") 107 | end 108 | 109 | def label_prefix 110 | m.label_prefix 111 | end 112 | end 113 | end 114 | -------------------------------------------------------------------------------- /test/classifier_test.rb: -------------------------------------------------------------------------------- 1 | require_relative "test_helper" 2 | 3 | class ClassifierTest < Minitest::Test 4 | def test_works 5 | x = [ 6 | "This is the first document", 7 | "Hello, this is the second document", 8 | "Hello, and this is the third one", 9 | "Is this the first document?" 10 | ] 11 | y = ["ham", "spam", "spam", "ham"] 12 | 13 | model = FastText::Classifier.new(seed: 123, thread: 1) 14 | model.fit(x, y) 15 | 16 | assert_equal 100, model.dimension 17 | assert !model.quantized? 18 | assert_equal 14, model.words.size 19 | assert_equal model.words, model.words(include_freq: true).keys 20 | assert_equal ["ham", "spam"].sort, model.labels.sort 21 | assert_equal model.labels, model.labels(include_freq: true).keys 22 | 23 | assert model.word_id("first") 24 | assert model.subwords("first") 25 | assert model.word_vector("first") 26 | assert model.sentence_vector("first document") 27 | 28 | assert model.predict("First document") 29 | assert model.predict(["First document", "Second document"], k: 3) 30 | 31 | pred = model.predict("First document").first 32 | assert_equal "ham", pred[0] 33 | assert_in_delta 0.50003284, pred[1] 34 | 35 | result = model.test(x, y) 36 | assert_equal 4, result[:examples] 37 | assert_equal 1.0, result[:precision] 38 | assert_equal 1.0, result[:recall] 39 | 40 | model_path = "#{Dir.mktmpdir}/model.bin" 41 | model.save_model(model_path) 42 | model = FastText.load_model(model_path) 43 | 44 | assert_equal 100, model.dimension 45 | assert !model.quantized? 46 | assert_equal 14, model.words.size 47 | 48 | # takes a while 49 | # model.quantize 50 | # model.save_model("#{Dir.mktmpdir}/model.ftz") 51 | # assert model.quantized? 52 | end 53 | 54 | def test_input_file 55 | model = FastText::Classifier.new 56 | model.fit("test/support/supervised.txt") 57 | 58 | assert_equal 100, model.dimension 59 | assert !model.quantized? 60 | assert_equal 14, model.words.size 61 | assert_equal model.words, model.words(include_freq: true).keys 62 | assert_equal ["ham", "spam"].sort, model.labels.sort 63 | assert_equal model.labels, model.labels(include_freq: true).keys 64 | end 65 | 66 | def test_autotune 67 | skip "Takes too much memory" if ci? 68 | 69 | x = [ 70 | "This is the first document", 71 | "Hello, this is the second document", 72 | "Hello, and this is the third one", 73 | "Is this the first document?" 74 | ] 75 | y = ["ham", "spam", "spam", "ham"] 76 | 77 | model = FastText::Classifier.new(autotune_duration: 2) 78 | model.fit(x, y, autotune_set: [x, y]) 79 | end 80 | 81 | def test_autotune_file 82 | skip "Takes too much memory" if ci? 83 | 84 | model = FastText::Classifier.new(autotune_duration: 2) 85 | model.fit("test/support/supervised.txt", autotune_set: "test/support/supervised.txt") 86 | end 87 | 88 | def test_train_supervised 89 | model = FastText.train_supervised(input: "test/support/supervised.txt") 90 | 91 | assert_equal 100, model.dimension 92 | assert !model.quantized? 93 | assert_equal 14, model.words.size 94 | assert_equal model.words, model.words(include_freq: true).keys 95 | assert_equal ["ham", "spam"].sort, model.labels.sort 96 | assert_equal model.labels, model.labels(include_freq: true).keys 97 | end 98 | 99 | def test_train_supervised_autotune 100 | skip "Takes too much memory" if ci? 101 | 102 | FastText.train_supervised( 103 | input: "test/support/supervised.txt", 104 | autotune_validation_file: "test/support/supervised.txt", 105 | autotune_duration: 2 106 | ) 107 | end 108 | 109 | def test_untrained 110 | model = FastText::Classifier.new 111 | error = assert_raises FastText::Error do 112 | model.dimension 113 | end 114 | assert_equal "Not fit", error.message 115 | end 116 | 117 | def test_language_identification 118 | skip "Don't want to include lid.176 with project" 119 | 120 | model = FastText.load_model("path/to/lid.176.ftz") 121 | assert_equal ["fr"], model.predict("bon appétit").keys 122 | end 123 | 124 | def ci? 125 | ENV["CI"] 126 | end 127 | end 128 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fastText Ruby 2 | 3 | [fastText](https://fasttext.cc) - efficient text classification and representation learning - for Ruby 4 | 5 | [![Build Status](https://github.com/ankane/fastText-ruby/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/fastText-ruby/actions) 6 | 7 | ## Installation 8 | 9 | Add this line to your application’s Gemfile: 10 | 11 | ```ruby 12 | gem "fasttext" 13 | ``` 14 | 15 | ## Getting Started 16 | 17 | fastText has two primary use cases: 18 | 19 | - [text classification](#text-classification) 20 | - [word representations](#word-representations) 21 | 22 | ## Text Classification 23 | 24 | Prep your data 25 | 26 | ```ruby 27 | # documents 28 | x = [ 29 | "text from document one", 30 | "text from document two", 31 | "text from document three" 32 | ] 33 | 34 | # labels 35 | y = ["ham", "ham", "spam"] 36 | ``` 37 | 38 | > Use an array if a document has multiple labels 39 | 40 | Train a model 41 | 42 | ```ruby 43 | model = FastText::Classifier.new 44 | model.fit(x, y) 45 | ``` 46 | 47 | Get predictions 48 | 49 | ```ruby 50 | model.predict(x) 51 | ``` 52 | 53 | Save the model to a file 54 | 55 | ```ruby 56 | model.save_model("model.bin") 57 | ``` 58 | 59 | Load the model from a file 60 | 61 | ```ruby 62 | model = FastText.load_model("model.bin") 63 | ``` 64 | 65 | Evaluate the model 66 | 67 | ```ruby 68 | model.test(x_test, y_test) 69 | ``` 70 | 71 | Get words and labels 72 | 73 | ```ruby 74 | model.words 75 | model.labels 76 | ``` 77 | 78 | > Use `include_freq: true` to get their frequency 79 | 80 | Search for the best hyperparameters 81 | 82 | ```ruby 83 | model.fit(x, y, autotune_set: [x_valid, y_valid]) 84 | ``` 85 | 86 | Compress the model - significantly reduces size but sacrifices a little performance 87 | 88 | ```ruby 89 | model.quantize 90 | model.save_model("model.ftz") 91 | ``` 92 | 93 | ## Word Representations 94 | 95 | Prep your data 96 | 97 | ```ruby 98 | x = [ 99 | "text from document one", 100 | "text from document two", 101 | "text from document three" 102 | ] 103 | ``` 104 | 105 | Train a model 106 | 107 | ```ruby 108 | model = FastText::Vectorizer.new 109 | model.fit(x) 110 | ``` 111 | 112 | Get nearest neighbors 113 | 114 | ```ruby 115 | model.nearest_neighbors("asparagus") 116 | ``` 117 | 118 | Get analogies 119 | 120 | ```ruby 121 | model.analogies("berlin", "germany", "france") 122 | ``` 123 | 124 | Get a word vector 125 | 126 | ```ruby 127 | model.word_vector("carrot") 128 | ``` 129 | 130 | Get a sentence vector 131 | 132 | ```ruby 133 | model.sentence_vector("sentence text") 134 | ``` 135 | 136 | Get words 137 | 138 | ```ruby 139 | model.words 140 | ``` 141 | 142 | Save the model to a file 143 | 144 | ```ruby 145 | model.save_model("model.bin") 146 | ``` 147 | 148 | Load the model from a file 149 | 150 | ```ruby 151 | model = FastText.load_model("model.bin") 152 | ``` 153 | 154 | Use continuous bag-of-words 155 | 156 | ```ruby 157 | model = FastText::Vectorizer.new(model: "cbow") 158 | ``` 159 | 160 | ## Parameters 161 | 162 | Text classification 163 | 164 | ```ruby 165 | FastText::Classifier.new( 166 | lr: 0.1, # learning rate 167 | dim: 100, # size of word vectors 168 | ws: 5, # size of the context window 169 | epoch: 5, # number of epochs 170 | min_count: 1, # minimal number of word occurrences 171 | min_count_label: 1, # minimal number of label occurrences 172 | minn: 0, # min length of char ngram 173 | maxn: 0, # max length of char ngram 174 | neg: 5, # number of negatives sampled 175 | word_ngrams: 1, # max length of word ngram 176 | loss: "softmax", # loss function {ns, hs, softmax, ova} 177 | bucket: 2000000, # number of buckets 178 | thread: 3, # number of threads 179 | lr_update_rate: 100, # change the rate of updates for the learning rate 180 | t: 0.0001, # sampling threshold 181 | label_prefix: "__label__", # label prefix 182 | verbose: 2, # verbose 183 | pretrained_vectors: nil, # pretrained word vectors (.vec file) 184 | autotune_metric: "f1", # autotune optimization metric 185 | autotune_predictions: 1, # autotune predictions 186 | autotune_duration: 300, # autotune search time in seconds 187 | autotune_model_size: nil # autotune model size, like 2M 188 | ) 189 | ``` 190 | 191 | Word representations 192 | 193 | ```ruby 194 | FastText::Vectorizer.new( 195 | model: "skipgram", # unsupervised fasttext model {cbow, skipgram} 196 | lr: 0.05, # learning rate 197 | dim: 100, # size of word vectors 198 | ws: 5, # size of the context window 199 | epoch: 5, # number of epochs 200 | min_count: 5, # minimal number of word occurrences 201 | minn: 3, # min length of char ngram 202 | maxn: 6, # max length of char ngram 203 | neg: 5, # number of negatives sampled 204 | word_ngrams: 1, # max length of word ngram 205 | loss: "ns", # loss function {ns, hs, softmax, ova} 206 | bucket: 2000000, # number of buckets 207 | thread: 3, # number of threads 208 | lr_update_rate: 100, # change the rate of updates for the learning rate 209 | t: 0.0001, # sampling threshold 210 | verbose: 2 # verbose 211 | ) 212 | ``` 213 | 214 | ## Input Files 215 | 216 | Input can be read directly from files 217 | 218 | ```ruby 219 | model.fit("train.txt", autotune_set: "valid.txt") 220 | model.test("test.txt") 221 | ``` 222 | 223 | Each line should be a document 224 | 225 | ```txt 226 | text from document one 227 | text from document two 228 | text from document three 229 | ``` 230 | 231 | For text classification, lines should start with a list of labels prefixed with `__label__` 232 | 233 | ```txt 234 | __label__ham text from document one 235 | __label__ham text from document two 236 | __label__spam text from document three 237 | ``` 238 | 239 | ## Pretrained Models 240 | 241 | There are a number of [pretrained models](https://fasttext.cc/docs/en/supervised-models.html) you can download 242 | 243 | ### Language Identification 244 | 245 | Download one of the [pretrained models](https://fasttext.cc/docs/en/language-identification.html) and load it 246 | 247 | ```ruby 248 | model = FastText.load_model("lid.176.ftz") 249 | ``` 250 | 251 | Get language predictions 252 | 253 | ```ruby 254 | model.predict("bon appétit") 255 | ``` 256 | 257 | ## History 258 | 259 | View the [changelog](https://github.com/ankane/fastText-ruby/blob/master/CHANGELOG.md) 260 | 261 | ## Contributing 262 | 263 | Everyone is encouraged to help improve this project. Here are a few ways you can help: 264 | 265 | - [Report bugs](https://github.com/ankane/fastText-ruby/issues) 266 | - Fix bugs and [submit pull requests](https://github.com/ankane/fastText-ruby/pulls) 267 | - Write, clarify, or fix documentation 268 | - Suggest or add new features 269 | 270 | To get started with development: 271 | 272 | ```sh 273 | git clone --recursive https://github.com/ankane/fastText-ruby.git 274 | cd fastText-ruby 275 | bundle install 276 | bundle exec rake compile 277 | bundle exec rake test 278 | ``` 279 | -------------------------------------------------------------------------------- /ext/fasttext/ext.cpp: -------------------------------------------------------------------------------- 1 | // stdlib 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | // fasttext 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | // rice 20 | #include 21 | #include 22 | 23 | using fasttext::Args; 24 | using fasttext::FastText; 25 | 26 | using Rice::Array; 27 | 28 | namespace Rice::detail { 29 | template<> 30 | class To_Ruby>> { 31 | public: 32 | explicit To_Ruby(Arg* arg) : arg_(arg) { } 33 | 34 | VALUE convert(std::vector> const & x) { 35 | Array ret; 36 | for (const auto& v : x) { 37 | Array a; 38 | a.push(v.first, false); 39 | a.push(v.second, false); 40 | ret.push(a, false); 41 | } 42 | return ret; 43 | } 44 | 45 | private: 46 | Arg* arg_ = nullptr; 47 | }; 48 | } // namespace Rice::detail 49 | 50 | extern "C" 51 | void Init_ext() { 52 | Rice::Module rb_mFastText = Rice::define_module("FastText"); 53 | Rice::Module rb_mExt = Rice::define_module_under(rb_mFastText, "Ext"); 54 | 55 | Rice::define_class_under(rb_mExt, "Args") 56 | .define_constructor(Rice::Constructor()) 57 | .define_attr("input", &Args::input) 58 | .define_attr("output", &Args::output) 59 | .define_attr("lr", &Args::lr) 60 | .define_attr("lr_update_rate", &Args::lrUpdateRate) 61 | .define_attr("dim", &Args::dim) 62 | .define_attr("ws", &Args::ws) 63 | .define_attr("epoch", &Args::epoch) 64 | .define_attr("min_count", &Args::minCount) 65 | .define_attr("min_count_label", &Args::minCountLabel) 66 | .define_attr("neg", &Args::neg) 67 | .define_attr("word_ngrams", &Args::wordNgrams) 68 | .define_method( 69 | "loss=", 70 | [](Args& a, const std::string& str) { 71 | if (str == "softmax") { 72 | a.loss = fasttext::loss_name::softmax; 73 | } else if (str == "ns") { 74 | a.loss = fasttext::loss_name::ns; 75 | } else if (str == "hs") { 76 | a.loss = fasttext::loss_name::hs; 77 | } else if (str == "ova") { 78 | a.loss = fasttext::loss_name::ova; 79 | } else { 80 | throw std::invalid_argument("Unknown loss: " + str); 81 | } 82 | }) 83 | .define_method( 84 | "model=", 85 | [](Args& a, const std::string& str) { 86 | if (str == "supervised") { 87 | a.model = fasttext::model_name::sup; 88 | } else if (str == "skipgram") { 89 | a.model = fasttext::model_name::sg; 90 | } else if (str == "cbow") { 91 | a.model = fasttext::model_name::cbow; 92 | } else { 93 | throw std::invalid_argument("Unknown model: " + str); 94 | } 95 | }) 96 | .define_attr("bucket", &Args::bucket) 97 | .define_attr("minn", &Args::minn) 98 | .define_attr("maxn", &Args::maxn) 99 | .define_attr("thread", &Args::thread) 100 | .define_attr("t", &Args::t) 101 | .define_attr("label_prefix", &Args::label) 102 | .define_attr("verbose", &Args::verbose) 103 | .define_attr("pretrained_vectors", &Args::pretrainedVectors) 104 | .define_attr("save_output", &Args::saveOutput) 105 | .define_attr("seed", &Args::seed) 106 | .define_attr("autotune_validation_file", &Args::autotuneValidationFile) 107 | .define_attr("autotune_metric", &Args::autotuneMetric) 108 | .define_attr("autotune_predictions", &Args::autotunePredictions) 109 | .define_attr("autotune_duration", &Args::autotuneDuration) 110 | .define_attr("autotune_model_size", &Args::autotuneModelSize); 111 | 112 | Rice::define_class_under(rb_mExt, "Model") 113 | .define_constructor(Rice::Constructor()) 114 | .define_method( 115 | "words", 116 | [](FastText& m) { 117 | std::shared_ptr d = m.getDictionary(); 118 | std::vector freq = d->getCounts(fasttext::entry_type::word); 119 | 120 | Array vocab_list; 121 | Array vocab_freq; 122 | for (int32_t i = 0; i < d->nwords(); i++) { 123 | vocab_list.push(d->getWord(i), false); 124 | vocab_freq.push(freq[i], false); 125 | } 126 | 127 | Array ret; 128 | ret.push(vocab_list, false); 129 | ret.push(vocab_freq, false); 130 | return ret; 131 | }) 132 | .define_method( 133 | "labels", 134 | [](FastText& m) { 135 | std::shared_ptr d = m.getDictionary(); 136 | std::vector freq = d->getCounts(fasttext::entry_type::label); 137 | 138 | Array vocab_list; 139 | Array vocab_freq; 140 | for (int32_t i = 0; i < d->nlabels(); i++) { 141 | vocab_list.push(d->getLabel(i), false); 142 | vocab_freq.push(freq[i], false); 143 | } 144 | 145 | Array ret; 146 | ret.push(vocab_list, false); 147 | ret.push(vocab_freq, false); 148 | return ret; 149 | }) 150 | .define_method( 151 | "test", 152 | [](FastText& m, const std::string& filename, int32_t k) { 153 | std::ifstream ifs(filename); 154 | if (!ifs.is_open()) { 155 | throw std::invalid_argument("Test file cannot be opened!"); 156 | } 157 | fasttext::Meter meter(false); 158 | m.test(ifs, k, 0.0, meter); 159 | ifs.close(); 160 | 161 | Array ret; 162 | ret.push(meter.nexamples(), false); 163 | ret.push(meter.precision(), false); 164 | ret.push(meter.recall(), false); 165 | return ret; 166 | }) 167 | .define_method( 168 | "load_model", 169 | [](FastText& m, const std::string& s) { 170 | m.loadModel(s); 171 | }) 172 | .define_method( 173 | "save_model", 174 | [](FastText& m, const std::string& s) { 175 | m.saveModel(s); 176 | }) 177 | .define_method("dimension", &FastText::getDimension) 178 | .define_method("quantized?", &FastText::isQuant) 179 | .define_method("word_id", &FastText::getWordId) 180 | .define_method("subword_id", &FastText::getSubwordId) 181 | .define_method( 182 | "predict", 183 | [](FastText& m, const std::string& text, int32_t k, float threshold) { 184 | std::stringstream ioss(text); 185 | std::vector> predictions; 186 | m.predictLine(ioss, predictions, k, threshold); 187 | return predictions; 188 | }) 189 | .define_method( 190 | "nearest_neighbors", 191 | [](FastText& m, const std::string& word, int32_t k) { 192 | return m.getNN(word, k); 193 | }) 194 | .define_method("analogies", &FastText::getAnalogies) 195 | // .define_method("ngram_vectors", &FastText::getNgramVectors) 196 | .define_method( 197 | "word_vector", 198 | [](FastText& m, const std::string& word) { 199 | auto dimension = m.getDimension(); 200 | fasttext::Vector vec = fasttext::Vector(dimension); 201 | m.getWordVector(vec, word); 202 | Array ret; 203 | for (size_t i = 0; i < vec.size(); i++) { 204 | ret.push(vec[i], false); 205 | } 206 | return ret; 207 | }) 208 | .define_method( 209 | "subwords", 210 | [](FastText& m, const std::string& word) { 211 | std::vector subwords; 212 | std::vector ngrams; 213 | std::shared_ptr d = m.getDictionary(); 214 | d->getSubwords(word, ngrams, subwords); 215 | 216 | Array ret; 217 | for (const auto& subword : subwords) { 218 | ret.push(subword, false); 219 | } 220 | return ret; 221 | }) 222 | .define_method( 223 | "sentence_vector", 224 | [](FastText& m, const std::string& text) { 225 | std::istringstream in(text); 226 | auto dimension = m.getDimension(); 227 | fasttext::Vector vec = fasttext::Vector(dimension); 228 | m.getSentenceVector(in, vec); 229 | Array ret; 230 | for (size_t i = 0; i < vec.size(); i++) { 231 | ret.push(vec[i], false); 232 | } 233 | return ret; 234 | }) 235 | .define_method( 236 | "train", 237 | [](FastText& m, Args& a) { 238 | if (a.hasAutotune()) { 239 | fasttext::Autotune autotune(std::shared_ptr(&m, [](fasttext::FastText*) {})); 240 | autotune.train(a); 241 | } else { 242 | m.train(a); 243 | } 244 | }) 245 | .define_method( 246 | "quantize", 247 | [](FastText& m, Args& a) { 248 | m.quantize(a); 249 | }) 250 | .define_method( 251 | "supervised?", 252 | [](FastText& m) { 253 | return m.getArgs().model == fasttext::model_name::sup; 254 | }) 255 | .define_method( 256 | "label_prefix", 257 | [](FastText& m) { 258 | return m.getArgs().label; 259 | }); 260 | } 261 | --------------------------------------------------------------------------------