├── .gitignore ├── testdata ├── chunking.txt └── dep_chunking.txt ├── .gitattributes ├── man ├── Makefile ├── finalfrontier.1.md ├── finalfrontier.1 ├── finalfrontier-skipgram.1.md ├── finalfrontier-deps.1.md ├── finalfrontier-skipgram.1 └── finalfrontier-deps.1 ├── src ├── subcommands │ ├── traits.rs │ ├── mod.rs │ ├── progress.rs │ ├── skipgram.rs │ └── deps.rs ├── lib.rs ├── main.rs ├── idx.rs ├── hogwild.rs ├── util.rs ├── loss.rs ├── vocab │ ├── simple.rs │ ├── mod.rs │ └── subword.rs ├── dep_trainer.rs ├── sampling.rs ├── sgd.rs ├── skipgram_trainer.rs ├── config.rs ├── io.rs ├── opts.rs └── train_model.rs ├── docs ├── INSTALL.md └── QUICKSTART.md ├── COPYRIGHT.md ├── LICENSE-MIT ├── Cargo.toml ├── .github └── workflows │ ├── rust.yml │ └── release.yml ├── benches └── dot_product.rs ├── README.md └── LICENSE-APACHE /.gitignore: -------------------------------------------------------------------------------- 1 | # Editor files 2 | .* 3 | *~ 4 | 5 | # Rust files 6 | /target 7 | **/*.rs.bk 8 | -------------------------------------------------------------------------------- /testdata/chunking.txt: -------------------------------------------------------------------------------- 1 | a b c 2 | d e f 3 | g h i 4 | j k l 5 | m n o 6 | p q r 7 | s t u 8 | v w x 9 | y z 10 | -------------------------------------------------------------------------------- /testdata/dep_chunking.txt: -------------------------------------------------------------------------------- 1 | a b c 2 | d e f 3 | 4 | g h i 5 | j k l 6 | 7 | m n o 8 | p q r 9 | 10 | s t u 11 | v w x 12 | y z 13 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Let Git decide which files are text files. 2 | * text=auto 3 | 4 | # Avoid that unit test files are converted to CR-LF on Windows. 5 | *.txt -text -------------------------------------------------------------------------------- /man/Makefile: -------------------------------------------------------------------------------- 1 | all: finalfrontier.1 finalfrontier-skipgram.1 finalfrontier-deps.1 2 | 3 | clean: 4 | rm -f *.1 *.5 5 | 6 | %: %.md 7 | pandoc -s -w man -o $@ $< 8 | -------------------------------------------------------------------------------- /src/subcommands/traits.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | 3 | pub trait FinalfrontierApp 4 | where 5 | Self: Sized, 6 | { 7 | fn run(&self) -> Result<()>; 8 | } 9 | -------------------------------------------------------------------------------- /src/subcommands/mod.rs: -------------------------------------------------------------------------------- 1 | mod deps; 2 | pub use self::deps::DepsApp; 3 | 4 | mod progress; 5 | pub use self::progress::show_progress; 6 | 7 | mod skipgram; 8 | pub use self::skipgram::SkipgramApp; 9 | 10 | mod traits; 11 | pub use self::traits::FinalfrontierApp; 12 | -------------------------------------------------------------------------------- /docs/INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## From source 4 | 5 | There are currently no pre-compiled finalfrontier binaries. Compilation 6 | requires a working [Rust](https://www.rust-lang.org/) toolchain. After 7 | installing Rust, you can compile finalfrontier using cargo: 8 | 9 | ~~~shell 10 | $ cargo install finalfrontier 11 | ~~~ 12 | 13 | Afterwards, the `finalfrontier` binary is available in your 14 | `~/.cargo/bin`. 15 | -------------------------------------------------------------------------------- /COPYRIGHT.md: -------------------------------------------------------------------------------- 1 | ## finalfrontier 2 | 3 | Copyright 2018-2021 The finalfrontier contributors 4 | 5 | Licensed under the [Apache License, Version 6 | 2.0](http://www.apache.org/licenses/LICENSE-2.0) or the [MIT 7 | license](http://opensource.org/licenses/MIT), at your option. 8 | 9 | Contributors: 10 | 11 | * Daniël de Kok 12 | * Sebastian Pütz 13 | * Nianheng Wu 14 | -------------------------------------------------------------------------------- /docs/QUICKSTART.md: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | 3 | Train a model with 300-dimensional word embeddings, the structured skip-gram 4 | model, discarding words that occur fewer than 10 times: 5 | 6 | 7 | finalfrontier skipgram --dims 300 --model structgram --epochs 10 --mincount 10 \ 8 | --threads 16 corpus.txt corpus-embeddings.fifu 9 | 10 | The format of the input file is simple: tokens are separated by spaces, 11 | sentences by newlines (`\n`). 12 | 13 | After training, you can use and query the embeddings with 14 | [finalfusion](https://github.com/finalfusion/finalfusion-rust) and 15 | `finalfusion-utils`: 16 | 17 | finalfusion similar corpus-embeddings.fifu 18 | -------------------------------------------------------------------------------- /man/finalfrontier.1.md: -------------------------------------------------------------------------------- 1 | % FINALFRONTIER(1) 2 | % Daniel de Kok 3 | % Nov 1, 2019 4 | 5 | NAME 6 | ==== 7 | 8 | **finalfrontier** -- train finalfusion word embeddings 9 | 10 | SYNOPSIS 11 | ======== 12 | 13 | **finalfrontier** *command* [*options*] [*args*] 14 | **finalfrontier** completions *shell* 15 | **finalfrontier** help *command* 16 | 17 | DESCRIPTION 18 | =========== 19 | 20 | finalfrontier is a utility for training finalfusion word embeddings. 21 | 22 | COMMANDS 23 | ======== 24 | 25 | `finalfrontier-deps`(1) 26 | 27 | : Train word embeddings using the dependency model (Levy & Goldberg, 2014) 28 | 29 | `finalfrontier-skipgram`(1) 30 | 31 | : Train word embeddings using the skipgram model (Mikolov et al, 2013) 32 | 33 | SEE ALSO 34 | ======== 35 | 36 | `finalfrontier-deps`(1), `finalfrontier-skipgram`(1) 37 | -------------------------------------------------------------------------------- /man/finalfrontier.1: -------------------------------------------------------------------------------- 1 | .\" Automatically generated by Pandoc 2.7.3 2 | .\" 3 | .TH "FINALFRONTIER" "1" "Nov 1, 2019" "" "" 4 | .hy 5 | .SH NAME 6 | .PP 7 | \f[B]finalfrontier\f[R] \[en] train finalfusion word embeddings 8 | .SH SYNOPSIS 9 | .PP 10 | \f[B]finalfrontier\f[R] \f[I]command\f[R] [\f[I]options\f[R]] 11 | [\f[I]args\f[R]] 12 | .PD 0 13 | .P 14 | .PD 15 | \f[B]finalfrontier\f[R] completions \f[I]shell\f[R] 16 | .PD 0 17 | .P 18 | .PD 19 | \f[B]finalfrontier\f[R] help \f[I]command\f[R] 20 | .SH DESCRIPTION 21 | .PP 22 | finalfrontier is a utility for training finalfusion word embeddings. 23 | .SH COMMANDS 24 | .TP 25 | .B \f[C]finalfrontier-deps\f[R](1) 26 | Train word embeddings using the dependency model (Levy & Goldberg, 2014) 27 | .TP 28 | .B \f[C]finalfrontier-skipgram\f[R](1) 29 | Train word embeddings using the skipgram model (Mikolov et al, 2013) 30 | .SH SEE ALSO 31 | .PP 32 | \f[C]finalfrontier-deps\f[R](1), \f[C]finalfrontier-skipgram\f[R](1) 33 | .SH AUTHORS 34 | Daniel de Kok. 35 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "finalfrontier" 3 | version = "0.10.0" 4 | edition = "2021" 5 | description = "Train/use word embeddings with subword units" 6 | documentation = "https://docs.rs/finalfrontier/" 7 | homepage = "https://finalfusion.github.io/finalfrontier" 8 | repository = "https://github.com/finalfusion/finalfrontier.git" 9 | rust-version = "1.70" 10 | license = "MIT OR Apache-2.0" 11 | 12 | [dependencies] 13 | anyhow = "1" 14 | chrono = "0.4" 15 | clap = { version = "4", features = ["derive"] } 16 | clap_complete = "4" 17 | conllu = "0.8" 18 | finalfusion = "0.18" 19 | fnv = "1" 20 | indicatif = "0.17" 21 | memmap = "0.7" 22 | ndarray = "0.15" 23 | ndarray-rand = "0.14" 24 | num_cpus = "1" 25 | rand = "0.8" 26 | rand_core = "0.6" 27 | rand_xorshift = "0.3" 28 | serde = { version = "1", features = ["derive"] } 29 | superslice = "1" 30 | toml = "0.8" 31 | udgraph = "0.8" 32 | udgraph-projectivize = "0.8" 33 | 34 | [build-dependencies] 35 | git2 = "0.18" 36 | rustversion = "1" 37 | 38 | [dev-dependencies] 39 | criterion = "0.5" 40 | lazy_static = "1" 41 | maplit = "1" 42 | 43 | [[bench]] 44 | name = "dot_product" 45 | harness = false 46 | -------------------------------------------------------------------------------- /src/subcommands/progress.rs: -------------------------------------------------------------------------------- 1 | use std::thread; 2 | use std::time::Duration; 3 | 4 | use finalfrontier::{CommonConfig, Sgd, Trainer, Vocab}; 5 | use indicatif::{ProgressBar, ProgressStyle}; 6 | 7 | pub fn show_progress(config: &CommonConfig, sgd: &Sgd, update_interval: Duration) 8 | where 9 | T: Trainer, 10 | V: Vocab, 11 | { 12 | let n_tokens = sgd.model().input_vocab().n_types(); 13 | 14 | let pb = ProgressBar::new(u64::from(config.epochs) * n_tokens as u64); 15 | let style = ProgressStyle::default_bar() 16 | .template("{bar:30} {percent}% {msg} ETA: {eta_precise}") 17 | .expect("template string expected is to be valid"); 18 | pb.set_style(style); 19 | 20 | while sgd.n_tokens_processed() < n_tokens * config.epochs as usize { 21 | let lr = (1.0 22 | - (sgd.n_tokens_processed() as f32 / (config.epochs as usize * n_tokens) as f32)) 23 | * config.lr; 24 | 25 | pb.set_position(sgd.n_tokens_processed() as u64); 26 | pb.set_message(format!("loss: {:.*} lr: {:.*}", 5, sgd.train_loss(), 5, lr)); 27 | 28 | thread::sleep(update_interval); 29 | } 30 | 31 | pb.finish(); 32 | } 33 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | fmt: 7 | name: Rustfmt 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v1 11 | - uses: actions-rs/toolchain@v1 12 | with: 13 | profile: minimal 14 | toolchain: stable 15 | override: true 16 | - run: rustup component add rustfmt 17 | - uses: actions-rs/cargo@v1 18 | with: 19 | command: fmt 20 | args: --all -- --check 21 | 22 | tests: 23 | name: Test 24 | runs-on: ubuntu-latest 25 | steps: 26 | - uses: actions/checkout@v1 27 | - uses: actions/cache@v2 28 | with: 29 | path: | 30 | ~/.cargo/registry 31 | ~/.cargo/git 32 | target 33 | key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 34 | - uses: actions-rs/toolchain@v1 35 | with: 36 | profile: minimal 37 | toolchain: stable 38 | override: true 39 | components: clippy 40 | - uses: actions-rs/cargo@v1 41 | name: Clippy 42 | with: 43 | command: clippy 44 | args: -- -D warnings 45 | - uses: actions-rs/cargo@v1 46 | name: Test 47 | with: 48 | command: test 49 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg_attr(feature = "neon", feature(aarch64_target_feature))] 2 | #![cfg_attr(feature = "neon", feature(stdsimd))] 3 | 4 | mod config; 5 | pub use crate::config::{ 6 | BucketConfig, BucketIndexerType, CommonConfig, DepembedsConfig, FloretConfig, LossType, 7 | NGramConfig, SimpleVocabConfig, SkipGramConfig, SkipgramModelType, SubwordVocabConfig, 8 | VocabConfig, 9 | }; 10 | 11 | mod deps; 12 | pub use crate::deps::{DepIter, Dependency, DependencyIterator}; 13 | 14 | pub(crate) mod dep_trainer; 15 | pub use crate::dep_trainer::DepembedsTrainer; 16 | 17 | pub(crate) mod hogwild; 18 | 19 | pub mod idx; 20 | 21 | pub mod io; 22 | pub use io::{SentenceIterator, WriteModelBinary, WriteModelText, WriteModelWord2Vec}; 23 | 24 | pub(crate) mod loss; 25 | 26 | mod opts; 27 | pub use opts::{ 28 | BucketIndexerArg, EmbeddingFormatArg, ModelConfig, ModelSubcommand, Opts, TrainOpts, 29 | VocabSubcommand, 30 | }; 31 | 32 | pub(crate) mod sampling; 33 | 34 | mod sgd; 35 | pub use crate::sgd::Sgd; 36 | 37 | mod train_model; 38 | pub use crate::train_model::{TrainModel, Trainer}; 39 | 40 | pub(crate) mod skipgram_trainer; 41 | pub use crate::skipgram_trainer::SkipgramTrainer; 42 | 43 | pub(crate) mod util; 44 | 45 | #[doc(hidden)] 46 | pub mod vec_simd; 47 | 48 | mod vocab; 49 | pub use crate::vocab::{ 50 | simple::SimpleVocab, subword::SubwordVocab, CountedType, Cutoff, Vocab, VocabBuilder, Word, 51 | }; 52 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | jobs: 9 | release: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v1 13 | - name: Get release version 14 | run: echo ::set-env name=TAG_NAME::$(echo ${GITHUB_REF:10}) 15 | - uses: actions-rs/toolchain@v1 16 | with: 17 | profile: minimal 18 | toolchain: stable 19 | target: x86_64-unknown-linux-musl 20 | override: true 21 | - uses: actions-rs/cargo@v1 22 | with: 23 | command: build 24 | args: --release 25 | - name: Create release archive 26 | id: create_archive 27 | run: | 28 | ARCHIVE=finalfrontier-${TAG_NAME}-x86_64-unknown-linux-musl.tar.gz 29 | strip target/release/finalfrontier 30 | tar -czvf ${ARCHIVE} -C target/release finalfrontier 31 | echo ::set-output name=ASSET::$ARCHIVE 32 | - uses: actions/create-release@v1.0.0 33 | id: create_release 34 | env: 35 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 36 | with: 37 | tag_name: ${{ github.ref }} 38 | release_name: Release ${{ github.ref }} 39 | draft: true 40 | prerelease: false 41 | - uses: actions/upload-release-asset@v1.0.1 42 | env: 43 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 44 | with: 45 | upload_url: ${{ steps.create_release.outputs.upload_url }} 46 | asset_path: ${{ steps.create_archive.outputs.ASSET }} 47 | asset_name: ${{ steps.create_archive.outputs.ASSET }} 48 | asset_content_type: application/gzip 49 | -------------------------------------------------------------------------------- /benches/dot_product.rs: -------------------------------------------------------------------------------- 1 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 2 | use finalfrontier::vec_simd; 3 | use ndarray::Array1; 4 | use ndarray_rand::rand_distr::Normal; 5 | use ndarray_rand::RandomExt; 6 | 7 | const ARRAY_SIZE: usize = 512; 8 | 9 | fn random_array(n: usize) -> Array1 { 10 | Array1::random((n,), Normal::new(0.0, 0.5).unwrap()) 11 | } 12 | 13 | fn dot_avx(c: &mut Criterion) { 14 | let u = random_array(ARRAY_SIZE); 15 | let v = random_array(ARRAY_SIZE); 16 | c.bench_function("dot_avx", move |b| { 17 | b.iter(|| black_box(unsafe { vec_simd::avx::dot(u.view(), v.view()) })) 18 | }); 19 | } 20 | 21 | fn dot_fma(c: &mut Criterion) { 22 | let u = random_array(ARRAY_SIZE); 23 | let v = random_array(ARRAY_SIZE); 24 | c.bench_function("dot_fma", move |b| { 25 | b.iter(|| black_box(unsafe { vec_simd::avx_fma::dot(u.view(), v.view()) })) 26 | }); 27 | } 28 | 29 | fn dot_ndarray(c: &mut Criterion) { 30 | let u = random_array(ARRAY_SIZE); 31 | let v = random_array(ARRAY_SIZE); 32 | c.bench_function("dot_ndarray", move |b| b.iter(|| black_box(u.dot(&v)))); 33 | } 34 | 35 | fn dot_sse(c: &mut Criterion) { 36 | let u = random_array(ARRAY_SIZE); 37 | let v = random_array(ARRAY_SIZE); 38 | c.bench_function("dot_sse", move |b| { 39 | b.iter(|| black_box(unsafe { vec_simd::sse::dot(u.view(), v.view()) })) 40 | }); 41 | } 42 | 43 | fn dot_unvectorized(c: &mut Criterion) { 44 | let u = random_array(ARRAY_SIZE); 45 | let v = random_array(ARRAY_SIZE); 46 | c.bench_function("dot_unvectorized", move |b| { 47 | b.iter(|| { 48 | black_box(vec_simd::dot_unvectorized( 49 | u.as_slice().unwrap(), 50 | v.as_slice().unwrap(), 51 | )) 52 | }); 53 | }); 54 | } 55 | 56 | criterion_group!( 57 | benches, 58 | dot_avx, 59 | dot_fma, 60 | dot_ndarray, 61 | dot_sse, 62 | dot_unvectorized 63 | ); 64 | criterion_main!(benches); 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Crate](https://img.shields.io/crates/v/finalfrontier.svg)](https://crates.io/crates/finalfrontier) 2 | [![Docs](https://docs.rs/finalfrontier/badge.svg)](https://docs.rs/finalfrontier/) 3 | [![Build Status](https://travis-ci.org/finalfusion/finalfrontier.svg?branch=master)](https://travis-ci.org/finalfusion/finalfrontier) 4 | 5 | # finalfrontier 6 | 7 | ## Introduction 8 | 9 | finalfrontier is a Rust program for training word embeddings. 10 | finalfrontier currently has the following features: 11 | 12 | * Models: 13 | - skip-gram (Mikolov et al., 2013) 14 | - structured skip-gram (Ling et al., 2015) 15 | - directional skip-gram (Song et al., 2018) 16 | - dependency (Levy and Goldberg, 2014) 17 | * Output formats: 18 | - [finalfusion](https://finalfusion.github.io) 19 | - fastText 20 | - word2vec binary 21 | - word2vec text 22 | - GloVe text 23 | * Noise contrastive estimation (Gutmann and Hyvärinen, 2012) 24 | * Subword representations (Bojanowski et al., 2016) 25 | * Hogwild SGD (Recht et al., 2011) 26 | * Quantized embeddings through the [`finalfusion 27 | quantize`](https://github.com/finalfusion/finalfusion-utils) 28 | command. 29 | 30 | The trained embeddings can be stored in the versatile `finalfusion` 31 | format, which can be read and used with the 32 | [finalfusion](https://github.com/finalfusion/finalfusion-rust) crate 33 | and the 34 | [finalfusion](https://github.com/finalfusion/finalfusion-python) 35 | Python module. 36 | 37 | The minimum required Rust version is currently 1.70. 38 | 39 | ## Where to go from here 40 | 41 | * [Installation](docs/INSTALL.md) 42 | * [Quickstart](docs/QUICKSTART.md) 43 | * Manual pages: 44 | - [finalfrontier-skipgram(1)](man/finalfrontier-skipgram.1.md) — train word 45 | embeddings with the (structured) skip-gram model 46 | - [finalfrontier-deps(1)](man/finalfrontier-deps.1.md) — train word embeddings with dependency contexts 47 | * [finalfusion crate](https://github.com/finalfusion/finalfusion-rust) 48 | * [Python module](https://github.com/finalfusion/finalfusion-python) 49 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use std::io; 2 | 3 | use anyhow::{bail, Result}; 4 | use clap::{Command, CommandFactory, FromArgMatches}; 5 | use clap_complete::{generate, Generator}; 6 | use finalfrontier::{ 7 | BucketIndexerArg, EmbeddingFormatArg, ModelConfig, Opts, TrainOpts, VocabSubcommand, 8 | }; 9 | 10 | mod subcommands; 11 | use subcommands::{DepsApp, FinalfrontierApp, SkipgramApp}; 12 | 13 | const FASTTEXT_FORMAT_ERROR: &str = "Only embeddings trained with: 14 | 15 | train [skipgram/deps] buckets --hash-indexer fasttext 16 | 17 | can be stored in fastText format."; 18 | 19 | const FLORET_FORMAT_ERROR: &str = "Only embeddings trained with 20 | 21 | train [skipgram/deps] floret 22 | 23 | can be stored in floret format."; 24 | 25 | fn extra_opts_validation(train_opts: &TrainOpts) -> Result<()> { 26 | let vocab_args = train_opts.vocab_args(); 27 | 28 | if train_opts.common.format == EmbeddingFormatArg::FastText 29 | && !matches!( 30 | vocab_args, 31 | VocabSubcommand::Buckets { 32 | hash_indexer: BucketIndexerArg::FastText, 33 | .. 34 | } 35 | ) 36 | { 37 | bail!(FASTTEXT_FORMAT_ERROR); 38 | } 39 | 40 | if train_opts.common.format == EmbeddingFormatArg::Floret 41 | && !matches!(vocab_args, VocabSubcommand::Floret { .. }) 42 | { 43 | bail!(FLORET_FORMAT_ERROR); 44 | } 45 | 46 | if train_opts.common.format != EmbeddingFormatArg::Floret 47 | && matches!(vocab_args, VocabSubcommand::Floret { .. }) 48 | { 49 | bail!("Floret embeddings can only be written in the 'floret' format."); 50 | } 51 | 52 | Ok(()) 53 | } 54 | 55 | fn print_completions(gen: G, app: &mut Command) { 56 | generate(gen, app, app.get_name().to_string(), &mut io::stdout()); 57 | } 58 | 59 | /// Get features that will be used by SIMD code paths. 60 | #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] 61 | fn simd_features() -> Vec<&'static str> { 62 | let mut features = vec![]; 63 | 64 | if is_x86_feature_detected!("sse") { 65 | features.push("+sse"); 66 | } else { 67 | features.push("-sse"); 68 | } 69 | 70 | if is_x86_feature_detected!("avx") { 71 | features.push("+avx"); 72 | } else { 73 | features.push("-avx"); 74 | } 75 | 76 | if is_x86_feature_detected!("fma") { 77 | features.push("+fma"); 78 | } else { 79 | features.push("-fma"); 80 | } 81 | 82 | features 83 | } 84 | 85 | fn main() -> Result<()> { 86 | let mut app = Opts::command(); 87 | let matches = app.clone().get_matches(); 88 | let opts = Opts::from_arg_matches(&matches).unwrap(); 89 | 90 | let train_opts = match opts { 91 | Opts::Completions { shell } => { 92 | print_completions(shell, &mut app); 93 | std::process::exit(0) 94 | } 95 | Opts::Train(train_opts) => train_opts, 96 | }; 97 | 98 | extra_opts_validation(&train_opts)?; 99 | 100 | let config = train_opts.to_train_config(); 101 | 102 | #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] 103 | eprintln!("SIMD features: {}", simd_features().join(" ")); 104 | 105 | match config.model_config { 106 | ModelConfig::SkipGram(skipgram_config) => SkipgramApp::new( 107 | config.train_info, 108 | config.common_config, 109 | skipgram_config, 110 | config.vocab_config, 111 | ) 112 | .run(), 113 | ModelConfig::DepEmbeds(deps_config) => DepsApp::new( 114 | config.train_info, 115 | config.common_config, 116 | deps_config, 117 | config.vocab_config, 118 | ) 119 | .run(), 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /src/idx.rs: -------------------------------------------------------------------------------- 1 | use std::iter::FusedIterator; 2 | use std::{option, slice}; 3 | 4 | /// A single lookup index. 5 | #[derive(Copy, Clone)] 6 | pub struct SingleIdx { 7 | idx: u64, 8 | } 9 | 10 | impl SingleIdx { 11 | pub fn new(idx: u64) -> Self { 12 | SingleIdx { idx } 13 | } 14 | 15 | pub fn idx(self) -> u64 { 16 | self.idx 17 | } 18 | } 19 | 20 | /// A lookup index with associated subword indices. 21 | #[derive(Clone)] 22 | pub struct WordWithSubwordsIdx { 23 | word_idx: u64, 24 | subwords: Vec, 25 | } 26 | 27 | impl WordWithSubwordsIdx { 28 | pub fn new(word_idx: u64, subwords: impl Into>) -> Self { 29 | WordWithSubwordsIdx { 30 | word_idx, 31 | subwords: subwords.into(), 32 | } 33 | } 34 | } 35 | 36 | /// Vocabulary indexing trait. 37 | /// 38 | /// This trait defines methods shared by indexing types. 39 | pub trait WordIdx: Clone { 40 | /// Return the unique word index for the WordIdx. 41 | fn word_idx(&self) -> u64; 42 | 43 | /// Build a new WordIdx containing only a single index. 44 | fn from_word_idx(word_idx: u64) -> Self; 45 | 46 | /// Return the number of indices. 47 | fn len(&self) -> usize; 48 | 49 | /// Return whether this is empty. 50 | fn is_empty(&self) -> bool { 51 | self.len() == 0 52 | } 53 | } 54 | 55 | impl WordIdx for SingleIdx { 56 | fn word_idx(&self) -> u64 { 57 | self.idx 58 | } 59 | 60 | fn from_word_idx(word_idx: u64) -> Self { 61 | SingleIdx::new(word_idx) 62 | } 63 | 64 | fn len(&self) -> usize { 65 | 1 66 | } 67 | } 68 | 69 | impl<'a> IntoIterator for &'a SingleIdx { 70 | type Item = u64; 71 | type IntoIter = option::IntoIter; 72 | 73 | fn into_iter(self) -> Self::IntoIter { 74 | Some(self.idx).into_iter() 75 | } 76 | } 77 | 78 | impl WordIdx for WordWithSubwordsIdx { 79 | fn word_idx(&self) -> u64 { 80 | self.word_idx 81 | } 82 | 83 | fn from_word_idx(word_idx: u64) -> Self { 84 | WordWithSubwordsIdx { 85 | word_idx, 86 | subwords: Vec::new(), 87 | } 88 | } 89 | 90 | fn len(&self) -> usize { 91 | 1 + self.subwords.len() 92 | } 93 | } 94 | 95 | impl<'a> IntoIterator for &'a WordWithSubwordsIdx { 96 | type Item = u64; 97 | type IntoIter = IdxIter<'a>; 98 | 99 | fn into_iter(self) -> Self::IntoIter { 100 | IdxIter { 101 | word_idx: Some(self.word_idx), 102 | subwords: self.subwords.iter(), 103 | } 104 | } 105 | } 106 | 107 | /// Iterator over Indices. 108 | pub struct IdxIter<'a> { 109 | word_idx: Option, 110 | subwords: slice::Iter<'a, u64>, 111 | } 112 | 113 | impl<'a> Iterator for IdxIter<'a> { 114 | type Item = u64; 115 | 116 | fn next(&mut self) -> Option { 117 | if let Some(idx) = self.subwords.next() { 118 | Some(*idx) 119 | } else { 120 | self.word_idx.take() 121 | } 122 | } 123 | } 124 | 125 | impl<'a> FusedIterator for IdxIter<'a> {} 126 | 127 | #[cfg(test)] 128 | mod test { 129 | use crate::idx::{SingleIdx, WordIdx, WordWithSubwordsIdx}; 130 | 131 | #[test] 132 | fn test_idx_iter() { 133 | let with_subwords = WordWithSubwordsIdx::new(0, vec![24, 4, 42]); 134 | let mut idx_iter = (&with_subwords).into_iter(); 135 | assert_eq!(24, idx_iter.next().unwrap()); 136 | assert_eq!(4, idx_iter.next().unwrap()); 137 | assert_eq!(42, idx_iter.next().unwrap()); 138 | assert_eq!(0, idx_iter.next().unwrap()); 139 | assert_eq!(0, with_subwords.word_idx()); 140 | 141 | let single = SingleIdx::from_word_idx(0); 142 | let mut idx_iter = (&single).into_iter(); 143 | assert_eq!(0, idx_iter.next().unwrap()); 144 | assert_eq!(0, single.word_idx()); 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/hogwild.rs: -------------------------------------------------------------------------------- 1 | use std::cell::UnsafeCell; 2 | use std::ops::{Deref, DerefMut}; 3 | use std::sync::Arc; 4 | 5 | use ndarray::{Array, ArrayView, ArrayViewMut, Axis, Dimension, Ix, Ix2, RemoveAxis}; 6 | 7 | /// Array for Hogwild parallel optimization. 8 | /// 9 | /// This array type can be used for the Hogwild (Niu, et al. 2011) method 10 | /// of parallel Stochastic Gradient descent. In Hogwild different threads 11 | /// share the same parameters without locking. If SGD is performed on a 12 | /// sparse optimization problem, where only a small subset of parameters 13 | /// is updated in each gradient descent, the impact of data races is 14 | /// negligible. 15 | /// 16 | /// In order to use Hogwild in Rust, we have to subvert the ownership 17 | /// system. This is what the `HogwildArray` type does. It uses reference 18 | /// counting to share an *ndarray* `Array` type between multiple 19 | /// `HogwildArray` instances. Views of the underling `Array` can be borrowed 20 | /// mutably from each instance, without mutual exclusion between mutable 21 | /// borrows in different `HogwildArray` instances. 22 | #[derive(Clone)] 23 | pub struct HogwildArray(Arc>>); 24 | 25 | impl HogwildArray { 26 | #[inline] 27 | fn as_mut(&mut self) -> &mut Array { 28 | let ptr = self.0.as_ref().get(); 29 | unsafe { &mut *ptr } 30 | } 31 | 32 | #[inline] 33 | fn as_ref(&self) -> &Array { 34 | let ptr = self.0.as_ref().get(); 35 | unsafe { &*ptr } 36 | } 37 | 38 | pub fn into_inner(self) -> Arc>> { 39 | self.0 40 | } 41 | } 42 | 43 | impl HogwildArray 44 | where 45 | D: Dimension + RemoveAxis, 46 | { 47 | /// Get an immutable subview of the Hogwild array. 48 | #[inline] 49 | pub fn subview(&self, axis: Axis, index: Ix) -> ArrayView { 50 | self.as_ref().index_axis(axis, index) 51 | } 52 | 53 | /// Get a mutable subview of the Hogwild array. 54 | #[inline] 55 | pub fn subview_mut(&mut self, axis: Axis, index: Ix) -> ArrayViewMut { 56 | self.as_mut().index_axis_mut(axis, index) 57 | } 58 | } 59 | 60 | impl HogwildArray 61 | where 62 | D: Dimension, 63 | { 64 | /// Get an immutable view of the Hogwild array. 65 | #[inline] 66 | pub fn view(&self) -> ArrayView { 67 | self.as_ref().view() 68 | } 69 | } 70 | 71 | impl From> for HogwildArray { 72 | fn from(a: Array) -> Self { 73 | HogwildArray(Arc::new(UnsafeCell::new(a))) 74 | } 75 | } 76 | 77 | unsafe impl Send for HogwildArray {} 78 | 79 | unsafe impl Sync for HogwildArray {} 80 | 81 | /// Two-dimensional Hogwild array. 82 | pub type HogwildArray2 = HogwildArray; 83 | 84 | /// Hogwild for arbitrary data types. 85 | /// 86 | /// `Hogwild` subverts Rust's type system by allowing concurrent modification 87 | /// of values. This should only be used for data types that cannot end up in 88 | /// an inconsistent state due to data races. For arrays `HogwildArray` should 89 | /// be preferred. 90 | #[derive(Clone)] 91 | pub struct Hogwild(Arc>); 92 | 93 | impl Default for Hogwild 94 | where 95 | T: Default, 96 | { 97 | fn default() -> Self { 98 | Hogwild(Arc::new(UnsafeCell::new(T::default()))) 99 | } 100 | } 101 | 102 | impl Deref for Hogwild { 103 | type Target = T; 104 | 105 | fn deref(&self) -> &Self::Target { 106 | let ptr = self.0.as_ref().get(); 107 | unsafe { &*ptr } 108 | } 109 | } 110 | 111 | impl DerefMut for Hogwild { 112 | fn deref_mut(&mut self) -> &mut T { 113 | let ptr = self.0.as_ref().get(); 114 | unsafe { &mut *ptr } 115 | } 116 | } 117 | 118 | unsafe impl Send for Hogwild {} 119 | 120 | unsafe impl Sync for Hogwild {} 121 | 122 | #[cfg(test)] 123 | mod test { 124 | use ndarray::Array2; 125 | 126 | use super::{Hogwild, HogwildArray2}; 127 | 128 | #[test] 129 | pub fn hogwild_test() { 130 | let mut a1: Hogwild = Hogwild::default(); 131 | let mut a2 = a1.clone(); 132 | 133 | *a1 = 1; 134 | assert_eq!(*a2, 1); 135 | *a2 = 2; 136 | assert_eq!(*a1, 2); 137 | } 138 | 139 | #[test] 140 | pub fn hogwild_array_test() { 141 | let mut a1: HogwildArray2 = Array2::zeros((2, 2)).into(); 142 | let mut a2 = a1.clone(); 143 | 144 | let mut a1_view = a1.as_mut().view_mut(); 145 | 146 | let c00 = &mut a1_view[(0, 0)]; 147 | *c00 = 1.0; 148 | 149 | // Two simultaneous mutable borrows of the underlying array. 150 | a2.as_mut().view_mut()[(1, 1)] = *c00 * 2.0; 151 | 152 | assert_eq!(&[1f32, 0f32, 0f32, 2f32], a2.as_ref().as_slice().unwrap()); 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /src/util.rs: -------------------------------------------------------------------------------- 1 | use rand::SeedableRng; 2 | use rand_core::{self, RngCore}; 3 | use serde::Serialize; 4 | 5 | /// Tolerance for small negative values. 6 | const NEGATIVE_TOLERANCE: f32 = 1e-5; 7 | 8 | /// Add a small value, to prevent returning Inf on underflow. 9 | #[inline] 10 | pub fn safe_ln(v: f32) -> f32 { 11 | (v + NEGATIVE_TOLERANCE).ln() 12 | } 13 | 14 | /// RNG that reseeds on clone. 15 | /// 16 | /// This is a wrapper struct for RNGs implementing the `RngCore` 17 | /// trait. It adds the following simple behavior: when a 18 | /// `ReseedOnCloneRng` is cloned, the clone is constructed using fresh 19 | /// entropy. This assures that the state of the clone is not related 20 | /// to the cloned RNG. 21 | /// 22 | /// The `rand` crate provides similar behavior in the `ReseedingRng` 23 | /// struct. However, `ReseedingRng` requires that the RNG is 24 | /// `BlockRngCore`. 25 | pub struct ReseedOnCloneRng(pub R) 26 | where 27 | R: RngCore + SeedableRng; 28 | 29 | impl RngCore for ReseedOnCloneRng 30 | where 31 | R: RngCore + SeedableRng, 32 | { 33 | #[inline] 34 | fn next_u32(&mut self) -> u32 { 35 | self.0.next_u32() 36 | } 37 | 38 | #[inline] 39 | fn next_u64(&mut self) -> u64 { 40 | self.0.next_u64() 41 | } 42 | 43 | #[inline] 44 | fn fill_bytes(&mut self, dest: &mut [u8]) { 45 | self.0.fill_bytes(dest) 46 | } 47 | 48 | #[inline] 49 | fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { 50 | self.0.try_fill_bytes(dest) 51 | } 52 | } 53 | 54 | impl Clone for ReseedOnCloneRng 55 | where 56 | R: RngCore + SeedableRng, 57 | { 58 | fn clone(&self) -> Self { 59 | ReseedOnCloneRng(R::from_entropy()) 60 | } 61 | } 62 | 63 | #[derive(Serialize)] 64 | pub(crate) struct VersionInfo { 65 | finalfusion_version: &'static str, 66 | git_desc: Option<&'static str>, 67 | } 68 | 69 | impl VersionInfo { 70 | pub(crate) fn new() -> Self { 71 | VersionInfo { 72 | finalfusion_version: env!("CARGO_PKG_VERSION"), 73 | git_desc: option_env!("MAYBE_FINALFRONTIER_GIT_DESC"), 74 | } 75 | } 76 | } 77 | 78 | #[cfg(test)] 79 | pub use self::test::*; 80 | 81 | #[cfg(test)] 82 | mod test { 83 | #[cfg(any(feature = "neon", target_feature = "sse", target_feature = "avx"))] 84 | use ndarray::{ArrayView, Dimension}; 85 | use rand::SeedableRng; 86 | use rand_core::{self, impls, le, RngCore}; 87 | 88 | use super::ReseedOnCloneRng; 89 | 90 | #[derive(Clone)] 91 | struct BogusRng(pub u64); 92 | 93 | impl RngCore for BogusRng { 94 | fn next_u32(&mut self) -> u32 { 95 | self.next_u64() as u32 96 | } 97 | 98 | fn next_u64(&mut self) -> u64 { 99 | self.0 += 1; 100 | self.0 101 | } 102 | 103 | fn fill_bytes(&mut self, dest: &mut [u8]) { 104 | impls::fill_bytes_via_next(self, dest) 105 | } 106 | 107 | fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { 108 | self.fill_bytes(dest); 109 | Ok(()) 110 | } 111 | } 112 | 113 | impl SeedableRng for BogusRng { 114 | type Seed = [u8; 8]; 115 | 116 | fn from_seed(seed: Self::Seed) -> Self { 117 | let mut state = [0u64; 1]; 118 | le::read_u64_into(&seed, &mut state); 119 | BogusRng(state[0]) 120 | } 121 | } 122 | 123 | pub fn close(a: f32, b: f32, eps: f32) -> bool { 124 | let diff = (a - b).abs(); 125 | if diff > eps { 126 | return false; 127 | } 128 | 129 | true 130 | } 131 | 132 | pub fn all_close(a: &[f32], b: &[f32], eps: f32) -> bool { 133 | for (&av, &bv) in a.iter().zip(b) { 134 | if !close(av, bv, eps) { 135 | return false; 136 | } 137 | } 138 | 139 | true 140 | } 141 | 142 | #[cfg(any(feature = "neon", target_feature = "sse", target_feature = "avx"))] 143 | pub fn array_all_close(a: ArrayView, b: ArrayView, eps: f32) -> bool 144 | where 145 | Ix: Dimension, 146 | { 147 | for (&av, &bv) in a.iter().zip(b) { 148 | if !close(av, bv, eps) { 149 | return false; 150 | } 151 | } 152 | 153 | true 154 | } 155 | 156 | #[test] 157 | fn reseed_on_clone_rng() { 158 | let bogus_rng = BogusRng::from_entropy(); 159 | let bogus_rng_clone = bogus_rng.clone(); 160 | assert_eq!(bogus_rng.0, bogus_rng_clone.0); 161 | 162 | let reseed = ReseedOnCloneRng(bogus_rng); 163 | let reseed_clone = reseed.clone(); 164 | // One in 2^64 probability of collision given good entropy source. 165 | assert_ne!((reseed.0).0, (reseed_clone.0).0); 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /src/loss.rs: -------------------------------------------------------------------------------- 1 | use ndarray::ArrayView1; 2 | 3 | use crate::util; 4 | use crate::vec_simd::dot; 5 | 6 | /// Absolute activations to round in logistic regression. 7 | /// 8 | /// Since the logistic function is asymptotic, there is always (a small) 9 | /// gradient for larger activations. As a result, optimization of logistic 10 | /// regression will not converge without e.g. regularization. In the 11 | /// training of embeddings, this has the result of ever-increasing weights 12 | /// (amplified by the optimization two vectors). 13 | /// 14 | /// A simpler solution than regularization is to round the output of the 15 | /// logistic function to 0 (negative activation) or 1 (positive activiation) 16 | /// for large activations, to kill gradient. 17 | /// 18 | /// This constant controls at what activation the logistic function should 19 | /// round. 20 | const LOGISTIC_ROUND_ACTIVATION: f32 = 10.0; 21 | 22 | /// Return the loss and gradient of the co-occurence classification. 23 | /// 24 | /// This function returns the negative log likelihood and gradient of 25 | /// a training instance using the probability function *P(1|x) = 26 | /// σ(u·v)*. `u` and `v` are word embeddings and `label` is the 27 | /// target label, where a label of `1` means that the words co-occur 28 | /// and a label of `0` that they do not. 29 | /// 30 | /// This model is very similar to logistic regression, except that we 31 | /// optimize both u and v. 32 | /// 33 | /// The loss is as follows (y is used as the label): 34 | /// 35 | /// log(P(y|x)) = 36 | /// y log(P(1|x)) + (1-y) log(P(0|x)) = 37 | /// y log(P(1|x)) + (1-y) log(1 - P(1|x)) = 38 | /// y log(σ(u·v)) + (1-y) log(1 - σ(u·v)) = 39 | /// y log(σ(u·v)) + (1-y) log(σ(-u·v)) 40 | /// 41 | /// We can simplify the first term: 42 | /// 43 | /// y log(σ(u·v)) = 44 | /// y log(1/(1+e^{-u·v})) = 45 | /// -y log(1+e^{-u·v}) 46 | /// 47 | /// Then we find the derivative with respect to v_1: 48 | /// 49 | /// ∂/∂v_1 -y log(1+e^{-u·v}) = 50 | /// -y σ(u·v) ∂/∂v_1(1+e^{-u·v}) = 51 | /// -y σ(u·v) e^{-u·v} -u_1 = 52 | /// y σ(-u·v) u_1 = 53 | /// y (1 - σ(u·v)) u_1 = 54 | /// (y - yσ(u·v)) u_1 55 | /// 56 | /// Iff y = 1, then: 57 | /// 58 | /// 1 - σ(u·v) 59 | /// 60 | /// For the second term above, we also find the derivative: 61 | /// 62 | /// ∂/∂v_1 -(1 - y) log(1+e^{u·v}) = 63 | /// -(1 - y) σ(-u·v) ∂/∂v_1(1+e^{u·v}) = 64 | /// -(1 - y) σ(-u·v) e^{u·v} ∂/∂v_1 u·v= 65 | /// -(1 - y) σ(-u·v) e^{u·v} u_1 = 66 | /// -(1 - y) σ(u·v) u_1 = 67 | /// (-σ(u·v) + yσ(u·v)) u_1 68 | /// 69 | /// When y = 0 then: 70 | /// 71 | /// -σ(u·v)u_1 72 | /// 73 | /// Combining both, the partial derivative of v_1 is: y - σ(u·v)u_1 74 | /// 75 | /// We return y - σ(u·v) as the gradient, so that the caller can compute 76 | /// the gradient for all components of u and v. 77 | pub fn log_logistic_loss(u: ArrayView1, v: ArrayView1, label: bool) -> (f32, f32) { 78 | let dp = dot(u, v); 79 | let lf = logistic_function(dp); 80 | let grad = (label as usize) as f32 - lf; 81 | let loss = if label { 82 | -util::safe_ln(lf) 83 | } else { 84 | -util::safe_ln(1.0 - lf) 85 | }; 86 | 87 | (loss, grad) 88 | } 89 | 90 | /// Compute the logistic function. 91 | /// 92 | /// **σ(a) = 1 / (1 + e^{-a})** 93 | fn logistic_function(a: f32) -> f32 { 94 | if a > LOGISTIC_ROUND_ACTIVATION { 95 | 1.0 96 | } else if a < -LOGISTIC_ROUND_ACTIVATION { 97 | 0.0 98 | } else { 99 | 1.0 / (1.0 + (-a).exp()) 100 | } 101 | } 102 | 103 | #[cfg(test)] 104 | mod tests { 105 | use ndarray::Array1; 106 | 107 | use crate::util::{all_close, close}; 108 | 109 | use super::{log_logistic_loss, logistic_function}; 110 | 111 | #[test] 112 | fn logistic_function_test() { 113 | let activations = &[ 114 | -11.0, -5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 11.0, 115 | ]; 116 | let outputs: Vec<_> = activations.iter().map(|&a| logistic_function(a)).collect(); 117 | assert!(all_close( 118 | &[ 119 | 0.0, 0.00669, 0.01799, 0.04743, 0.11920, 0.26894, 0.5, 0.73106, 0.88080, 0.95257, 120 | 0.982014, 0.99331, 1.0 121 | ], 122 | outputs.as_slice(), 123 | 1e-5 124 | )); 125 | } 126 | 127 | #[test] 128 | fn log_logistic_loss_test() { 129 | let a = Array1::from_shape_vec((6,), vec![1., 1., 1., 0., 0., 0.]).unwrap(); 130 | let a_orth = Array1::from_shape_vec((6,), vec![0., 0., 0., 1., 1., 1.]).unwrap(); 131 | let a_opp = Array1::from_shape_vec((6,), vec![-1., -1., -1., 0., 0., 0.]).unwrap(); 132 | 133 | let (loss, gradient) = log_logistic_loss(a.view(), a_orth.view(), true); 134 | assert!(close(loss, 0.69312, 1e-5)); 135 | assert!(close(gradient, 0.5, 1e-5)); 136 | 137 | let (loss, gradient) = log_logistic_loss(a.view(), a_orth.view(), false); 138 | assert!(close(loss, 0.69312, 1e-5)); 139 | assert!(close(gradient, -0.5, 1e-5)); 140 | 141 | let (loss, gradient) = log_logistic_loss(a.view(), a.view(), true); 142 | assert!(close(loss, 0.04858, 1e-5)); 143 | assert!(close(gradient, 0.04742, 1e-5)); 144 | 145 | let (loss, gradient) = log_logistic_loss(a.view(), a_opp.view(), false); 146 | assert!(close(loss, 0.04858, 1e-5)); 147 | assert!(close(gradient, -0.04743, 1e-5)); 148 | 149 | let (loss, gradient) = log_logistic_loss(a.view(), a_opp.view(), true); 150 | assert!(close(loss, 3.04838, 1e-5)); 151 | assert!(close(gradient, 0.95257, 1e-5)); 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /src/vocab/simple.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Borrow; 2 | use std::collections::HashMap; 3 | use std::hash::Hash; 4 | 5 | use finalfusion::vocab::{SimpleVocab as FiFuSimpleVocab, VocabWrap}; 6 | 7 | use crate::idx::{SingleIdx, WordIdx}; 8 | use crate::vocab::{create_discards, create_indices}; 9 | use crate::{CountedType, SimpleVocabConfig, Vocab, VocabBuilder}; 10 | 11 | /// Generic corpus vocabulary type. 12 | /// 13 | /// Can be used as an input or output lookup. 14 | #[derive(Clone)] 15 | pub struct SimpleVocab { 16 | config: SimpleVocabConfig, 17 | types: Vec>, 18 | index: HashMap, 19 | n_types: usize, 20 | discards: Vec, 21 | } 22 | 23 | impl SimpleVocab 24 | where 25 | T: Hash + Eq + Clone + Ord, 26 | { 27 | /// Constructor only used by the Vocabbuilder 28 | pub(crate) fn new( 29 | config: SimpleVocabConfig, 30 | types: Vec>, 31 | n_types: usize, 32 | ) -> Self { 33 | let discards = create_discards(config.discard_threshold, &types, n_types); 34 | let index = create_indices(&types); 35 | SimpleVocab { 36 | config, 37 | types, 38 | index, 39 | n_types, 40 | discards, 41 | } 42 | } 43 | 44 | /// Get a specific context 45 | pub fn get(&self, context: &Q) -> Option<&CountedType> 46 | where 47 | T: Borrow, 48 | Q: Hash + ?Sized + Eq, 49 | { 50 | self.idx(context) 51 | .map(|idx| &self.types[idx.word_idx() as usize]) 52 | } 53 | } 54 | 55 | impl From> for VocabWrap { 56 | fn from(vocab: SimpleVocab) -> VocabWrap { 57 | FiFuSimpleVocab::new( 58 | vocab 59 | .types 60 | .iter() 61 | .map(|l| l.label().to_owned()) 62 | .collect::>(), 63 | ) 64 | .into() 65 | } 66 | } 67 | 68 | impl Vocab for SimpleVocab 69 | where 70 | T: Hash + Eq, 71 | { 72 | type VocabType = T; 73 | type IdxType = SingleIdx; 74 | type Config = SimpleVocabConfig; 75 | 76 | fn config(&self) -> SimpleVocabConfig { 77 | self.config 78 | } 79 | 80 | fn idx(&self, key: &Q) -> Option 81 | where 82 | Self::VocabType: Borrow, 83 | Q: Hash + ?Sized + Eq, 84 | { 85 | self.index 86 | .get(key) 87 | .cloned() 88 | .map(|idx| SingleIdx::from_word_idx(idx as u64)) 89 | } 90 | 91 | fn discard(&self, idx: usize) -> f32 { 92 | self.discards[idx] 93 | } 94 | 95 | fn n_input_types(&self) -> usize { 96 | self.len() 97 | } 98 | 99 | fn types(&self) -> &[CountedType] { 100 | &self.types 101 | } 102 | 103 | fn n_types(&self) -> usize { 104 | self.n_types 105 | } 106 | } 107 | 108 | /// Constructs a `SimpleVocab` from a `VocabBuilder` where `T: Into`. 109 | impl From> for SimpleVocab 110 | where 111 | T: Hash + Eq + Into, 112 | S: Hash + Eq + Clone + Ord, 113 | { 114 | fn from(builder: VocabBuilder) -> Self { 115 | let types = builder.config.cutoff.filter(builder.items); 116 | SimpleVocab::new(builder.config, types, builder.n_items) 117 | } 118 | } 119 | 120 | #[cfg(test)] 121 | mod tests { 122 | use super::{SimpleVocab, Vocab, VocabBuilder}; 123 | use crate::idx::WordIdx; 124 | use crate::{util, Cutoff, SimpleVocabConfig}; 125 | 126 | const TEST_SIMPLECONFIG: SimpleVocabConfig = SimpleVocabConfig { 127 | discard_threshold: 1e-4, 128 | cutoff: Cutoff::MinCount(2), 129 | }; 130 | 131 | #[test] 132 | pub fn types_are_sorted_simple_vocab() { 133 | let mut builder: VocabBuilder = 134 | VocabBuilder::new(TEST_SIMPLECONFIG); 135 | for _ in 0..5 { 136 | builder.count("a"); 137 | } 138 | for _ in 0..2 { 139 | builder.count("b"); 140 | } 141 | for _ in 0..10 { 142 | builder.count("d"); 143 | } 144 | builder.count("c"); 145 | 146 | let vocab: SimpleVocab<&str> = builder.into(); 147 | let contexts = vocab.types(); 148 | for idx in 1..contexts.len() { 149 | assert!( 150 | contexts[idx - 1].count >= contexts[idx].count, 151 | "Types are not frequency-sorted" 152 | ); 153 | } 154 | } 155 | 156 | #[test] 157 | pub fn test_simple_vocab_builder() { 158 | let mut builder: VocabBuilder = 159 | VocabBuilder::new(TEST_SIMPLECONFIG); 160 | for _ in 0..5 { 161 | builder.count("a"); 162 | } 163 | for _ in 0..2 { 164 | builder.count("b"); 165 | } 166 | for _ in 0..10 { 167 | builder.count("d"); 168 | } 169 | builder.count("c"); 170 | 171 | let vocab: SimpleVocab<&str> = builder.into(); 172 | 173 | assert_eq!(vocab.len(), 3); 174 | assert_eq!(vocab.get("c"), None); 175 | 176 | assert_eq!(vocab.n_types(), 18); 177 | let a = vocab.get("a").unwrap(); 178 | assert_eq!("a", a.label); 179 | assert_eq!(5, a.count()); 180 | // 0.0001 / 5/18 + (0.0001 / 5/18).sqrt() = 0.019334 181 | assert!(util::close( 182 | 0.019334, 183 | vocab.discard(vocab.idx("a").unwrap().word_idx() as usize), 184 | 1e-5 185 | )); 186 | } 187 | } 188 | -------------------------------------------------------------------------------- /src/dep_trainer.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Borrow; 2 | use std::sync::Arc; 3 | 4 | use anyhow::{bail, Result}; 5 | use rand::{Rng, SeedableRng}; 6 | use serde::Serialize; 7 | use udgraph::graph::Sentence; 8 | 9 | use crate::idx::WordIdx; 10 | use crate::sampling::DiscountFrequencyRangeGenerator; 11 | use crate::train_model::{NegativeSamples, TrainIterFrom}; 12 | use crate::util::ReseedOnCloneRng; 13 | use crate::{ 14 | CommonConfig, DepembedsConfig, Dependency, DependencyIterator, SimpleVocab, SimpleVocabConfig, 15 | Trainer, Vocab, 16 | }; 17 | 18 | /// Dependency embeddings Trainer. 19 | /// 20 | /// The `DepembedsTrainer` holds the information and logic necessary to transform a 21 | /// `conllu::Sentence` into an iterator of focus and context tuples. The struct is cheap to clone 22 | /// because the vocabulary is shared between clones. 23 | #[derive(Clone)] 24 | pub struct DepembedsTrainer { 25 | dep_config: DepembedsConfig, 26 | common_config: CommonConfig, 27 | input_vocab: Arc, 28 | output_vocab: Arc>, 29 | range_gen: DiscountFrequencyRangeGenerator, 30 | rng: R, 31 | } 32 | 33 | impl DepembedsTrainer { 34 | pub fn dep_config(&self) -> DepembedsConfig { 35 | self.dep_config 36 | } 37 | } 38 | 39 | impl DepembedsTrainer, V> 40 | where 41 | R: Rng + Clone + SeedableRng, 42 | { 43 | /// Constructs a new `DepTrainer`. 44 | pub fn new( 45 | input_vocab: V, 46 | output_vocab: SimpleVocab, 47 | common_config: CommonConfig, 48 | dep_config: DepembedsConfig, 49 | rng: R, 50 | ) -> Self { 51 | let rng = ReseedOnCloneRng(rng); 52 | let range_gen = DiscountFrequencyRangeGenerator::new_with_default_table_size( 53 | rng.clone(), 54 | output_vocab.types(), 55 | ); 56 | DepembedsTrainer { 57 | common_config, 58 | dep_config, 59 | input_vocab: Arc::new(input_vocab), 60 | output_vocab: Arc::new(output_vocab), 61 | range_gen, 62 | rng, 63 | } 64 | } 65 | } 66 | 67 | impl NegativeSamples for DepembedsTrainer 68 | where 69 | R: Rng, 70 | { 71 | fn negative_sample(&mut self, output: usize) -> usize { 72 | loop { 73 | let negative = self.range_gen.next().unwrap(); 74 | if negative != output { 75 | return negative; 76 | } 77 | } 78 | } 79 | } 80 | 81 | impl<'a, R, V> TrainIterFrom<'a, Sentence> for DepembedsTrainer 82 | where 83 | R: Rng, 84 | V: Vocab, 85 | V::VocabType: Borrow, 86 | V::IdxType: WordIdx + 'a, 87 | { 88 | type Iter = Box)> + 'a>; 89 | type Focus = V::IdxType; 90 | type Contexts = Vec; 91 | 92 | fn train_iter_from(&mut self, sentence: &Sentence) -> Self::Iter { 93 | let invalid_idx = self.input_vocab.len() as u64; 94 | let mut tokens = vec![WordIdx::from_word_idx(invalid_idx); sentence.len() - 1]; 95 | for (idx, token) in sentence.iter().filter_map(|node| node.token()).enumerate() { 96 | if let Some(vocab_idx) = self.input_vocab.idx(token.form()) { 97 | if self.rng.gen_range(0f32..1f32) 98 | < self.input_vocab.discard(vocab_idx.word_idx() as usize) 99 | { 100 | tokens[idx] = vocab_idx 101 | } 102 | } 103 | } 104 | 105 | let mut contexts = vec![Vec::new(); sentence.len() - 1]; 106 | let graph = sentence.dep_graph(); 107 | for (focus, dep) in DependencyIterator::new_from_config(&graph, self.dep_config) 108 | .filter(|(focus, _dep)| tokens[*focus].word_idx() != invalid_idx) 109 | { 110 | if let Some(dep_id) = self.output_vocab.idx(&dep) { 111 | if self.rng.gen_range(0f32..1f32) < self.output_vocab.discard(dep_id.idx() as usize) 112 | { 113 | contexts[focus].push(dep_id.idx() as usize) 114 | } 115 | } 116 | } 117 | Box::new( 118 | tokens 119 | .into_iter() 120 | .zip(contexts) 121 | .filter(move |(focus, _)| focus.word_idx() != invalid_idx), 122 | ) 123 | } 124 | } 125 | 126 | impl Trainer for DepembedsTrainer 127 | where 128 | R: Rng, 129 | V: Vocab, 130 | V::Config: Serialize, 131 | { 132 | type InputVocab = V; 133 | type Metadata = DepembedsMetadata; 134 | 135 | fn input_vocab(&self) -> &Self::InputVocab { 136 | &self.input_vocab 137 | } 138 | 139 | fn try_into_input_vocab(self) -> Result { 140 | match Arc::try_unwrap(self.input_vocab) { 141 | Ok(vocab) => Ok(vocab), 142 | Err(_) => bail!("Cannot unwrap input vocab."), 143 | } 144 | } 145 | 146 | fn n_input_types(&self) -> usize { 147 | self.input_vocab.n_input_types() 148 | } 149 | 150 | fn n_output_types(&self) -> usize { 151 | self.output_vocab.len() 152 | } 153 | 154 | fn config(&self) -> &CommonConfig { 155 | &self.common_config 156 | } 157 | 158 | fn to_metadata(&self) -> Self::Metadata { 159 | DepembedsMetadata { 160 | common_config: self.common_config, 161 | dep_config: self.dep_config, 162 | input_vocab_config: self.input_vocab.config(), 163 | output_vocab_config: self.output_vocab.config(), 164 | } 165 | } 166 | } 167 | 168 | /// Metadata for dependency embeddings. 169 | #[derive(Clone, Copy, Debug, Serialize)] 170 | pub struct DepembedsMetadata { 171 | common_config: CommonConfig, 172 | #[serde(rename = "model_config")] 173 | dep_config: DepembedsConfig, 174 | input_vocab_config: IC, 175 | output_vocab_config: OC, 176 | } 177 | -------------------------------------------------------------------------------- /man/finalfrontier-skipgram.1.md: -------------------------------------------------------------------------------- 1 | % FINALFRONTIER-SKIPGRAM(1) 2 | % Daniel de Kok 3 | % Sep 8, 2018 4 | 5 | NAME 6 | ==== 7 | 8 | **finalfrontier skipgram** -- train word embeddings with subword representations 9 | 10 | SYNOPSIS 11 | ======== 12 | 13 | **finalfrontier skipgram** [*options*] *corpus* *output* 14 | 15 | DESCRIPTION 16 | =========== 17 | 18 | The **finalfrontier skipgram** subcommand trains word embeddings using data 19 | from a *corpus*. The corpus should have tokens separated by spaces and 20 | sentences separated by newlines. After training, the embeddings are written to 21 | *output* in the finalfusion format. 22 | 23 | OPTIONS 24 | ======= 25 | 26 | `--buckets` *EXP* 27 | 28 | : The bucket exponent. finalfrontier uses 2^*EXP* buckets to store subword 29 | representations. Each subword representation (n-gram) is hashed and 30 | mapped to a bucket based on this hash. Using more buckets will result 31 | in fewer bucket collisions between subword representations at the cost 32 | of memory use. The default bucket exponent is *21* (approximately 2 33 | million buckets). 34 | 35 | `--context` *CONTEXT_SIZE* 36 | 37 | : Words within the *CONTEXT_SIZE* of a focus word will be used to learn 38 | the representation of the focus word. The default context size is *10*. 39 | 40 | `--dims` *DIMENSIONS* 41 | 42 | : The dimensionality of the trained word embeddings. The default 43 | dimensionality is 300. 44 | 45 | `--discard` *THRESHOLD* 46 | 47 | : The discard threshold influences how often frequent words are discarded 48 | from training. The default discard threshold is *1e-4*. 49 | 50 | `--epochs` *N* 51 | 52 | : The number of training epochs. The number of necessary training epochs 53 | typically decreases with the corpus size. The default number of epochs 54 | is *15*. 55 | 56 | `-f`, `--format` *FORMAT* 57 | 58 | : The output format. This must be one of *fasttext*, *finalfusion*, 59 | *word2vec*, *text*, and *textdims*. 60 | 61 | All formats, except *finalfusion*, result in a loss of 62 | information: *word2vec*, *text*, and *textdims* do not store 63 | subword embeddings, nor hyperparameters. The *fastText* format 64 | does not store all hyperparemeters. 65 | 66 | The *fasttext* format can only be used in conjunction with 67 | `--subwords buckets` and `--hash-indexer fasttext`. 68 | 69 | `--hash-indexer` *INDEXER* 70 | 71 | : The indexer to use when bucket-based subwords are used (see 72 | `--subwords`). The possible values are *finalfusion* or 73 | *fasttext*. Default: finalfusion 74 | 75 | *finalfusion* uses the FNV-1a hasher, whereas *fasttext* emulates 76 | the (broken) implementation of FNV-1a in fastText. Use of 77 | *finalfusion* is recommended, unless the resulting embeddings 78 | should be compatible with fastText. 79 | 80 | `--lr` *LEARNING_RATE* 81 | 82 | : The learning rate determines what fraction of a gradient is used for 83 | parameter updates. The default initial learning rate is *0.05*, the 84 | learning rate decreases monotonically during training. 85 | 86 | `--maxn` *LEN* 87 | 88 | : The maximum n-gram length for subword representations. Default: 6 89 | 90 | `--mincount` *FREQ* 91 | 92 | : The minimum count controls discarding of infrequent. Words occuring 93 | fewer than *FREQ* times are not considered during training. The 94 | default minimum count is 5. 95 | 96 | `--minn` *LEN* 97 | 98 | : The minimum n-gram length for subword representations. Default: 3 99 | 100 | `--model` *MODEL* 101 | 102 | : The model to use for training word embeddings. The choices here are: 103 | *dirgram* for the directional skip-gram model (Song et al., 2018), 104 | *skipgram* for the skip-gram model (Mikolov et al., 2013), and 105 | *structgram* for the stuctured skip-gram model (Ling et al. 2015). 106 | 107 | The structured skip-gram model takes the position of a context word 108 | into account and results in embeddings that are typically better 109 | suited for syntax-oriented tasks. 110 | 111 | The dependency embeddings model is supported by the separate 112 | `finalfrontier deps`(1) subcommand. 113 | 114 | The default model is *skipgram*. 115 | 116 | `--ngram-mincount` *FREQ* 117 | 118 | : The minimum n-gram frequency. n-grams occurring fewer than *FREQ* 119 | times are excluded from training. This option is only applicable 120 | with the *ngrams* argument of the `subwords` option. 121 | 122 | `--ngram-target-size` *SIZE* 123 | 124 | : The target size for the n-gram vocabulary. At most *SIZE* n-ngrams are 125 | included for training. Only n-grams appearing more frequently than the 126 | n-gram at *SIZE* are included. This option is only applicable with the 127 | *ngrams* argument of the `subwords` option. 128 | 129 | `--ns` *FREQ* 130 | 131 | : The number of negatives to sample per positive example. Default: 5 132 | 133 | `--subwords` *SUBWORDS* 134 | 135 | : The type of subword embeddings to train. The possible types are 136 | *buckets*, *ngrams*, and *none*. Subword embeddings are used to 137 | compute embeddings for unknown words by summing embeddings of 138 | n-grams within unknown words. 139 | 140 | The *none* type does not use subwords. The resulting model will 141 | not be able assign an embeddings to unknown words. 142 | 143 | The *ngrams* type stores subword n-grams explicitly. The included 144 | n-gram lengths are specified using the `minn` and `maxn` 145 | options. The frequency threshold for n-grams is configured with 146 | the `ngram-mincount` option. 147 | 148 | The *buckets* type maps n-grams to buckets using the FNV1 hash. 149 | The considered n-gram lengths are specified using the `minn` and 150 | `maxn` options. The number of buckets is controlled with the 151 | `buckets` option. 152 | 153 | `--target-size` *SIZE* 154 | 155 | : The target size for the token vocabulary. At most *SIZE* tokens are 156 | included for training. Only tokens appearing more frequently than the token 157 | at *SIZE* are included. 158 | 159 | `--threads` *N* 160 | 161 | : The number of thread to use during training for 162 | parallelization. The default is to use half of the logical CPUs of 163 | the machine, capped at 20 threads. Increasing the number of 164 | threads increases the probability of update collisions, requiring 165 | more epochs to reach the same loss. 166 | 167 | `--zipf` *EXP* 168 | 169 | : Exponent *s* used in the Zipf distribution `p(k) = 1 / (k^s H_N)` for 170 | negative sampling. Default: 0.5 171 | 172 | EXAMPLES 173 | ======== 174 | 175 | Train embeddings on *dewiki.txt* using the skip-gram model: 176 | 177 | finalfrontier skipgram dewiki.txt dewiki-skipgram.bin 178 | 179 | Train embeddings with dimensionality 200 on *dewiki.txt* using the 180 | structured skip-gram model with a context window of 5 tokens: 181 | 182 | finalfrontier skipgram --model structgram --context 5 --dims 200 \ 183 | dewiki.txt dewiki-structgram.bin 184 | 185 | SEE ALSO 186 | ======== 187 | 188 | `finalfrontier`(1), `finalfrontier-deps`(1) 189 | -------------------------------------------------------------------------------- /man/finalfrontier-deps.1.md: -------------------------------------------------------------------------------- 1 | % FINALFRONTIER-DEPS(1) % Daniel de Kok, Sebastian Pütz % Apr 6, 2019 2 | 3 | NAME 4 | ==== 5 | 6 | **finalfrontier deps** -- train dependency-based word embeddings with subword 7 | representations 8 | 9 | SYNOPSIS 10 | ======== 11 | 12 | **finalfrontier deps** [*options*] *corpus* *output* 13 | 14 | DESCRIPTION 15 | =========== 16 | 17 | The **finalfrontier-deps** subcommand trains dependency based word embeddings 18 | (Levy and Goldberg, 2014) using data from a *corpus* in CONLL-U format. The 19 | corpus contains sentences seperated by empty lines. Each sentence needs to be 20 | annotated with a dependency graph. After training, the embeddings are written 21 | to *output* in the finalfusion format. 22 | 23 | OPTIONS 24 | ======= 25 | 26 | `--buckets` *EXP* 27 | 28 | : The bucket exponent. finalfrontier uses 2^*EXP* buckets to store subword 29 | representations. Each subword representation (n-gram) is hashed and mapped to a 30 | bucket based on this hash. Using more buckets will result in fewer bucket 31 | collisions between subword representations at the cost of memory use. The 32 | default bucket exponent is *21* (approximately 2 million buckets). 33 | 34 | `--context-discard` *THRESHOLD* 35 | 36 | : The context discard threshold influences how often frequent contexts are 37 | discarded during training. The default context discard threshold is *1e-4*. 38 | 39 | `--context-mincount` *FREQ* 40 | 41 | : The minimum count controls discarding of infrequent contexts. Contexts 42 | occuring fewer than *FREQ* times are not considered during training. The 43 | default minimum count is 5. 44 | 45 | `--context-target-size` *SIZE* 46 | 47 | : The target size for the context vocabulary. At most *SIZE* contexts are 48 | included for training. Only contexts appearing more frequently than the 49 | context at *SIZE* are included. 50 | 51 | `--dependency-depth` *DEPTH* 52 | 53 | : Dependency contexts up to *DEPTH* distance from the focus word in the 54 | dependency graph will be used to learn the representation of the focus word. The 55 | default depth is *1*. 56 | 57 | `--dims` *DIMS* 58 | 59 | : The dimensionality of the trained word embeddings. The default 60 | dimensionality is 300. 61 | 62 | `--discard` *THRESHOLD* 63 | 64 | : The discard threshold influences how often frequent focus words are 65 | discarded from training. The default discard threshold is *1e-4*. 66 | 67 | `--epochs` *N* 68 | 69 | : The number of training epochs. The number of necessary training 70 | epochs typically decreases with the corpus size. The default 71 | number of epochs is *15*. 72 | 73 | `--hash-indexer` *INDEXER* 74 | 75 | : The indexer to use when bucket-based subwords are used (see 76 | `--subwords`). The possible values are *finalfusion* or 77 | *fasttext*. Default: finalfusion 78 | 79 | *finalfusion* uses the FNV-1a hasher, whereas *fasttext* emulates 80 | the (broken) implementation of FNV-1a in fastText. Use of 81 | *finalfusion* is recommended, unless the resulting embeddings 82 | should be compatible with fastText. 83 | 84 | `-f`, `--format` *FORMAT* 85 | 86 | : The output format. This must be one of *fasttext*, *finalfusion*, 87 | *word2vec*, *text*, and *textdims*. 88 | 89 | All formats, except *finalfusion*, result in a loss of 90 | information: *word2vec*, *text*, and *textdims* do not store 91 | subword embeddings, nor hyperparameters. The *fastText* format 92 | does not store all hyperparemeters. 93 | 94 | The *fasttext* format can only be used in conjunction with 95 | `--subwords buckets` and `--hash-indexer fasttext`. 96 | 97 | `--lr` *LEARNING_RATE* 98 | 99 | : The learning rate determines what fraction of a gradient is used for 100 | parameter updates. The default initial learning rate is *0.05*, the learning 101 | rate decreases monotonically during training. 102 | 103 | `--maxn` *LEN* 104 | 105 | : The maximum n-gram length for subword representations. Default: 6 106 | 107 | `--mincount` *FREQ* 108 | 109 | : The minimum count controls discarding of infrequent focus words. Focus words 110 | occuring fewer than *FREQ* times are not considered during training. The default 111 | minimum count is 5. 112 | 113 | `--minn` *LEN* 114 | 115 | : The minimum n-gram length for subword representations. Default: 3 116 | 117 | `--ngram-mincount` *FREQ* 118 | 119 | : The minimum n-gram frequency. n-grams occurring fewer than *FREQ* 120 | times are excluded from training. This option is only applicable 121 | with the *ngrams* argument of the `subwords` option. 122 | 123 | `--ngram-target-size` *SIZE* 124 | 125 | : The target size for the n-gram vocabulary. At most *SIZE* n-ngrams are 126 | included for training. Only n-grams appearing more frequently than the 127 | n-gram at *SIZE* are included. This option is only applicable with the 128 | *ngrams* argument of the `subwords` option. 129 | 130 | `--normalize-contexts` 131 | 132 | : Normalize the attached form in the dependency contexts. 133 | 134 | `--ns` *FREQ* 135 | 136 | : The number of negatives to sample per positive example. Default: 5 137 | 138 | `--projectivize` 139 | 140 | : Projectivize dependency graphs before training embeddings. 141 | 142 | `--threads` *N* 143 | 144 | : The number of thread to use during training for 145 | parallelization. The default is to use half of the logical CPUs of 146 | the machine, capped at 20 threads. Increasing the number of 147 | threads increases the probability of update collisions, requiring 148 | more epochs to reach the same loss. 149 | 150 | `--subwords` *SUBWORDS* 151 | 152 | : The type of subword embeddings to train. The possible types are 153 | *buckets*, *ngrams*, and *none*. Subword embeddings are used to 154 | compute embeddings for unknown words by summing embeddings of 155 | n-grams within unknown words. 156 | 157 | The *none* type does not use subwords. The resulting model will 158 | not be able assign an embeddings to unknown words. 159 | 160 | The *ngrams* type stores subword n-grams explicitly. The included 161 | n-gram lengths are specified using the `minn` and `maxn` 162 | options. The frequency threshold for n-grams is configured with 163 | the `ngram-mincount` option. 164 | 165 | The *buckets* type maps n-grams to buckets using the FNV1 hash. 166 | The considered n-gram lengths are specified using the `minn` and 167 | `maxn` options. The number of buckets is controlled with the 168 | `buckets` option. 169 | 170 | `--target-size` *SIZE* 171 | 172 | : The target size for the token vocabulary. At most *SIZE* tokens are 173 | included for training. Only tokens appearing more frequently than the token 174 | at *SIZE* are included. 175 | 176 | `--untyped-deps` 177 | 178 | : Only use the word of the attached token in the dependency relation as 179 | contexts to train the representation of the focus word. 180 | 181 | `--use-root` 182 | 183 | : Include the abstract root node in the dependency graph as contexts during 184 | training. 185 | 186 | `--zipf` *EXP* 187 | 188 | : Exponent *s* used in the Zipf distribution `p(k) = 1 / (k^s H_N)` for 189 | negative sampling. Default: 0.5 190 | 191 | EXAMPLES 192 | ======== 193 | 194 | Train embeddings on *dewiki.txt* using the dependency model with default 195 | parameters: 196 | 197 | finalfrontier deps dewiki.conll dewiki-deps.bin 198 | 199 | Train embeddings with dimensionality 200 on *dewiki.conll* using the dependency 200 | model from contexts with depth up to 2: 201 | 202 | finalfrontier deps --depth 2 --normalize --dims 200 \ 203 | dewiki.conll dewiki-deps.bin 204 | 205 | SEE ALSO 206 | ======== 207 | 208 | `finalfrontier`(1), `finalfrontier-skipgram`(1) 209 | -------------------------------------------------------------------------------- /src/sampling.rs: -------------------------------------------------------------------------------- 1 | use rand::distributions::{Distribution, Uniform}; 2 | use rand::Rng; 3 | 4 | use crate::CountedType; 5 | 6 | pub trait RangeGenerator: Iterator { 7 | /// Get the upper bound in *[0, upper_bound)*. 8 | fn upper_bound(&self) -> usize; 9 | } 10 | 11 | /// Default discount table size. 12 | const DISCOUNT_FREQ_TABLE_SIZE: usize = 10_000_000; 13 | 14 | /// Range generator that items based on their discounted frequencies. 15 | /// 16 | /// This range generator samples items with the probability 17 | /// 18 | /// *P(i) = f(i)^d / sum_j f(j)^d* 19 | /// 20 | /// where the power *d* is set to 0.5 to shrink the probability of 21 | /// frequent events more than infrequent events. Using shrunk frequencies 22 | /// was proposed by Mikolov et al., 2013. Use of the square root was 23 | /// proposed by Bojanowski et al., 2017. 24 | #[derive(Clone)] 25 | pub struct DiscountFrequencyRangeGenerator { 26 | rng: R, 27 | uniform: Uniform, 28 | vocab_indices: Vec, 29 | } 30 | 31 | impl DiscountFrequencyRangeGenerator 32 | where 33 | R: Rng, 34 | { 35 | /// Construct a range generator for the given vocabulary item counts. 36 | /// 37 | /// The `table_size` parameter is used to determine how large the sampling 38 | /// table should be. A larger table makes sampling more accurate, at the 39 | /// cost of larger memory use and increased cache misses. 40 | pub fn new(rng: R, vocab: &[CountedType], table_size: usize) -> Self { 41 | let sum = vocab.iter().map(|t| (t.count() as f32).sqrt()).sum::(); 42 | 43 | let mut vocab_indices = Vec::new(); 44 | for (i, counted_type) in vocab.iter().enumerate() { 45 | let discounted_freq = (counted_type.count() as f32).sqrt(); 46 | let type_n_entries = (discounted_freq / sum) * table_size as f32; 47 | vocab_indices.extend(std::iter::repeat(i).take(type_n_entries as usize)); 48 | } 49 | 50 | Self { 51 | rng, 52 | uniform: Uniform::new(0, vocab_indices.len()), 53 | vocab_indices, 54 | } 55 | } 56 | 57 | /// Construct a range generator for the given vocabulary item counts. 58 | /// 59 | /// Uses a table size that provides a good trade-off between accuracy and 60 | /// memory use. 61 | pub fn new_with_default_table_size(rng: R, vocab: &[CountedType]) -> Self { 62 | Self::new(rng, vocab, DISCOUNT_FREQ_TABLE_SIZE) 63 | } 64 | } 65 | 66 | impl Iterator for DiscountFrequencyRangeGenerator 67 | where 68 | R: Rng, 69 | { 70 | type Item = usize; 71 | 72 | fn next(&mut self) -> Option { 73 | Some(self.vocab_indices[self.uniform.sample(&mut self.rng)]) 74 | } 75 | } 76 | 77 | impl RangeGenerator for DiscountFrequencyRangeGenerator 78 | where 79 | R: Rng, 80 | { 81 | fn upper_bound(&self) -> usize { 82 | self.vocab_indices.last().map(|l| l + 1).unwrap_or(0) 83 | } 84 | } 85 | 86 | /// A banded range generator. 87 | /// 88 | /// This range generator assumes that the overall range consists of 89 | /// bands with a probability distribution implied by another range 90 | /// generator and items within that band with a uniform distribution. 91 | #[derive(Clone)] 92 | pub struct BandedRangeGenerator { 93 | uniform: Uniform, 94 | band_size: usize, 95 | inner: G, 96 | rng: R, 97 | } 98 | 99 | impl BandedRangeGenerator 100 | where 101 | R: Rng, 102 | G: RangeGenerator, 103 | { 104 | #[allow(dead_code)] 105 | pub fn new(rng: R, band_range_gen: G, band_size: usize) -> Self { 106 | BandedRangeGenerator { 107 | uniform: Uniform::new(0, band_size), 108 | band_size, 109 | inner: band_range_gen, 110 | rng, 111 | } 112 | } 113 | } 114 | 115 | impl Iterator for BandedRangeGenerator 116 | where 117 | R: Rng, 118 | G: RangeGenerator, 119 | { 120 | type Item = usize; 121 | 122 | fn next(&mut self) -> Option { 123 | let band = self.inner.next().unwrap(); 124 | if self.band_size == 1 { 125 | // Every band consist of one item, return the only item of the band. 126 | Some(band) 127 | } else { 128 | let band_item = self.uniform.sample(&mut self.rng); 129 | Some(band * self.band_size + band_item) 130 | } 131 | } 132 | } 133 | 134 | impl RangeGenerator for BandedRangeGenerator 135 | where 136 | R: Rng, 137 | G: RangeGenerator, 138 | { 139 | fn upper_bound(&self) -> usize { 140 | self.inner.upper_bound() * self.band_size 141 | } 142 | } 143 | 144 | #[cfg(test)] 145 | mod tests { 146 | use crate::sampling::DiscountFrequencyRangeGenerator; 147 | use crate::CountedType; 148 | use rand::SeedableRng; 149 | use rand_xorshift::XorShiftRng; 150 | 151 | use super::{BandedRangeGenerator, RangeGenerator}; 152 | use crate::util::{all_close, close}; 153 | 154 | const SEED: [u8; 16] = [ 155 | 0xe9, 0xfe, 0xf0, 0xfb, 0x6a, 0x23, 0x2a, 0xb3, 0x7c, 0xce, 0x27, 0x9b, 0x56, 0xac, 0xdb, 156 | 0xf8, 157 | ]; 158 | 159 | const SEED2: [u8; 16] = [ 160 | 0xc8, 0xae, 0xa3, 0x99, 0x28, 0x5a, 0xbb, 0x27, 0x90, 0xe9, 0x61, 0x60, 0xe5, 0xca, 0xfe, 161 | 0x22, 162 | ]; 163 | 164 | #[test] 165 | fn discount_frequency_range_generator() { 166 | let types = vec![ 167 | CountedType::new("foo", 1000), 168 | CountedType::new("bar", 100), 169 | CountedType::new("baz", 10), 170 | CountedType::new("quux", 1), 171 | ]; 172 | 173 | let generator = 174 | DiscountFrequencyRangeGenerator::new(XorShiftRng::from_seed(SEED), &types, 1000); 175 | 176 | assert_eq!(generator.upper_bound(), types.len()); 177 | 178 | let mut check_indices = Vec::new(); 179 | check_indices.extend(std::iter::repeat(0).take(690)); 180 | check_indices.extend(std::iter::repeat(1).take(218)); 181 | check_indices.extend(std::iter::repeat(2).take(69)); 182 | check_indices.extend(std::iter::repeat(3).take(21)); 183 | 184 | assert_eq!(generator.vocab_indices, check_indices); 185 | } 186 | 187 | #[test] 188 | fn banded_range_generator_test() { 189 | const DRAWS: usize = 20_000; 190 | 191 | let rng = XorShiftRng::from_seed(SEED); 192 | let inner_gen = DiscountFrequencyRangeGenerator::new( 193 | rng, 194 | &[ 195 | CountedType::new("a", 100), 196 | CountedType::new("b", 50), 197 | CountedType::new("c", 33), 198 | CountedType::new("d", 25), 199 | ], 200 | 208, 201 | ); 202 | 203 | let rng = XorShiftRng::from_seed(SEED2); 204 | let weighted_gen = BandedRangeGenerator::new(rng, inner_gen, 4); 205 | 206 | // Sample using the given weights. 207 | let mut hits = vec![0; weighted_gen.upper_bound()]; 208 | for idx in weighted_gen.take(DRAWS) { 209 | hits[idx] += 1; 210 | } 211 | 212 | // Convert counts to a probability distribution. 213 | let probs: Vec<_> = hits 214 | .into_iter() 215 | .map(|count| count as f32 / DRAWS as f32) 216 | .collect(); 217 | 218 | // Probabilities should be proportional to weights. 219 | assert!(all_close( 220 | &[ 221 | 0.09155, 0.0912, 0.0906, 0.08845, 0.0636, 0.0621, 0.0643, 0.0669, 0.0491, 0.0521, 222 | 0.0492, 0.05055, 0.0465, 0.04455, 0.044, 0.0453 223 | ], 224 | probs.as_slice(), 225 | 1e-2 226 | )); 227 | assert!(close(1.0f32, probs.iter().cloned().sum(), 1e-2)); 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /src/subcommands/skipgram.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::{BufReader, BufWriter}; 3 | use std::path::{Path, PathBuf}; 4 | use std::thread; 5 | use std::time::Duration; 6 | 7 | use anyhow::{Context, Result}; 8 | use finalfrontier::io::{thread_data_text, FileProgress, TrainInfo}; 9 | use finalfrontier::{ 10 | BucketIndexerType, CommonConfig, SentenceIterator, Sgd, SimpleVocab, SkipGramConfig, 11 | SkipgramTrainer, SubwordVocab, Vocab, VocabBuilder, VocabConfig, WriteModelBinary, 12 | }; 13 | use finalfusion::compat::fasttext::FastTextIndexer; 14 | use finalfusion::prelude::VocabWrap; 15 | use finalfusion::subword::FinalfusionHashIndexer; 16 | use rand::{Rng, SeedableRng}; 17 | use rand_xorshift::XorShiftRng; 18 | use serde::Serialize; 19 | 20 | use crate::subcommands::{show_progress, FinalfrontierApp}; 21 | 22 | const PROGRESS_UPDATE_INTERVAL: u64 = 200; 23 | 24 | /// Subcommand for training skipgram models. 25 | pub struct SkipgramApp { 26 | train_info: TrainInfo, 27 | common_config: CommonConfig, 28 | skipgram_config: SkipGramConfig, 29 | vocab_config: VocabConfig, 30 | } 31 | 32 | impl SkipgramApp { 33 | pub fn new( 34 | train_info: TrainInfo, 35 | common_config: CommonConfig, 36 | skipgram_config: SkipGramConfig, 37 | vocab_config: VocabConfig, 38 | ) -> Self { 39 | Self { 40 | train_info, 41 | common_config, 42 | skipgram_config, 43 | vocab_config, 44 | } 45 | } 46 | 47 | /// Get the corpus path. 48 | pub fn corpus(&self) -> &str { 49 | self.train_info.corpus() 50 | } 51 | 52 | /// Get the output path. 53 | pub fn output(&self) -> &str { 54 | self.train_info.output() 55 | } 56 | 57 | /// Get the number of threads. 58 | pub fn n_threads(&self) -> usize { 59 | self.train_info.n_threads() 60 | } 61 | 62 | /// Get the common config. 63 | pub fn common_config(&self) -> CommonConfig { 64 | self.common_config 65 | } 66 | 67 | /// Get the skipgram config. 68 | pub fn skipgram_config(&self) -> SkipGramConfig { 69 | self.skipgram_config 70 | } 71 | 72 | /// Get the vocab config. 73 | pub fn vocab_config(&self) -> VocabConfig { 74 | self.vocab_config 75 | } 76 | 77 | /// Get the train information. 78 | pub fn train_info(&self) -> &TrainInfo { 79 | &self.train_info 80 | } 81 | } 82 | 83 | impl FinalfrontierApp for SkipgramApp { 84 | fn run(&self) -> Result<()> { 85 | match self.vocab_config() { 86 | VocabConfig::SubwordVocab(config) => match config.indexer.indexer_type { 87 | BucketIndexerType::Finalfusion => { 88 | let vocab: SubwordVocab<_, FinalfusionHashIndexer> = 89 | build_vocab(config, self.corpus())?; 90 | train(vocab, self) 91 | } 92 | BucketIndexerType::FastText => { 93 | let vocab: SubwordVocab<_, FastTextIndexer> = 94 | build_vocab(config, self.corpus())?; 95 | train(vocab, self) 96 | } 97 | }, 98 | VocabConfig::FloretVocab(config) => { 99 | let vocab: SubwordVocab<_, _> = build_vocab(config, self.corpus())?; 100 | train(vocab, self) 101 | } 102 | VocabConfig::SimpleVocab(config) => { 103 | let vocab: SimpleVocab = build_vocab(config, self.corpus())?; 104 | train(vocab, self) 105 | } 106 | VocabConfig::NGramVocab(config) => { 107 | let vocab: SubwordVocab<_, _> = build_vocab(config, self.corpus())?; 108 | train(vocab, self) 109 | } 110 | } 111 | } 112 | } 113 | 114 | fn train(vocab: V, app: &SkipgramApp) -> Result<()> 115 | where 116 | V: Vocab + Into + Clone + Send + Sync + 'static, 117 | V::Config: Serialize, 118 | for<'a> &'a V::IdxType: IntoIterator, 119 | { 120 | let common_config = app.common_config(); 121 | let n_threads = app.n_threads(); 122 | let corpus = app.corpus(); 123 | let mut output_writer = 124 | BufWriter::new(File::create(app.output()).context("Cannot open output file for writing.")?); 125 | let trainer = SkipgramTrainer::new( 126 | vocab, 127 | XorShiftRng::from_entropy(), 128 | common_config, 129 | app.skipgram_config(), 130 | ); 131 | let sgd = Sgd::new(trainer.into()); 132 | 133 | let mut children = Vec::with_capacity(n_threads); 134 | for thread in 0..n_threads { 135 | let corpus = corpus.to_owned(); 136 | let sgd = sgd.clone(); 137 | 138 | children.push(thread::spawn(move || { 139 | do_work( 140 | corpus, 141 | sgd, 142 | thread, 143 | n_threads, 144 | common_config.epochs, 145 | common_config.lr, 146 | ) 147 | })); 148 | } 149 | 150 | show_progress( 151 | &common_config, 152 | &sgd, 153 | Duration::from_millis(PROGRESS_UPDATE_INTERVAL), 154 | ); 155 | 156 | // Wait until all threads have finished. 157 | for child in children { 158 | child.join().expect("Thread panicked")?; 159 | } 160 | 161 | sgd.into_model() 162 | .write_model_binary( 163 | &mut output_writer, 164 | app.train_info().clone(), 165 | app.common_config.format, 166 | ) 167 | .context("Cannot write model") 168 | } 169 | 170 | fn do_work( 171 | corpus_path: P, 172 | mut sgd: Sgd>, 173 | thread: usize, 174 | n_threads: usize, 175 | epochs: u32, 176 | start_lr: f32, 177 | ) -> Result<()> 178 | where 179 | P: Into, 180 | R: Clone + Rng, 181 | V: Vocab, 182 | V::Config: Serialize, 183 | for<'a> &'a V::IdxType: IntoIterator, 184 | { 185 | let n_tokens = sgd.model().input_vocab().n_types(); 186 | 187 | let f = File::open(corpus_path.into()).context("Cannot open corpus for reading")?; 188 | let (data, start) = 189 | thread_data_text(&f, thread, n_threads).context("Could not get thread-specific data")?; 190 | 191 | let mut sentences = SentenceIterator::new(&data[start..]); 192 | while sgd.n_tokens_processed() < epochs as usize * n_tokens { 193 | let sentence = if let Some(sentence) = sentences.next() { 194 | sentence 195 | } else { 196 | sentences = SentenceIterator::new(&*data); 197 | sentences 198 | .next() 199 | .context("Iterator does not provide sentences")? 200 | } 201 | .context("Cannot read sentence")?; 202 | 203 | let lr = (1.0 - (sgd.n_tokens_processed() as f32 / (epochs as usize * n_tokens) as f32)) 204 | * start_lr; 205 | 206 | sgd.update_sentence(&sentence, lr); 207 | } 208 | 209 | Ok(()) 210 | } 211 | 212 | fn build_vocab(config: C, corpus_path: P) -> Result 213 | where 214 | P: AsRef, 215 | V: Vocab + From>, 216 | VocabBuilder: Into, 217 | { 218 | let f = File::open(corpus_path).context("Cannot open corpus for reading")?; 219 | let file_progress = FileProgress::new(f).context("Cannot create progress bar")?; 220 | 221 | let sentences = SentenceIterator::new(BufReader::new(file_progress)); 222 | 223 | let mut builder = VocabBuilder::new(config); 224 | for sentence in sentences { 225 | let sentence = sentence.context("Cannot read sentence")?; 226 | 227 | for token in sentence { 228 | builder.count(token); 229 | } 230 | } 231 | 232 | Ok(builder.into()) 233 | } 234 | -------------------------------------------------------------------------------- /man/finalfrontier-skipgram.1: -------------------------------------------------------------------------------- 1 | .\" Automatically generated by Pandoc 2.7.3 2 | .\" 3 | .TH "FINALFRONTIER-SKIPGRAM" "1" "Sep 8, 2018" "" "" 4 | .hy 5 | .SH NAME 6 | .PP 7 | \f[B]finalfrontier skipgram\f[R] \[en] train word embeddings with 8 | subword representations 9 | .SH SYNOPSIS 10 | .PP 11 | \f[B]finalfrontier skipgram\f[R] [\f[I]options\f[R]] \f[I]corpus\f[R] 12 | \f[I]output\f[R] 13 | .SH DESCRIPTION 14 | .PP 15 | The \f[B]finalfrontier skipgram\f[R] subcommand trains word embeddings 16 | using data from a \f[I]corpus\f[R]. 17 | The corpus should have tokens separated by spaces and sentences 18 | separated by newlines. 19 | After training, the embeddings are written to \f[I]output\f[R] in the 20 | finalfusion format. 21 | .SH OPTIONS 22 | .TP 23 | .B \f[C]--buckets\f[R] \f[I]EXP\f[R] 24 | The bucket exponent. 25 | finalfrontier uses 2\[ha]\f[I]EXP\f[R] buckets to store subword 26 | representations. 27 | Each subword representation (n-gram) is hashed and mapped to a bucket 28 | based on this hash. 29 | Using more buckets will result in fewer bucket collisions between 30 | subword representations at the cost of memory use. 31 | The default bucket exponent is \f[I]21\f[R] (approximately 2 million 32 | buckets). 33 | .TP 34 | .B \f[C]--context\f[R] \f[I]CONTEXT_SIZE\f[R] 35 | Words within the \f[I]CONTEXT_SIZE\f[R] of a focus word will be used to 36 | learn the representation of the focus word. 37 | The default context size is \f[I]10\f[R]. 38 | .TP 39 | .B \f[C]--dims\f[R] \f[I]DIMENSIONS\f[R] 40 | The dimensionality of the trained word embeddings. 41 | The default dimensionality is 300. 42 | .TP 43 | .B \f[C]--discard\f[R] \f[I]THRESHOLD\f[R] 44 | The discard threshold influences how often frequent words are discarded 45 | from training. 46 | The default discard threshold is \f[I]1e-4\f[R]. 47 | .TP 48 | .B \f[C]--epochs\f[R] \f[I]N\f[R] 49 | The number of training epochs. 50 | The number of necessary training epochs typically decreases with the 51 | corpus size. 52 | The default number of epochs is \f[I]15\f[R]. 53 | .TP 54 | .B \f[C]-f\f[R], \f[C]--format\f[R] \f[I]FORMAT\f[R] 55 | The output format. 56 | This must be one of \f[I]fasttext\f[R], \f[I]finalfusion\f[R], 57 | \f[I]word2vec\f[R], \f[I]text\f[R], and \f[I]textdims\f[R]. 58 | .RS 59 | .PP 60 | All formats, except \f[I]finalfusion\f[R], result in a loss of 61 | information: \f[I]word2vec\f[R], \f[I]text\f[R], and \f[I]textdims\f[R] 62 | do not store subword embeddings, nor hyperparameters. 63 | The \f[I]fastText\f[R] format does not store all hyperparemeters. 64 | .PP 65 | The \f[I]fasttext\f[R] format can only be used in conjunction with 66 | \f[C]--subwords buckets\f[R] and \f[C]--hash-indexer fasttext\f[R]. 67 | .RE 68 | .TP 69 | .B \f[C]--hash-indexer\f[R] \f[I]INDEXER\f[R] 70 | The indexer to use when bucket-based subwords are used (see 71 | \f[C]--subwords\f[R]). 72 | The possible values are \f[I]finalfusion\f[R] or \f[I]fasttext\f[R]. 73 | Default: finalfusion 74 | .RS 75 | .PP 76 | \f[I]finalfusion\f[R] uses the FNV-1a hasher, whereas \f[I]fasttext\f[R] 77 | emulates the (broken) implementation of FNV-1a in fastText. 78 | Use of \f[I]finalfusion\f[R] is recommended, unless the resulting 79 | embeddings should be compatible with fastText. 80 | .RE 81 | .TP 82 | .B \f[C]--lr\f[R] \f[I]LEARNING_RATE\f[R] 83 | The learning rate determines what fraction of a gradient is used for 84 | parameter updates. 85 | The default initial learning rate is \f[I]0.05\f[R], the learning rate 86 | decreases monotonically during training. 87 | .TP 88 | .B \f[C]--maxn\f[R] \f[I]LEN\f[R] 89 | The maximum n-gram length for subword representations. 90 | Default: 6 91 | .TP 92 | .B \f[C]--mincount\f[R] \f[I]FREQ\f[R] 93 | The minimum count controls discarding of infrequent. 94 | Words occuring fewer than \f[I]FREQ\f[R] times are not considered during 95 | training. 96 | The default minimum count is 5. 97 | .TP 98 | .B \f[C]--minn\f[R] \f[I]LEN\f[R] 99 | The minimum n-gram length for subword representations. 100 | Default: 3 101 | .TP 102 | .B \f[C]--model\f[R] \f[I]MODEL\f[R] 103 | The model to use for training word embeddings. 104 | The choices here are: \f[I]dirgram\f[R] for the directional skip-gram 105 | model (Song et al., 2018), \f[I]skipgram\f[R] for the skip-gram model 106 | (Mikolov et al., 2013), and \f[I]structgram\f[R] for the stuctured 107 | skip-gram model (Ling et al.\ 2015). 108 | .RS 109 | .PP 110 | The structured skip-gram model takes the position of a context word into 111 | account and results in embeddings that are typically better suited for 112 | syntax-oriented tasks. 113 | .PP 114 | The dependency embeddings model is supported by the separate 115 | \f[C]finalfrontier deps\f[R](1) subcommand. 116 | .PP 117 | The default model is \f[I]skipgram\f[R]. 118 | .RE 119 | .TP 120 | .B \f[C]--ngram-mincount\f[R] \f[I]FREQ\f[R] 121 | The minimum n-gram frequency. 122 | n-grams occurring fewer than \f[I]FREQ\f[R] times are excluded from 123 | training. 124 | This option is only applicable with the \f[I]ngrams\f[R] argument of the 125 | \f[C]subwords\f[R] option. 126 | .TP 127 | .B \f[C]--ngram-target-size\f[R] \f[I]SIZE\f[R] 128 | The target size for the n-gram vocabulary. 129 | At most \f[I]SIZE\f[R] n-ngrams are included for training. 130 | Only n-grams appearing more frequently than the n-gram at \f[I]SIZE\f[R] 131 | are included. 132 | This option is only applicable with the \f[I]ngrams\f[R] argument of the 133 | \f[C]subwords\f[R] option. 134 | .TP 135 | .B \f[C]--ns\f[R] \f[I]FREQ\f[R] 136 | The number of negatives to sample per positive example. 137 | Default: 5 138 | .TP 139 | .B \f[C]--subwords\f[R] \f[I]SUBWORDS\f[R] 140 | The type of subword embeddings to train. 141 | The possible types are \f[I]buckets\f[R], \f[I]ngrams\f[R], and 142 | \f[I]none\f[R]. 143 | Subword embeddings are used to compute embeddings for unknown words by 144 | summing embeddings of n-grams within unknown words. 145 | .RS 146 | .PP 147 | The \f[I]none\f[R] type does not use subwords. 148 | The resulting model will not be able assign an embeddings to unknown 149 | words. 150 | .PP 151 | The \f[I]ngrams\f[R] type stores subword n-grams explicitly. 152 | The included n-gram lengths are specified using the \f[C]minn\f[R] and 153 | \f[C]maxn\f[R] options. 154 | The frequency threshold for n-grams is configured with the 155 | \f[C]ngram-mincount\f[R] option. 156 | .PP 157 | The \f[I]buckets\f[R] type maps n-grams to buckets using the FNV1 hash. 158 | The considered n-gram lengths are specified using the \f[C]minn\f[R] and 159 | \f[C]maxn\f[R] options. 160 | The number of buckets is controlled with the \f[C]buckets\f[R] option. 161 | .RE 162 | .TP 163 | .B \f[C]--target-size\f[R] \f[I]SIZE\f[R] 164 | The target size for the token vocabulary. 165 | At most \f[I]SIZE\f[R] tokens are included for training. 166 | Only tokens appearing more frequently than the token at \f[I]SIZE\f[R] 167 | are included. 168 | .TP 169 | .B \f[C]--threads\f[R] \f[I]N\f[R] 170 | The number of thread to use during training for parallelization. 171 | The default is to use half of the logical CPUs of the machine, capped at 172 | 20 threads. 173 | Increasing the number of threads increases the probability of update 174 | collisions, requiring more epochs to reach the same loss. 175 | .TP 176 | .B \f[C]--zipf\f[R] \f[I]EXP\f[R] 177 | Exponent \f[I]s\f[R] used in the Zipf distribution 178 | \f[C]p(k) = 1 / (k\[ha]s H_N)\f[R] for negative sampling. 179 | Default: 0.5 180 | .SH EXAMPLES 181 | .PP 182 | Train embeddings on \f[I]dewiki.txt\f[R] using the skip-gram model: 183 | .IP 184 | .nf 185 | \f[C] 186 | finalfrontier skipgram dewiki.txt dewiki-skipgram.bin 187 | \f[R] 188 | .fi 189 | .PP 190 | Train embeddings with dimensionality 200 on \f[I]dewiki.txt\f[R] using 191 | the structured skip-gram model with a context window of 5 tokens: 192 | .IP 193 | .nf 194 | \f[C] 195 | finalfrontier skipgram --model structgram --context 5 --dims 200 \[rs] 196 | dewiki.txt dewiki-structgram.bin 197 | \f[R] 198 | .fi 199 | .SH SEE ALSO 200 | .PP 201 | \f[C]finalfrontier\f[R](1), \f[C]finalfrontier-deps\f[R](1) 202 | .SH AUTHORS 203 | Daniel de Kok. 204 | -------------------------------------------------------------------------------- /src/sgd.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{Array1, ArrayView1, ArrayViewMut1}; 2 | 3 | use crate::hogwild::Hogwild; 4 | use crate::idx::WordIdx; 5 | use crate::loss::log_logistic_loss; 6 | use crate::train_model::{NegativeSamples, TrainIterFrom, TrainModel, Trainer}; 7 | use crate::vec_simd::scaled_add; 8 | 9 | /// Stochastic gradient descent 10 | /// 11 | /// This data type applies stochastic gradient descent on sentences. 12 | #[derive(Clone)] 13 | pub struct Sgd { 14 | loss: Hogwild, 15 | model: TrainModel, 16 | n_examples: Hogwild, 17 | n_tokens_processed: Hogwild, 18 | sgd_impl: NegativeSamplingSgd, 19 | } 20 | 21 | impl Sgd 22 | where 23 | T: Trainer, 24 | { 25 | pub fn into_model(self) -> TrainModel { 26 | self.model 27 | } 28 | 29 | /// Construct a new SGD instance, 30 | pub fn new(model: TrainModel) -> Self { 31 | let sgd_impl = NegativeSamplingSgd::new(model.config().negative_samples as usize); 32 | 33 | Sgd { 34 | loss: Hogwild::default(), 35 | model, 36 | n_examples: Hogwild::default(), 37 | n_tokens_processed: Hogwild::default(), 38 | sgd_impl, 39 | } 40 | } 41 | /// Get the training model associated with this SGD. 42 | pub fn model(&self) -> &TrainModel { 43 | &self.model 44 | } 45 | 46 | /// Get the number of tokens that are processed by this SGD. 47 | pub fn n_tokens_processed(&self) -> usize { 48 | *self.n_tokens_processed 49 | } 50 | 51 | /// Get the average training loss of this SGD. 52 | /// 53 | /// This returns the average training loss over all instances seen by 54 | /// this SGD instance since its construction. 55 | pub fn train_loss(&self) -> f32 { 56 | *self.loss / *self.n_examples as f32 57 | } 58 | 59 | /// Update the model parameters using the given sentence. 60 | /// 61 | /// This applies a gradient descent step on the sentence, with the given 62 | /// learning rate. 63 | pub fn update_sentence<'b, S>(&mut self, sentence: &S, lr: f32) 64 | where 65 | S: ?Sized, 66 | T: TrainIterFrom<'b, S> + Trainer + NegativeSamples, 67 | for<'a> &'a T::Focus: IntoIterator, 68 | T::Focus: WordIdx, 69 | { 70 | for (focus, contexts) in self.model.trainer().train_iter_from(sentence) { 71 | // Update parameters for the token focus token i and the 72 | // context token j. 73 | let input_embed = self.model.mean_input_embedding(&focus); 74 | 75 | for context in contexts { 76 | *self.loss += self.sgd_impl.sgd_step( 77 | &mut self.model, 78 | (&focus).into_iter(), 79 | input_embed.view(), 80 | context, 81 | lr, 82 | ); 83 | *self.n_examples += 1; 84 | } 85 | *self.n_tokens_processed += 1; 86 | } 87 | } 88 | } 89 | 90 | /// Log-logistic loss SGD with negative sampling. 91 | /// 92 | /// This type implements gradient descent for log-logistic loss with negative 93 | /// sampling (Mikolov, 2013). 94 | /// 95 | /// In this approach, word embeddings training is shaped as a 96 | /// prediction task. The word vectors should be parametrized such that 97 | /// words that co-occur with a given input get an estimated probability 98 | /// of 1.0, whereas words that do not co-occur with the input get an 99 | /// estimated probability of 0.0. 100 | /// 101 | /// The probability is computed from the inner product of two word 102 | /// vectors by applying the logistic function. The loss is the negative 103 | /// log likelihood. 104 | /// 105 | /// Due to the vocabulary sizes, it is not possible to update the vectors 106 | /// for all words that do not co-occur in every step. Instead, such 107 | /// negatives are sampled, weighted by word frequency. 108 | #[derive(Clone)] 109 | pub struct NegativeSamplingSgd { 110 | negative_samples: usize, 111 | } 112 | 113 | impl NegativeSamplingSgd { 114 | /// Create a new loss function. 115 | pub fn new(negative_samples: usize) -> Self { 116 | NegativeSamplingSgd { negative_samples } 117 | } 118 | 119 | /// Perform a step of gradient descent. 120 | /// 121 | /// This method will estimate the probability of `output` and randomly 122 | /// chosen negative samples, given the input. It will then update the 123 | /// embeddings of the positive/negative outputs and the input (and its 124 | /// subwords). 125 | /// 126 | /// The function returns the sum of losses. 127 | pub fn sgd_step( 128 | &mut self, 129 | model: &mut TrainModel, 130 | input: impl IntoIterator, 131 | input_embed: ArrayView1, 132 | output: usize, 133 | lr: f32, 134 | ) -> f32 135 | where 136 | T: NegativeSamples, 137 | { 138 | let mut loss = 0.0; 139 | let mut input_delta = Array1::zeros(input_embed.shape()[0]); 140 | 141 | // Update the output embedding of the positive instance. 142 | loss += self.update_output( 143 | model, 144 | input_embed.view(), 145 | input_delta.view_mut(), 146 | output, 147 | true, 148 | lr, 149 | ); 150 | 151 | // Pick the negative examples and update their output embeddings. 152 | loss += self.negative_samples(model, input_embed, input_delta.view_mut(), output, lr); 153 | 154 | // Update the input embeddings with the accumulated gradient. 155 | for idx in input { 156 | let input_embed = model.input_embedding_mut(idx as usize); 157 | scaled_add(input_embed, input_delta.view(), 1.0); 158 | } 159 | 160 | loss 161 | } 162 | 163 | /// Pick, predict and update negative samples. 164 | fn negative_samples( 165 | &mut self, 166 | model: &mut TrainModel, 167 | input_embed: ArrayView1, 168 | mut input_delta: ArrayViewMut1, 169 | output: usize, 170 | lr: f32, 171 | ) -> f32 172 | where 173 | T: NegativeSamples, 174 | { 175 | let mut loss = 0f32; 176 | 177 | for _ in 0..self.negative_samples { 178 | let negative = model.trainer().negative_sample(output); 179 | // Update input and output for this negative sample. 180 | loss += self.update_output( 181 | model, 182 | input_embed.view(), 183 | input_delta.view_mut(), 184 | negative, 185 | false, 186 | lr, 187 | ); 188 | } 189 | 190 | loss 191 | } 192 | 193 | /// Update an output embedding. 194 | /// 195 | /// This also accumulates an update for the input embedding. 196 | /// 197 | /// The method returns the loss for predicting the output. 198 | fn update_output( 199 | &mut self, 200 | model: &mut TrainModel, 201 | input_embed: ArrayView1, 202 | input_delta: ArrayViewMut1, 203 | output: usize, 204 | label: bool, 205 | lr: f32, 206 | ) -> f32 { 207 | let (loss, part_gradient) = 208 | log_logistic_loss(input_embed.view(), model.output_embedding(output), label); 209 | 210 | // Update the input weight: u_n += lr * u_n' v_n. We are not updating 211 | // the weight immediately, but accumulating the weight updates in 212 | // input_delta. 213 | scaled_add( 214 | input_delta, 215 | model.output_embedding(output), 216 | lr * part_gradient, 217 | ); 218 | 219 | // Update the output weight: v_n += lr * v_n' u_n. 220 | scaled_add( 221 | model.output_embedding_mut(output), 222 | input_embed.view(), 223 | lr * part_gradient, 224 | ); 225 | 226 | loss 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /src/skipgram_trainer.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Borrow; 2 | use std::hash::Hash; 3 | use std::iter::FusedIterator; 4 | use std::sync::Arc; 5 | use std::{cmp, mem}; 6 | 7 | use anyhow::{bail, Result}; 8 | use rand::{Rng, SeedableRng}; 9 | use serde::Serialize; 10 | 11 | use crate::idx::WordIdx; 12 | use crate::sampling::{BandedRangeGenerator, DiscountFrequencyRangeGenerator}; 13 | use crate::train_model::{NegativeSamples, TrainIterFrom, Trainer}; 14 | use crate::util::ReseedOnCloneRng; 15 | use crate::{CommonConfig, SkipGramConfig, SkipgramModelType, Vocab}; 16 | 17 | /// Skipgram Trainer 18 | /// 19 | /// The `SkipgramTrainer` holds the information and logic necessary to transform a tokenized 20 | /// sentence into an iterator of focus and context tuples. The struct is cheap to clone because 21 | /// the vocabulary is shared between clones. 22 | #[derive(Clone)] 23 | pub struct SkipgramTrainer { 24 | vocab: Arc, 25 | rng: R, 26 | range_gen: BandedRangeGenerator>, 27 | common_config: CommonConfig, 28 | skipgram_config: SkipGramConfig, 29 | } 30 | 31 | impl SkipgramTrainer, V> 32 | where 33 | R: Rng + Clone + SeedableRng, 34 | V: Vocab, 35 | { 36 | /// Constructs a new `SkipgramTrainer`. 37 | pub fn new( 38 | vocab: V, 39 | rng: R, 40 | common_config: CommonConfig, 41 | skipgram_config: SkipGramConfig, 42 | ) -> Self { 43 | let vocab = Arc::new(vocab); 44 | let rng = ReseedOnCloneRng(rng); 45 | let band_size = match skipgram_config.model { 46 | SkipgramModelType::SkipGram => 1, 47 | SkipgramModelType::StructuredSkipGram => skipgram_config.context_size * 2, 48 | SkipgramModelType::DirectionalSkipgram => 2, 49 | }; 50 | 51 | let range_gen = BandedRangeGenerator::new( 52 | rng.clone(), 53 | DiscountFrequencyRangeGenerator::new_with_default_table_size( 54 | rng.clone(), 55 | vocab.types(), 56 | ), 57 | band_size as usize, 58 | ); 59 | SkipgramTrainer { 60 | vocab, 61 | rng, 62 | range_gen, 63 | common_config, 64 | skipgram_config, 65 | } 66 | } 67 | } 68 | 69 | impl<'a, S, R, V, I> TrainIterFrom<'a, [S]> for SkipgramTrainer 70 | where 71 | S: Hash + Eq, 72 | R: Rng + Clone, 73 | V: Vocab, 74 | V::VocabType: Borrow, 75 | I: WordIdx, 76 | { 77 | type Iter = SkipGramIter; 78 | type Focus = I; 79 | type Contexts = Vec; 80 | 81 | fn train_iter_from(&mut self, sequence: &[S]) -> Self::Iter { 82 | let mut ids = Vec::new(); 83 | for t in sequence { 84 | if let Some(idx) = self.vocab.idx(t) { 85 | if self.rng.gen_range(0f32..1f32) < self.vocab.discard(idx.word_idx() as usize) { 86 | ids.push(idx); 87 | } 88 | } 89 | } 90 | SkipGramIter::new(self.rng.clone(), ids, self.skipgram_config) 91 | } 92 | } 93 | 94 | impl NegativeSamples for SkipgramTrainer 95 | where 96 | R: Rng, 97 | { 98 | fn negative_sample(&mut self, output: usize) -> usize { 99 | loop { 100 | let negative = self.range_gen.next().unwrap(); 101 | if negative != output { 102 | return negative; 103 | } 104 | } 105 | } 106 | } 107 | 108 | impl Trainer for SkipgramTrainer 109 | where 110 | R: Rng + Clone, 111 | V: Vocab, 112 | V::Config: Serialize, 113 | { 114 | type InputVocab = V; 115 | type Metadata = SkipgramMetadata; 116 | 117 | fn input_vocab(&self) -> &V { 118 | &self.vocab 119 | } 120 | 121 | fn try_into_input_vocab(self) -> Result { 122 | match Arc::try_unwrap(self.vocab) { 123 | Ok(vocab) => Ok(vocab), 124 | Err(_) => bail!("Cannot unwrap input vocab."), 125 | } 126 | } 127 | 128 | fn n_input_types(&self) -> usize { 129 | self.input_vocab().n_input_types() 130 | } 131 | 132 | fn n_output_types(&self) -> usize { 133 | match self.skipgram_config.model { 134 | SkipgramModelType::StructuredSkipGram => { 135 | self.vocab.len() * 2 * self.skipgram_config.context_size as usize 136 | } 137 | SkipgramModelType::SkipGram => self.vocab.len(), 138 | SkipgramModelType::DirectionalSkipgram => self.vocab.len() * 2, 139 | } 140 | } 141 | 142 | fn config(&self) -> &CommonConfig { 143 | &self.common_config 144 | } 145 | 146 | fn to_metadata(&self) -> SkipgramMetadata { 147 | SkipgramMetadata { 148 | common_config: self.common_config, 149 | skipgram_config: self.skipgram_config, 150 | vocab_config: self.vocab.config(), 151 | } 152 | } 153 | } 154 | 155 | /// Iterator over focus identifier and associated context identifiers in a sentence. 156 | pub struct SkipGramIter { 157 | ids: Vec, 158 | rng: R, 159 | i: usize, 160 | model_type: SkipgramModelType, 161 | ctx_size: usize, 162 | } 163 | 164 | impl SkipGramIter 165 | where 166 | R: Rng + Clone, 167 | I: WordIdx, 168 | { 169 | /// Constructs a new `SkipGramIter`. 170 | /// 171 | /// The `rng` is used to determine the window size for each focus token. 172 | pub fn new(rng: R, ids: Vec, skip_config: SkipGramConfig) -> Self { 173 | SkipGramIter { 174 | ids, 175 | rng, 176 | i: 0, 177 | model_type: skip_config.model, 178 | ctx_size: skip_config.context_size as usize, 179 | } 180 | } 181 | 182 | fn output_(&self, token: usize, focus_idx: usize, offset_idx: usize) -> usize { 183 | match self.model_type { 184 | SkipgramModelType::StructuredSkipGram => { 185 | let offset = if offset_idx < focus_idx { 186 | (offset_idx + self.ctx_size) - focus_idx 187 | } else { 188 | (offset_idx - focus_idx - 1) + self.ctx_size 189 | }; 190 | 191 | (token * self.ctx_size * 2) + offset 192 | } 193 | SkipgramModelType::SkipGram => token, 194 | SkipgramModelType::DirectionalSkipgram => { 195 | let offset = if offset_idx < focus_idx { 0 } else { 1 }; 196 | 197 | (token * 2) + offset 198 | } 199 | } 200 | } 201 | } 202 | 203 | impl Iterator for SkipGramIter 204 | where 205 | R: Rng + Clone, 206 | I: WordIdx, 207 | { 208 | type Item = (I, Vec); 209 | 210 | fn next(&mut self) -> Option { 211 | if self.i < self.ids.len() { 212 | // Bojanowski, et al., 2017 uniformly sample the context size between 1 and c. 213 | let context_size = self.rng.gen_range(1..self.ctx_size + 1); 214 | let left = self.i - cmp::min(self.i, context_size); 215 | let right = cmp::min(self.i + context_size + 1, self.ids.len()); 216 | let contexts = (left..right) 217 | .filter(|&idx| idx != self.i) 218 | .map(|idx| self.output_(self.ids[idx].word_idx() as usize, self.i, idx)) 219 | .fold(Vec::with_capacity(right - left), |mut contexts, idx| { 220 | contexts.push(idx); 221 | contexts 222 | }); 223 | 224 | // swap the representation possibly containing multiple indices with one that only 225 | // contains the distinct word index since we need the word index for context lookups. 226 | let mut word_idx = WordIdx::from_word_idx(self.ids[self.i].word_idx()); 227 | mem::swap(&mut self.ids[self.i], &mut word_idx); 228 | self.i += 1; 229 | return Some((word_idx, contexts)); 230 | } 231 | None 232 | } 233 | } 234 | 235 | impl FusedIterator for SkipGramIter 236 | where 237 | R: Rng + Clone, 238 | I: WordIdx, 239 | { 240 | } 241 | 242 | /// Metadata for Skipgramlike training algorithms. 243 | #[derive(Clone, Copy, Debug, Serialize)] 244 | pub struct SkipgramMetadata { 245 | common_config: CommonConfig, 246 | #[serde(rename = "model_config")] 247 | skipgram_config: SkipGramConfig, 248 | vocab_config: V, 249 | } 250 | -------------------------------------------------------------------------------- /src/config.rs: -------------------------------------------------------------------------------- 1 | use std::convert::TryFrom; 2 | use std::str::FromStr; 3 | 4 | use anyhow::{bail, Error, Result}; 5 | use serde::Serialize; 6 | 7 | use crate::io::EmbeddingFormat; 8 | use crate::vocab::Cutoff; 9 | 10 | /// Model types. 11 | #[derive(Copy, Clone, Debug, Serialize)] 12 | pub enum SkipgramModelType { 13 | // The skip-gram model (Mikolov, 2013). 14 | SkipGram, 15 | 16 | // The structured skip-gram model (Ling et al., 2015). 17 | StructuredSkipGram, 18 | 19 | // The directional skip-gram model (Song et al., 2018). 20 | DirectionalSkipgram, 21 | } 22 | 23 | impl TryFrom for SkipgramModelType { 24 | type Error = Error; 25 | 26 | fn try_from(model: u8) -> Result { 27 | match model { 28 | 0 => Ok(SkipgramModelType::SkipGram), 29 | 1 => Ok(SkipgramModelType::StructuredSkipGram), 30 | 2 => Ok(SkipgramModelType::DirectionalSkipgram), 31 | _ => bail!("Unknown model type: {}", model), 32 | } 33 | } 34 | } 35 | 36 | impl TryFrom<&str> for SkipgramModelType { 37 | type Error = Error; 38 | 39 | fn try_from(model: &str) -> Result { 40 | match model { 41 | "skipgram" => Ok(SkipgramModelType::SkipGram), 42 | "structgram" => Ok(SkipgramModelType::StructuredSkipGram), 43 | "dirgram" => Ok(SkipgramModelType::DirectionalSkipgram), 44 | _ => bail!("Unknown model type: {}", model), 45 | } 46 | } 47 | } 48 | 49 | impl FromStr for SkipgramModelType { 50 | type Err = Error; 51 | 52 | fn from_str(s: &str) -> Result { 53 | match s { 54 | "skipgram" => Ok(SkipgramModelType::SkipGram), 55 | "structgram" => Ok(SkipgramModelType::StructuredSkipGram), 56 | "dirgram" => Ok(SkipgramModelType::DirectionalSkipgram), 57 | _ => bail!("Unknown model type: {}", s), 58 | } 59 | } 60 | } 61 | 62 | /// Losses. 63 | #[derive(Copy, Clone, Debug, Serialize)] 64 | pub enum LossType { 65 | /// Logistic regression with negative sampling. 66 | LogisticNegativeSampling, 67 | } 68 | 69 | impl TryFrom for LossType { 70 | type Error = Error; 71 | 72 | fn try_from(model: u8) -> Result { 73 | match model { 74 | 0 => Ok(LossType::LogisticNegativeSampling), 75 | _ => bail!("Unknown model type: {}", model), 76 | } 77 | } 78 | } 79 | 80 | /// Bucket Indexer Types 81 | #[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize)] 82 | pub enum BucketIndexerType { 83 | /// FinalfusionBucketIndexer 84 | Finalfusion, 85 | /// FastTextIndexer 86 | FastText, 87 | } 88 | 89 | /// Common embedding model hyperparameters. 90 | #[derive(Clone, Copy, Debug, Serialize)] 91 | pub struct CommonConfig { 92 | /// The loss function used for the model. 93 | pub loss: LossType, 94 | 95 | /// Word embedding dimensionality. 96 | pub dims: u32, 97 | 98 | /// The number of training epochs. 99 | pub epochs: u32, 100 | 101 | /// The output format. 102 | #[serde(skip)] 103 | pub format: EmbeddingFormat, 104 | 105 | /// Number of negative samples to use for each context word. 106 | pub negative_samples: u32, 107 | 108 | /// The initial learning rate. 109 | pub lr: f32, 110 | } 111 | 112 | /// Hyperparameters for Dependency Embeddings. 113 | #[derive(Clone, Copy, Debug, Serialize)] 114 | #[serde(tag = "type")] 115 | #[serde(rename = "Depembeds")] 116 | pub struct DepembedsConfig { 117 | /// Maximum depth to extract dependency contexts from. 118 | pub depth: u32, 119 | 120 | /// Include the ROOT as dependency context. 121 | pub use_root: bool, 122 | 123 | /// Lowercase all tokens when used as context. 124 | pub normalize: bool, 125 | 126 | /// Projectivize dependency graphs before training. 127 | pub projectivize: bool, 128 | 129 | /// Extract untyped dependency contexts. 130 | /// 131 | /// Only takes the attached word-form into account. 132 | pub untyped: bool, 133 | } 134 | 135 | /// Hyperparameters for Subword vocabs. 136 | #[derive(Clone, Copy, Debug, Serialize)] 137 | #[serde(rename = "SubwordVocab")] 138 | #[serde(tag = "type")] 139 | pub struct SubwordVocabConfig { 140 | /// Token cutoff. 141 | /// 142 | /// No word-specific embeddings will be trained for tokens excluded by the 143 | /// cutoff. 144 | pub cutoff: Cutoff, 145 | 146 | /// Discard threshold. 147 | /// 148 | /// The discard threshold is used to compute the discard probability of 149 | /// a token. E.g. with a threshold of 0.00001 tokens with approximately 150 | /// that probability will never be discarded. 151 | pub discard_threshold: f32, 152 | 153 | /// Minimum n-gram length for subword units (inclusive). 154 | pub min_n: u32, 155 | 156 | /// Maximum n-gram length for subword units (inclusive). 157 | pub max_n: u32, 158 | 159 | /// Indexer specific parameters. 160 | pub indexer: V, 161 | } 162 | 163 | /// Hyperparameters for bucket-vocabs. 164 | #[derive(Clone, Copy, Debug, Serialize)] 165 | #[serde(rename = "Buckets")] 166 | #[serde(tag = "type")] 167 | pub struct BucketConfig { 168 | /// Bucket exponent. The model will use 2^bucket_exp buckets. 169 | /// 170 | /// A typical value for this parameter is 21, which gives roughly 2M 171 | /// buckets. 172 | pub buckets_exp: u32, 173 | 174 | pub indexer_type: BucketIndexerType, 175 | } 176 | 177 | /// Hyperparameters for Floret-vocabs. 178 | #[derive(Clone, Copy, Debug, Serialize)] 179 | #[serde(rename = "Floret")] 180 | #[serde(tag = "type")] 181 | pub struct FloretConfig { 182 | /// Number of buckets. 183 | pub buckets: u64, 184 | 185 | /// Number of hashes. 186 | pub n_hashes: u32, 187 | 188 | /// Seed. 189 | pub seed: u32, 190 | } 191 | 192 | /// Hyperparameters for ngram-vocabs. 193 | #[derive(Clone, Copy, Debug, Serialize)] 194 | #[serde(rename = "NGrams")] 195 | #[serde(tag = "type")] 196 | pub struct NGramConfig { 197 | /// NGram cutoff. 198 | /// 199 | /// NGrams excluded by the cutoff will be ignored during training. 200 | pub cutoff: Cutoff, 201 | } 202 | 203 | /// Hyperparameters for simple vocabs. 204 | #[derive(Clone, Copy, Debug, Serialize)] 205 | #[serde(rename = "SimpleVocab")] 206 | #[serde(tag = "type")] 207 | pub struct SimpleVocabConfig { 208 | /// Token cutoff. 209 | /// 210 | /// No word-specific embeddings will be trained for tokens excluded by the 211 | /// cutoff. 212 | pub cutoff: Cutoff, 213 | 214 | /// Discard threshold. 215 | /// 216 | /// The discard threshold is used to compute the discard probability of 217 | /// a token. E.g. with a threshold of 0.00001 tokens with approximately 218 | /// that probability will never be discarded. 219 | pub discard_threshold: f32, 220 | } 221 | 222 | /// Hyperparameters for SkipGram-like models. 223 | #[derive(Clone, Copy, Debug, Serialize)] 224 | #[serde(tag = "type")] 225 | #[serde(rename = "SkipGramLike")] 226 | pub struct SkipGramConfig { 227 | /// The model type. 228 | pub model: SkipgramModelType, 229 | 230 | /// The number of preceding and succeeding tokens that will be consider 231 | /// as context during training. 232 | /// 233 | /// For example, a context size of 5 will consider the 5 tokens preceding 234 | /// and the 5 tokens succeeding the focus token. 235 | pub context_size: u32, 236 | } 237 | 238 | #[derive(Copy, Clone)] 239 | pub enum VocabConfig { 240 | FloretVocab(SubwordVocabConfig), 241 | NGramVocab(SubwordVocabConfig), 242 | SimpleVocab(SimpleVocabConfig), 243 | SubwordVocab(SubwordVocabConfig), 244 | } 245 | 246 | impl VocabConfig { 247 | pub fn cutoff(&self) -> Cutoff { 248 | match self { 249 | VocabConfig::FloretVocab(config) => config.cutoff, 250 | VocabConfig::NGramVocab(config) => config.cutoff, 251 | VocabConfig::SimpleVocab(config) => config.cutoff, 252 | VocabConfig::SubwordVocab(config) => config.cutoff, 253 | } 254 | } 255 | pub fn discard_threshold(&self) -> f32 { 256 | match self { 257 | VocabConfig::FloretVocab(config) => config.discard_threshold, 258 | VocabConfig::NGramVocab(config) => config.discard_threshold, 259 | VocabConfig::SimpleVocab(config) => config.discard_threshold, 260 | VocabConfig::SubwordVocab(config) => config.discard_threshold, 261 | } 262 | } 263 | } 264 | -------------------------------------------------------------------------------- /man/finalfrontier-deps.1: -------------------------------------------------------------------------------- 1 | .\" Automatically generated by Pandoc 2.7.3 2 | .\" 3 | .TH "FINALFRONTIER-DEPS" "1" "" "% Daniel de Kok, Sebastian P\[:u]tz % Apr 6, 2019" "" 4 | .hy 5 | .SH NAME 6 | .PP 7 | \f[B]finalfrontier deps\f[R] \[en] train dependency-based word 8 | embeddings with subword representations 9 | .SH SYNOPSIS 10 | .PP 11 | \f[B]finalfrontier deps\f[R] [\f[I]options\f[R]] \f[I]corpus\f[R] 12 | \f[I]output\f[R] 13 | .SH DESCRIPTION 14 | .PP 15 | The \f[B]finalfrontier-deps\f[R] subcommand trains dependency based word 16 | embeddings (Levy and Goldberg, 2014) using data from a \f[I]corpus\f[R] 17 | in CONLL-U format. 18 | The corpus contains sentences seperated by empty lines. 19 | Each sentence needs to be annotated with a dependency graph. 20 | After training, the embeddings are written to \f[I]output\f[R] in the 21 | finalfusion format. 22 | .SH OPTIONS 23 | .TP 24 | .B \f[C]--buckets\f[R] \f[I]EXP\f[R] 25 | The bucket exponent. 26 | finalfrontier uses 2\[ha]\f[I]EXP\f[R] buckets to store subword 27 | representations. 28 | Each subword representation (n-gram) is hashed and mapped to a bucket 29 | based on this hash. 30 | Using more buckets will result in fewer bucket collisions between 31 | subword representations at the cost of memory use. 32 | The default bucket exponent is \f[I]21\f[R] (approximately 2 million 33 | buckets). 34 | .TP 35 | .B \f[C]--context-discard\f[R] \f[I]THRESHOLD\f[R] 36 | The context discard threshold influences how often frequent contexts are 37 | discarded during training. 38 | The default context discard threshold is \f[I]1e-4\f[R]. 39 | .TP 40 | .B \f[C]--context-mincount\f[R] \f[I]FREQ\f[R] 41 | The minimum count controls discarding of infrequent contexts. 42 | Contexts occuring fewer than \f[I]FREQ\f[R] times are not considered 43 | during training. 44 | The default minimum count is 5. 45 | .TP 46 | .B \f[C]--context-target-size\f[R] \f[I]SIZE\f[R] 47 | The target size for the context vocabulary. 48 | At most \f[I]SIZE\f[R] contexts are included for training. 49 | Only contexts appearing more frequently than the context at 50 | \f[I]SIZE\f[R] are included. 51 | .TP 52 | .B \f[C]--dependency-depth\f[R] \f[I]DEPTH\f[R] 53 | Dependency contexts up to \f[I]DEPTH\f[R] distance from the focus word 54 | in the dependency graph will be used to learn the representation of the 55 | focus word. 56 | The default depth is \f[I]1\f[R]. 57 | .TP 58 | .B \f[C]--dims\f[R] \f[I]DIMS\f[R] 59 | The dimensionality of the trained word embeddings. 60 | The default dimensionality is 300. 61 | .TP 62 | .B \f[C]--discard\f[R] \f[I]THRESHOLD\f[R] 63 | The discard threshold influences how often frequent focus words are 64 | discarded from training. 65 | The default discard threshold is \f[I]1e-4\f[R]. 66 | .TP 67 | .B \f[C]--epochs\f[R] \f[I]N\f[R] 68 | The number of training epochs. 69 | The number of necessary training epochs typically decreases with the 70 | corpus size. 71 | The default number of epochs is \f[I]15\f[R]. 72 | .TP 73 | .B \f[C]--hash-indexer\f[R] \f[I]INDEXER\f[R] 74 | The indexer to use when bucket-based subwords are used (see 75 | \f[C]--subwords\f[R]). 76 | The possible values are \f[I]finalfusion\f[R] or \f[I]fasttext\f[R]. 77 | Default: finalfusion 78 | .RS 79 | .PP 80 | \f[I]finalfusion\f[R] uses the FNV-1a hasher, whereas \f[I]fasttext\f[R] 81 | emulates the (broken) implementation of FNV-1a in fastText. 82 | Use of \f[I]finalfusion\f[R] is recommended, unless the resulting 83 | embeddings should be compatible with fastText. 84 | .RE 85 | .TP 86 | .B \f[C]-f\f[R], \f[C]--format\f[R] \f[I]FORMAT\f[R] 87 | The output format. 88 | This must be one of \f[I]fasttext\f[R], \f[I]finalfusion\f[R], 89 | \f[I]word2vec\f[R], \f[I]text\f[R], and \f[I]textdims\f[R]. 90 | .RS 91 | .PP 92 | All formats, except \f[I]finalfusion\f[R], result in a loss of 93 | information: \f[I]word2vec\f[R], \f[I]text\f[R], and \f[I]textdims\f[R] 94 | do not store subword embeddings, nor hyperparameters. 95 | The \f[I]fastText\f[R] format does not store all hyperparemeters. 96 | .PP 97 | The \f[I]fasttext\f[R] format can only be used in conjunction with 98 | \f[C]--subwords buckets\f[R] and \f[C]--hash-indexer fasttext\f[R]. 99 | .RE 100 | .TP 101 | .B \f[C]--lr\f[R] \f[I]LEARNING_RATE\f[R] 102 | The learning rate determines what fraction of a gradient is used for 103 | parameter updates. 104 | The default initial learning rate is \f[I]0.05\f[R], the learning rate 105 | decreases monotonically during training. 106 | .TP 107 | .B \f[C]--maxn\f[R] \f[I]LEN\f[R] 108 | The maximum n-gram length for subword representations. 109 | Default: 6 110 | .TP 111 | .B \f[C]--mincount\f[R] \f[I]FREQ\f[R] 112 | The minimum count controls discarding of infrequent focus words. 113 | Focus words occuring fewer than \f[I]FREQ\f[R] times are not considered 114 | during training. 115 | The default minimum count is 5. 116 | .TP 117 | .B \f[C]--minn\f[R] \f[I]LEN\f[R] 118 | The minimum n-gram length for subword representations. 119 | Default: 3 120 | .TP 121 | .B \f[C]--ngram-mincount\f[R] \f[I]FREQ\f[R] 122 | The minimum n-gram frequency. 123 | n-grams occurring fewer than \f[I]FREQ\f[R] times are excluded from 124 | training. 125 | This option is only applicable with the \f[I]ngrams\f[R] argument of the 126 | \f[C]subwords\f[R] option. 127 | .TP 128 | .B \f[C]--ngram-target-size\f[R] \f[I]SIZE\f[R] 129 | The target size for the n-gram vocabulary. 130 | At most \f[I]SIZE\f[R] n-ngrams are included for training. 131 | Only n-grams appearing more frequently than the n-gram at \f[I]SIZE\f[R] 132 | are included. 133 | This option is only applicable with the \f[I]ngrams\f[R] argument of the 134 | \f[C]subwords\f[R] option. 135 | .TP 136 | .B \f[C]--normalize-contexts\f[R] 137 | Normalize the attached form in the dependency contexts. 138 | .TP 139 | .B \f[C]--ns\f[R] \f[I]FREQ\f[R] 140 | The number of negatives to sample per positive example. 141 | Default: 5 142 | .TP 143 | .B \f[C]--projectivize\f[R] 144 | Projectivize dependency graphs before training embeddings. 145 | .TP 146 | .B \f[C]--threads\f[R] \f[I]N\f[R] 147 | The number of thread to use during training for parallelization. 148 | The default is to use half of the logical CPUs of the machine, capped at 149 | 20 threads. 150 | Increasing the number of threads increases the probability of update 151 | collisions, requiring more epochs to reach the same loss. 152 | .TP 153 | .B \f[C]--subwords\f[R] \f[I]SUBWORDS\f[R] 154 | The type of subword embeddings to train. 155 | The possible types are \f[I]buckets\f[R], \f[I]ngrams\f[R], and 156 | \f[I]none\f[R]. 157 | Subword embeddings are used to compute embeddings for unknown words by 158 | summing embeddings of n-grams within unknown words. 159 | .RS 160 | .PP 161 | The \f[I]none\f[R] type does not use subwords. 162 | The resulting model will not be able assign an embeddings to unknown 163 | words. 164 | .PP 165 | The \f[I]ngrams\f[R] type stores subword n-grams explicitly. 166 | The included n-gram lengths are specified using the \f[C]minn\f[R] and 167 | \f[C]maxn\f[R] options. 168 | The frequency threshold for n-grams is configured with the 169 | \f[C]ngram-mincount\f[R] option. 170 | .PP 171 | The \f[I]buckets\f[R] type maps n-grams to buckets using the FNV1 hash. 172 | The considered n-gram lengths are specified using the \f[C]minn\f[R] and 173 | \f[C]maxn\f[R] options. 174 | The number of buckets is controlled with the \f[C]buckets\f[R] option. 175 | .RE 176 | .TP 177 | .B \f[C]--target-size\f[R] \f[I]SIZE\f[R] 178 | The target size for the token vocabulary. 179 | At most \f[I]SIZE\f[R] tokens are included for training. 180 | Only tokens appearing more frequently than the token at \f[I]SIZE\f[R] 181 | are included. 182 | .TP 183 | .B \f[C]--untyped-deps\f[R] 184 | Only use the word of the attached token in the dependency relation as 185 | contexts to train the representation of the focus word. 186 | .TP 187 | .B \f[C]--use-root\f[R] 188 | Include the abstract root node in the dependency graph as contexts 189 | during training. 190 | .TP 191 | .B \f[C]--zipf\f[R] \f[I]EXP\f[R] 192 | Exponent \f[I]s\f[R] used in the Zipf distribution 193 | \f[C]p(k) = 1 / (k\[ha]s H_N)\f[R] for negative sampling. 194 | Default: 0.5 195 | .SH EXAMPLES 196 | .PP 197 | Train embeddings on \f[I]dewiki.txt\f[R] using the dependency model with 198 | default parameters: 199 | .IP 200 | .nf 201 | \f[C] 202 | finalfrontier deps dewiki.conll dewiki-deps.bin 203 | \f[R] 204 | .fi 205 | .PP 206 | Train embeddings with dimensionality 200 on \f[I]dewiki.conll\f[R] using 207 | the dependency model from contexts with depth up to 2: 208 | .IP 209 | .nf 210 | \f[C] 211 | finalfrontier deps --depth 2 --normalize --dims 200 \[rs] 212 | dewiki.conll dewiki-deps.bin 213 | \f[R] 214 | .fi 215 | .SH SEE ALSO 216 | .PP 217 | \f[C]finalfrontier\f[R](1), \f[C]finalfrontier-skipgram\f[R](1) 218 | -------------------------------------------------------------------------------- /src/vocab/mod.rs: -------------------------------------------------------------------------------- 1 | pub(crate) mod simple; 2 | pub(crate) mod subword; 3 | 4 | use std::borrow::Borrow; 5 | use std::collections::HashMap; 6 | use std::hash::Hash; 7 | 8 | use serde::Serialize; 9 | use superslice::Ext; 10 | 11 | use crate::idx::WordIdx; 12 | use std::cmp::Reverse; 13 | 14 | const BOW: char = '<'; 15 | const EOW: char = '>'; 16 | 17 | pub type Word = CountedType; 18 | #[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord)] 19 | pub struct CountedType { 20 | count: usize, 21 | label: T, 22 | } 23 | 24 | impl CountedType { 25 | /// Construct a new type. 26 | pub fn new(label: T, count: usize) -> Self { 27 | CountedType { count, label } 28 | } 29 | pub fn count(&self) -> usize { 30 | self.count 31 | } 32 | pub fn label(&self) -> &T { 33 | &self.label 34 | } 35 | } 36 | 37 | impl CountedType { 38 | /// The string representation of the word. 39 | pub fn word(&self) -> &str { 40 | &self.label 41 | } 42 | } 43 | 44 | /// Trait for lookup of indices. 45 | pub trait Vocab { 46 | type VocabType: Hash + Eq; 47 | type IdxType: WordIdx; 48 | type Config; 49 | 50 | /// Return this vocabulary's config. 51 | fn config(&self) -> Self::Config; 52 | 53 | fn is_empty(&self) -> bool { 54 | self.len() == 0 55 | } 56 | 57 | /// Get the number of entries in the vocabulary. 58 | fn len(&self) -> usize { 59 | self.types().len() 60 | } 61 | 62 | /// Get the index of the entry, will return None if the item is not present. 63 | fn idx(&self, key: &Q) -> Option 64 | where 65 | Self::VocabType: Borrow, 66 | Q: Hash + ?Sized + Eq; 67 | 68 | /// Get the discard probability of the entry with the given index. 69 | fn discard(&self, idx: usize) -> f32; 70 | 71 | /// Get the number of possible input types. 72 | fn n_input_types(&self) -> usize; 73 | 74 | /// Get all types in the vocabulary. 75 | fn types(&self) -> &[CountedType]; 76 | 77 | /// Get the number of types in the corpus. 78 | /// 79 | /// This returns the number of types in the corpus that the vocabulary 80 | /// was constructed from, **before** removing types that are below the 81 | /// minimum count. 82 | fn n_types(&self) -> usize; 83 | } 84 | 85 | /// Generic builder struct to count types. 86 | /// 87 | /// Items are added to the vocabulary and counted using the `count` method. 88 | /// There is no explicit build method, conversion is done via implementing 89 | /// `From>`. 90 | pub struct VocabBuilder { 91 | config: C, 92 | items: HashMap, 93 | n_items: usize, 94 | } 95 | 96 | impl VocabBuilder 97 | where 98 | T: Hash + Eq, 99 | { 100 | pub fn new(config: C) -> Self { 101 | VocabBuilder { 102 | config, 103 | items: HashMap::new(), 104 | n_items: 0, 105 | } 106 | } 107 | 108 | pub fn count(&mut self, item: S) 109 | where 110 | S: Into, 111 | { 112 | self.n_items += 1; 113 | let cnt = self.items.entry(item.into()).or_insert(0); 114 | *cnt += 1; 115 | } 116 | } 117 | 118 | /// Create discard probabilities based on threshold, specific counts and total counts. 119 | pub(crate) fn create_discards( 120 | discard_threshold: f32, 121 | types: &[CountedType], 122 | n_tokens: usize, 123 | ) -> Vec { 124 | let mut discards = Vec::with_capacity(types.len()); 125 | 126 | for item in types { 127 | let p = item.count() as f32 / n_tokens as f32; 128 | let p_discard = discard_threshold / p + (discard_threshold / p).sqrt(); 129 | 130 | // Not a proper probability, upper bound at 1.0. 131 | discards.push(1f32.min(p_discard)); 132 | } 133 | 134 | discards 135 | } 136 | 137 | /// Create lookup. 138 | pub(crate) fn create_indices(types: &[CountedType]) -> HashMap 139 | where 140 | S: Hash + Eq + Clone, 141 | { 142 | let mut token_indices = HashMap::new(); 143 | 144 | for (idx, item) in types.iter().enumerate() { 145 | token_indices.insert(item.label.clone(), idx); 146 | } 147 | 148 | // Invariant: The index size should be the same as the number of 149 | // types. 150 | assert_eq!(types.len(), token_indices.len()); 151 | 152 | token_indices 153 | } 154 | 155 | /// Add begin/end-of-word brackets. 156 | pub(crate) fn bracket(word: &str) -> String { 157 | let mut bracketed = String::new(); 158 | bracketed.push(BOW); 159 | bracketed.push_str(word); 160 | bracketed.push(EOW); 161 | 162 | bracketed 163 | } 164 | 165 | /// Cutoff to determine vocabulary size. 166 | #[derive(Copy, Clone, Debug, Eq, PartialEq, Serialize)] 167 | #[serde(tag = "type", content = "value")] 168 | pub enum Cutoff { 169 | /// Cutoff based on minimum frequency, items appearing less than 170 | /// `min_count` times are discarded. 171 | MinCount(usize), 172 | /// Cutoff based on a target size, up to `target_size` items are kept 173 | /// in the vocabulary. If the item at `target_size+1` appears `n` times, 174 | /// all items with frequency `n` and smaller are discarded. 175 | TargetSize(usize), 176 | } 177 | 178 | impl Cutoff { 179 | pub(crate) fn filter( 180 | &self, 181 | items: impl IntoIterator, 182 | ) -> Vec> 183 | where 184 | T: Hash + Eq + Into, 185 | S: Hash + Eq + Clone + Ord, 186 | { 187 | match self { 188 | Cutoff::MinCount(min_count) => filter_minfreq(items, *min_count), 189 | Cutoff::TargetSize(target_size) => filter_targetsize(items, *target_size), 190 | } 191 | } 192 | } 193 | 194 | fn filter_minfreq( 195 | items: impl IntoIterator, 196 | min_count: usize, 197 | ) -> Vec> 198 | where 199 | T: Hash + Eq + Into, 200 | S: Hash + Eq + Clone + Ord, 201 | { 202 | let mut types: Vec<_> = items 203 | .into_iter() 204 | .filter(|(_, count)| *count >= min_count) 205 | .map(|(item, count)| CountedType::new(item.into(), count)) 206 | .collect(); 207 | types.sort_unstable_by(|w1, w2| w2.cmp(w1)); 208 | types 209 | } 210 | 211 | fn filter_targetsize( 212 | items: impl IntoIterator, 213 | target_size: usize, 214 | ) -> Vec> 215 | where 216 | T: Hash + Eq + Into, 217 | S: Hash + Eq + Clone + Ord, 218 | { 219 | let mut items = items 220 | .into_iter() 221 | .map(|(item, count)| CountedType::new(item.into(), count)) 222 | .collect::>(); 223 | items.sort_unstable_by(|i1, i2| i2.cmp(i1)); 224 | 225 | if target_size > items.len() { 226 | return items; 227 | } 228 | 229 | let cutoff_idx = 230 | items.lower_bound_by_key(&Reverse(items[target_size].count), |key| Reverse(key.count)); 231 | items.truncate(cutoff_idx); 232 | items 233 | } 234 | 235 | #[cfg(test)] 236 | mod test { 237 | use crate::{Cutoff, Word}; 238 | 239 | #[test] 240 | pub fn target_size_unique_counts() { 241 | let cutoff = Cutoff::TargetSize(3); 242 | let items = vec![("a", 10), ("b", 3), ("c", 12), ("d", 5)]; 243 | let filtered: Vec = cutoff.filter(items); 244 | let target_items = vec![ 245 | Word::new("c".to_string(), 12), 246 | Word::new("a".to_string(), 10), 247 | Word::new("d".to_string(), 5), 248 | ]; 249 | assert!( 250 | filtered == target_items, 251 | "{:#?}\n != \n {:#?}", 252 | filtered, 253 | target_items 254 | ); 255 | } 256 | 257 | #[test] 258 | pub fn target_size_discard_equal() { 259 | let cutoff = Cutoff::TargetSize(3); 260 | let items = vec![("a", 10), ("b", 3), ("c", 12), ("e", 12), ("d", 10)]; 261 | let filtered: Vec = cutoff.filter(items); 262 | let target_items = vec![ 263 | Word::new("e".to_string(), 12), 264 | Word::new("c".to_string(), 12), 265 | ]; 266 | assert!( 267 | filtered == target_items, 268 | "{:#?}\n != \n {:#?}", 269 | filtered, 270 | target_items 271 | ); 272 | } 273 | 274 | #[test] 275 | pub fn target_size_0() { 276 | let cutoff = Cutoff::TargetSize(0); 277 | let items = vec![("a", 10), ("b", 3), ("c", 12), ("e", 12), ("d", 10)]; 278 | let filtered: Vec = cutoff.filter(items); 279 | let target_items = vec![]; 280 | assert!( 281 | filtered == target_items, 282 | "{:#?}\n != \n {:#?}", 283 | filtered, 284 | target_items 285 | ); 286 | } 287 | 288 | #[test] 289 | pub fn target_size_large() { 290 | let cutoff = Cutoff::TargetSize(10); 291 | let items = vec![("a", 10), ("b", 3), ("c", 12), ("e", 12), ("d", 10)]; 292 | let filtered: Vec = cutoff.filter(items); 293 | let target_items = vec![ 294 | Word::new("e".to_string(), 12), 295 | Word::new("c".to_string(), 12), 296 | Word::new("d".to_string(), 10), 297 | Word::new("a".to_string(), 10), 298 | Word::new("b".to_string(), 3), 299 | ]; 300 | assert!( 301 | filtered == target_items, 302 | "{:#?}\n != \n {:#?}", 303 | filtered, 304 | target_items 305 | ); 306 | } 307 | 308 | #[test] 309 | pub fn target_size_all_equal_too_many() { 310 | let cutoff = Cutoff::TargetSize(3); 311 | let items = vec![("a", 10), ("b", 10), ("c", 10), ("e", 10), ("d", 10)]; 312 | let filtered: Vec = cutoff.filter(items); 313 | let target_items = vec![]; 314 | assert!( 315 | filtered == target_items, 316 | "{:#?}\n != \n {:#?}", 317 | filtered, 318 | target_items 319 | ); 320 | } 321 | } 322 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/subcommands/deps.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::{BufRead, BufReader, BufWriter}; 3 | use std::path::{Path, PathBuf}; 4 | use std::thread; 5 | use std::time::Duration; 6 | 7 | use anyhow::{Context, Result}; 8 | use conllu::io::{ReadSentence, Reader, Sentences}; 9 | use finalfrontier::io::{thread_data_conllu, FileProgress, TrainInfo}; 10 | use finalfrontier::{ 11 | BucketIndexerType, CommonConfig, DepembedsConfig, DepembedsTrainer, Dependency, 12 | DependencyIterator, Sgd, SimpleVocab, SimpleVocabConfig, SubwordVocab, Vocab, VocabBuilder, 13 | VocabConfig, WriteModelBinary, 14 | }; 15 | use finalfusion::compat::fasttext::FastTextIndexer; 16 | use finalfusion::compat::floret::FloretIndexer; 17 | use finalfusion::prelude::VocabWrap; 18 | use finalfusion::subword::FinalfusionHashIndexer; 19 | use rand::{Rng, SeedableRng}; 20 | use rand_xorshift::XorShiftRng; 21 | use serde::Serialize; 22 | use udgraph::graph::{Node, Sentence}; 23 | use udgraph_projectivize::{HeadProjectivizer, Projectivize}; 24 | 25 | use crate::subcommands::show_progress; 26 | use crate::FinalfrontierApp; 27 | 28 | const PROGRESS_UPDATE_INTERVAL: u64 = 200; 29 | 30 | /// Dependency embeddings subcommand. 31 | pub struct DepsApp { 32 | train_info: TrainInfo, 33 | common_config: CommonConfig, 34 | depembeds_config: DepembedsConfig, 35 | input_vocab_config: VocabConfig, 36 | output_vocab_config: SimpleVocabConfig, 37 | } 38 | 39 | impl DepsApp { 40 | pub fn new( 41 | train_info: TrainInfo, 42 | common_config: CommonConfig, 43 | depembeds_config: DepembedsConfig, 44 | vocab_config: VocabConfig, 45 | ) -> Self { 46 | Self { 47 | train_info, 48 | common_config, 49 | depembeds_config, 50 | input_vocab_config: vocab_config, 51 | output_vocab_config: SimpleVocabConfig { 52 | cutoff: vocab_config.cutoff(), 53 | discard_threshold: vocab_config.discard_threshold(), 54 | }, 55 | } 56 | } 57 | 58 | /// Get the corpus path. 59 | pub fn corpus(&self) -> &str { 60 | self.train_info().corpus() 61 | } 62 | 63 | /// Get the output path. 64 | pub fn output(&self) -> &str { 65 | self.train_info().output() 66 | } 67 | 68 | /// Get the number of threads. 69 | pub fn n_threads(&self) -> usize { 70 | self.train_info().n_threads() 71 | } 72 | 73 | /// Get the common config. 74 | pub fn common_config(&self) -> CommonConfig { 75 | self.common_config 76 | } 77 | 78 | /// Get the depembeds config. 79 | pub fn depembeds_config(&self) -> DepembedsConfig { 80 | self.depembeds_config 81 | } 82 | 83 | /// Get the input vocab config. 84 | pub fn input_vocab_config(&self) -> VocabConfig { 85 | self.input_vocab_config 86 | } 87 | 88 | /// Get the output vocab config. 89 | pub fn output_vocab_config(&self) -> SimpleVocabConfig { 90 | self.output_vocab_config 91 | } 92 | 93 | /// Get the train information. 94 | pub fn train_info(&self) -> &TrainInfo { 95 | &self.train_info 96 | } 97 | } 98 | 99 | impl FinalfrontierApp for DepsApp { 100 | fn run(&self) -> Result<()> { 101 | match self.input_vocab_config() { 102 | VocabConfig::SimpleVocab(config) => { 103 | let (input_vocab, output_vocab) = build_vocab::<_, SimpleVocab, _>( 104 | config, 105 | self.output_vocab_config(), 106 | self.depembeds_config(), 107 | self.corpus(), 108 | )?; 109 | train(input_vocab, output_vocab, self)?; 110 | } 111 | VocabConfig::FloretVocab(config) => { 112 | let (input_vocab, output_vocab) = 113 | build_vocab::<_, SubwordVocab<_, FloretIndexer>, _>( 114 | config, 115 | self.output_vocab_config(), 116 | self.depembeds_config, 117 | self.corpus(), 118 | )?; 119 | train(input_vocab, output_vocab, self)? 120 | } 121 | VocabConfig::SubwordVocab(config) => match config.indexer.indexer_type { 122 | BucketIndexerType::Finalfusion => { 123 | let (input_vocab, output_vocab) = 124 | build_vocab::<_, SubwordVocab<_, FinalfusionHashIndexer>, _>( 125 | config, 126 | self.output_vocab_config(), 127 | self.depembeds_config(), 128 | self.corpus(), 129 | )?; 130 | train(input_vocab, output_vocab, self)? 131 | } 132 | BucketIndexerType::FastText => { 133 | let (input_vocab, output_vocab) = 134 | build_vocab::<_, SubwordVocab<_, FastTextIndexer>, _>( 135 | config, 136 | self.output_vocab_config(), 137 | self.depembeds_config(), 138 | self.corpus(), 139 | )?; 140 | train(input_vocab, output_vocab, self)?; 141 | } 142 | }, 143 | VocabConfig::NGramVocab(config) => { 144 | let (input_vocab, output_vocab) = build_vocab::<_, SubwordVocab<_, _>, _>( 145 | config, 146 | self.output_vocab_config(), 147 | self.depembeds_config(), 148 | self.corpus(), 149 | )?; 150 | train(input_vocab, output_vocab, self)?; 151 | } 152 | } 153 | 154 | Ok(()) 155 | } 156 | } 157 | 158 | fn train(input_vocab: V, output_vocab: SimpleVocab, app: &DepsApp) -> Result<()> 159 | where 160 | V: Vocab + Into + Clone + Send + Sync + 'static, 161 | V::Config: Serialize, 162 | for<'a> &'a V::IdxType: IntoIterator, 163 | { 164 | let corpus = app.corpus(); 165 | let common_config = app.common_config(); 166 | let n_threads = app.n_threads(); 167 | 168 | let mut output_writer = 169 | BufWriter::new(File::create(app.output()).context("Cannot open output file for writing.")?); 170 | let trainer = DepembedsTrainer::new( 171 | input_vocab, 172 | output_vocab, 173 | app.common_config(), 174 | app.depembeds_config(), 175 | XorShiftRng::from_entropy(), 176 | ); 177 | let sgd = Sgd::new(trainer.into()); 178 | 179 | let projectivize = app.depembeds_config().projectivize; 180 | let mut children = Vec::with_capacity(n_threads); 181 | for thread in 0..n_threads { 182 | let corpus = corpus.to_owned(); 183 | let sgd = sgd.clone(); 184 | 185 | children.push(thread::spawn(move || { 186 | do_work( 187 | corpus, 188 | sgd, 189 | thread, 190 | n_threads, 191 | common_config.epochs, 192 | common_config.lr, 193 | projectivize, 194 | ) 195 | })); 196 | } 197 | 198 | show_progress( 199 | &app.common_config(), 200 | &sgd, 201 | Duration::from_millis(PROGRESS_UPDATE_INTERVAL), 202 | ); 203 | 204 | // Wait until all threads have finished. 205 | for child in children { 206 | child.join().expect("Thread panicked")?; 207 | } 208 | 209 | sgd.into_model() 210 | .write_model_binary( 211 | &mut output_writer, 212 | app.train_info().clone(), 213 | app.common_config.format, 214 | ) 215 | .context("Cannot write model") 216 | } 217 | 218 | fn do_work( 219 | corpus_path: P, 220 | mut sgd: Sgd>, 221 | thread: usize, 222 | n_threads: usize, 223 | epochs: u32, 224 | start_lr: f32, 225 | projectivize: bool, 226 | ) -> Result<()> 227 | where 228 | P: Into, 229 | R: Clone + Rng, 230 | V: Vocab, 231 | V::Config: Serialize, 232 | for<'a> &'a V::IdxType: IntoIterator, 233 | { 234 | let n_tokens = sgd.model().input_vocab().n_types(); 235 | 236 | let f = File::open(corpus_path.into()).context("Cannot open corpus for reading")?; 237 | let (data, start) = 238 | thread_data_conllu(&f, thread, n_threads).context("Could not get thread-specific data")?; 239 | let projectivizer = if projectivize { 240 | Some(HeadProjectivizer::new()) 241 | } else { 242 | None 243 | }; 244 | 245 | let mut sentences = SentenceIter::new(BufReader::new(&data[start..]), projectivizer); 246 | while sgd.n_tokens_processed() < epochs as usize * n_tokens { 247 | let sentence = sentences 248 | .next() 249 | .or_else(|| { 250 | sentences = SentenceIter::new(BufReader::new(&*data), projectivizer); 251 | sentences.next() 252 | }) 253 | .transpose()? 254 | .context("Cannot read sentence")?; 255 | 256 | let lr = (1.0 - (sgd.n_tokens_processed() as f32 / (epochs as usize * n_tokens) as f32)) 257 | * start_lr; 258 | sgd.update_sentence(&sentence, lr); 259 | } 260 | 261 | Ok(()) 262 | } 263 | 264 | fn build_vocab( 265 | input_config: C, 266 | output_config: SimpleVocabConfig, 267 | dep_config: DepembedsConfig, 268 | corpus_path: P, 269 | ) -> Result<(V, SimpleVocab)> 270 | where 271 | P: AsRef, 272 | V: Vocab + From>, 273 | VocabBuilder: Into, 274 | { 275 | let f = File::open(corpus_path).context("Cannot open corpus for reading")?; 276 | let file_progress = FileProgress::new(f).context("Cannot create progress bar")?; 277 | let mut input_builder = VocabBuilder::new(input_config); 278 | let mut output_builder: VocabBuilder<_, Dependency> = VocabBuilder::new(output_config); 279 | 280 | let projectivizer = if dep_config.projectivize { 281 | Some(HeadProjectivizer::new()) 282 | } else { 283 | None 284 | }; 285 | 286 | for sentence in SentenceIter::new(BufReader::new(file_progress), projectivizer) { 287 | let sentence = sentence?; 288 | 289 | for token in sentence.iter().filter_map(Node::token) { 290 | input_builder.count(token.form()); 291 | } 292 | 293 | for (_, context) in DependencyIterator::new_from_config(&sentence.dep_graph(), dep_config) { 294 | output_builder.count(context); 295 | } 296 | } 297 | 298 | Ok((input_builder.into(), output_builder.into())) 299 | } 300 | 301 | struct SentenceIter 302 | where 303 | R: ReadSentence, 304 | { 305 | inner: Sentences, 306 | projectivizer: Option

, 307 | } 308 | 309 | impl SentenceIter> 310 | where 311 | R: BufRead, 312 | { 313 | fn new(read: R, projectivizer: Option

) -> Self { 314 | SentenceIter { 315 | inner: Reader::new(read).into_iter(), 316 | projectivizer, 317 | } 318 | } 319 | } 320 | 321 | impl Iterator for SentenceIter 322 | where 323 | P: Projectivize, 324 | R: ReadSentence, 325 | { 326 | type Item = Result; 327 | 328 | fn next(&mut self) -> Option { 329 | let sentence = self.inner.next()?; 330 | let mut sentence = match sentence.context("Cannot read sentence") { 331 | Ok(sentence) => sentence, 332 | err @ Err(_) => return Some(err), 333 | }; 334 | 335 | if let Some(proj) = &self.projectivizer { 336 | // Rewrap error. 337 | if let Err(err) = proj.projectivize(&mut sentence) { 338 | return Some(Err(err).context("Cannot projectivize sentence.")); 339 | } 340 | } 341 | 342 | Some(Ok(sentence)) 343 | } 344 | } 345 | -------------------------------------------------------------------------------- /src/io.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::{self, BufRead, Lines, Read, Seek, Write}; 3 | 4 | use anyhow::{Context, Result}; 5 | use chrono::{DateTime, Local}; 6 | use indicatif::{ProgressBar, ProgressStyle}; 7 | use memmap::{Mmap, MmapOptions}; 8 | use serde::Serialize; 9 | 10 | pub struct FileProgress { 11 | inner: File, 12 | progress: ProgressBar, 13 | } 14 | 15 | #[derive(Clone, Copy, Debug, PartialEq, Eq)] 16 | pub enum EmbeddingFormat { 17 | FastText, 18 | FinalFusion, 19 | Floret, 20 | Word2Vec, 21 | Text, 22 | TextDims, 23 | } 24 | 25 | /// A progress bar that implements the `Read` trait. 26 | /// 27 | /// This wrapper of `indicatif`'s `ProgressBar` updates progress based on the 28 | /// current offset within the file. 29 | impl FileProgress { 30 | pub fn new(file: File) -> io::Result { 31 | let metadata = file.metadata()?; 32 | let progress = ProgressBar::new(metadata.len()); 33 | let style = ProgressStyle::default_bar() 34 | .template("{bar:30} {bytes}/{total_bytes} ETA: {eta_precise}") 35 | .expect("template string expected is to be valid"); 36 | progress.set_style(style); 37 | 38 | Ok(FileProgress { 39 | inner: file, 40 | progress, 41 | }) 42 | } 43 | } 44 | 45 | impl Read for FileProgress { 46 | fn read(&mut self, buf: &mut [u8]) -> io::Result { 47 | let n_read = self.inner.read(buf)?; 48 | let pos = self.inner.stream_position()?; 49 | self.progress.set_position(pos); 50 | Ok(n_read) 51 | } 52 | } 53 | 54 | impl Drop for FileProgress { 55 | fn drop(&mut self) { 56 | self.progress.finish(); 57 | } 58 | } 59 | 60 | /// Sentence iterator. 61 | /// 62 | /// This iterator consumes a reader with tokenized sentences: 63 | /// 64 | /// - One sentence per line. 65 | /// - Tokens separated by a space. 66 | /// 67 | /// It produces `Vec`s with the tokens, adding an end-of-sentence marker 68 | /// to the end of the sentence. Lines that are empty or only consist of 69 | /// whitespace are discarded. 70 | pub struct SentenceIterator { 71 | lines: Lines, 72 | } 73 | 74 | impl SentenceIterator 75 | where 76 | R: BufRead, 77 | { 78 | pub fn new(read: R) -> Self { 79 | SentenceIterator { 80 | lines: read.lines(), 81 | } 82 | } 83 | } 84 | 85 | impl Iterator for SentenceIterator 86 | where 87 | R: BufRead, 88 | { 89 | type Item = Result>; 90 | 91 | fn next(&mut self) -> Option { 92 | for line in &mut self.lines { 93 | let line = match line { 94 | Ok(ref line) => line.trim(), 95 | Err(err) => return Some(Err(err.into())), 96 | }; 97 | 98 | // Skip empty lines. 99 | if !line.is_empty() { 100 | return Some(Ok(whitespace_tokenize(line))); 101 | } 102 | } 103 | 104 | None 105 | } 106 | } 107 | 108 | /// Get thread-specific data. 109 | /// 110 | /// This function will return a memory map of the corpus data. The initial 111 | /// starting position for the given thread is also returned. This starting 112 | /// Position will always be the beginning of a sentence. 113 | pub fn thread_data_text(f: &File, thread: usize, n_threads: usize) -> Result<(Mmap, usize)> { 114 | assert!( 115 | thread < n_threads, 116 | "Thread {} out of index [0, {})", 117 | thread, 118 | n_threads 119 | ); 120 | 121 | let size = f.metadata().context("Cannot get file metadata")?.len(); 122 | let chunk_size = size as usize / n_threads; 123 | 124 | let mmap = unsafe { MmapOptions::new().map(f)? }; 125 | 126 | if thread == 0 { 127 | return Ok((mmap, 0)); 128 | } 129 | 130 | let mut start = thread * chunk_size; 131 | while start < mmap.len() { 132 | let next = mmap[start]; 133 | start += 1; 134 | if next == b'\n' { 135 | break; 136 | } 137 | } 138 | 139 | Ok((mmap, start)) 140 | } 141 | 142 | /// Get thread-specific data for a CoNLL-U corpus. 143 | /// 144 | /// This function will return a memory map of the corpus data. The initial 145 | /// starting position for the given thread is also returned. This starting 146 | /// Position will always be the beginning of a sentence. 147 | pub fn thread_data_conllu(f: &File, thread: usize, n_threads: usize) -> Result<(Mmap, usize)> { 148 | assert!( 149 | thread < n_threads, 150 | "Thread {} out of index [0, {})", 151 | thread, 152 | n_threads 153 | ); 154 | 155 | let size = f.metadata().context("Cannot get file metadata")?.len(); 156 | let chunk_size = size as usize / n_threads; 157 | 158 | let mmap = unsafe { MmapOptions::new().map(f)? }; 159 | 160 | if thread == 0 { 161 | return Ok((mmap, 0)); 162 | } 163 | 164 | let mut start = thread * chunk_size; 165 | while start < mmap.len() - 1 { 166 | let next = mmap[start]; 167 | start += 1; 168 | if next == b'\n' && mmap[start] == b'\n' { 169 | start += 1; 170 | break; 171 | } 172 | } 173 | 174 | Ok((mmap, start)) 175 | } 176 | 177 | /// Meta information about training. 178 | #[derive(Clone, Serialize)] 179 | pub struct TrainInfo { 180 | corpus: String, 181 | output: String, 182 | n_threads: usize, 183 | start_datetime: String, 184 | end_datetime: Option, 185 | } 186 | 187 | impl TrainInfo { 188 | /// Construct new TrainInfo. 189 | /// 190 | /// Constructs TrainInfo with `start_datetime` set to the current datetime. `end_datetime` is 191 | /// set to `None` and can be set through `TrainInfo::set_end`. 192 | pub fn new(corpus: String, output: String, n_threads: usize) -> Self { 193 | let start_datetime: DateTime = Local::now(); 194 | TrainInfo { 195 | corpus, 196 | output, 197 | n_threads, 198 | start_datetime: start_datetime.format("%Y-%m-%d %H:%M:%S").to_string(), 199 | end_datetime: None, 200 | } 201 | } 202 | 203 | /// Get the corpus path. 204 | pub fn corpus(&self) -> &str { 205 | &self.corpus 206 | } 207 | 208 | /// Get the output file. 209 | pub fn output(&self) -> &str { 210 | &self.output 211 | } 212 | 213 | /// Get the number of threads. 214 | pub fn n_threads(&self) -> usize { 215 | self.n_threads 216 | } 217 | 218 | /// Get the start datetime. 219 | pub fn start_datetime(&self) -> &str { 220 | &self.start_datetime 221 | } 222 | 223 | /// Get the end datetime. 224 | pub fn end_datetime(&self) -> Option<&str> { 225 | self.end_datetime.as_deref() 226 | } 227 | 228 | /// Set the end datetime to current datetime. 229 | pub fn set_end(&mut self) { 230 | let start_datetime: DateTime = Local::now(); 231 | self.end_datetime = Some(start_datetime.format("%Y-%m-%d %H:%M:%S").to_string()); 232 | } 233 | } 234 | 235 | /// Trait for writing models in binary format. 236 | pub trait WriteModelBinary 237 | where 238 | W: Write, 239 | { 240 | fn write_model_binary( 241 | self, 242 | write: &mut W, 243 | train_info: TrainInfo, 244 | format: EmbeddingFormat, 245 | ) -> Result<()>; 246 | } 247 | 248 | fn whitespace_tokenize(line: &str) -> Vec { 249 | line.split_whitespace() 250 | .map(ToOwned::to_owned) 251 | .collect::>() 252 | } 253 | 254 | /// Trait for writing models in text format. 255 | pub trait WriteModelText 256 | where 257 | W: Write, 258 | { 259 | /// Write the model in text format. 260 | /// 261 | /// This function only writes the word embeddings. The subword 262 | /// embeddings are discarded. 263 | /// 264 | /// The `write_dims` parameter indicates whether the first line 265 | /// should contain the dimensionality of the embedding matrix. 266 | fn write_model_text(&self, write: &mut W, write_dims: bool) -> Result<()>; 267 | } 268 | 269 | /// Trait for writing models in binary format. 270 | pub trait WriteModelWord2Vec 271 | where 272 | W: Write, 273 | { 274 | fn write_model_word2vec(&self, write: &mut W) -> Result<()>; 275 | } 276 | 277 | #[cfg(test)] 278 | mod tests { 279 | use std::fs::File; 280 | use std::io::Cursor; 281 | 282 | use super::SentenceIterator; 283 | use super::{thread_data_conllu, thread_data_text}; 284 | 285 | #[test] 286 | fn sentence_iterator_test() { 287 | let v = b"This is a sentence .\nAnd another one .\n".to_vec(); 288 | let c = Cursor::new(v); 289 | let mut iter = SentenceIterator::new(c); 290 | assert_eq!( 291 | iter.next().unwrap().unwrap(), 292 | vec!["This", "is", "a", "sentence", "."] 293 | ); 294 | assert_eq!( 295 | iter.next().unwrap().unwrap(), 296 | vec!["And", "another", "one", "."] 297 | ); 298 | assert!(iter.next().is_none()); 299 | } 300 | 301 | #[test] 302 | fn sentence_iterator_no_newline_test() { 303 | let v = b"This is a sentence .\nAnd another one .".to_vec(); 304 | let c = Cursor::new(v); 305 | let mut iter = SentenceIterator::new(c); 306 | assert_eq!( 307 | iter.next().unwrap().unwrap(), 308 | vec!["This", "is", "a", "sentence", "."] 309 | ); 310 | assert_eq!( 311 | iter.next().unwrap().unwrap(), 312 | vec!["And", "another", "one", "."] 313 | ); 314 | assert!(iter.next().is_none()); 315 | } 316 | 317 | #[test] 318 | fn sentence_iterator_empty_test() { 319 | let v = b"".to_vec(); 320 | let c = Cursor::new(v); 321 | let mut iter = SentenceIterator::new(c); 322 | assert!(iter.next().is_none()); 323 | } 324 | 325 | #[test] 326 | fn sentence_iterator_empty_newline_test() { 327 | let v = b"\n \n \n".to_vec(); 328 | let c = Cursor::new(v); 329 | let mut iter = SentenceIterator::new(c); 330 | assert!(iter.next().is_none()); 331 | } 332 | 333 | static CHUNKING_TEST_DATA: &str = 334 | "a b c\nd e f\ng h i\nj k l\nm n o\np q r\ns t u\nv w x\ny z\n"; 335 | 336 | static CHUNKING_TEST_DATA_DEPS: &str = 337 | "a b c\nd e f\n\ng h i\nj k l\n\nm n o\np q r\n\ns t u\nv w x\ny z\n"; 338 | 339 | #[test] 340 | fn thread_data_test() { 341 | let f = File::open("testdata/chunking.txt").unwrap(); 342 | 343 | let (mmap, start) = thread_data_text(&f, 0, 3).unwrap(); 344 | assert_eq!( 345 | &*mmap, 346 | CHUNKING_TEST_DATA.as_bytes(), 347 | "Memory mapping is incorrect" 348 | ); 349 | assert_eq!(start, 0, "Incorrect start index"); 350 | 351 | let (mmap, start) = thread_data_text(&f, 1, 3).unwrap(); 352 | assert_eq!( 353 | &*mmap, 354 | CHUNKING_TEST_DATA.as_bytes(), 355 | "Memory mapping is incorrect" 356 | ); 357 | assert_eq!(start, 18, "Incorrect start index"); 358 | 359 | let (mmap, start) = thread_data_text(&f, 2, 3).unwrap(); 360 | assert_eq!( 361 | &*mmap, 362 | CHUNKING_TEST_DATA.as_bytes(), 363 | "Memory mapping is incorrect" 364 | ); 365 | assert_eq!(start, 36, "Incorrect start index"); 366 | } 367 | 368 | #[test] 369 | fn deps_thread_data_test() { 370 | // file size is 55 bytes 371 | // starts scanning at index 19 372 | // first double linebreak is at 26 373 | // second at 39 374 | let f = File::open("testdata/dep_chunking.txt").unwrap(); 375 | let (mmap, start) = thread_data_conllu(&f, 0, 3).unwrap(); 376 | assert_eq!( 377 | &*mmap, 378 | CHUNKING_TEST_DATA_DEPS.as_bytes(), 379 | "Memory mapping is incorrect" 380 | ); 381 | assert_eq!(start, 0, "Incorrect start index"); 382 | 383 | let (mmap, start) = thread_data_conllu(&f, 1, 3).unwrap(); 384 | assert_eq!( 385 | &*mmap, 386 | CHUNKING_TEST_DATA_DEPS.as_bytes(), 387 | "Memory mapping is incorrect" 388 | ); 389 | assert_eq!(start, 26, "Incorrect start index"); 390 | 391 | let (mmap, start) = thread_data_conllu(&f, 2, 3).unwrap(); 392 | assert_eq!( 393 | &*mmap, 394 | CHUNKING_TEST_DATA_DEPS.as_bytes(), 395 | "Memory mapping is incorrect" 396 | ); 397 | assert_eq!(start, 39, "Incorrect start index"); 398 | } 399 | 400 | #[should_panic] 401 | #[test] 402 | fn thread_data_out_of_bounds_test() { 403 | let f = File::open("testdata/chunking.txt").unwrap(); 404 | let _ = thread_data_conllu(&f, 3, 3).unwrap(); 405 | } 406 | } 407 | -------------------------------------------------------------------------------- /src/opts.rs: -------------------------------------------------------------------------------- 1 | use std::cmp; 2 | 3 | use clap::{Parser, ValueEnum}; 4 | use clap_complete::Shell; 5 | 6 | use crate::io::{EmbeddingFormat, TrainInfo}; 7 | use crate::{ 8 | BucketConfig, BucketIndexerType, CommonConfig, Cutoff, DepembedsConfig, FloretConfig, LossType, 9 | NGramConfig, SimpleVocabConfig, SkipGramConfig, SkipgramModelType, SubwordVocabConfig, 10 | VocabConfig, 11 | }; 12 | 13 | pub struct TrainConfig { 14 | pub train_info: TrainInfo, 15 | pub common_config: CommonConfig, 16 | pub model_config: ModelConfig, 17 | pub vocab_config: VocabConfig, 18 | } 19 | 20 | pub enum ModelConfig { 21 | SkipGram(SkipGramConfig), 22 | DepEmbeds(DepembedsConfig), 23 | } 24 | 25 | const VERSION: &str = if let Some(git_desc) = option_env!("MAYBE_FINALFRONTIER_GIT_DESC") { 26 | git_desc 27 | } else { 28 | env!("CARGO_PKG_VERSION") 29 | }; 30 | 31 | #[derive(Clone, Debug, Parser)] 32 | #[clap(version = VERSION)] 33 | pub enum Opts { 34 | /// Generate shell completions 35 | Completions { 36 | #[clap(value_enum)] 37 | shell: Shell, 38 | }, 39 | 40 | /// Train an embedding model 41 | Train(TrainOpts), 42 | } 43 | 44 | #[derive(Clone, Debug, Parser)] 45 | pub struct TrainOpts { 46 | #[clap(flatten)] 47 | pub common: CommonTrainOpts, 48 | 49 | #[clap(subcommand)] 50 | pub model: ModelSubcommand, 51 | } 52 | 53 | impl TrainOpts { 54 | pub fn to_train_config(&self) -> TrainConfig { 55 | let train_info = TrainInfo::new( 56 | self.common.corpus.clone(), 57 | self.common.output.clone(), 58 | self.common 59 | .threads 60 | .unwrap_or_else(|| cmp::min(num_cpus::get() / 2, 20)), 61 | ); 62 | 63 | let common_config = CommonConfig { 64 | loss: LossType::LogisticNegativeSampling, 65 | dims: self.common.dims, 66 | epochs: self.common.epochs, 67 | format: self.common.format.into(), 68 | negative_samples: self.common.negative_samples, 69 | lr: self.common.lr, 70 | }; 71 | 72 | let model_config = match &self.model { 73 | ModelSubcommand::Deps { dep_embeds, .. } => ModelConfig::DepEmbeds(DepembedsConfig { 74 | depth: dep_embeds.dependency_depth, 75 | use_root: dep_embeds.use_root, 76 | normalize: dep_embeds.normalize_context, 77 | projectivize: dep_embeds.projectivize, 78 | untyped: dep_embeds.untyped_deps, 79 | }), 80 | ModelSubcommand::Skipgram { skipgram, .. } => ModelConfig::SkipGram(SkipGramConfig { 81 | model: skipgram.model.into(), 82 | context_size: skipgram.context, 83 | }), 84 | }; 85 | 86 | TrainConfig { 87 | train_info, 88 | common_config, 89 | model_config, 90 | vocab_config: self.vocab_args().to_vocab_config(), 91 | } 92 | } 93 | 94 | pub fn vocab_args(&self) -> &VocabSubcommand { 95 | match &self.model { 96 | ModelSubcommand::Deps { vocab, .. } => vocab, 97 | ModelSubcommand::Skipgram { vocab, .. } => vocab, 98 | } 99 | } 100 | } 101 | 102 | #[derive(Clone, Debug, Parser)] 103 | pub enum ModelSubcommand { 104 | /// Train a dependency embeddings model 105 | Deps { 106 | #[clap(flatten)] 107 | dep_embeds: DepsOpts, 108 | 109 | #[clap(subcommand)] 110 | vocab: VocabSubcommand, 111 | }, 112 | 113 | /// Train a skip-gram model 114 | Skipgram { 115 | #[clap(flatten)] 116 | skipgram: SkipGramOpts, 117 | 118 | #[clap(subcommand)] 119 | vocab: VocabSubcommand, 120 | }, 121 | } 122 | 123 | #[derive(Clone, Debug, Parser)] 124 | pub enum VocabSubcommand { 125 | /// Vocabulary using n-gram bucketing 126 | Buckets { 127 | /// Number of buckets: 2^EXP 128 | #[clap(long, default_value = "21")] 129 | buckets: u32, 130 | 131 | #[clap(flatten)] 132 | vocab_common: CommonVocabOpts, 133 | 134 | /// Hash indexer type 135 | #[clap(value_enum, long, default_value = "finalfusion")] 136 | hash_indexer: BucketIndexerArg, 137 | 138 | #[clap(flatten)] 139 | subword_common: CommonSubwordVocabOpts, 140 | }, 141 | 142 | /// Vocabulary using explicit n-grams 143 | Explicit { 144 | #[clap(flatten)] 145 | common: CommonVocabOpts, 146 | 147 | /// Minimum ngram frequency 148 | #[clap(long, default_value = "5")] 149 | ngram_mincount: usize, 150 | 151 | /// Target ngram vocab size 152 | #[clap(long)] 153 | ngram_target_size: Option, 154 | 155 | #[clap(flatten)] 156 | subword_common: CommonSubwordVocabOpts, 157 | }, 158 | 159 | /// Floret vocab. 160 | Floret { 161 | #[clap(flatten)] 162 | common: CommonVocabOpts, 163 | 164 | /// Number of buckets 165 | #[clap(long, default_value = "100000")] 166 | buckets: u64, 167 | 168 | /// Number of hashes (1-4) 169 | #[clap(long, default_value = "2")] 170 | n_hashes: u32, 171 | 172 | #[clap(long, default_value = "2166136261")] 173 | seed: u32, 174 | 175 | #[clap(flatten)] 176 | subword_common: CommonSubwordVocabOpts, 177 | }, 178 | 179 | /// Vocabulary without n-grams 180 | Simple { 181 | #[clap(flatten)] 182 | common: CommonVocabOpts, 183 | }, 184 | } 185 | 186 | impl VocabSubcommand { 187 | pub fn to_vocab_config(&self) -> VocabConfig { 188 | match self { 189 | VocabSubcommand::Buckets { 190 | buckets, 191 | vocab_common, 192 | hash_indexer, 193 | subword_common, 194 | } => VocabConfig::SubwordVocab(SubwordVocabConfig { 195 | cutoff: vocab_common 196 | .target_size 197 | .map(Cutoff::TargetSize) 198 | .unwrap_or(Cutoff::MinCount(vocab_common.mincount)), 199 | discard_threshold: vocab_common.discard, 200 | min_n: subword_common.minn, 201 | max_n: subword_common.maxn, 202 | indexer: BucketConfig { 203 | buckets_exp: *buckets, 204 | indexer_type: hash_indexer.into(), 205 | }, 206 | }), 207 | VocabSubcommand::Explicit { 208 | common, 209 | ngram_mincount, 210 | ngram_target_size, 211 | subword_common: subword_command, 212 | } => VocabConfig::NGramVocab(SubwordVocabConfig { 213 | cutoff: common 214 | .target_size 215 | .map(Cutoff::TargetSize) 216 | .unwrap_or(Cutoff::MinCount(common.mincount)), 217 | discard_threshold: common.discard, 218 | min_n: subword_command.minn, 219 | max_n: subword_command.maxn, 220 | indexer: NGramConfig { 221 | cutoff: ngram_target_size 222 | .map(Cutoff::TargetSize) 223 | .unwrap_or(Cutoff::MinCount(*ngram_mincount)), 224 | }, 225 | }), 226 | VocabSubcommand::Floret { 227 | common, 228 | buckets, 229 | n_hashes, 230 | seed, 231 | subword_common, 232 | } => VocabConfig::FloretVocab(SubwordVocabConfig { 233 | cutoff: common 234 | .target_size 235 | .map(Cutoff::TargetSize) 236 | .unwrap_or(Cutoff::MinCount(common.mincount)), 237 | discard_threshold: common.discard, 238 | min_n: subword_common.minn, 239 | max_n: subword_common.maxn, 240 | indexer: FloretConfig { 241 | buckets: *buckets, 242 | n_hashes: *n_hashes, 243 | seed: *seed, 244 | }, 245 | }), 246 | VocabSubcommand::Simple { common } => VocabConfig::SimpleVocab(SimpleVocabConfig { 247 | cutoff: common 248 | .target_size 249 | .map(Cutoff::TargetSize) 250 | .unwrap_or(Cutoff::MinCount(common.mincount)), 251 | discard_threshold: common.discard, 252 | }), 253 | } 254 | } 255 | } 256 | 257 | #[derive(ValueEnum, Clone, Copy, Debug)] 258 | pub enum BucketIndexerArg { 259 | #[clap(name = "finalfusion")] 260 | Finalfusion, 261 | 262 | #[clap(name = "fasttext")] 263 | FastText, 264 | } 265 | 266 | impl From<&BucketIndexerArg> for BucketIndexerType { 267 | fn from(arg: &BucketIndexerArg) -> Self { 268 | match arg { 269 | BucketIndexerArg::Finalfusion => BucketIndexerType::Finalfusion, 270 | BucketIndexerArg::FastText => BucketIndexerType::FastText, 271 | } 272 | } 273 | } 274 | 275 | #[derive(Clone, Debug, Parser)] 276 | pub struct CommonTrainOpts { 277 | /// Training corpus 278 | pub corpus: String, 279 | 280 | /// Embedding dimensionality 281 | #[clap(long, default_value = "300")] 282 | pub dims: u32, 283 | 284 | /// Number of epochs. 285 | #[clap(long, default_value = "15")] 286 | pub epochs: u32, 287 | 288 | /// Output format. 289 | #[clap(value_enum, long, short, default_value = "finalfusion")] 290 | pub format: EmbeddingFormatArg, 291 | 292 | /// Negative samples per word 293 | #[clap(long, default_value = "5")] 294 | pub negative_samples: u32, 295 | 296 | /// Initial learning rate 297 | #[clap(long, default_value = "0.05")] 298 | pub lr: f32, 299 | 300 | /// File to write the embeddings to 301 | pub output: String, 302 | 303 | /// Number of threads (default: min(logical_cpus / 2, 20)) 304 | #[clap(long)] 305 | pub threads: Option, 306 | } 307 | 308 | #[derive(ValueEnum, Clone, Copy, Debug, Eq, PartialEq)] 309 | pub enum EmbeddingFormatArg { 310 | #[clap(name = "fasttext")] 311 | FastText, 312 | 313 | #[clap(name = "floret")] 314 | Floret, 315 | 316 | #[clap(name = "finalfusion")] 317 | FinalFusion, 318 | 319 | #[clap(name = "word2vec")] 320 | Word2Vec, 321 | 322 | #[clap(name = "text")] 323 | Text, 324 | 325 | #[clap(name = "textdims")] 326 | TextDims, 327 | } 328 | 329 | impl From for EmbeddingFormat { 330 | fn from(format: EmbeddingFormatArg) -> Self { 331 | match format { 332 | EmbeddingFormatArg::FastText => EmbeddingFormat::FastText, 333 | EmbeddingFormatArg::Floret => EmbeddingFormat::Floret, 334 | EmbeddingFormatArg::FinalFusion => EmbeddingFormat::FinalFusion, 335 | EmbeddingFormatArg::Word2Vec => EmbeddingFormat::Word2Vec, 336 | EmbeddingFormatArg::Text => EmbeddingFormat::Text, 337 | EmbeddingFormatArg::TextDims => EmbeddingFormat::TextDims, 338 | } 339 | } 340 | } 341 | 342 | #[derive(Clone, Copy, Debug, Parser)] 343 | pub struct DepsOpts { 344 | /// Dependency depth 345 | #[clap(long, default_value = "1")] 346 | pub dependency_depth: u32, 347 | 348 | /// Use root when extracting dependency contexts. 349 | #[clap(long)] 350 | pub use_root: bool, 351 | 352 | /// Normalize contexts 353 | #[clap(long)] 354 | pub normalize_context: bool, 355 | 356 | /// Projectivize dependency graphs before training. 357 | #[clap(long)] 358 | pub projectivize: bool, 359 | 360 | /// Do not use dependency relation labels. 361 | #[clap(long)] 362 | pub untyped_deps: bool, 363 | } 364 | 365 | /// Hyperparameters for SkipGram-like models. 366 | #[derive(Clone, Copy, Debug, Parser)] 367 | pub struct SkipGramOpts { 368 | /// Model 369 | #[clap(value_enum, long, default_value = "skipgram")] 370 | pub model: SkipgramModelTypeArg, 371 | 372 | /// Context size 373 | #[clap(long, default_value = "10")] 374 | pub context: u32, 375 | } 376 | 377 | #[derive(ValueEnum, Clone, Copy, Debug)] 378 | pub enum SkipgramModelTypeArg { 379 | #[clap(name = "skipgram")] 380 | SkipGram, 381 | 382 | #[clap(name = "structgram")] 383 | StructuredSkipGram, 384 | 385 | #[clap(name = "dirgram")] 386 | DirectionalSkipgram, 387 | } 388 | 389 | impl From for SkipgramModelType { 390 | fn from(model_type: SkipgramModelTypeArg) -> Self { 391 | match model_type { 392 | SkipgramModelTypeArg::SkipGram => SkipgramModelType::SkipGram, 393 | SkipgramModelTypeArg::StructuredSkipGram => SkipgramModelType::StructuredSkipGram, 394 | SkipgramModelTypeArg::DirectionalSkipgram => SkipgramModelType::DirectionalSkipgram, 395 | } 396 | } 397 | } 398 | 399 | #[derive(Clone, Copy, Debug, Parser)] 400 | pub struct CommonVocabOpts { 401 | /// Discard threshold 402 | #[clap(long, default_value = "1e-4")] 403 | pub discard: f32, 404 | 405 | /// Minimum token frequency 406 | #[clap(long, default_value = "5")] 407 | pub mincount: usize, 408 | 409 | /// Target vocab size 410 | #[clap(long)] 411 | target_size: Option, 412 | } 413 | 414 | #[derive(Clone, Copy, Debug, Parser)] 415 | pub struct CommonSubwordVocabOpts { 416 | /// Minimum ngram length 417 | #[clap(long, default_value = "3")] 418 | pub minn: u32, 419 | 420 | /// Maximum ngram length 421 | #[clap(long, default_value = "6")] 422 | pub maxn: u32, 423 | } 424 | -------------------------------------------------------------------------------- /src/vocab/subword.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Borrow; 2 | use std::collections::HashMap; 3 | use std::hash::Hash; 4 | 5 | use finalfusion::compat::fasttext::FastTextIndexer; 6 | use finalfusion::compat::floret::FloretIndexer; 7 | use finalfusion::subword::{ 8 | BucketIndexer, ExplicitIndexer, FinalfusionHashIndexer, Indexer, NGrams, SubwordIndices, 9 | }; 10 | use finalfusion::vocab::{SubwordVocab as FiFuSubwordVocab, VocabWrap}; 11 | 12 | use crate::idx::{WordIdx, WordWithSubwordsIdx}; 13 | use crate::vocab::{bracket, create_discards, create_indices}; 14 | use crate::{ 15 | BucketConfig, BucketIndexerType, CountedType, FloretConfig, NGramConfig, SubwordVocabConfig, 16 | Vocab, VocabBuilder, Word, 17 | }; 18 | 19 | /// A corpus vocabulary with subword lookup. 20 | #[derive(Clone)] 21 | pub struct SubwordVocab { 22 | config: SubwordVocabConfig, 23 | words: Vec, 24 | indexer: I, 25 | subwords: Vec>, 26 | discards: Vec, 27 | index: HashMap, 28 | n_tokens: usize, 29 | } 30 | 31 | impl SubwordVocab 32 | where 33 | C: Copy + Clone, 34 | I: Indexer, 35 | { 36 | /// Construct a new vocabulary. 37 | pub fn new( 38 | config: SubwordVocabConfig, 39 | words: Vec, 40 | n_tokens: usize, 41 | indexer: I, 42 | ) -> Self { 43 | let index = create_indices(&words); 44 | let subwords = Self::create_subword_indices( 45 | config.min_n as usize, 46 | config.max_n as usize, 47 | &indexer, 48 | &words, 49 | ); 50 | let discards = create_discards(config.discard_threshold, &words, n_tokens); 51 | SubwordVocab { 52 | config, 53 | words, 54 | indexer, 55 | subwords, 56 | discards, 57 | index, 58 | n_tokens, 59 | } 60 | } 61 | 62 | fn create_subword_indices( 63 | min_n: usize, 64 | max_n: usize, 65 | indexer: &I, 66 | words: &[Word], 67 | ) -> Vec> { 68 | let mut subword_indices = Vec::new(); 69 | 70 | for word in words { 71 | subword_indices.push( 72 | bracket(word.word()) 73 | .as_str() 74 | .subword_indices(min_n, max_n, indexer) 75 | .map(|idx| idx + words.len() as u64) 76 | .collect(), 77 | ); 78 | } 79 | 80 | assert_eq!(words.len(), subword_indices.len()); 81 | 82 | subword_indices 83 | } 84 | 85 | /// Get the given word. 86 | pub fn word(&self, word: &str) -> Option<&Word> { 87 | self.idx(word) 88 | .map(|idx| &self.words[idx.word_idx() as usize]) 89 | } 90 | } 91 | 92 | impl SubwordVocab { 93 | pub(crate) fn subword_indices_idx(&self, idx: usize) -> Option<&[u64]> { 94 | self.subwords.get(idx).map(|v| v.as_slice()) 95 | } 96 | } 97 | 98 | impl Vocab for SubwordVocab 99 | where 100 | C: Copy + Clone, 101 | I: Indexer, 102 | { 103 | type VocabType = String; 104 | type IdxType = WordWithSubwordsIdx; 105 | type Config = SubwordVocabConfig; 106 | 107 | fn config(&self) -> SubwordVocabConfig { 108 | self.config 109 | } 110 | 111 | fn idx(&self, key: &Q) -> Option 112 | where 113 | Self::VocabType: Borrow, 114 | Q: Hash + ?Sized + Eq, 115 | { 116 | self.index.get(key).and_then(|idx| { 117 | self.subword_indices_idx(*idx) 118 | .map(|v| WordWithSubwordsIdx::new(*idx as u64, v)) 119 | }) 120 | } 121 | 122 | fn discard(&self, idx: usize) -> f32 { 123 | self.discards[idx] 124 | } 125 | 126 | fn n_input_types(&self) -> usize { 127 | self.len() + self.indexer.upper_bound() as usize 128 | } 129 | 130 | fn types(&self) -> &[Word] { 131 | &self.words 132 | } 133 | 134 | fn n_types(&self) -> usize { 135 | self.n_tokens 136 | } 137 | } 138 | 139 | /// Constructs a `SubwordVocab` from a `VocabBuilder` where `T: Into`. 140 | impl From, T>> for SubwordVocab 141 | where 142 | T: Hash + Eq + Into, 143 | I: BucketIndexer, 144 | { 145 | fn from(builder: VocabBuilder, T>) -> Self { 146 | let config = builder.config; 147 | let words = config.cutoff.filter(builder.items); 148 | let buckets = match config.indexer.indexer_type { 149 | BucketIndexerType::Finalfusion => config.indexer.buckets_exp as usize, 150 | BucketIndexerType::FastText => 2u64.pow(config.indexer.buckets_exp) as usize, 151 | }; 152 | SubwordVocab::new(config, words, builder.n_items, I::new(buckets)) 153 | } 154 | } 155 | 156 | /// Constructs a `SubwordVocab` from a `VocabBuilder` where `T: Into`. 157 | impl From, T>> 158 | for SubwordVocab 159 | where 160 | T: Hash + Eq + Into, 161 | { 162 | fn from(builder: VocabBuilder, T>) -> Self { 163 | let config = builder.config; 164 | let words = config.cutoff.filter(builder.items); 165 | let buckets = config.indexer.buckets; 166 | let n_hashes = config.indexer.n_hashes; 167 | let seed = config.indexer.seed; 168 | SubwordVocab::new( 169 | config, 170 | words, 171 | builder.n_items, 172 | FloretIndexer::new(buckets, n_hashes, seed), 173 | ) 174 | } 175 | } 176 | 177 | /// Constructs a `SubwordVocab` from a `VocabBuilder` where `T: Into`. 178 | impl From, T>> 179 | for SubwordVocab 180 | where 181 | T: Hash + Eq + Into, 182 | { 183 | fn from(builder: VocabBuilder, T>) -> Self { 184 | let config = builder.config; 185 | let words: Vec = builder.config.cutoff.filter(builder.items); 186 | let mut ngram_counts: HashMap = HashMap::new(); 187 | for word in words.iter() { 188 | for ngram in NGrams::new( 189 | &bracket(word.label()), 190 | config.min_n as usize, 191 | config.max_n as usize, 192 | ) 193 | .map(|ngram| ngram.to_string()) 194 | { 195 | let cnt = ngram_counts.entry(ngram).or_default(); 196 | *cnt += word.count; 197 | } 198 | } 199 | 200 | let ngrams: Vec> = config.indexer.cutoff.filter(ngram_counts); 201 | let ngrams = ngrams 202 | .into_iter() 203 | .map(|counted| counted.label) 204 | .collect::>(); 205 | SubwordVocab::new(config, words, builder.n_items, ExplicitIndexer::new(ngrams)) 206 | } 207 | } 208 | 209 | macro_rules! impl_into_vocabwrap ( 210 | ($vocab:ty) => { 211 | impl From<$vocab> for VocabWrap { 212 | fn from(vocab: $vocab) -> Self { 213 | let config = vocab.config; 214 | let words = vocab 215 | .words 216 | .into_iter() 217 | .map(|word| word.label) 218 | .collect::>(); 219 | FiFuSubwordVocab::new(words, config.min_n, config.max_n, vocab.indexer).into() 220 | } 221 | } 222 | } 223 | ); 224 | 225 | impl_into_vocabwrap!(SubwordVocab); 226 | impl_into_vocabwrap!(SubwordVocab); 227 | impl_into_vocabwrap!(SubwordVocab); 228 | impl_into_vocabwrap!(SubwordVocab); 229 | 230 | #[cfg(test)] 231 | mod tests { 232 | use super::{SubwordVocab, Vocab, VocabBuilder}; 233 | use crate::config::SubwordVocabConfig; 234 | use crate::idx::WordIdx; 235 | use crate::{util, BucketConfig, Cutoff, NGramConfig}; 236 | 237 | use crate::config::BucketIndexerType::Finalfusion; 238 | use finalfusion::subword::{ExplicitIndexer, FinalfusionHashIndexer, Indexer}; 239 | 240 | const TEST_SUBWORDCONFIG: SubwordVocabConfig = SubwordVocabConfig { 241 | discard_threshold: 1e-4, 242 | cutoff: Cutoff::MinCount(2), 243 | max_n: 6, 244 | min_n: 3, 245 | indexer: BucketConfig { 246 | buckets_exp: 21, 247 | indexer_type: Finalfusion, 248 | }, 249 | }; 250 | 251 | const TEST_NGRAMCONFIG: SubwordVocabConfig = SubwordVocabConfig { 252 | discard_threshold: 1e-4, 253 | cutoff: Cutoff::MinCount(2), 254 | max_n: 6, 255 | min_n: 3, 256 | indexer: NGramConfig { 257 | cutoff: Cutoff::MinCount(2), 258 | }, 259 | }; 260 | 261 | #[test] 262 | pub fn vocab_is_sorted() { 263 | let mut config = TEST_SUBWORDCONFIG; 264 | config.cutoff = Cutoff::MinCount(1); 265 | 266 | let mut builder: VocabBuilder<_, &str> = VocabBuilder::new(config); 267 | builder.count("to"); 268 | builder.count("be"); 269 | builder.count("or"); 270 | builder.count("not"); 271 | builder.count("to"); 272 | builder.count("be"); 273 | builder.count(""); 274 | 275 | let vocab: SubwordVocab<_, FinalfusionHashIndexer> = builder.into(); 276 | let words = vocab.types(); 277 | 278 | for idx in 1..words.len() { 279 | assert!( 280 | words[idx - 1].count >= words[idx].count, 281 | "Words are not frequency-sorted" 282 | ); 283 | } 284 | } 285 | 286 | #[test] 287 | pub fn test_bucket_vocab_builder() { 288 | let mut builder: VocabBuilder<_, &str> = VocabBuilder::new(TEST_SUBWORDCONFIG); 289 | builder.count("to"); 290 | builder.count("be"); 291 | builder.count("or"); 292 | builder.count("not"); 293 | builder.count("to"); 294 | builder.count("be"); 295 | builder.count(""); 296 | 297 | let vocab: SubwordVocab<_, FinalfusionHashIndexer> = builder.into(); 298 | 299 | // 'or' and 'not' should be filtered due to the minimum count. 300 | assert_eq!(vocab.len(), 2); 301 | 302 | assert_eq!(vocab.n_types(), 7); 303 | 304 | // Check expected properties of 'to'. 305 | let to = vocab.word("to").unwrap(); 306 | assert_eq!("to", to.word()); 307 | assert_eq!(2, to.count); 308 | assert_eq!( 309 | vec![1141946, 215571, 1324229, 0], 310 | vocab.idx("to").unwrap().into_iter().collect::>() 311 | ); 312 | assert!(util::close( 313 | 0.019058, 314 | vocab.discard(vocab.idx("to").unwrap().word_idx() as usize), 315 | 1e-5, 316 | )); 317 | 318 | // Check expected properties of 'be'. 319 | let be = vocab.word("be").unwrap(); 320 | assert_eq!("be", be.label); 321 | assert_eq!(2, be.count); 322 | assert_eq!( 323 | vec![277350, 1105487, 1482881, 1], 324 | vocab.idx("be").unwrap().into_iter().collect::>() 325 | ); 326 | assert!(util::close( 327 | 0.019058, 328 | vocab.discard(vocab.idx("be").unwrap().word_idx() as usize), 329 | 1e-5, 330 | )); 331 | 332 | // Check indices for an unknown word. 333 | assert!(vocab.idx("too").is_none()); 334 | } 335 | 336 | #[test] 337 | pub fn test_ngram_vocab_builder() { 338 | let mut builder: VocabBuilder<_, &str> = VocabBuilder::new(TEST_NGRAMCONFIG); 339 | builder.count("to"); 340 | builder.count("be"); 341 | builder.count("or"); 342 | builder.count("not"); 343 | builder.count("to"); 344 | builder.count("be"); 345 | builder.count(""); 346 | 347 | let vocab: SubwordVocab<_, ExplicitIndexer> = builder.into(); 348 | 349 | // 'or' and 'not' should be filtered due to the minimum count. 350 | assert_eq!(vocab.len(), 2); 351 | 352 | assert_eq!(vocab.n_types(), 7); 353 | 354 | // Check expected properties of 'to'. 355 | let to = vocab.word("to").unwrap(); 356 | assert_eq!("to", to.word()); 357 | assert_eq!(2, to.count); 358 | // 2x ["", "to>", "", "be>"] 359 | // sorted ["to>", "be>", "", "", "", "be>", "", "", ">() 370 | ); 371 | assert!(util::close( 372 | 0.019058, 373 | vocab.discard(vocab.idx("to").unwrap().word_idx() as usize), 374 | 1e-5, 375 | )); 376 | 377 | // Check expected properties of 'be'. 378 | let be = vocab.word("be").unwrap(); 379 | assert_eq!("be", be.label); 380 | assert_eq!(2, be.count); 381 | // see above explanation 382 | assert_eq!( 383 | vec![6, 7, 3, 1], 384 | vocab.idx("be").unwrap().into_iter().collect::>() 385 | ); 386 | assert!(util::close( 387 | 0.019058, 388 | vocab.discard(vocab.idx("be").unwrap().word_idx() as usize), 389 | 1e-5, 390 | )); 391 | 392 | // Check indices for an unknown word. Only " 35 | /// index mappings and word discard probabilities. Additionally the trainer 36 | /// provides the logic to transform some input to an iterator of training 37 | /// examples. 38 | /// 39 | /// `TrainModel` stores the matrices as `HogwildArray`s to share parameters 40 | /// between clones of the same model. The trainer is also shared between 41 | /// clones due to memory considerations. 42 | #[derive(Clone)] 43 | pub struct TrainModel { 44 | trainer: T, 45 | input: HogwildArray2, 46 | output: HogwildArray2, 47 | } 48 | 49 | impl From for TrainModel 50 | where 51 | T: Trainer, 52 | { 53 | /// Construct a model from a Trainer. 54 | /// 55 | /// This randomly initializes the input and output matrices using a 56 | /// uniform distribution in the range [-1/dims, 1/dims). 57 | /// 58 | /// The number of rows of the input matrix is the vocabulary size 59 | /// plus the number of buckets for subword units. The number of rows 60 | /// of the output matrix is the number of possible outputs for the model. 61 | fn from(trainer: T) -> TrainModel { 62 | let config = *trainer.config(); 63 | let init_bound = 1.0 / config.dims as f32; 64 | let distribution = Uniform::new_inclusive(-init_bound, init_bound); 65 | 66 | let input = Array2::random( 67 | (trainer.input_vocab().n_input_types(), config.dims as usize), 68 | distribution, 69 | ) 70 | .into(); 71 | let output = Array2::random( 72 | (trainer.n_output_types(), config.dims as usize), 73 | distribution, 74 | ) 75 | .into(); 76 | TrainModel { 77 | trainer, 78 | input, 79 | output, 80 | } 81 | } 82 | } 83 | 84 | impl TrainModel 85 | where 86 | T: Trainer, 87 | { 88 | /// Get the model configuration. 89 | pub fn config(&self) -> &CommonConfig { 90 | self.trainer.config() 91 | } 92 | } 93 | 94 | impl TrainModel 95 | where 96 | T: Trainer, 97 | V: Vocab, 98 | { 99 | /// Get this model's input vocabulary. 100 | pub fn input_vocab(&self) -> &V { 101 | self.trainer.input_vocab() 102 | } 103 | } 104 | 105 | impl TrainModel { 106 | /// Get this model's trainer mutably. 107 | pub fn trainer(&mut self) -> &mut T { 108 | &mut self.trainer 109 | } 110 | 111 | /// Get the mean input embedding of the given indices. 112 | pub(crate) fn mean_input_embedding<'a, I>(&self, idx: &'a I) -> Array1 113 | where 114 | I: WordIdx, 115 | &'a I: IntoIterator, 116 | { 117 | if idx.len() == 1 { 118 | self.input 119 | .view() 120 | .row(idx.into_iter().next().unwrap() as usize) 121 | .to_owned() 122 | } else { 123 | Self::mean_embedding(self.input.view(), idx) 124 | } 125 | } 126 | 127 | /// Get the mean input embedding of the given indices. 128 | fn mean_embedding<'a, I>(embeds: ArrayView2, indices: &'a I) -> Array1 129 | where 130 | I: WordIdx, 131 | &'a I: IntoIterator, 132 | { 133 | let mut embed = Array1::zeros((embeds.ncols(),)); 134 | let len = indices.len(); 135 | for idx in indices { 136 | scaled_add( 137 | embed.view_mut(), 138 | embeds.index_axis(Axis(0), idx as usize), 139 | 1.0, 140 | ); 141 | } 142 | 143 | scale(embed.view_mut(), 1.0 / len as f32); 144 | 145 | embed 146 | } 147 | 148 | /// Get the input embedding with the given index. 149 | #[allow(dead_code)] 150 | #[inline] 151 | pub(crate) fn input_embedding(&self, idx: usize) -> ArrayView1 { 152 | self.input.subview(Axis(0), idx) 153 | } 154 | 155 | /// Get the input embedding with the given index mutably. 156 | #[inline] 157 | pub(crate) fn input_embedding_mut(&mut self, idx: usize) -> ArrayViewMut1 { 158 | self.input.subview_mut(Axis(0), idx) 159 | } 160 | 161 | pub(crate) fn into_parts(self) -> Result<(T, Array2)> { 162 | let input = match Arc::try_unwrap(self.input.into_inner()) { 163 | Ok(input) => input.into_inner(), 164 | Err(_) => bail!("Cannot unwrap input matrix."), 165 | }; 166 | 167 | Ok((self.trainer, input)) 168 | } 169 | 170 | /// Get the output embedding with the given index. 171 | #[inline] 172 | pub(crate) fn output_embedding(&self, idx: usize) -> ArrayView1 { 173 | self.output.subview(Axis(0), idx) 174 | } 175 | 176 | /// Get the output embedding with the given index mutably. 177 | #[inline] 178 | pub(crate) fn output_embedding_mut(&mut self, idx: usize) -> ArrayViewMut1 { 179 | self.output.subview_mut(Axis(0), idx) 180 | } 181 | } 182 | 183 | impl WriteModelBinary for TrainModel 184 | where 185 | W: Seek + Write, 186 | T: Trainer, 187 | V: Vocab + Into, 188 | V::VocabType: ToString, 189 | for<'a> &'a V::IdxType: IntoIterator, 190 | M: Serialize, 191 | { 192 | fn write_model_binary( 193 | self, 194 | write: &mut W, 195 | mut train_info: TrainInfo, 196 | format: EmbeddingFormat, 197 | ) -> Result<()> { 198 | let (trainer, mut input_matrix) = self.into_parts()?; 199 | let mut metadata = Map::try_from(trainer.to_metadata())?; 200 | let build_info = Value::try_from(VersionInfo::new())?; 201 | metadata.insert("version_info".to_string(), build_info); 202 | train_info.set_end(); 203 | let train_info = Value::try_from(train_info)?; 204 | metadata.insert("training_info".to_string(), train_info); 205 | 206 | // Compute and write word embeddings. 207 | let mut norms = vec![0f32; trainer.input_vocab().len()]; 208 | for (i, (norm, word)) in norms 209 | .iter_mut() 210 | .zip(trainer.input_vocab().types()) 211 | .take(trainer.input_vocab().len()) 212 | .enumerate() 213 | { 214 | let input = trainer.input_vocab().idx(word.label()).unwrap(); 215 | let mut embed = Self::mean_embedding(input_matrix.view(), &input); 216 | *norm = l2_normalize(embed.view_mut()); 217 | input_matrix.index_axis_mut(Axis(0), i).assign(&embed); 218 | } 219 | 220 | let vocab: VocabWrap = trainer.try_into_input_vocab()?.into(); 221 | let storage = NdArray::new(input_matrix); 222 | let norms = NdNorms::new(Array1::from(norms)); 223 | 224 | use self::EmbeddingFormat::*; 225 | match format { 226 | FastText => { 227 | let vocab = match vocab { 228 | VocabWrap::FastTextSubwordVocab(vocab) => vocab, 229 | _ => bail!("Only fastText vocabularies can be written to fastText files"), 230 | }; 231 | Embeddings::new(Some(Metadata::new(metadata)), vocab, storage, norms) 232 | .write_fasttext(write)? 233 | } 234 | FinalFusion => Embeddings::new(Some(Metadata::new(metadata)), vocab, storage, norms) 235 | .write_embeddings(write)?, 236 | Floret => { 237 | let vocab = match vocab { 238 | VocabWrap::FloretSubwordVocab(vocab) => vocab, 239 | _ => bail!("Only floret vocabularies can be written to floret files"), 240 | }; 241 | Embeddings::new(Some(Metadata::new(metadata)), vocab, storage, norms) 242 | .write_floret_text(write)? 243 | } 244 | Word2Vec => Embeddings::new(Some(Metadata::new(metadata)), vocab, storage, norms) 245 | .write_word2vec_binary(write, true)?, 246 | Text => Embeddings::new(Some(Metadata::new(metadata)), vocab, storage, norms) 247 | .write_text(write, true)?, 248 | TextDims => Embeddings::new(Some(Metadata::new(metadata)), vocab, storage, norms) 249 | .write_text_dims(write, true)?, 250 | }; 251 | 252 | Ok(()) 253 | } 254 | } 255 | 256 | /// Trainer Trait. 257 | pub trait Trainer { 258 | type InputVocab: Vocab; 259 | type Metadata; 260 | 261 | /// Get the trainer's input vocabulary. 262 | fn input_vocab(&self) -> &Self::InputVocab; 263 | 264 | /// Destruct the trainer and get the input vocabulary. 265 | fn try_into_input_vocab(self) -> Result; 266 | 267 | /// Get the number of possible input types. 268 | /// 269 | /// In a model with subword units this value is calculated as: 270 | /// `2^n_buckets + input_vocab.len()`. 271 | fn n_input_types(&self) -> usize; 272 | 273 | /// Get the number of possible outputs. 274 | /// 275 | /// In a structured skipgram model this value is calculated as: 276 | /// `output_vocab.len() * context_size * 2` 277 | fn n_output_types(&self) -> usize; 278 | 279 | /// Get this Trainer's common hyperparameters. 280 | fn config(&self) -> &CommonConfig; 281 | 282 | /// Get this Trainer's configuration. 283 | fn to_metadata(&self) -> Self::Metadata; 284 | } 285 | 286 | /// TrainIterFrom. 287 | /// 288 | /// This trait defines how some input `&S` is transformed into an iterator of training examples. 289 | pub trait TrainIterFrom<'a, S> 290 | where 291 | S: ?Sized, 292 | { 293 | type Iter: Iterator; 294 | type Focus; 295 | type Contexts: IntoIterator; 296 | 297 | fn train_iter_from(&mut self, sequence: &S) -> Self::Iter; 298 | } 299 | 300 | /// Negative Samples 301 | /// 302 | /// This trait defines a method on how to draw a negative sample given some output. The return value 303 | /// should follow the distribution of the underlying output vocabulary. 304 | pub trait NegativeSamples { 305 | fn negative_sample(&mut self, output: usize) -> usize; 306 | } 307 | 308 | #[cfg(test)] 309 | mod tests { 310 | use finalfusion::subword::FinalfusionHashIndexer; 311 | use ndarray::Array2; 312 | use rand::SeedableRng; 313 | use rand_xorshift::XorShiftRng; 314 | 315 | use super::TrainModel; 316 | use crate::config::BucketIndexerType::Finalfusion; 317 | use crate::config::SubwordVocabConfig; 318 | use crate::idx::WordWithSubwordsIdx; 319 | use crate::io::EmbeddingFormat; 320 | use crate::skipgram_trainer::SkipgramTrainer; 321 | use crate::util::all_close; 322 | use crate::{ 323 | BucketConfig, CommonConfig, Cutoff, LossType, SkipGramConfig, SkipgramModelType, 324 | SubwordVocab, VocabBuilder, 325 | }; 326 | 327 | const TEST_COMMON_CONFIG: CommonConfig = CommonConfig { 328 | dims: 3, 329 | epochs: 5, 330 | format: EmbeddingFormat::FinalFusion, 331 | loss: LossType::LogisticNegativeSampling, 332 | lr: 0.05, 333 | negative_samples: 5, 334 | }; 335 | 336 | const TEST_SKIP_CONFIG: SkipGramConfig = SkipGramConfig { 337 | context_size: 5, 338 | model: SkipgramModelType::SkipGram, 339 | }; 340 | 341 | const VOCAB_CONF: SubwordVocabConfig = SubwordVocabConfig { 342 | discard_threshold: 1e-4, 343 | cutoff: Cutoff::MinCount(2), 344 | max_n: 6, 345 | min_n: 3, 346 | indexer: BucketConfig { 347 | buckets_exp: 21, 348 | indexer_type: Finalfusion, 349 | }, 350 | }; 351 | 352 | #[test] 353 | pub fn model_embed_methods() { 354 | let mut vocab_config = VOCAB_CONF; 355 | vocab_config.cutoff = Cutoff::MinCount(1); 356 | 357 | let common_config = TEST_COMMON_CONFIG; 358 | let skipgram_config = TEST_SKIP_CONFIG; 359 | // We just need some bogus vocabulary 360 | let mut builder: VocabBuilder<_, String> = VocabBuilder::new(vocab_config); 361 | builder.count("bla".to_string()); 362 | let vocab: SubwordVocab<_, FinalfusionHashIndexer> = builder.into(); 363 | 364 | let input = Array2::from_shape_vec((2, 3), vec![1., 2., 3., 4., 5., 6.]) 365 | .unwrap() 366 | .into(); 367 | let output = Array2::from_shape_vec((2, 3), vec![-1., -2., -3., -4., -5., -6.]) 368 | .unwrap() 369 | .into(); 370 | 371 | let mut model = TrainModel { 372 | trainer: SkipgramTrainer::new( 373 | vocab, 374 | XorShiftRng::from_entropy(), 375 | common_config, 376 | skipgram_config, 377 | ), 378 | input, 379 | output, 380 | }; 381 | 382 | // Input embeddings 383 | assert!(all_close( 384 | model.input_embedding(0).as_slice().unwrap(), 385 | &[1., 2., 3.], 386 | 1e-5 387 | )); 388 | assert!(all_close( 389 | model.input_embedding(1).as_slice().unwrap(), 390 | &[4., 5., 6.], 391 | 1e-5 392 | )); 393 | 394 | // Mutable input embeddings 395 | assert!(all_close( 396 | model.input_embedding_mut(0).as_slice().unwrap(), 397 | &[1., 2., 3.], 398 | 1e-5 399 | )); 400 | assert!(all_close( 401 | model.input_embedding_mut(1).as_slice().unwrap(), 402 | &[4., 5., 6.], 403 | 1e-5 404 | )); 405 | 406 | // Output embeddings 407 | assert!(all_close( 408 | model.output_embedding(0).as_slice().unwrap(), 409 | &[-1., -2., -3.], 410 | 1e-5 411 | )); 412 | assert!(all_close( 413 | model.output_embedding(1).as_slice().unwrap(), 414 | &[-4., -5., -6.], 415 | 1e-5 416 | )); 417 | 418 | // Mutable output embeddings 419 | assert!(all_close( 420 | model.output_embedding_mut(0).as_slice().unwrap(), 421 | &[-1., -2., -3.], 422 | 1e-5 423 | )); 424 | assert!(all_close( 425 | model.output_embedding_mut(1).as_slice().unwrap(), 426 | &[-4., -5., -6.], 427 | 1e-5 428 | )); 429 | 430 | // Mean input embedding. 431 | assert!(all_close( 432 | model 433 | .mean_input_embedding(&WordWithSubwordsIdx::new(0, vec![1])) 434 | .as_slice() 435 | .unwrap(), 436 | &[2.5, 3.5, 4.5], 437 | 1e-5 438 | )); 439 | } 440 | } 441 | --------------------------------------------------------------------------------