├── src ├── version.rs ├── port_buffer.rs ├── logging_layer.rs ├── multithread_helpers.rs ├── hogwild.rs ├── radix_tree.rs ├── buffer_handler.rs ├── block_relu.rs ├── block_loss_functions.rs ├── quantization.rs ├── block_helpers.rs ├── vwmap.rs ├── block_normalize.rs ├── lib.rs ├── cache.rs ├── optimizer.rs ├── block_lr.rs ├── feature_transform_executor.rs ├── main.rs └── cmdline.rs ├── .gitignore ├── benchmark ├── run_with_plots.sh ├── run_without_plots.sh ├── run_with_plots_intel.sh ├── clean_caches.py ├── cleanup.py ├── print_system_info.py ├── calc_loss.py ├── measure.py └── generate.py ├── benchmark_results.png ├── weight_patcher ├── Cargo.toml └── src │ └── main.rs ├── profile.sh ├── .github └── workflows │ ├── docker-image-mkl-build.yml │ ├── rust.yml │ ├── rust-Ubuntu18.yml │ └── prediction-and-learning-workflow.yml ├── examples ├── vw-compatibility │ ├── README.md │ ├── datasets │ │ └── vw_namespace_map.csv │ └── run.sh ├── basic │ ├── datasets │ │ └── vw_namespace_map.csv │ └── run.sh └── ffm │ ├── run_vw_equivalent.sh │ ├── run.sh │ ├── README.md │ ├── generate.py │ └── run_fw_with_prediction_tests.sh ├── CHANGELOG ├── LICENSE.md ├── Cargo.toml ├── Dockerfile ├── COMPATIBILITY.md ├── README.md ├── run_one.sh ├── BENCHMARK.md └── SPEED.md /src/version.rs: -------------------------------------------------------------------------------- 1 | pub static LATEST: &str = "0.2"; 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | Cargo.lock 2 | .idea/ 3 | target/ 4 | .DS_Store 5 | __pycache__/ 6 | -------------------------------------------------------------------------------- /benchmark/run_with_plots.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cargo build --release 3 | python3 benchmark.py fw all True 4 | -------------------------------------------------------------------------------- /benchmark_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/outbrain-inc/fwumious_wabbit/HEAD/benchmark_results.png -------------------------------------------------------------------------------- /benchmark/run_without_plots.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cargo build --release 4 | 5 | python3 benchmark.py fw all False 6 | -------------------------------------------------------------------------------- /benchmark/run_with_plots_intel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export RUSTFLAGS="-C opt-level=3 -C target-cpu=skylake" 4 | cargo build --release 5 | 6 | python3 benchmark.py fw all True 7 | -------------------------------------------------------------------------------- /weight_patcher/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "automl_patcher" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | gzp = "0.11.3" 10 | log = "0.4.17" 11 | env_logger = "0.10.0" -------------------------------------------------------------------------------- /profile.sh: -------------------------------------------------------------------------------- 1 | git clone git@github.com:brendangregg/FlameGraph.git 2 | git clone git@github.com:Yamakaky/rust-unmangle.git 3 | set -x 4 | perf record --call-graph dwarf,16384 -e cpu-clock -F 997 -- "$@" && perf script \ 5 | | FlameGraph/stackcollapse-perf.pl | sed rust-unmangle/rust-unmangle | FlameGraph/flamegraph.pl > flame.svg && firefox flame.svg 6 | -------------------------------------------------------------------------------- /benchmark/clean_caches.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys 3 | 4 | def rm_quietly(f): 5 | if os.path.isfile(f): 6 | os.remove(f) 7 | 8 | if __name__ == "__main__": 9 | rm_quietly("train.vw.cache") 10 | rm_quietly("train.vw.fwcache") 11 | rm_quietly("easy.vw.cache") 12 | rm_quietly("easy.vw.fwcache") 13 | rm_quietly("hard.vw.cache") 14 | rm_quietly("hard.vw.fwcache") 15 | 16 | -------------------------------------------------------------------------------- /.github/workflows/docker-image-mkl-build.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | 11 | build: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v3 17 | - name: Build the Docker image 18 | run: docker build . --file Dockerfile --tag my-image-name:$(date +%s) 19 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust-UbuntuLatest 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Build 20 | run: cargo build --verbose 21 | - name: Run tests 22 | run: cargo test --verbose 23 | -------------------------------------------------------------------------------- /.github/workflows/rust-Ubuntu18.yml: -------------------------------------------------------------------------------- 1 | name: Rust-Ubuntu18 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-18.04 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Build 20 | run: cargo build --verbose 21 | - name: Run tests 22 | run: cargo test --verbose 23 | -------------------------------------------------------------------------------- /.github/workflows/prediction-and-learning-workflow.yml: -------------------------------------------------------------------------------- 1 | name: FW-prediction-and-learning-test 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Build and test 20 | run: cd examples/ffm; bash run_fw_with_prediction_tests.sh 21 | -------------------------------------------------------------------------------- /examples/vw-compatibility/README.md: -------------------------------------------------------------------------------- 1 | # Vowpal Wabbit compatibility 2 | 3 | Under certain circumstances Fwumious Wabbit makes bit-by-bit the same 4 | predictions as Vowpal Wabbit. This is achieved because FW uses the same 5 | hashing mechanisms with the same constants as VW. 6 | 7 | This mode can be turned on by --vwcompat. This mode is slower as it 8 | does not use Adagrad look up tables. 9 | 10 | Compatibility also isn't perfect. There are multiple edge cases. 11 | 12 | -------------------------------------------------------------------------------- /benchmark/cleanup.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys 3 | 4 | def rm_quietly(f): 5 | if os.path.isfile(f): 6 | os.remove(f) 7 | 8 | if __name__ == "__main__": 9 | rm_quietly("train.vw") 10 | rm_quietly("train.vw.gz") 11 | rm_quietly("train.vw.cache") 12 | rm_quietly("train.vw.fwcache") 13 | rm_quietly("easy.vw") 14 | rm_quietly("easy.vw.cache") 15 | rm_quietly("easy.vw.fwcache") 16 | rm_quietly("hard.vw") 17 | rm_quietly("hard.vw.cache") 18 | rm_quietly("hard.vw.fwcache") 19 | 20 | -------------------------------------------------------------------------------- /src/port_buffer.rs: -------------------------------------------------------------------------------- 1 | #[derive(Clone, Debug)] 2 | pub struct PortBuffer { 3 | pub tape: Vec, 4 | pub observations: Vec, 5 | pub tape_len: usize, 6 | } 7 | 8 | impl PortBuffer { 9 | pub fn new(tape_len: usize) -> PortBuffer { 10 | PortBuffer { 11 | tape: Default::default(), 12 | observations: Default::default(), 13 | tape_len, 14 | } 15 | } 16 | 17 | pub fn reset(&mut self) { 18 | self.observations.truncate(0); 19 | self.tape.resize(self.tape_len, 0.0); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /benchmark/print_system_info.py: -------------------------------------------------------------------------------- 1 | import psutil 2 | import platform 3 | 4 | print("="*40, "CPU Info", "="*40) 5 | # number of cores 6 | print("Physical cores:", psutil.cpu_count(logical=False)) 7 | print("Total cores:", psutil.cpu_count(logical=True)) 8 | # CPU frequencies 9 | cpufreq = psutil.cpu_freq() 10 | print(f"Current Frequency: {cpufreq.current:.2f}Mhz") 11 | # CPU usage 12 | 13 | print("="*40, "System Information", "="*40) 14 | uname = platform.uname() 15 | print(f"System: {uname.system}") 16 | print(f"Release: {uname.release}") 17 | print(f"Version: {uname.version}") 18 | print(f"Machine: {uname.machine}") 19 | print(f"Processor: {uname.processor}") 20 | -------------------------------------------------------------------------------- /CHANGELOG: -------------------------------------------------------------------------------- 1 | 2 | # May 2021 3 | - introduce support for multi-letter namespace names 4 | 5 | # April 2021 6 | - introduce --linear and --ffm_field_verbose for passing verbose names of features for linear combo terms and fields 7 | 8 | # March 2021 9 | - Implement binning functionality; features that are floating point 10 | values can now be binned by customized functions 11 | 12 | # February 2021 13 | - much faster FFM prediction mode when using multi-value features 14 | - more than 15% speedup of FFM training 15 | - refactored Regressor to be more easily extendable 16 | - added "load_hogwile" command to serving mode. Enabling in-place 17 | replacement of the models 18 | 19 | # January 2021 20 | - fix truncated model when saving large >1Gb models 21 | -------------------------------------------------------------------------------- /benchmark/calc_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | 4 | 5 | def calc_loss(model_preds_file, input_file): 6 | model_preds = open(model_preds_file, 'rt') 7 | input = open(input_file, 'rt') 8 | 9 | loss = 0. 10 | i = 0 11 | for y_hat in model_preds: 12 | i += 1 13 | y = next(input).split("|")[0].strip() 14 | loss += cross_entropy(float(y_hat), float(y)) 15 | 16 | return loss / float(i) 17 | 18 | 19 | def cross_entropy(y_hat, y): 20 | try: 21 | return -math.log(y_hat) if y == 1 else -math.log(1 - y_hat) 22 | except ValueError: 23 | return cross_entropy(1e-15, y) 24 | 25 | 26 | if __name__ == "__main__": 27 | loss = calc_loss(sys.argv[1], sys.argv[2]) 28 | print(f"loss: {loss}") -------------------------------------------------------------------------------- /examples/basic/datasets/vw_namespace_map.csv: -------------------------------------------------------------------------------- 1 | Z,namespace_names_are_not_used_anywhere 2 | j,some_feature1 3 | E,some_feature2 4 | p,some_feature3 5 | K,some_feature4 6 | J,some_feature5 7 | T,some_feature6 8 | U,some_feature7 9 | L,some_feature8 10 | c,some_feature9 11 | R,some_feature10 12 | s,some_feature11 13 | b,some_feature11 14 | M,some_feature12 15 | A,some_feature14 16 | w,some_feature15_also_colons_do_nothing_here:2 17 | x,some_feature16:3 18 | y,some_feature17:4 19 | z,some_feature18:5 20 | 0,some_feature19:6 21 | 1,some_feature20:7 22 | 2,some_feature21:8 23 | P,some_feature22 24 | B,some_feature23 25 | X,some_feature24 26 | F,some_feature25 27 | l,some_feature26 28 | t,some_feature27 29 | k,some_feature28 30 | u,some_feature29 31 | a,some_feature30 32 | v,some_feature31 33 | V,some_feature32 34 | W,some_feature33 35 | H,some_feature34 36 | Y,some_feature35 37 | m,some_feature36 38 | n,some_feature37 39 | I,some_feature38 40 | 3,some_feature39:2 41 | 4,some_feature40:3 42 | 5,some_feature41:4 43 | S,some_feature42 44 | r,some_feature43 45 | e,some_feature44 46 | d,some_feature45 47 | N,some_feature46 48 | C,some_feature47 49 | q,some_feature48 50 | Q,some_feature49 51 | D,some_feature50 52 | h,some_feature51 53 | i,some_feature52 54 | O,some_feature53 55 | o,some_feature54 56 | G,some_feature55 57 | g,some_feature56 58 | f,some_feature57 59 | -------------------------------------------------------------------------------- /examples/vw-compatibility/datasets/vw_namespace_map.csv: -------------------------------------------------------------------------------- 1 | Z,namespace_names_are_not_used_anywhere 2 | j,some_feature1 3 | E,some_feature2 4 | p,some_feature3 5 | K,some_feature4 6 | J,some_feature5 7 | T,some_feature6 8 | U,some_feature7 9 | L,some_feature8 10 | c,some_feature9 11 | R,some_feature10 12 | s,some_feature11 13 | b,some_feature11 14 | M,some_feature12 15 | A,some_feature14 16 | w,some_feature15_also_colons_do_nothing_here:2 17 | x,some_feature16:3 18 | y,some_feature17:4 19 | z,some_feature18:5 20 | 0,some_feature19:6 21 | 1,some_feature20:7 22 | 2,some_feature21:8 23 | P,some_feature22 24 | B,some_feature23 25 | X,some_feature24 26 | F,some_feature25 27 | l,some_feature26 28 | t,some_feature27 29 | k,some_feature28 30 | u,some_feature29 31 | a,some_feature30 32 | v,some_feature31 33 | V,some_feature32 34 | W,some_feature33 35 | H,some_feature34 36 | Y,some_feature35 37 | m,some_feature36 38 | n,some_feature37 39 | I,some_feature38 40 | 3,some_feature39:2 41 | 4,some_feature40:3 42 | 5,some_feature41:4 43 | S,some_feature42 44 | r,some_feature43 45 | e,some_feature44 46 | d,some_feature45 47 | N,some_feature46 48 | C,some_feature47 49 | q,some_feature48 50 | Q,some_feature49 51 | D,some_feature50 52 | h,some_feature51 53 | i,some_feature52 54 | O,some_feature53 55 | o,some_feature54 56 | G,some_feature55 57 | g,some_feature56 58 | f,some_feature57 59 | -------------------------------------------------------------------------------- /examples/basic/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | SCRIPT=$(readlink -f "$0") 3 | DIR=$(dirname "$SCRIPT") 4 | 5 | MODELS_DIR=$DIR/models 6 | PREDICTIONS_DIR=$DIR/predictions 7 | DATASETS_DIR=$DIR/datasets 8 | PROJECT_ROOT=$DIR/../../ 9 | FW=$PROJECT_ROOT/target/release/fw 10 | 11 | echo $FW 12 | namespaces="--interactions 4G --interactions 4GHX --interactions 4GUW --interactions 4K --interactions 4c --interactions 4go --interactions 4v --interactions BC --interactions BD --interactions BGO --interactions BX --interactions CO --interactions DG --interactions DW --interactions GU --interactions Gx --interactions KR --interactions MN --interactions UW --interactions Ug --interactions eg --keep B --keep C --keep D --keep F --keep G --keep H --keep L --keep O --keep S --keep U --keep W --keep e --keep f --keep g --keep h --keep i --keep o --keep p --keep q --keep r --keep v --keep x " 13 | rest="-l 0.025 -b 25 --adaptive --sgd --link=logistic --loss_function logistic --power_t 0.39 --l2 0.0 --hash all" 14 | 15 | mkdir -p $MODELS_DIR 16 | mkdir -p $PREDICTIONS_DIR 17 | rm -f $MODELS_DIR/*.fw.model 18 | rm -f $DATASETS_DIR/*.fwcache 19 | rm -f $PREDICTIONS_DIR/*.fw.out 20 | echo "Building FW" 21 | (cd $PROJECT_ROOT 22 | cargo build --release) 23 | CMDLINE="$FW $namespaces $rest --data $DATASETS_DIR/train.vw -p $PREDICTIONS_DIR/train.fw.out -f $MODELS_DIR/trained.fw.model --save_resume" 24 | echo "We will run $CMDLINE" 25 | $CMDLINE 26 | echo "DONE" 27 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Outbrain Inc. 2 | 3 | Redistribution and use in source and binary forms, with or without 4 | modification, are permitted provided that the following conditions 5 | are met: 6 | 1. Redistributions of source code must retain the above copyright 7 | notice, this list of conditions and the following disclaimer. 8 | 2. Redistributions in binary form must reproduce the above copyright 9 | notice, this list of conditions and the following disclaimer in the 10 | documentation and/or other materials provided with the distribution. 11 | 3. Neither the name of the author nor the names of its contributors may 12 | be used to endorse or promote products derived from this software 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND 15 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 17 | ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 20 | OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 21 | HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 22 | LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 23 | OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 24 | SUCH DAMAGE. 25 | 26 | https://github.com/outbrain/fwumious_wabbit 27 | -------------------------------------------------------------------------------- /examples/ffm/run_vw_equivalent.sh: -------------------------------------------------------------------------------- 1 | #/bin/sh 2 | SCRIPT=$(readlink -f "$0") 3 | DIR=$(dirname "$SCRIPT") 4 | 5 | MODELS_DIR=$DIR/models 6 | PREDICTIONS_DIR=$DIR/predictions 7 | DATASETS_DIR=$DIR/datasets 8 | PROJECT_ROOT=$DIR/../../ 9 | VW=vw 10 | echo "Generating input datasets" 11 | (cd $DIR 12 | python3 generate.py) 13 | # keep namespaces A and B as regular logistic regression features 14 | # have two fields, one with feature A, and the second with feature B 15 | # we only need k=1, but we use k=10 for test here 16 | namespaces="--keep A --keep B --interactions AB --lrqfa AB10" 17 | 18 | rest="-l 0.1 -b 25 -c --adaptive --sgd --loss_function logistic --link logistic --power_t 0.0 --l2 0.0 --hash all --noconstant" 19 | mkdir -p $MODELS_DIR 20 | mkdir -p $PREDICTIONS_DIR 21 | rm -f $MODELS_DIR/*.vw.model 22 | rm -f $DATASETS_DIR/*.cache 23 | rm -f $PREDICTIONS_DIR/*.vw.out 24 | 25 | 26 | 27 | echo "Running training" 28 | $VW $namespaces $rest --data $DATASETS_DIR/train.vw -p $PREDICTIONS_DIR/train.vw.out -f $MODELS_DIR/trained.vw.model --save_resume 29 | echo "Running prediction on \"easy\" data set" 30 | $VW $namespaces $rest --data $DATASETS_DIR/test-easy.vw -p $PREDICTIONS_DIR/test-easy.vw.out -i $MODELS_DIR/trained.vw.model -t 31 | echo "Running prediction on \"hard\" data set (that needs factorization to be succeffully predicted)" 32 | $VW $namespaces $rest --data $DATASETS_DIR/test-hard.vw -p $PREDICTIONS_DIR/test-hard.vw.out -i $MODELS_DIR/trained.vw.model -t 33 | 34 | echo "DONE" 35 | echo "You can find output datasets in directory $PREDICTIONS" 36 | -------------------------------------------------------------------------------- /examples/ffm/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | SCRIPT=$(readlink -f "$0") 3 | DIR=$(dirname "$SCRIPT") 4 | 5 | MODELS_DIR=$DIR/models 6 | PREDICTIONS_DIR=$DIR/predictions 7 | DATASETS_DIR=$DIR/datasets 8 | PROJECT_ROOT=$DIR/../../ 9 | FW=$PROJECT_ROOT/target/release/fw 10 | echo "Generating input datasets" 11 | (cd $DIR 12 | python3 generate.py) 13 | # keep namespaces A and B as regular logistic regression features 14 | # have two fields, one with feature A, and the second with feature B 15 | # we only need k=1, but we use k=10 for test here 16 | namespaces="--keep A --keep B --interactions AB --ffm_k 10 --ffm_field A --ffm_field B" 17 | 18 | rest="-l 0.1 -b 25 -c --adaptive --sgd --loss_function logistic --link logistic --power_t 0.0 --l2 0.0 --hash all --noconstant" 19 | mkdir -p $MODELS_DIR 20 | mkdir -p $PREDICTIONS_DIR 21 | rm -f $MODELS_DIR/*.fw.model 22 | rm -f $DATASETS_DIR/*.fwcache 23 | rm -f $PREDICTIONS_DIR/*.fw.out 24 | echo "Building FW" 25 | (cd $PROJECT_ROOT 26 | cargo build --release) && \ 27 | echo "Running training" 28 | $FW $namespaces $rest --data $DATASETS_DIR/train.vw -p $PREDICTIONS_DIR/train.fw.out -f $MODELS_DIR/trained.fw.model --save_resume 29 | echo "Running prediction on \"easy\" data set" 30 | $FW $namespaces $rest --data $DATASETS_DIR/test-easy.vw -p $PREDICTIONS_DIR/test-easy.fw.out -i $MODELS_DIR/trained.fw.model -t 31 | echo "Running prediction on \"hard\" data set (that needs factorization to be succeffully predicted)" 32 | $FW $namespaces $rest --data $DATASETS_DIR/test-hard.vw -p $PREDICTIONS_DIR/test-hard.fw.out -i $MODELS_DIR/trained.fw.model -t 33 | 34 | echo "DONE" 35 | echo "You can find output datasets in directory $PREDICTIONS" 36 | -------------------------------------------------------------------------------- /src/logging_layer.rs: -------------------------------------------------------------------------------- 1 | extern crate log; 2 | use env_logger::Builder; 3 | 4 | pub fn initialize_logging_layer() { 5 | let mut builder = Builder::new(); 6 | let log_level = std::env::var("LOG_LEVEL").unwrap_or_else(|_| "info".to_string()); 7 | match log_level.to_lowercase().as_str() { 8 | "info" => builder.filter_level(log::LevelFilter::Info), 9 | "warn" => builder.filter_level(log::LevelFilter::Warn), 10 | "error" => builder.filter_level(log::LevelFilter::Error), 11 | "trace" => builder.filter_level(log::LevelFilter::Trace), 12 | "debug" => builder.filter_level(log::LevelFilter::Debug), 13 | "off" => builder.filter_level(log::LevelFilter::Off), 14 | _ => builder.filter_level(log::LevelFilter::Info), 15 | }; 16 | 17 | if builder.try_init().is_ok() { 18 | log::info!("Initialized the logger ..") 19 | } 20 | 21 | log_detected_x86_features(); 22 | } 23 | 24 | fn log_detected_x86_features() { 25 | let mut features: Vec = Vec::new(); 26 | if is_x86_feature_detected!("avx") { 27 | features.push("AVX".to_string()); 28 | } 29 | 30 | if is_x86_feature_detected!("avx2") { 31 | features.push("AVX2".to_string()); 32 | } 33 | 34 | if is_x86_feature_detected!("avx512f") { 35 | features.push("AVX512F".to_string()); 36 | } 37 | 38 | if is_x86_feature_detected!("fma") { 39 | features.push("FMA".to_string()); 40 | } 41 | 42 | if features.is_empty() { 43 | log::info!("No selected CPU features detected .."); 44 | } else { 45 | log::info!("Detected CPU features: {:?}", features.join(", ")); 46 | } 47 | } -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "fw" 3 | version = "0.2.0" 4 | authors = ["Andraz Tori "] 5 | description = "Like Vowpal Wabbit, but meaner" 6 | edition = "2018" 7 | 8 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 9 | 10 | [dependencies] 11 | csv = "1.2" 12 | # we need new version to enable static builds 13 | fasthash = "0.4" 14 | serde = {version = "1.0.163" , features = ["derive"]} 15 | serde_json = "1.0.96" 16 | clap = "2.33.1" 17 | byteorder = "1.4.3" 18 | merand48 = "0.1.0" 19 | daemonize = "0.5.0" 20 | lz4 = "1.24.0" 21 | nom = "7.1.3" 22 | dyn-clone = "1.0.11" 23 | rand = "0.8.5" 24 | rand_distr = "0.4.3" 25 | rand_xoshiro = "0.6.0" 26 | flate2 = { version = "1.0.26", features = ["zlib-ng"], default-features = false } 27 | shellwords = "1.1.0" 28 | blas = "0.22.0" 29 | intel-mkl-src = {version= "0.8.1", default-features = false, features=["mkl-static-lp64-seq"]} 30 | log = "0.4.18" 31 | env_logger = "0.10.0" 32 | rustc-hash = "1.1.0" 33 | half = "2.3.1" 34 | zstd = "0.13.1" 35 | 36 | [build-dependencies] 37 | cbindgen = "0.23.0" 38 | 39 | [lib] 40 | name = "fw" 41 | path="src/lib.rs" 42 | crate_type = ["lib", "cdylib"] 43 | doctest = false 44 | 45 | [[bin]] 46 | name="fw" 47 | path="src/main.rs" 48 | 49 | [dev-dependencies] 50 | tempfile = "3.1.0" 51 | mockstream = "0.0.3" 52 | 53 | [profile.release] 54 | debug = false 55 | lto = false 56 | panic = 'abort' 57 | codegen-units=1 58 | 59 | [profile.dev] 60 | opt-level = 2 61 | debug = true 62 | debug-assertions = true 63 | overflow-checks = true 64 | lto = false 65 | panic = 'unwind' 66 | incremental = false 67 | codegen-units = 16 68 | rpath = false 69 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | ENV IMAGENAME="fwumious-builder" 3 | ENV DEBIAN_FRONTEND=noninteractive 4 | ARG RUST_VERSION="1.61.0" 5 | RUN apt-get update && apt-get install gcc g++ -y && apt-get install libboost-dev libboost-thread-dev libboost-program-options-dev libboost-system-dev libboost-math-dev libboost-test-dev zlib1g-dev -y && apt-get install git python3 python3-psutil python3-matplotlib lsb-release wget software-properties-common openjdk-8-jdk curl -y 6 | RUN apt-get install -y libssl-dev 7 | 8 | # Install LLVM 9 | WORKDIR /scripts 10 | RUN wget https://apt.llvm.org/llvm.sh 11 | RUN chmod +x llvm.sh 12 | RUN ./llvm.sh 13 13 | ENV PATH="/usr/lib/llvm-11/bin/:${PATH}" 14 | 15 | # Install newer cmake 16 | RUN wget -O - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | gpg --dearmor - | tee /etc/apt/trusted.gpg.d/kitware.gpg >/dev/null 17 | RUN apt-add-repository "deb https://apt.kitware.com/ubuntu/ $(lsb_release -cs) main" 18 | RUN apt update && apt install cmake -y 19 | 20 | # Compile fbs 21 | WORKDIR / 22 | RUN git clone https://github.com/google/flatbuffers.git 23 | WORKDIR /flatbuffers 24 | RUN git checkout tags/v1.12.0 25 | RUN cmake -G "Unix Makefiles" -DCMAKE_BUILD_TYPE=Release 26 | RUN make 27 | RUN make install 28 | 29 | # Get rust ecosystem operating 30 | WORKDIR / 31 | RUN apt-get update 32 | 33 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y 34 | ENV PATH="/root/.cargo/bin:${PATH}" 35 | RUN rustup install $RUST_VERSION 36 | ENV PATH="/root/.cargo/bin:/vowpal_wabbit/vowpalwabbit/vowpalwabbit/cli/:${PATH}" 37 | 38 | # Conduct benchmark against vw + produce --release bin 39 | WORKDIR /FW 40 | COPY . /FW 41 | 42 | #RUN cargo test 43 | RUN chmod +x build.sh 44 | RUN cargo test 45 | RUN ./build.sh 46 | -------------------------------------------------------------------------------- /benchmark/measure.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from subprocess import CalledProcessError 3 | import sys 4 | import psutil 5 | import platform 6 | from timeit import default_timer as timer 7 | 8 | def eprint(*args, **kwargs): 9 | print(*args, file=sys.stderr, **kwargs) 10 | 11 | def measure(cmd, proc_name): 12 | try: 13 | start = timer() 14 | cmdp = subprocess.Popen(cmd.split(), shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 15 | psp = psutil.Process(cmdp.pid) 16 | 17 | cpu = 0 18 | mem = 0 19 | time = 0 20 | 21 | while True: 22 | with psp.oneshot(): 23 | try: 24 | cpu = max(cpu, psp.cpu_percent()) 25 | 26 | if platform.system() == "Darwin": 27 | mem = max(mem, psp.memory_info().rss / 1024.) 28 | else: 29 | mem = max(mem, psp.memory_full_info().pss / 1024.) 30 | except psutil.AccessDenied: 31 | pass 32 | except psutil.ZombieProcess: 33 | pass 34 | try: 35 | psp.wait(timeout=0.5) 36 | time = timer() - start 37 | except psutil.TimeoutExpired: 38 | continue 39 | else: 40 | break 41 | return_code = cmdp.poll() 42 | 43 | # eprint(f"\nERROR_CODE: {return_code}\n" + str(b"\n".join(cmdp.stdout.readlines()))) 44 | except CalledProcessError as e: 45 | output = e.output.decode() 46 | print(output) 47 | return None 48 | 49 | 50 | return time, mem, cpu 51 | 52 | 53 | 54 | if __name__ == "__main__": 55 | cmd = " ".join(sys.argv[1:]) 56 | time, mem, cpu = measure(cmd) 57 | print(f"{time}, {mem}, {cpu}") 58 | -------------------------------------------------------------------------------- /examples/vw-compatibility/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | SCRIPT=$(readlink -f "$0") 3 | DIR=$(dirname "$SCRIPT") 4 | 5 | MODELS_DIR=$DIR/models 6 | PREDICTIONS_DIR=$DIR/predictions 7 | DATASETS_DIR=$DIR/datasets 8 | PROJECT_ROOT=$DIR/../../ 9 | FW=$PROJECT_ROOT/target/release/fw 10 | VW=vw 11 | 12 | # the incompatibility problem: Vowpal automatically takes all interaction features also as single features. fw does not. 13 | #namespaces="--interactions 4G --interactions 4GHX --interactions 4GUW --interactions 4K --interactions 4c --interactions 4go --interactions 4v --interactions BC --interactions BD --interactions BGO --interactions BX --interactions CO --interactions DG --interactions DW --interactions GU --interactions Gx --interactions KR --interactions MN --interactions UW --interactions Ug --interactions eg --keep B --keep C --keep D --keep F --keep G --keep H --keep L --keep O --keep S --keep U --keep W --keep e --keep f --keep g --keep h --keep i --keep o --keep p --keep q --keep r --keep v --keep x " 14 | namespaces="--keep B --keep C --keep D --keep F --keep G --keep H --keep L --keep O --keep S --keep U --keep W --keep e --keep f --keep g --keep h --keep i --keep o --keep p --keep q --keep r --keep v --keep x " 15 | rest="-l 0.025 -b 25 --adaptive --sgd --link=logistic --loss_function logistic --power_t 0.35 --l2 0.0 --hash all" 16 | 17 | mkdir -p $PREDICTIONS_DIR 18 | rm -f $DATASETS_DIR/*.fwcache 19 | rm -f $PREDICTIONS_DIR/*.fw.out 20 | rm -f $DATASETS_DIR/*.cache 21 | rm -f $PREDICTIONS_DIR/*.vw.out 22 | echo "Building FW" 23 | (cd $PROJECT_ROOT 24 | cargo build --release) 25 | VW_CMDLINE="$VW $namespaces $rest --data $DATASETS_DIR/train.vw -p $PREDICTIONS_DIR/train.vw.out" 26 | FW_CMDLINE="$FW $namespaces $rest --data $DATASETS_DIR/train.vw -p $PREDICTIONS_DIR/train.fw.out --vwcompat" 27 | $VW_CMDLINE && $FW_CMDLINE 28 | echo "DONE, now running diff" 29 | 30 | diff -s $PREDICTIONS_DIR/train.vw.out $PREDICTIONS_DIR/train.fw.out 31 | 32 | 33 | -------------------------------------------------------------------------------- /examples/ffm/README.md: -------------------------------------------------------------------------------- 1 | # Tests for Factorization Machines functionality 2 | 3 | We're trying to predict which animal likes to eat what. 4 | 5 | Our dataset will be a simple combination of animal and food and 6 | outcome - either one likes it (1) or dislikes it (-1) 7 | 8 | Namespace A will be Animal, and B will be Food. Each animal has its latent 9 | type - Herbivore or Carnivore. And each food has its latent type - Plant or 10 | Meat. 11 | 12 | Example of features will be "Herbivore-13" and "Plant-55". However these are 13 | just strings for the algo. These names are handy when we look at the data as 14 | it makes it easy for us to check for correctness. 15 | 16 | We will split each feature set into two sets - let's call them A1 & A2 set 17 | and B1 & B2 set. 18 | 19 | Our training data will only have interactions between 20 | - A1 and (B1 U B2) 21 | - B1 and (A1 U A2) 22 | 23 | Importantly there are no interactions between A2 and B2 in the training set. 24 | Predicting resutls of these interactions correctly requires discovery of 25 | latent variable - which is what FFM can do. 26 | 27 | Datasets created by generate.py: 28 | train.vw - In the training set there are no interactions between A2 and B2. 29 | easy.vw - The distribution here is the same as in the training dataset 30 | hard.vw - The distribution here consists entirely of interactions between A2 31 | and B2. 32 | 33 | # Notes: 34 | - if we don't use feature combinations - only namespaces A and B in 35 | isolation, then Logistic Regression will not be able to provide any 36 | predictive power 37 | - if we use plain feature combinations - AB, then LR will easily have 38 | correct predictions on easy.vw, but not on hard.vw 39 | - Only if the algo is able to capture the existence of latent variable 40 | it can perform better than random on hard.vw 41 | 42 | # Usage: 43 | Demonstrate factorization machines: 44 | sh run.sh 45 | 46 | Run equivalent vowpal setup: 47 | sh run_vw_equivalent.sh 48 | -------------------------------------------------------------------------------- /COMPATIBILITY.md: -------------------------------------------------------------------------------- 1 | # Vowpal Wabbit compatibility 2 | 3 | WARNING: Fwumious Wabbit cuts a lot of corners. Beware. 4 | 5 | ### Input file format 6 | - [Vowpal Wabbit input format](https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Input-format) is supported 7 | - Namespaces can only be single letters 8 | - In each example each namespace can only be delcared once (and can have multiple features) 9 | - there has to be a map file ("vw_namespace_map.csv") available with all the namespaces declared 10 | 11 | 12 | ### Command line arguments 13 | 14 | #### Vowpal Wabbit compatibility 15 | --vwcompat It causes Fwumious to complain if the arguments used for LR part would not 16 | produce exactly the same results in vowpal wabbit. 17 | 18 | #### Required when using "--vwcompat" to force 19 | --hash all This treats all features as categorical, 20 | Otherwise Vowpal Wabbit treats some as pre-hashed. 21 | 22 | --adaptive Adagrad mode 23 | 24 | --sgd disable vowpal defaults of "--invariant" and "--normalize" 25 | 26 | 27 | #### Optional 28 | --link logistic Use logistic function for prediction printouts (always on) 29 | 30 | --loss_function logistic Use logloss (always on) 31 | 32 | --power_t 0.5 Value for Adagrad's exponent (default 0.5 = square root) 33 | 34 | --l2 0.0 L2 regularization, not supported. Only 0.0 allowed 35 | 36 | --keep X Include namespace into the feature set 37 | 38 | --interactions XYZ Include namesapce interactions into the feature set 39 | 40 | --noconstant Don't add intercept 41 | 42 | --testonly Don't learn, only predict 43 | 44 | 45 | #### Other known incompatibilities and differences: 46 | - Fwumious Wabbit currently only supports log-loss for loss function 47 | - when not specifying either --keep or --interactions, Vowpal Wabbit will use all 48 | input features. Fwumious Wabbit will use none. 49 | 50 | #### vw_namspace_map.csv 51 | It maps single letter namespaces to their full names. Its purpose is: 52 | - to disclose namespaces ahead of time 53 | - to map from namespace letters to their full names 54 | Check out examples directory to see how it is formatted. 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Fwumious Wabbit is 2 | - a very fast machine learning tool 3 | - built with Rust 4 | - inspired by and partially compatible with Vowpal Wabbit (much love! read more about compatibility [here](COMPATIBILITY.md)) 5 | - currently supports logistic regression and (deep) field-aware factorization machines 6 | 7 | [![Rust-UbuntuLatest](https://github.com/outbrain/fwumious_wabbit/actions/workflows/rust.yml/badge.svg?branch=main)](https://github.com/outbrain/fwumious_wabbit/actions/workflows/rust.yml) 8 | [![Rust-Ubuntu18](https://github.com/outbrain/fwumious_wabbit/actions/workflows/rust-Ubuntu18.yml/badge.svg)](https://github.com/outbrain/fwumious_wabbit/actions/workflows/rust-Ubuntu18.yml) 9 | [![Gitter](https://badges.gitter.im/FwumiousWabbit/community.svg)](https://gitter.im/FwumiousWabbit/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge) 10 | 11 | Fwumious Wabbit is actively used in Outbrain for offline research, as well as for some production flows. It 12 | enables "high bandwidth research" when doing feature engineering, feature 13 | selection, hyperparameter tuning, and the like. 14 | 15 | Data scientists can train hundreds of models over hundreds of millions of examples in 16 | a matter of hours on a single machine. 17 | 18 | For our tested scenarios it is almost two orders of magnitude faster than the 19 | fastest Tensorflow implementation of Logistic Regression and FFMs that we could 20 | come up with. It is an order of magnitude faster than Vowpal Wabbit for some specific use-cases. 21 | 22 | Check out our [benchmark](BENCHMARK.md), here's a teaser: 23 | 24 | ![benchmark results](benchmark_results.png) 25 | 26 | 27 | **Why is it faster?** (see [here](SPEED.md) for more details) 28 | - Only implements Logistic Regression and (Deep) Field-aware Factorization Machines 29 | - Uses hashing trick, lookup table for AdaGrad and a tight encoding format for the "input cache" 30 | - Features' namespaces have to be declared up-front 31 | - Prefetching of weights from memory (avoiding pipeline stalls) 32 | - Written in Rust with heavy use of code specialization (via macros and traits) 33 | - Special emphasis on efficiency of sparse operations and serving 34 | 35 | 36 | # Weight patching 37 | This repo also contains the patching algorithm that enables very fast weight diff computation see `weight_patcher` for more details. 38 | -------------------------------------------------------------------------------- /src/multithread_helpers.rs: -------------------------------------------------------------------------------- 1 | use crate::regressor::Regressor; 2 | use core::ops::{Deref, DerefMut}; 3 | use std::marker::PhantomData; 4 | use std::mem; 5 | use std::mem::ManuallyDrop; 6 | use std::sync::Arc; 7 | use std::sync::Mutex; 8 | 9 | // This is a helper for UNSAFELY sharing data between threads 10 | 11 | pub struct UnsafelySharableTrait { 12 | content: ManuallyDrop, 13 | reference_count: Arc>>, 14 | } 15 | 16 | pub type BoxedRegressorTrait = UnsafelySharableTrait>; 17 | 18 | // SUPER UNSAFE 19 | // SUPER UNSAFE 20 | // SUPER UNSAFE 21 | // This literary means we are on our own -- but it is the only way to implement HogWild performantly 22 | unsafe impl Sync for UnsafelySharableTrait {} 23 | unsafe impl Send for UnsafelySharableTrait {} 24 | 25 | impl Deref for UnsafelySharableTrait { 26 | type Target = T; 27 | 28 | fn deref(&self) -> &T { 29 | &self.content 30 | } 31 | } 32 | 33 | impl DerefMut for UnsafelySharableTrait { 34 | fn deref_mut(&mut self) -> &mut T { 35 | &mut self.content 36 | } 37 | } 38 | 39 | impl Drop for UnsafelySharableTrait { 40 | fn drop(&mut self) { 41 | unsafe { 42 | // we are called before reference is removed, so we need to decide if to drop it or not 43 | let count = Arc::>>::strong_count(&self.reference_count) - 1; 44 | if count == 0 { 45 | let _box_to_be_dropped = ManuallyDrop::take(&mut self.content); 46 | // Now this means that the content will be dropped 47 | } 48 | } 49 | } 50 | } 51 | 52 | impl UnsafelySharableTrait { 53 | pub fn new(a: T) -> UnsafelySharableTrait { 54 | UnsafelySharableTrait:: { 55 | content: ManuallyDrop::new(a), 56 | reference_count: Arc::new(Mutex::new(std::marker::PhantomData {})), 57 | } 58 | } 59 | } 60 | 61 | // Non-generalized implementation 62 | // Todo - generalize this[A 63 | 64 | impl BoxedRegressorTrait { 65 | pub fn clone(&self) -> BoxedRegressorTrait { 66 | // UNSAFE AS HELL 67 | unsafe { 68 | // Double deref here sounds weird, but you got to know that dyn Trait and Box are the same thing, just box owns it. 69 | // And you can get dyn Trait content, but you can't get box content (directly) 70 | let r2: Box = mem::transmute(self.content.deref().deref()); 71 | 72 | BoxedRegressorTrait { 73 | content: ManuallyDrop::new(r2), 74 | reference_count: self.reference_count.clone(), 75 | } 76 | } 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /benchmark/generate.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | from pathlib import Path 4 | 5 | # deterministic random seed 6 | random.seed(1) 7 | 8 | HERBIVORE = 1 9 | CARNIVORE = 2 10 | PLANT = 100 11 | MEAT = 101 12 | 13 | ADICT = { 14 | HERBIVORE: "Herbivore", 15 | CARNIVORE: "Carnivore" 16 | } 17 | 18 | BDICT = { 19 | PLANT: "Plant", 20 | MEAT: "Meat" 21 | } 22 | 23 | 24 | def get_score(a, b): 25 | if a[0] == HERBIVORE and b[0] == PLANT: 26 | return 1 27 | elif a[0] == CARNIVORE and b[0] == MEAT: 28 | return 1 29 | else: 30 | return -1 31 | 32 | 33 | def random_features(num_random_features): 34 | l = [""] 35 | for x in range(num_random_features): 36 | namespace = chr(ord('C') + x) 37 | l.append("|" + namespace + " " + namespace + str(random.randint(0, 10000))) 38 | return " ".join(l) 39 | 40 | 41 | def render_example(a, b, num_random_features): 42 | score = get_score(a, b) 43 | return " ".join([str(score), u"|A", ADICT[a[0]] + u"-" + str(a[1]), u"|B", 44 | BDICT[b[0]] + u"-" + str(b[1])]) + random_features(num_random_features) + "\n" 45 | 46 | 47 | def generate(output_dir, train_examples, test_examples, feature_variety, num_random_features): 48 | f = open(output_dir / "vw_namespace_map.csv", "w"); 49 | f.write("A,animal\n") 50 | f.write("B,food\n") 51 | for x in range(num_random_features): 52 | namespace = chr(ord('C') + x) 53 | f.write(namespace + ",somefeature\n") 54 | 55 | i = 0 56 | f = open(output_dir / "train.vw", "w") 57 | block_beyond = int(feature_variety / 4.0) 58 | while i < train_examples: 59 | add_dataset_record(f, block_beyond, feature_variety, num_random_features) 60 | i += 1 61 | 62 | i = 0 63 | # this has the same distribution as for train... 64 | f = open(output_dir / "easy.vw", "w") 65 | while i < test_examples: 66 | add_dataset_record(f, block_beyond, feature_variety, num_random_features) 67 | i += 1 68 | 69 | # now we will test for completely unseen combos 70 | f = open(output_dir / "hard.vw", "w") 71 | i = 0 72 | while i < test_examples: 73 | animal_type = random.choices([HERBIVORE, CARNIVORE])[0] 74 | food_type = random.choices([PLANT, MEAT])[0] 75 | animal_name = random.randint(block_beyond + 1, feature_variety) 76 | food_name = random.randint(block_beyond + 1, feature_variety) 77 | 78 | f.write(render_example((animal_type, animal_name), (food_type, food_name), num_random_features)) 79 | i += 1 80 | 81 | 82 | def add_dataset_record(f, block_beyond, feature_variety, num_random_features): 83 | animal_type = random.choices([HERBIVORE, CARNIVORE])[0] 84 | food_type = random.choices([PLANT, MEAT])[0] 85 | missone = random.randint(0, 1) 86 | if missone: 87 | animal_name = random.randint(0, feature_variety) 88 | food_name = random.randint(0, block_beyond) 89 | else: 90 | animal_name = random.randint(0, block_beyond) 91 | food_name = random.randint(0, feature_variety) 92 | f.write(render_example((animal_type, animal_name), (food_type, food_name), num_random_features)) 93 | 94 | 95 | if __name__ == "__main__": 96 | dataset_size = 500000 97 | if len(sys.argv) == 2: 98 | dataset_size = int(sys.argv[1]) 99 | 100 | generate(Path(""), dataset_size, dataset_size, 1000, 10) 101 | -------------------------------------------------------------------------------- /run_one.sh: -------------------------------------------------------------------------------- 1 | #/bin/sh 2 | infile=$1 3 | #namespaces="--interactions BC --interactions BD --interactions BGO --interactions BTf --interactions BX --interactions CO --interactions DG --interactions DW --interactions GHXx --interactions GU --interactions GUx --interactions Gw --interactions KR --interactions KTd --interactions Kx --interactions MN --interactions QS --interactions Sx --interactions UW --interactions Ug --interactions cx --interactions eg --interactions gx --interactions rvx --keep B --keep D --keep F --keep G --keep H --keep L --keep O --keep S --keep U --keep W --keep e --keep f --keep g --keep h --keep i --keep o --keep p --keep q --keep r --keep v --keep w" 4 | #namespaces="--keep A" 5 | namespaces="--interactions 0GU --interactions 0K --interactions 0S --interactions 0b --interactions 0y --interactions BD --interactions BGO --interactions BTe --interactions BX --interactions Bn --interactions Bo --interactions Br --interactions CO --interactions Cu --interactions DG --interactions FG --interactions GHX --interactions Gt --interactions Gz --interactions KR --interactions KRt --interactions KTc --interactions MN --interactions UW --interactions Uf --interactions df --interactions tx --keep F --keep L --keep U --keep d --keep e --keep g --keep h --keep o --keep q --keep t --keep u --keep y --keep z --ffm_bit_precision 25 --ffm_field BGz --ffm_field CFL --ffm_field O --ffm_field UW --ffm_field cfnr --ffm_k 8" 6 | 7 | rest="--data $infile -l 0.025 -b 25 --adaptive --sgd --loss_function logistic --link logistic --power_t 0.38 --l2 0.00 --hash all --noconstant" 8 | #vwonly="--lrqfa AB-64" 9 | #fwonly="--lrqfa AB-64" 10 | 11 | #rest="-b 24 --data $infile -l 0.1 --power_t 0.39 --adaptive --link logistic --sgd --loss_function logistic --noconstant --l2 0.0" 12 | #rm v f 13 | #rm $1.cache $1.fwcache 14 | vw=/home/minmax/minmax_old/zgit/vowpal_wabbit/vowpalwabbit/vw 15 | #vw=/home/minmax/obgit/vw/vowpal_wabbit/vowpalwabbit/vw 16 | #clear; 17 | #echo "build --release && target/release/fw $namespaces $rest -p f" 18 | #echo "cargo build --release && target/release/fw $namespaces $rest -c -p f $fwonly && time $vw $namespaces $rest -c -p v $vwonly" 19 | 20 | echo "target/release/fw $namespaces $rest -c -p f $fwonly " 21 | 22 | cargo build --release --target x86_64-unknown-linux-musl --bin fw && \ 23 | target/x86_64-unknown-linux-musl/release/fw $namespaces $rest -c -p f $fwonly -f out.fwmodel --save_resume 24 | #cargo build --release --bin fw && \ 25 | #target/release/fw $namespaces $rest -c -p f $fwonly 26 | 27 | #&& time $vw $namespaces $rest -c -p v $vwonly 28 | #clear; cargo build && target/debug/fw $namespaces $rest -p f && time vw $namespaces $rest -p v 29 | 30 | # an example of command line of some specific model and its debugging 31 | # for debugging use --foreground opt 32 | #target/release/fw --foreground -i /tmp/vw/time=15-25/initial_weights.vw -t --daemon --quiet --port 26542 --interactions 0t --interactions 13 --interactions 2G --interactions 3GHX --interactions 3GU --interactions 3K --interactions 3S --interactions 3c --interactions 3g --interactions BD --interactions BGO --interactions BTf --interactions BX --interactions Bo --interactions Bp --interactions Bs --interactions CO --interactions Cx --interactions DG --interactions DW --interactions FG --interactions GU --interactions Gt --interactions KR --interactions KRt --interactions KTd --interactions Lt --interactions MN --interactions QS --interactions UW --interactions Ug --interactions eg --keep 1 --keep 2 --keep F --keep H --keep L --keep O --keep U --keep W --keep e --keep f --keep h --keep i --keep o --keep p --keep r --keep s --keep t --keep x --num_children 4 33 | -------------------------------------------------------------------------------- /src/hogwild.rs: -------------------------------------------------------------------------------- 1 | use std::sync::mpsc::{Receiver, SyncSender}; 2 | use std::sync::{mpsc, Arc, Mutex}; 3 | use std::thread; 4 | use std::thread::JoinHandle; 5 | 6 | use crate::feature_buffer::FeatureBufferTranslator; 7 | use crate::model_instance::ModelInstance; 8 | use crate::multithread_helpers::BoxedRegressorTrait; 9 | use crate::port_buffer::PortBuffer; 10 | 11 | static CHANNEL_CAPACITY: usize = 100_000; 12 | 13 | pub struct HogwildTrainer { 14 | workers: Vec>, 15 | sender: SyncSender>, 16 | } 17 | 18 | pub struct HogwildWorker { 19 | regressor: BoxedRegressorTrait, 20 | feature_buffer_translator: FeatureBufferTranslator, 21 | port_buffer: PortBuffer, 22 | } 23 | 24 | impl HogwildTrainer { 25 | pub fn new( 26 | sharable_regressor: BoxedRegressorTrait, 27 | model_instance: &ModelInstance, 28 | num_workers: u32, 29 | ) -> HogwildTrainer { 30 | let (sender, receiver): (SyncSender>, Receiver>) = 31 | mpsc::sync_channel(CHANNEL_CAPACITY); 32 | let mut trainer = HogwildTrainer { 33 | workers: Vec::with_capacity(num_workers as usize), 34 | sender, 35 | }; 36 | let receiver: Arc>>> = Arc::new(Mutex::new(receiver)); 37 | let feature_buffer_translator = FeatureBufferTranslator::new(model_instance); 38 | let port_buffer = sharable_regressor.new_portbuffer(); 39 | for _ in 0..num_workers { 40 | let worker = HogwildWorker::new( 41 | sharable_regressor.clone(), 42 | feature_buffer_translator.clone(), 43 | port_buffer.clone(), 44 | Arc::clone(&receiver), 45 | ); 46 | trainer.workers.push(worker); 47 | } 48 | trainer 49 | } 50 | 51 | pub fn digest_example(&self, feature_buffer: Vec) { 52 | self.sender.send(feature_buffer).unwrap(); 53 | } 54 | 55 | pub fn block_until_workers_finished(self) { 56 | drop(self.sender); 57 | for worker in self.workers { 58 | worker.join().unwrap(); 59 | } 60 | } 61 | } 62 | 63 | impl Default for HogwildTrainer { 64 | fn default() -> Self { 65 | let (sender, _receiver) = mpsc::sync_channel(0); 66 | HogwildTrainer { 67 | workers: vec![], 68 | sender, 69 | } 70 | } 71 | } 72 | 73 | impl HogwildWorker { 74 | pub fn new( 75 | regressor: BoxedRegressorTrait, 76 | feature_buffer_translator: FeatureBufferTranslator, 77 | port_buffer: PortBuffer, 78 | receiver: Arc>>>, 79 | ) -> JoinHandle<()> { 80 | let mut worker = HogwildWorker { 81 | regressor, 82 | feature_buffer_translator, 83 | port_buffer, 84 | }; 85 | 86 | thread::spawn(move || worker.train(receiver)) 87 | } 88 | 89 | pub fn train(&mut self, receiver: Arc>>>) { 90 | loop { 91 | let buffer = match receiver.lock().unwrap().recv() { 92 | Ok(feature_buffer) => feature_buffer, 93 | Err(_) => break, // channel was closed 94 | }; 95 | self.feature_buffer_translator 96 | .translate(buffer.as_slice(), 0u64); 97 | self.regressor.learn( 98 | &self.feature_buffer_translator.feature_buffer, 99 | &mut self.port_buffer, 100 | true, 101 | ); 102 | } 103 | } 104 | } 105 | 106 | #[cfg(test)] 107 | mod tests { 108 | use super::*; 109 | use crate::regressor::Regressor; 110 | 111 | #[test] 112 | fn hogwild_trainer_new_creates_workers() { 113 | let num_workers = 4; 114 | let model_instance = ModelInstance::new_empty().unwrap(); 115 | let regressor = Regressor::new(&model_instance); 116 | let sharable_regressor: BoxedRegressorTrait = BoxedRegressorTrait::new(Box::new(regressor)); 117 | let trainer = HogwildTrainer::new(sharable_regressor, &model_instance, num_workers); 118 | 119 | assert_eq!(trainer.workers.len(), num_workers as usize); 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /examples/ffm/generate.py: -------------------------------------------------------------------------------- 1 | ########################################################################################## 2 | # A script for generation of synthetic data, suitable for sanity-checking a given binary # 3 | ########################################################################################## 4 | 5 | import random 6 | import pathlib 7 | import os 8 | import argparse 9 | 10 | 11 | def get_score(a, b): 12 | 13 | if (a[0] == "Herbivore" and b[0] == "Plant"): 14 | return 1 15 | elif a[0] == "Carnivore" and b[0] == "Meat": 16 | return 1 17 | else: 18 | return -1 19 | 20 | 21 | def render_example(a, b): 22 | 23 | score = get_score(a, b) 24 | return " ".join([ 25 | str(score), u"|A", a[0] + u"-" + str(a[1]), u"|B", 26 | b[0] + u"-" + str(b[1]) 27 | ]) + "\n" 28 | 29 | 30 | def generate_synthetic_dataset(): 31 | 32 | DATASETS_DIRECTORY.mkdir(exist_ok=True) 33 | 34 | f = open(os.path.join(DATASETS_DIRECTORY, "vw_namespace_map.csv"), "w") 35 | f.write("A,animal\n") 36 | f.write("B,food\n") 37 | 38 | i = 0 39 | f = open(os.path.join(DATASETS_DIRECTORY, "train.vw"), "w") 40 | while i < TRAIN_EXAMPLES: 41 | animal_type = random.choices(['Herbivore', 'Carnivore'])[0] 42 | food_type = random.choices(['Plant', 'Meat'])[0] 43 | missone = random.randint(0, 1) 44 | if missone: 45 | person = random.randint(0, NUM_ANIMALS) 46 | movie = random.randint(0, BLOCK_BEYOND) 47 | else: 48 | person = random.randint(0, BLOCK_BEYOND) 49 | movie = random.randint(0, NUM_FOODS) 50 | f.write(render_example((animal_type, person), (food_type, movie))) 51 | i += 1 52 | 53 | i = 0 54 | # this has the same distribution as for train... 55 | f = open(os.path.join(DATASETS_DIRECTORY, "test-easy.vw"), "w") 56 | while i < EVAL_EXAMPLES: 57 | animal_type = random.choices(['Herbivore', 'Carnivore'])[0] 58 | food_type = random.choices(['Plant', 'Meat'])[0] 59 | missone = random.randint(0, 1) 60 | if missone: 61 | person = random.randint(0, NUM_ANIMALS) 62 | movie = random.randint(0, BLOCK_BEYOND) 63 | else: 64 | person = random.randint(0, BLOCK_BEYOND) 65 | movie = random.randint(0, NUM_FOODS) 66 | f.write(render_example((animal_type, person), (food_type, movie))) 67 | i += 1 68 | 69 | # now we will test for completely unseen combos 70 | f = open(os.path.join(DATASETS_DIRECTORY, "test-hard.vw"), "w") 71 | i = 0 72 | while i < EVAL_EXAMPLES: 73 | animal_type = random.choices(['Herbivore', 'Carnivore'])[0] 74 | food_type = random.choices(['Plant', 'Meat'])[0] 75 | person = random.randint(BLOCK_BEYOND + 1, NUM_ANIMALS) 76 | movie = random.randint(BLOCK_BEYOND + 1, NUM_FOODS) 77 | 78 | f.write(render_example((animal_type, person), (food_type, movie))) 79 | i += 1 80 | 81 | 82 | if __name__ == "__main__": 83 | 84 | parser = argparse.ArgumentParser( 85 | description="FW synthetic data generator.") 86 | parser.add_argument("--num_train_examples", 87 | type=int, 88 | default=1000000, 89 | help="How many instances to generate?") 90 | parser.add_argument("--num_eval_examples", 91 | type=int, 92 | default=10000, 93 | help="How many instances to evaluate?") 94 | parser.add_argument("--num_animals", 95 | type=int, 96 | default=5, 97 | help="How many possible animals are there?") 98 | parser.add_argument("--num_foods", 99 | type=int, 100 | default=5, 101 | help="Number of possible foods for the animals?") 102 | parser.add_argument("--block_beyond", 103 | type=int, 104 | default=3, 105 | help="block_beyond parameter.") 106 | parser.add_argument("--random_seed", 107 | type=int, 108 | default=1, 109 | help="Random seed for the generation.") 110 | 111 | args = parser.parse_args() 112 | 113 | DATASETS_DIRECTORY = pathlib.Path("datasets") 114 | TRAIN_EXAMPLES = args.num_train_examples 115 | EVAL_EXAMPLES = args.num_eval_examples 116 | NUM_ANIMALS = args.num_animals 117 | NUM_FOODS = args.num_foods 118 | BLOCK_BEYOND = args.block_beyond 119 | Aval, Bval = [], [] 120 | 121 | random.seed(args.random_seed) 122 | 123 | generate_synthetic_dataset() 124 | -------------------------------------------------------------------------------- /BENCHMARK.md: -------------------------------------------------------------------------------- 1 | 2 | ## Scenarios 3 | 1. train a new model from a dataset and an output model file - *typical scenario for one-off training on the dataset* 4 | 1. train a new model from a cached dataset, and generate an output model - *this is also a typical scenario - we usually run many concurrent model evaluations as part of the model search* 5 | 1. use a generated model to make predictions over a dataset read from a text file, and print them to an output predictions file - *this is to illustrate potential serving performance, we don't usually predict from file input as our offline flows always apply online learning. note that when running as daemon we use half as much memory since gradients are not loaded - only model weights.* 6 | 7 | 8 | ## Model 9 | We train a logistic regression model, applying online learning one example at a time (no batches), 10 | using '--adaptive' flag for adaptive learning rates (AdaGrad variant). 11 | 12 | ## Results 13 | here are the results for 3 runs for each scenario, taking mean values: 14 | ![benchmark results](benchmark_results.png) 15 | Scenario|Runtime (seconds)|Memory (MB)|CPU % 16 | ----|----:|----:|----: 17 | vw train, no cache|97.34 | 568 | 174.63 18 | fw train, no cache|19.98 | 258 | 102.23 19 | vw train, using cache|99.46 | 566 | 169.30 20 | fw train, using cache|12.69 | 259 | 102.10 21 | vw predict, no cache|81.83 | 141 | 178.53 22 | fw predict, no cache|16.90 | 133 | 101.47 23 | 24 | ### Model equivalence 25 | loss values for the test set: 26 | 27 | ``` 28 | Vowpal Wabbit predictions loss: 0.6370 29 | Fwumious Wabbit predictions loss: 0.6370 30 | ``` 31 | 32 | 33 | for more details on what makes Fwumious Wabbit so fast, see [here](https://github.com/outbrain/fwumious_wabbit/blob/benchmark/SPEED.md) 34 | 35 | ### Dataset details 36 | we generate a synthetic dataset with 10,000,000 train records ('train.vw'), and 10,000,000 test records ('easy.vw'). 37 | 38 | the task is 'Eat-Rate prediction' - each record describes the observed result of a single feeding experiment. 39 | each record is made of a type of animal, a type of food (in Vowpal Wabbit jargon these are our namespaces A and B respectively), and a label indicating whether the animal ate the food. 40 | the underlying model is simple - animals are either herbivores or carnivores, 41 | and food is either plant based or meat based. 42 | 43 | herbivores always eat plants (and only plants), and carnivores always eat meat (and only meat). 44 | 45 | we name animals conveniently using the pattern 'diet-id', for example 'Herbivore-1234' and 'Carnivore-5678' 46 | and the food similarly as 'food_type-id' - for example 'Plant-678' and 'Meat-234' so the expected label for a record is always obvious. 47 | there are 1,000 animal types, and 1,000 food types. we generate additional 10 random features, 48 | to make the dataset dimensions a bit more realistic. 49 | 50 | see for example the first 5 lines from the train dataset (after some pretty-printing): 51 | 52 | label|animal|food|feat_2|feat_3|feat_4|feat_5|feat_6|feat_7|... 53 | ----:|------|----|----|----|----|----|----|----|---- 54 | -1 |A Herbivore-65 |B Meat-120 |C C8117 |D D7364 |E E7737 |F F6219 |G G3439 |H H1537 |... 55 | 1 |A Carnivore-272 |B Meat-184 |C C3748 |D D9685 |E E1674 |F F5200 |G G501 |H H365 |... 56 | 1 |A Carnivore-135 |B Meat-227 |C C7174 |D D8123 |E E9058 |F F3818 |G G5663 |H H3782 |... 57 | -1 |A Herbivore-47 |B Meat-644 |C C4856 |D D1980 |E E5450 |F F8205 |G G6915 |H H8318 |... 58 | -1 |A Carnivore-603 |B Plant-218 |C C565 |D D7868 |E E3977 |F F6623 |G G6788 |H H2834 |... 59 | 60 | 61 | ### Feature engineering 62 | if we train using separate 'animal type' and 'food type' features, the model won't learn well, 63 | since knowing the animal identity alone isn't enough to predict if it will eat or not - and the same 64 | goes for knowing the food type alone. 65 | so we apply an interaction between the animal type and food type fields. 66 | 67 | ## Prerequisites and running 68 | you should have Vowpal Wabbit installed, as the benchmark invokes it via the 'vw' command. 69 | 70 | additionally the rust toolchain (particularly cargo and rustc) is required in order to build Fwumious Wabbit (the benchmark invokes '../target/release/fw') 71 | in order to build and run the benchmark use one of these bash scripts: 72 | ``` 73 | ./run_with_plots.sh 74 | ``` 75 | in order to run the benchmark and plot the results (requires matplotlib, last used with version 2.1.2) 76 | or, if you just want the numbers with less dependencies run: 77 | ``` 78 | ./run_without_plots.sh 79 | ``` 80 | ## Latest run setup 81 | 82 | ### versions: 83 | ``` 84 | vowpal wabbit 8.9.2 (git commit: 884420267) 85 | fwumious wabbit 1.6 (git commit: c04ff7e) 86 | ``` 87 | 88 | ### CPU Info 89 | ``` 90 | Intel(R) Xeon(R) CPU E5-2630 v2 @ 2.60GHz 91 | ``` 92 | ### Operating System 93 | ``` 94 | System: Linux 95 | Version: #151-Ubuntu SMP Fri Jun 18 19:21:19 UTC 2021 96 | ``` 97 | -------------------------------------------------------------------------------- /src/radix_tree.rs: -------------------------------------------------------------------------------- 1 | use crate::vwmap::NamespaceDescriptor; 2 | 3 | #[derive(Clone, Copy, Debug, PartialEq)] 4 | pub(crate) struct NamespaceDescriptorWithHash { 5 | pub(crate) descriptor: NamespaceDescriptor, 6 | pub(crate) hash_seed: u32, 7 | } 8 | 9 | impl NamespaceDescriptorWithHash { 10 | pub(crate) fn new(descriptor: NamespaceDescriptor, hash_seed: u32) -> Self { 11 | Self { 12 | descriptor, 13 | hash_seed, 14 | } 15 | } 16 | } 17 | 18 | #[derive(Clone, Debug)] 19 | struct RadixTreeNode { 20 | children: Vec>, 21 | value: Option, 22 | } 23 | 24 | impl Default for RadixTreeNode { 25 | fn default() -> Self { 26 | Self { 27 | children: vec![None; 256], 28 | value: None, 29 | } 30 | } 31 | } 32 | 33 | #[derive(Clone, Default, Debug)] 34 | pub struct RadixTree { 35 | root: RadixTreeNode, 36 | } 37 | 38 | impl RadixTree { 39 | pub(crate) fn insert(&mut self, key: &[u8], value: NamespaceDescriptorWithHash) { 40 | let mut node = &mut self.root; 41 | 42 | for &byte in key { 43 | let child = &mut node.children[byte as usize]; 44 | node = child.get_or_insert_with(RadixTreeNode::default); 45 | } 46 | 47 | node.value = Some(value); 48 | } 49 | 50 | pub(crate) fn get(&self, key: &[u8]) -> Option<&NamespaceDescriptorWithHash> { 51 | let mut node = &self.root; 52 | 53 | for &byte in key { 54 | let maybe_child = &node.children[byte as usize]; 55 | if let Some(child) = maybe_child { 56 | node = child; 57 | } else { 58 | return None.as_ref(); 59 | } 60 | } 61 | 62 | node.value.as_ref() 63 | } 64 | } 65 | 66 | #[cfg(test)] 67 | mod tests { 68 | use super::*; 69 | use crate::vwmap::{NamespaceFormat, NamespaceType}; 70 | 71 | #[test] 72 | fn test_insert_and_get() { 73 | let mut tree = RadixTree::default(); 74 | 75 | let namespace_descriptor_with_hash_1 = NamespaceDescriptorWithHash { 76 | descriptor: NamespaceDescriptor { 77 | namespace_index: 0, 78 | namespace_type: NamespaceType::Primitive, 79 | namespace_format: NamespaceFormat::Categorical, 80 | }, 81 | hash_seed: 1, 82 | }; 83 | 84 | let namespace_descriptor_with_hash_2 = NamespaceDescriptorWithHash { 85 | descriptor: NamespaceDescriptor { 86 | namespace_index: 10, 87 | namespace_type: NamespaceType::Primitive, 88 | namespace_format: NamespaceFormat::Categorical, 89 | }, 90 | hash_seed: 2, 91 | }; 92 | 93 | let namespace_descriptor_with_hash_3 = NamespaceDescriptorWithHash { 94 | descriptor: NamespaceDescriptor { 95 | namespace_index: 20, 96 | namespace_type: NamespaceType::Primitive, 97 | namespace_format: NamespaceFormat::Categorical, 98 | }, 99 | hash_seed: 3, 100 | }; 101 | 102 | tree.insert(b"A", namespace_descriptor_with_hash_1); 103 | tree.insert(b"AB", namespace_descriptor_with_hash_2); 104 | tree.insert(b"ABC", namespace_descriptor_with_hash_3); 105 | 106 | assert_eq!(tree.get(b"A"), Some(&namespace_descriptor_with_hash_1)); 107 | assert_eq!(tree.get(b"AB"), Some(&namespace_descriptor_with_hash_2)); 108 | assert_eq!(tree.get(b"ABC"), Some(&namespace_descriptor_with_hash_3)); 109 | assert_eq!(tree.get(b"ABCD"), None); 110 | } 111 | 112 | #[test] 113 | fn test_insert_and_get_empty_key() { 114 | let mut tree = RadixTree::default(); 115 | 116 | let namespace_descriptor_with_hash = NamespaceDescriptorWithHash { 117 | descriptor: NamespaceDescriptor { 118 | namespace_index: 0, 119 | namespace_type: NamespaceType::Primitive, 120 | namespace_format: NamespaceFormat::Categorical, 121 | }, 122 | hash_seed: 1, 123 | }; 124 | 125 | tree.insert(b"", namespace_descriptor_with_hash); 126 | 127 | assert_eq!(tree.get(b""), Some(&namespace_descriptor_with_hash)); 128 | assert_eq!(tree.get(b"A"), None); 129 | } 130 | 131 | #[test] 132 | fn test_insert_and_get_long_key() { 133 | let mut tree = RadixTree::default(); 134 | 135 | let namespace_descriptor_with_hash = NamespaceDescriptorWithHash { 136 | descriptor: NamespaceDescriptor { 137 | namespace_index: 0, 138 | namespace_type: NamespaceType::Primitive, 139 | namespace_format: NamespaceFormat::Categorical, 140 | }, 141 | hash_seed: 1, 142 | }; 143 | 144 | tree.insert(b"AB", namespace_descriptor_with_hash); 145 | 146 | assert_eq!(tree.get(b"AB"), Some(&namespace_descriptor_with_hash)); 147 | assert_eq!(tree.get(b"A"), None); 148 | assert_eq!(tree.get(b"B"), None); 149 | assert_eq!(tree.get(b"ABC"), None); 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /src/buffer_handler.rs: -------------------------------------------------------------------------------- 1 | use flate2::read::MultiGzDecoder; 2 | use std::fs::File; 3 | use std::io; 4 | use std::io::BufRead; 5 | use std::path::Path; 6 | use zstd::stream::read::Decoder as ZstdDecoder; 7 | 8 | pub fn create_buffered_input(input_filename: &str) -> Box { 9 | // Handler for different (or no) compression types 10 | 11 | let input = File::open(input_filename).expect("Could not open the input file."); 12 | 13 | let input_format = Path::new(&input_filename) 14 | .extension() 15 | .and_then(|ext| ext.to_str()) 16 | .expect("Failed to get the file extension."); 17 | 18 | match input_format { 19 | "gz" => { 20 | let gz_decoder = MultiGzDecoder::new(input); 21 | let reader = io::BufReader::new(gz_decoder); 22 | Box::new(reader) 23 | } 24 | "zst" => { 25 | let zstd_decoder = ZstdDecoder::new(input).unwrap(); 26 | let reader = io::BufReader::new(zstd_decoder); 27 | Box::new(reader) 28 | } 29 | "vw" => { 30 | let reader = io::BufReader::new(input); 31 | Box::new(reader) 32 | } 33 | _ => { 34 | panic!("Please specify a valid input format (.vw, .zst, .gz)"); 35 | } 36 | } 37 | } 38 | 39 | #[cfg(test)] 40 | mod tests { 41 | use super::*; 42 | use flate2::write::GzEncoder; 43 | use flate2::Compression; 44 | use std::io::{self, Read, Write}; 45 | use tempfile::Builder as TempFileBuilder; 46 | use tempfile::NamedTempFile; 47 | use zstd::stream::Encoder as ZstdEncoder; 48 | 49 | fn create_temp_file_with_contents( 50 | extension: &str, 51 | contents: &[u8], 52 | ) -> io::Result { 53 | let temp_file = TempFileBuilder::new() 54 | .suffix(&format!(".{}", extension)) 55 | .tempfile()?; 56 | temp_file.as_file().write_all(contents)?; 57 | Ok(temp_file) 58 | } 59 | 60 | fn create_gzipped_temp_file(contents: &[u8]) -> io::Result { 61 | let temp_file = TempFileBuilder::new().suffix(".gz").tempfile()?; 62 | let gz = GzEncoder::new(Vec::new(), Compression::default()); 63 | let mut gz_writer = io::BufWriter::new(gz); 64 | gz_writer.write_all(contents)?; 65 | let gz = gz_writer.into_inner()?.finish()?; 66 | temp_file.as_file().write_all(&gz)?; 67 | Ok(temp_file) 68 | } 69 | 70 | fn create_zstd_temp_file(contents: &[u8]) -> io::Result { 71 | let temp_file = TempFileBuilder::new().suffix(".zst").tempfile()?; 72 | let mut zstd_encoder = ZstdEncoder::new(Vec::new(), 1)?; 73 | zstd_encoder.write_all(contents)?; 74 | let encoded_data = zstd_encoder.finish()?; 75 | temp_file.as_file().write_all(&encoded_data)?; 76 | Ok(temp_file) 77 | } 78 | 79 | // Test for uncompressed file ("vw" extension) 80 | #[test] 81 | fn test_uncompressed_file() { 82 | let contents = b"Sample text for uncompressed file."; 83 | let temp_file = 84 | create_temp_file_with_contents("vw", contents).expect("Failed to create temp file"); 85 | let mut reader = create_buffered_input(temp_file.path().to_str().unwrap()); 86 | 87 | let mut buffer = Vec::new(); 88 | reader 89 | .read_to_end(&mut buffer) 90 | .expect("Failed to read from the reader"); 91 | assert_eq!( 92 | buffer, contents, 93 | "Contents did not match for uncompressed file." 94 | ); 95 | } 96 | 97 | // Test for gzipped files ("gz" extension) 98 | #[test] 99 | fn test_gz_compressed_file() { 100 | let contents = b"Sample text for gzipped file."; 101 | let temp_file = 102 | create_gzipped_temp_file(contents).expect("Failed to create gzipped temp file"); 103 | let mut reader = create_buffered_input(temp_file.path().to_str().unwrap()); 104 | 105 | let mut buffer = Vec::new(); 106 | reader 107 | .read_to_end(&mut buffer) 108 | .expect("Failed to read from the reader"); 109 | assert_eq!(buffer, contents, "Contents did not match for gzipped file."); 110 | } 111 | 112 | // Test for zstd compressed files ("zst" extension) 113 | #[test] 114 | fn test_zstd_compressed_file() { 115 | let contents = b"Sample text for zstd compressed file."; 116 | let temp_file = create_zstd_temp_file(contents).expect("Failed to create zstd temp file"); 117 | let mut reader = create_buffered_input(temp_file.path().to_str().unwrap()); 118 | 119 | let mut buffer = Vec::new(); 120 | reader 121 | .read_to_end(&mut buffer) 122 | .expect("Failed to read from the reader"); 123 | assert_eq!( 124 | buffer, contents, 125 | "Contents did not match for zstd compressed file." 126 | ); 127 | } 128 | 129 | // Test for unsupported file format 130 | #[test] 131 | #[should_panic(expected = "Please specify a valid input format (.vw, .zst, .gz)")] 132 | fn test_unsupported_file_format() { 133 | let contents = b"Some content"; 134 | let temp_file = 135 | create_temp_file_with_contents("txt", contents).expect("Failed to create temp file"); 136 | let _reader = create_buffered_input(temp_file.path().to_str().unwrap()); 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /src/block_relu.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::error::Error; 3 | 4 | use crate::block_helpers; 5 | use crate::feature_buffer; 6 | use crate::feature_buffer::FeatureBuffer; 7 | use crate::graph::{BlockGraph, BlockPtrOutput, InputSlot, OutputSlot}; 8 | use crate::model_instance; 9 | use crate::port_buffer; 10 | use crate::port_buffer::PortBuffer; 11 | use crate::regressor; 12 | use crate::regressor::BlockCache; 13 | use regressor::BlockTrait; 14 | 15 | pub struct BlockRELU { 16 | pub num_inputs: usize, 17 | pub input_offset: usize, 18 | pub output_offset: usize, 19 | } 20 | 21 | pub fn new_relu_block( 22 | bg: &mut BlockGraph, 23 | _mi: &model_instance::ModelInstance, 24 | input: BlockPtrOutput, 25 | ) -> Result> { 26 | let num_inputs = bg.get_num_output_values(vec![&input]); 27 | assert_ne!(num_inputs, 0); 28 | let block = Box::new(BlockRELU { 29 | output_offset: usize::MAX, 30 | input_offset: usize::MAX, 31 | num_inputs, 32 | }); 33 | let mut block_outputs = bg.add_node(block, vec![input])?; 34 | assert_eq!(block_outputs.len(), 1); 35 | Ok(block_outputs.pop().unwrap()) 36 | } 37 | 38 | impl BlockRELU { 39 | #[inline(always)] 40 | fn internal_forward(&self, pb: &mut port_buffer::PortBuffer) { 41 | debug_assert!(self.output_offset != usize::MAX); 42 | debug_assert!(self.input_offset != usize::MAX); 43 | debug_assert!(self.num_inputs > 0); 44 | 45 | unsafe { 46 | for i in 0..self.num_inputs as usize { 47 | let w = *pb.tape.get_unchecked_mut(self.input_offset + i); 48 | if w < 0.0 { 49 | *pb.tape.get_unchecked_mut(self.output_offset + i) = 0.0; 50 | } else { 51 | *pb.tape.get_unchecked_mut(self.output_offset + i) = w; 52 | } 53 | } 54 | } 55 | } 56 | } 57 | 58 | impl BlockTrait for BlockRELU { 59 | fn as_any(&mut self) -> &mut dyn Any { 60 | self 61 | } 62 | 63 | fn get_num_output_values(&self, output: OutputSlot) -> usize { 64 | assert_eq!(output.get_output_index(), 0); 65 | self.num_inputs 66 | } 67 | 68 | fn set_input_offset(&mut self, input: InputSlot, offset: usize) { 69 | assert_eq!(input.get_input_index(), 0); 70 | self.input_offset = offset; 71 | } 72 | 73 | fn set_output_offset(&mut self, output: OutputSlot, offset: usize) { 74 | assert_eq!(output.get_output_index(), 0); 75 | self.output_offset = offset; 76 | } 77 | 78 | #[inline(always)] 79 | fn forward_backward( 80 | &mut self, 81 | further_blocks: &mut [Box], 82 | fb: &feature_buffer::FeatureBuffer, 83 | pb: &mut port_buffer::PortBuffer, 84 | update: bool, 85 | ) { 86 | debug_assert!(self.output_offset != usize::MAX); 87 | debug_assert!(self.input_offset != usize::MAX); 88 | debug_assert!(self.num_inputs > 0); 89 | 90 | unsafe { 91 | for i in 0..self.num_inputs { 92 | let w = *pb.tape.get_unchecked_mut(self.input_offset + i); 93 | if w < 0.0 { 94 | *pb.tape.get_unchecked_mut(self.output_offset + i) = 0.0; 95 | *pb.tape.get_unchecked_mut(self.input_offset + i) = 0.0; 96 | } else { 97 | *pb.tape.get_unchecked_mut(self.output_offset + i) = w; 98 | *pb.tape.get_unchecked_mut(self.input_offset + i) = 1.0; 99 | } 100 | } 101 | 102 | block_helpers::forward_backward(further_blocks, fb, pb, update); 103 | 104 | if update { 105 | for i in 0..self.num_inputs { 106 | let gradient = *pb.tape.get_unchecked(self.output_offset + i); 107 | *pb.tape.get_unchecked_mut(self.input_offset + i) *= gradient; 108 | } 109 | } 110 | } 111 | } 112 | 113 | fn forward( 114 | &self, 115 | further_blocks: &[Box], 116 | fb: &feature_buffer::FeatureBuffer, 117 | pb: &mut port_buffer::PortBuffer, 118 | ) { 119 | self.internal_forward(pb); 120 | block_helpers::forward(further_blocks, fb, pb); 121 | } 122 | 123 | fn forward_with_cache( 124 | &self, 125 | further_blocks: &[Box], 126 | fb: &FeatureBuffer, 127 | pb: &mut PortBuffer, 128 | caches: &[BlockCache], 129 | ) { 130 | self.internal_forward(pb); 131 | block_helpers::forward_with_cache(further_blocks, fb, pb, caches); 132 | } 133 | } 134 | 135 | #[cfg(test)] 136 | mod tests { 137 | // Note this useful idiom: importing names from outer (for mod tests) scope. 138 | use super::*; 139 | use crate::assert_epsilon; 140 | use crate::block_misc; 141 | use crate::feature_buffer; 142 | use block_helpers::slearn2; 143 | use block_misc::Observe; 144 | 145 | fn fb_vec() -> feature_buffer::FeatureBuffer { 146 | feature_buffer::FeatureBuffer { 147 | label: 0.0, 148 | example_importance: 1.0, 149 | example_number: 0, 150 | lr_buffer: Vec::new(), 151 | ffm_buffer: Vec::new(), 152 | } 153 | } 154 | 155 | #[test] 156 | fn test_simple_positive() { 157 | let mi = model_instance::ModelInstance::new_empty().unwrap(); 158 | let mut bg = BlockGraph::new(); 159 | let input_block = block_misc::new_const_block(&mut bg, vec![2.0]).unwrap(); 160 | let relu_block = new_relu_block(&mut bg, &mi, input_block).unwrap(); 161 | let _observe_block = 162 | block_misc::new_observe_block(&mut bg, relu_block, Observe::Forward, Some(1.0)) 163 | .unwrap(); 164 | bg.finalize(); 165 | bg.allocate_and_init_weights(&mi); 166 | 167 | let mut pb = bg.new_port_buffer(); 168 | 169 | let fb = fb_vec(); 170 | assert_epsilon!(slearn2(&mut bg, &fb, &mut pb, true), 2.0); 171 | assert_epsilon!(slearn2(&mut bg, &fb, &mut pb, true), 2.0); // relu desnt learn 172 | } 173 | } 174 | -------------------------------------------------------------------------------- /src/block_loss_functions.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::error::Error; 3 | 4 | use crate::block_helpers; 5 | use crate::feature_buffer; 6 | use crate::feature_buffer::FeatureBuffer; 7 | use crate::graph; 8 | use crate::port_buffer; 9 | use crate::port_buffer::PortBuffer; 10 | use crate::regressor; 11 | use crate::regressor::BlockCache; 12 | use regressor::BlockTrait; 13 | 14 | #[inline(always)] 15 | pub fn logistic(t: f32) -> f32 { 16 | (1.0 + (-t).exp()).recip() 17 | } 18 | 19 | pub struct BlockSigmoid { 20 | num_inputs: usize, 21 | input_offset: usize, 22 | output_offset: usize, 23 | copy_to_result: bool, 24 | } 25 | 26 | pub fn new_logloss_block( 27 | bg: &mut graph::BlockGraph, 28 | input: graph::BlockPtrOutput, 29 | copy_to_result: bool, 30 | ) -> Result> { 31 | let num_inputs = bg.get_num_output_values(vec![&input]); 32 | let block = Box::new(BlockSigmoid { 33 | num_inputs, 34 | input_offset: usize::MAX, 35 | output_offset: usize::MAX, 36 | copy_to_result, 37 | }); 38 | let mut block_outputs = bg.add_node(block, vec![input]).unwrap(); 39 | assert_eq!(block_outputs.len(), 1); 40 | Ok(block_outputs.pop().unwrap()) 41 | } 42 | 43 | impl BlockSigmoid { 44 | #[inline(always)] 45 | fn internal_forward( 46 | &self, 47 | fb: &feature_buffer::FeatureBuffer, 48 | pb: &mut port_buffer::PortBuffer, 49 | ) { 50 | unsafe { 51 | debug_assert!(self.input_offset != usize::MAX); 52 | debug_assert!(self.output_offset != usize::MAX); 53 | let wsum: f32 = pb 54 | .tape 55 | .get_unchecked(self.input_offset..(self.input_offset + self.num_inputs)) 56 | .iter() 57 | .sum(); 58 | 59 | let prediction_probability: f32; 60 | if wsum.is_nan() { 61 | log::warn!( 62 | "NAN prediction in example {}, forcing 0.0", 63 | fb.example_number 64 | ); 65 | prediction_probability = logistic(0.0); 66 | } else if wsum < -50.0 { 67 | prediction_probability = logistic(-50.0); 68 | } else if wsum > 50.0 { 69 | prediction_probability = logistic(50.0); 70 | } else { 71 | prediction_probability = logistic(wsum); 72 | } 73 | 74 | pb.tape[self.output_offset] = prediction_probability; 75 | if self.copy_to_result { 76 | pb.observations.push(prediction_probability); 77 | } 78 | } 79 | } 80 | } 81 | 82 | impl BlockTrait for BlockSigmoid { 83 | fn as_any(&mut self) -> &mut dyn Any { 84 | self 85 | } 86 | 87 | fn get_num_output_values(&self, output: graph::OutputSlot) -> usize { 88 | assert_eq!(output.get_output_index(), 0); 89 | 1 90 | } 91 | 92 | fn set_input_offset(&mut self, input: graph::InputSlot, offset: usize) { 93 | assert_eq!(input.get_input_index(), 0); 94 | assert_eq!(self.input_offset, usize::MAX); // We only allow a single call 95 | self.input_offset = offset; 96 | } 97 | 98 | fn set_output_offset(&mut self, output: graph::OutputSlot, offset: usize) { 99 | assert_eq!(self.output_offset, usize::MAX); // We only allow a single call 100 | assert_eq!(output.get_output_index(), 0); 101 | self.output_offset = offset; 102 | } 103 | 104 | #[inline(always)] 105 | fn forward_backward( 106 | &mut self, 107 | further_blocks: &mut [Box], 108 | fb: &feature_buffer::FeatureBuffer, 109 | pb: &mut port_buffer::PortBuffer, 110 | update: bool, 111 | ) { 112 | debug_assert!(self.input_offset != usize::MAX); 113 | debug_assert!(self.output_offset != usize::MAX); 114 | 115 | unsafe { 116 | let wsum: f32 = pb 117 | .tape 118 | .get_unchecked(self.input_offset..(self.input_offset + self.num_inputs)) 119 | .iter() 120 | .sum(); 121 | 122 | let prediction_probability: f32; 123 | let general_gradient: f32; 124 | 125 | if wsum.is_nan() { 126 | log::error!( 127 | "NAN prediction in example {}, forcing 0.0", 128 | fb.example_number 129 | ); 130 | prediction_probability = logistic(0.0); 131 | general_gradient = 0.0; 132 | } else if wsum < -50.0 { 133 | prediction_probability = logistic(-50.0); 134 | general_gradient = 0.0; 135 | } else if wsum > 50.0 { 136 | prediction_probability = logistic(50.0); 137 | general_gradient = 0.0; 138 | } else { 139 | prediction_probability = logistic(wsum); 140 | general_gradient = -(fb.label - prediction_probability) * fb.example_importance; 141 | } 142 | 143 | *pb.tape.get_unchecked_mut(self.output_offset) = prediction_probability; 144 | if self.copy_to_result { 145 | pb.observations.push(prediction_probability); 146 | } 147 | block_helpers::forward_backward(further_blocks, fb, pb, update); 148 | // replace inputs with their gradients 149 | pb.tape 150 | .get_unchecked_mut(self.input_offset..(self.input_offset + self.num_inputs)) 151 | .fill(general_gradient); 152 | } 153 | } 154 | 155 | fn forward( 156 | &self, 157 | further_blocks: &[Box], 158 | fb: &feature_buffer::FeatureBuffer, 159 | pb: &mut port_buffer::PortBuffer, 160 | ) { 161 | self.internal_forward(fb, pb); 162 | block_helpers::forward(further_blocks, fb, pb); 163 | } 164 | 165 | fn forward_with_cache( 166 | &self, 167 | further_blocks: &[Box], 168 | fb: &FeatureBuffer, 169 | pb: &mut PortBuffer, 170 | caches: &[BlockCache], 171 | ) { 172 | self.internal_forward(fb, pb); 173 | block_helpers::forward_with_cache(further_blocks, fb, pb, caches); 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /SPEED.md: -------------------------------------------------------------------------------- 1 | # What Makes Fwumious Wabbit (FW) Fast? 2 | 3 | # Strict Focus 4 | 5 | The biggest advantage that FW has over VW and other logistic regression 6 | implementations is much narrower focus on what it does. Everything that 7 | would cause conditional jumps in inner loops is avoided or specialized 8 | using macros. 9 | 10 | FW does not implement regularization nor multipass. They could be added 11 | without hurting performance in the fast path by using static traits. 12 | Multipass can be done with an external tool by saving the final model and 13 | then doing another pass with that as an initial model. 14 | 15 | Unlike VW we also do not track prediction performance during the run - 16 | we do not continuously compute logloss. Additionally we are interested in 17 | the predictions only on the evaluation part of the dataset, therefore a 18 | new parameter --predictions-after allows for skipping outputting all 19 | predictions. We were surprised to learn that formatting floating point 20 | values for human readable output can take significant time compared to 21 | making the prediction itself. 22 | 23 | # Reduced Flexibility in Input Formats 24 | 25 | FW builds on VW's idea that you will likely run a lot of different models 26 | over the same input data, so it is worthwhile to parse it and save it in a 27 | "cache". FW packs that cache format even more tightly than VW. 28 | 29 | Generally FW supports a subset of VW's input format with more rigidity 30 | around namespace names. Namespaces can only be single letters and those 31 | that are used in an input file must be listed ahead of time in a separate 32 | file (called vw_namespaces_map.csv). 33 | 34 | # Carefully Chosen External Libraries 35 | 36 | Benchmarking was done to pick the fastest gzip library for our use case 37 | (Cloudflare's). For input cache file compression we use an extremely 38 | efficient LZ4 library (https://github.com/lz4/lz4). The deterministic random 39 | library is a Rust copy of Vowpal's method (merand48). Fasthash's murmur3 40 | algorithm is used for hashing to be compatible with Vowpal. 41 | 42 | # Using Rust to an Extreme 43 | 44 | The code uses Rust macros in order to create specialized code blocks and 45 | thus avoids branching in inner loops. We are waiting for const generics to 46 | hopefully use them in the future. 47 | 48 | Core parts of the parser, translator and regressor use the "unsafe" mode. 49 | This is done in order avoid the need for bounds checking in inner loops and 50 | the need to initialize memory. 51 | 52 | Some frequent codepaths were unrolled manually. 53 | 54 | # Specialization 55 | - We have specialized inner loops (with macros) for --ffm_k of 2, 4 and 8. 56 | - Optimizer is specialized as part of the code, so inner-loop ifs are 57 | avoided 58 | - In FFM we optimize situations where feature values (not weights) are 1.0. 59 | Three different multiplications by 1.0 within the loop that iterates ffm_k 60 | times are avoided with, 4% speedup. 61 | 62 | # Algorithmic Optimization 63 | 64 | We relied heavily on ideas from VW and built on top of them. VW's buffer 65 | management is fully replicated for logistic regression code. 66 | 67 | The FFM implementation in VW is not well tested and in our opinion it is 68 | buggy. We created a novel approach to FFM calculation as follows. 69 | Traditional quadruple loop (which VW uses) was replaced by a double loop. 70 | Intra-field combinations were allowed due to better prediction performance 71 | on our datasets and a faster execution time (no conditional jumps in the 72 | inner loop). 73 | 74 | We sum all changes to each feature weight in all FFM combinations and do the 75 | final update of each feature weight only once per example. 76 | 77 | # Look Up Tables 78 | 79 | A large speed boost comes in Adagrad from using a look-up table to 80 | map accumulated squared gradients to the learning rate. A simple bit shift 81 | plus lookup replaces use of the power function (or sqrt) and multiplication. 82 | This removes two mathematical operations from inner loop and provides a 83 | substantial speed boost. 84 | 85 | # Code and Data Structures Optimization 86 | 87 | The examples cache file is very tightly packed, inspired by video codecs. 88 | Similarly other data structures were carefully trimmed to avoid cache 89 | thrashing. 90 | 91 | 92 | # Prefetching 93 | When sensible, we prefetch the weight ahead of the for-loop. Since weights 94 | basically cause random memory access and modern machines are effectively 95 | NUMA, this helps a bit. 96 | 97 | # Compiling for Your Architecture 98 | We are compiling our code with 99 | ``` 100 | export RUSTFLAGS="-C opt-level=3 -C target-cpu=skylake" 101 | ``` 102 | This provides a speed improvement of about 5%. 103 | 104 | # Using Stack for Temporary Buffer 105 | We use a fixed size stack for the FFM temporary buffer, and when the buffer 106 | is bigger than the fixed sized stack, we use a heap-allocated buffer. Our 107 | code is entirely specialized on each codepath. Surprisingly we saw a 5%+ 108 | speedup when using the stack. Our belief is that there are more optimization 109 | opportunities if we make all memory addresses static, however that is really 110 | hard to achieve without per-run recompilation of rust code. 111 | 112 | # Things we Have Tried 113 | - Using a data oriented programming approach: We've separated weights and 114 | accumulated gradients into separate vectors. This caused a 50% slowdown. The 115 | theory is that we are doing lots of random memory accesses (to load weights) 116 | and additional latencies overshadow the benefits of (possibly) better 117 | vectorization. 118 | - Manually rolled-out AVX2 code for LR: While on paper instructions take 119 | less time to execute than LLVM code, in practice there is no difference due 120 | to the floating point operations not being the bottleneck - it looks like 121 | the bottleneck is delivering values from memory. 122 | 123 | 124 | # Ideas for Future Speed Improvements 125 | - On-demand specialization. This would require a compile-per-run, however 126 | everything we have learned indicates that this would bring an additional 127 | speed boost. 128 | - Use vectorization 129 | export RUSTFLAGS="-C opt-level=3 -C target-cpu=skylake -C llvm-args=--force-vector-width=4": 130 | We saw no measurable effect on laptop. Need to do further testing on server. 131 | - Profile Guided Optimizations 132 | We tried using PGO. The difference was unmeasurable - at most 0.5% speed 133 | up, which is basically at the noise level. Given the complications of 134 | doing PGO builds it is simply not worth it. 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /src/quantization.rs: -------------------------------------------------------------------------------- 1 | use half::f16; 2 | use std::io; 3 | 4 | const BY_X: usize = 2; 5 | const NUM_BUCKETS: f32 = 65025.0; 6 | const CRITICAL_WEIGHT_BOUND: f32 = 10.0; // naive detection of really bad weights, this should never get to prod. 7 | const MEAN_SAMPLING_RATIO: usize = 10; 8 | const MIN_PREC: f32 = 10_000.0; 9 | const MAX_PREC: f32 = 10_000.0; 10 | 11 | #[derive(Debug)] 12 | struct WeightStat { 13 | min: f32, 14 | max: f32, 15 | mean: f32, 16 | } 17 | 18 | fn emit_weight_statistics(weights: &[f32]) -> WeightStat { 19 | let mut min_weight = weights[0]; 20 | let mut max_weight = weights[0]; 21 | let mut mean_weight = 0.0; 22 | let mut weight_counter = 0; 23 | 24 | for (enx, &weight) in weights.iter().enumerate() { 25 | max_weight = max_weight.max(weight); 26 | min_weight = min_weight.min(weight); 27 | 28 | if enx % MEAN_SAMPLING_RATIO == 0 { 29 | weight_counter += 1; 30 | mean_weight += weight; 31 | } 32 | } 33 | 34 | WeightStat { 35 | min: (min_weight * MIN_PREC).round() / MIN_PREC, 36 | max: (max_weight * MAX_PREC).round() / MAX_PREC, 37 | mean: mean_weight / weight_counter as f32, 38 | } 39 | } 40 | 41 | pub fn quantize_ffm_weights(weights: &[f32]) -> Vec<[u8; BY_X]> { 42 | let weight_statistics = emit_weight_statistics(weights); 43 | let weight_increment = (weight_statistics.max - weight_statistics.min) / NUM_BUCKETS; 44 | 45 | if weight_statistics.mean.abs() > CRITICAL_WEIGHT_BOUND { 46 | log::warn!("Identified a very skewed weight distribution indicating exploded weights, not serving that! Mean weight value: {}", weight_statistics.mean); 47 | } 48 | 49 | log::info!( 50 | "Weight values; min: {}, max: {}, mean: {}", 51 | weight_statistics.min, 52 | weight_statistics.max, 53 | weight_statistics.mean 54 | ); 55 | 56 | let weight_increment_bytes = weight_increment.to_le_bytes(); 57 | let min_val_bytes = weight_statistics.min.to_le_bytes(); 58 | 59 | let mut v = Vec::<[u8; BY_X]>::with_capacity(weights.len() + 4); 60 | 61 | // Bytes are stored as pairs 62 | v.push([weight_increment_bytes[0], weight_increment_bytes[1]]); 63 | v.push([weight_increment_bytes[2], weight_increment_bytes[3]]); 64 | v.push([min_val_bytes[0], min_val_bytes[1]]); 65 | v.push([min_val_bytes[2], min_val_bytes[3]]); 66 | 67 | for &weight in weights { 68 | let weight_interval = ((weight - weight_statistics.min) / weight_increment).round(); 69 | v.push(f16::to_le_bytes(f16::from_f32(weight_interval))); 70 | } 71 | 72 | assert_eq!(v.len() - 4, weights.len()); 73 | 74 | v 75 | } 76 | 77 | pub fn dequantize_ffm_weights( 78 | input_bufreader: &mut dyn io::Read, 79 | reference_weights: &mut Vec, 80 | ) { 81 | let mut header: [u8; 8] = [0; 8]; 82 | input_bufreader.read_exact(&mut header).unwrap(); 83 | 84 | let weight_increment = f32::from_le_bytes([header[0], header[1], header[2], header[3]]); 85 | let weight_min = f32::from_le_bytes([header[4], header[5], header[6], header[7]]); 86 | let mut weight_bytes: [u8; 2] = [0; 2]; 87 | 88 | for weight_index in 0..reference_weights.len() { 89 | input_bufreader.read_exact(&mut weight_bytes).unwrap(); 90 | 91 | let weight_interval = f16::from_le_bytes(weight_bytes); 92 | let final_weight = weight_min + weight_interval.to_f32() * weight_increment; 93 | reference_weights[weight_index] = final_weight; 94 | } 95 | } 96 | 97 | #[cfg(test)] 98 | mod tests { 99 | use super::*; 100 | 101 | #[test] 102 | fn test_emit_statistics() { 103 | let some_random_float_weights = [0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; 104 | let out_struct = emit_weight_statistics(&some_random_float_weights); 105 | assert_eq!(out_struct.mean, 0.51); 106 | assert_eq!(out_struct.max, 0.6123); 107 | assert_eq!(out_struct.min, 0.11); 108 | } 109 | 110 | #[test] 111 | fn test_quantize() { 112 | let some_random_float_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; 113 | let output_weights = quantize_ffm_weights(&some_random_float_weights); 114 | assert_eq!(output_weights.len(), 10); 115 | } 116 | 117 | #[test] 118 | fn test_dequantize() { 119 | let mut reference_weights = vec![0.51, 0.12, 0.11, 0.1232, 0.6123, 0.23]; 120 | let old_reference_weights = reference_weights.clone(); 121 | let quantized_representation = quantize_ffm_weights(&reference_weights); 122 | let mut all_bytes: Vec = Vec::new(); 123 | for el in quantized_representation { 124 | all_bytes.push(el[0]); 125 | all_bytes.push(el[1]); 126 | } 127 | let mut contents = io::Cursor::new(all_bytes); 128 | dequantize_ffm_weights(&mut contents, &mut reference_weights); 129 | 130 | let matching = old_reference_weights 131 | .iter() 132 | .zip(&reference_weights) 133 | .filter(|&(a, b)| a == b) 134 | .count(); 135 | 136 | assert_ne!(matching, 0); 137 | 138 | let allowed_eps = 0.0001; 139 | let mut all_diffs = 0.0; 140 | for it in old_reference_weights.iter().zip(reference_weights.iter()) { 141 | let (old, new) = it; 142 | all_diffs += (old - new).abs(); 143 | } 144 | assert!(all_diffs < allowed_eps); 145 | } 146 | 147 | #[test] 148 | fn test_large_values() { 149 | let weights = vec![-1e9, 1e9]; 150 | let quantized = quantize_ffm_weights(&weights); 151 | let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); 152 | let mut dequantized = vec![0.0; weights.len()]; 153 | dequantize_ffm_weights(&mut buffer, &mut dequantized); 154 | for (w, dw) in weights.iter().zip(&dequantized) { 155 | assert!( 156 | (w - dw).abs() / w.abs() < 0.1, 157 | "Relative error is too large" 158 | ); 159 | } 160 | } 161 | 162 | #[test] 163 | #[ignore] 164 | fn test_performance() { 165 | let weights: Vec = (0..10_000_000).map(|x| x as f32).collect(); 166 | let now = std::time::Instant::now(); 167 | let quantized = quantize_ffm_weights(&weights); 168 | assert!(now.elapsed().as_millis() < 300); 169 | 170 | let mut buffer = io::Cursor::new(quantized.into_iter().flatten().collect::>()); 171 | let mut dequantized = vec![0.0; weights.len()]; 172 | let now = std::time::Instant::now(); 173 | dequantize_ffm_weights(&mut buffer, &mut dequantized); 174 | assert!(now.elapsed().as_millis() < 300); 175 | } 176 | } 177 | -------------------------------------------------------------------------------- /src/block_helpers.rs: -------------------------------------------------------------------------------- 1 | use crate::optimizer::OptimizerTrait; 2 | use std::error::Error; 3 | use std::io; 4 | use std::io::Read; 5 | 6 | use crate::feature_buffer; 7 | use crate::optimizer::OptimizerSGD; 8 | use crate::port_buffer; 9 | use crate::regressor::{BlockCache, BlockTrait}; 10 | use std::cmp::min; 11 | use std::mem; 12 | use std::slice; 13 | 14 | #[cfg(test)] 15 | use crate::graph; 16 | 17 | #[derive(Clone, Debug)] 18 | #[repr(C)] 19 | pub struct OptimizerData { 20 | pub optimizer_data: L::PerWeightStore, 21 | } 22 | 23 | #[derive(Clone, Debug)] 24 | #[repr(C)] 25 | pub struct WeightAndOptimizerData { 26 | pub weight: f32, 27 | pub optimizer_data: L::PerWeightStore, 28 | } 29 | 30 | #[macro_export] 31 | macro_rules! assert_epsilon { 32 | ($x:expr, $y:expr) => { 33 | let x = $x; // Make sure we evaluate only once 34 | let y = $y; 35 | if !(x - y < 0.000005 && y - x < 0.000005) { 36 | println!("Expectation: {}, Got: {}", y, x); 37 | panic!(); 38 | } 39 | }; 40 | } 41 | 42 | // It's OK! I am a limo driver! 43 | pub fn read_weights_from_buf( 44 | weights: &mut Vec, 45 | input_bufreader: &mut dyn io::Read, 46 | _use_quantization: bool, 47 | ) -> Result<(), Box> { 48 | if weights.is_empty() { 49 | return Err("Loading weights to unallocated weighs buffer".to_string())?; 50 | } 51 | unsafe { 52 | let buf_view: &mut [u8] = slice::from_raw_parts_mut( 53 | weights.as_mut_ptr() as *mut u8, 54 | weights.len() * mem::size_of::(), 55 | ); 56 | input_bufreader.read_exact(buf_view)?; 57 | } 58 | Ok(()) 59 | } 60 | 61 | // We get a vec here just so we easily know the type... 62 | // Skip amount of bytes that a weights vector would be 63 | pub fn skip_weights_from_buf( 64 | weights_len: usize, 65 | input_bufreader: &mut dyn Read, 66 | ) -> Result<(), Box> { 67 | let bytes_skip = weights_len * mem::size_of::(); 68 | io::copy( 69 | &mut input_bufreader.take(bytes_skip as u64), 70 | &mut io::sink(), 71 | )?; 72 | Ok(()) 73 | } 74 | 75 | pub fn write_weights_to_buf( 76 | weights: &Vec, 77 | output_bufwriter: &mut dyn io::Write, 78 | _use_quantization: bool, 79 | ) -> Result<(), Box> { 80 | if weights.is_empty() { 81 | assert!(false); 82 | return Err("Writing weights of unallocated weights buffer".to_string())?; 83 | } 84 | unsafe { 85 | let buf_view: &[u8] = slice::from_raw_parts( 86 | weights.as_ptr() as *const u8, 87 | weights.len() * mem::size_of::(), 88 | ); 89 | output_bufwriter.write_all(buf_view)?; 90 | } 91 | Ok(()) 92 | } 93 | 94 | pub fn read_weights_only_from_buf2( 95 | weights_len: usize, 96 | out_weights: &mut Vec>, 97 | input_bufreader: &mut dyn io::Read, 98 | ) -> Result<(), Box> { 99 | const BUF_LEN: usize = 1024 * 1024; 100 | let mut in_weights: Vec> = Vec::with_capacity(BUF_LEN); 101 | let mut remaining_weights = weights_len; 102 | let mut out_idx: usize = 0; 103 | if weights_len != out_weights.len() { 104 | return Err(format!("read_weights_only_from_buf2 - number of weights to read ({}) and number of weights allocated ({}) isn't the same", weights_len, out_weights.len()))?; 105 | } 106 | 107 | unsafe { 108 | while remaining_weights > 0 { 109 | let chunk_size = min(remaining_weights, BUF_LEN); 110 | in_weights.set_len(chunk_size); 111 | let in_weights_view: &mut [u8] = slice::from_raw_parts_mut( 112 | in_weights.as_mut_ptr() as *mut u8, 113 | chunk_size * mem::size_of::>(), 114 | ); 115 | input_bufreader.read_exact(in_weights_view)?; 116 | for w in &in_weights { 117 | out_weights.get_unchecked_mut(out_idx).weight = w.weight; 118 | out_idx += 1; 119 | } 120 | remaining_weights -= chunk_size; 121 | } 122 | } 123 | Ok(()) 124 | } 125 | 126 | #[inline(always)] 127 | pub fn get_input_output_borrows( 128 | i: &mut Vec, 129 | start1: usize, 130 | len1: usize, 131 | start2: usize, 132 | len2: usize, 133 | ) -> (&mut [f32], &mut [f32]) { 134 | debug_assert!( 135 | (start1 >= start2 + len2) || (start2 >= start1 + len1), 136 | "start1: {}, len1: {}, start2: {}, len2 {}", 137 | start1, 138 | len1, 139 | start2, 140 | len2 141 | ); 142 | unsafe { 143 | return if start2 > start1 { 144 | let (rest, second) = i.split_at_mut(start2); 145 | let (_, first) = rest.split_at_mut(start1); 146 | ( 147 | first.get_unchecked_mut(0..len1), 148 | second.get_unchecked_mut(0..len2), 149 | ) 150 | } else { 151 | let (rest, first) = i.split_at_mut(start1); 152 | let (_, second) = rest.split_at_mut(start2); 153 | ( 154 | first.get_unchecked_mut(0..len1), 155 | second.get_unchecked_mut(0..len2), 156 | ) 157 | }; 158 | } 159 | } 160 | 161 | #[cfg(test)] 162 | pub fn slearn2( 163 | bg: &mut graph::BlockGraph, 164 | fb: &feature_buffer::FeatureBuffer, 165 | pb: &mut port_buffer::PortBuffer, 166 | update: bool, 167 | ) -> f32 { 168 | pb.reset(); 169 | let (block_run, further_blocks) = bg.blocks_final.split_at_mut(1); 170 | block_run[0].forward_backward(further_blocks, fb, pb, update); 171 | 172 | pb.observations[0] 173 | } 174 | 175 | #[cfg(test)] 176 | pub fn ssetup_cache2( 177 | bg: &mut graph::BlockGraph, 178 | cache_fb: &feature_buffer::FeatureBuffer, 179 | caches: &mut Vec, 180 | ) { 181 | let (create_block_run, create_further_blocks) = bg.blocks_final.split_at_mut(1); 182 | create_block_run[0].create_forward_cache(create_further_blocks, caches); 183 | 184 | let (prepare_block_run, prepare_further_blocks) = bg.blocks_final.split_at_mut(1); 185 | prepare_block_run[0].prepare_forward_cache( 186 | prepare_further_blocks, 187 | cache_fb, 188 | caches.as_mut_slice(), 189 | ); 190 | } 191 | 192 | #[cfg(test)] 193 | pub fn spredict2_with_cache( 194 | bg: &mut graph::BlockGraph, 195 | fb: &feature_buffer::FeatureBuffer, 196 | pb: &mut port_buffer::PortBuffer, 197 | caches: &[BlockCache], 198 | ) -> f32 { 199 | pb.reset(); 200 | let (block_run, further_blocks) = bg.blocks_final.split_at(1); 201 | block_run[0].forward_with_cache(further_blocks, fb, pb, caches); 202 | 203 | pb.observations[0] 204 | } 205 | 206 | #[cfg(test)] 207 | pub fn spredict2( 208 | bg: &mut graph::BlockGraph, 209 | fb: &feature_buffer::FeatureBuffer, 210 | pb: &mut port_buffer::PortBuffer, 211 | ) -> f32 { 212 | pb.reset(); 213 | let (block_run, further_blocks) = bg.blocks_final.split_at(1); 214 | block_run[0].forward(further_blocks, fb, pb); 215 | pb.observations[0] 216 | } 217 | 218 | #[inline(always)] 219 | pub fn forward_backward( 220 | further_blocks: &mut [Box], 221 | fb: &feature_buffer::FeatureBuffer, 222 | pb: &mut port_buffer::PortBuffer, 223 | update: bool, 224 | ) { 225 | if let Some((next_regressor, further_blocks)) = further_blocks.split_first_mut() { 226 | next_regressor.forward_backward(further_blocks, fb, pb, update) 227 | } 228 | } 229 | 230 | #[inline(always)] 231 | pub fn forward( 232 | further_blocks: &[Box], 233 | fb: &feature_buffer::FeatureBuffer, 234 | pb: &mut port_buffer::PortBuffer, 235 | ) { 236 | match further_blocks.split_first() { 237 | Some((next_regressor, further_blocks)) => next_regressor.forward(further_blocks, fb, pb), 238 | None => {} 239 | } 240 | } 241 | 242 | #[inline(always)] 243 | pub fn forward_with_cache( 244 | further_blocks: &[Box], 245 | fb: &feature_buffer::FeatureBuffer, 246 | pb: &mut port_buffer::PortBuffer, 247 | caches: &[BlockCache], 248 | ) { 249 | if let Some((next_regressor, further_blocks)) = further_blocks.split_first() { 250 | next_regressor.forward_with_cache(further_blocks, fb, pb, caches) 251 | } 252 | } 253 | 254 | #[inline(always)] 255 | pub fn prepare_forward_cache( 256 | further_blocks: &mut [Box], 257 | fb: &feature_buffer::FeatureBuffer, 258 | caches: &mut [BlockCache], 259 | ) { 260 | if let Some((next_regressor, further_blocks)) = further_blocks.split_first_mut() { 261 | next_regressor.prepare_forward_cache(further_blocks, fb, caches) 262 | } 263 | } 264 | 265 | #[inline(always)] 266 | pub fn create_forward_cache( 267 | further_blocks: &mut [Box], 268 | caches: &mut Vec, 269 | ) { 270 | if let Some((next_regressor, further_blocks)) = further_blocks.split_first_mut() { 271 | next_regressor.create_forward_cache(further_blocks, caches) 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /src/vwmap.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use std::collections::HashMap; 3 | use std::error::Error; 4 | use std::fs; 5 | use std::io::prelude::*; 6 | use std::io::Error as IOError; 7 | use std::io::ErrorKind; 8 | use std::path::PathBuf; 9 | 10 | #[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize, Eq)] 11 | pub enum NamespaceType { 12 | Primitive = 0, 13 | Transformed = 1, 14 | } 15 | 16 | #[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize, Eq)] 17 | pub enum NamespaceFormat { 18 | Categorical = 0, // categorical (binary) features encoding (we have the hash and weight of each feature, value of the feature is assumed to be 1.0 (binary)) 19 | F32 = 1, // f32 features encoding (we have the hash and value of each feature, weight is assumed to be 1.0) 20 | } 21 | 22 | #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Copy)] 23 | pub struct NamespaceDescriptor { 24 | pub namespace_index: u16, 25 | pub namespace_type: NamespaceType, 26 | pub namespace_format: NamespaceFormat, 27 | } 28 | 29 | #[derive(Clone, Debug)] 30 | pub struct VwNamespaceMap { 31 | pub num_namespaces: usize, 32 | pub map_verbose_to_namespace_descriptor: HashMap, 33 | pub map_vwname_to_namespace_descriptor: HashMap, NamespaceDescriptor>, 34 | pub map_vwname_to_name: HashMap, std::string::String>, 35 | pub vw_source: VwNamespaceMapSource, // this is the source from which VwNamespaceMap can be constructed - for persistence 36 | } 37 | 38 | // this is serializible source from which VwNamespaceMap can be constructed 39 | #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] 40 | pub struct VwNamespaceMapEntry { 41 | pub namespace_vwname: std::string::String, 42 | namespace_verbose: std::string::String, 43 | namespace_index: u16, 44 | namespace_format: NamespaceFormat, 45 | } 46 | 47 | #[derive(Serialize, Deserialize, Debug, Eq, PartialEq, Clone)] 48 | pub struct VwNamespaceMapSource { 49 | pub namespace_skip_prefix: u32, 50 | pub entries: Vec, 51 | } 52 | 53 | impl VwNamespaceMap { 54 | pub fn new_from_source( 55 | vw_source: VwNamespaceMapSource, 56 | ) -> Result> { 57 | let mut vw = VwNamespaceMap { 58 | num_namespaces: 0, 59 | map_verbose_to_namespace_descriptor: HashMap::new(), 60 | map_vwname_to_namespace_descriptor: HashMap::new(), 61 | map_vwname_to_name: HashMap::new(), 62 | vw_source, 63 | }; 64 | 65 | for vw_entry in &vw.vw_source.entries { 66 | //let record = result?; 67 | let name_str = &vw_entry.namespace_verbose; 68 | let vwname_str = &vw_entry.namespace_vwname; 69 | 70 | let namespace_descriptor = NamespaceDescriptor { 71 | namespace_index: vw_entry.namespace_index, 72 | namespace_type: NamespaceType::Primitive, 73 | namespace_format: vw_entry.namespace_format, 74 | }; 75 | 76 | vw.map_vwname_to_name 77 | .insert(vwname_str.as_bytes().to_vec(), String::from(name_str)); 78 | vw.map_vwname_to_namespace_descriptor 79 | .insert(vwname_str.as_bytes().to_vec(), namespace_descriptor); 80 | vw.map_verbose_to_namespace_descriptor 81 | .insert(String::from(name_str), namespace_descriptor); 82 | 83 | if vw_entry.namespace_index as usize > vw.num_namespaces { 84 | vw.num_namespaces = vw_entry.namespace_index as usize; 85 | } 86 | } 87 | vw.num_namespaces += 1; 88 | Ok(vw) 89 | } 90 | 91 | pub fn new_from_csv_filepath(path: PathBuf) -> Result> { 92 | let mut input_bufreader = fs::File::open(&path).unwrap_or_else(|_| { 93 | panic!( 94 | "{}", 95 | format!( 96 | "Could not find vw_namespace_map.csv in input dataset directory of {:?}", 97 | path 98 | ) 99 | ) 100 | }); 101 | let mut s = String::new(); 102 | input_bufreader.read_to_string(&mut s)?; 103 | VwNamespaceMap::new(&s) 104 | } 105 | 106 | pub fn new(data: &str) -> Result> { 107 | let mut rdr = csv::ReaderBuilder::new() 108 | .has_headers(false) 109 | .flexible(true) 110 | .from_reader(data.as_bytes()); 111 | let mut vw_source = VwNamespaceMapSource { 112 | entries: vec![], 113 | namespace_skip_prefix: 0, 114 | }; 115 | for (i, record_w) in rdr.records().enumerate() { 116 | let record = record_w?; 117 | let vwname_str = &record[0]; 118 | if vwname_str.as_bytes().len() != 1 && i == 0 { 119 | log::warn!("Warning: multi-byte namespace names are not compatible with old style namespace arguments"); 120 | } 121 | 122 | if vwname_str == "_namespace_skip_prefix" { 123 | let namespace_skip_prefix = record[1] 124 | .parse() 125 | .expect("Couldn't parse _namespace_skip_prefix in vw_namespaces_map.csv"); 126 | log::info!( 127 | "_namespace_skip_prefix set in vw_namespace_map.csv is {}", 128 | namespace_skip_prefix 129 | ); 130 | vw_source.namespace_skip_prefix = namespace_skip_prefix; 131 | continue; 132 | } 133 | 134 | let name_str = &record[1]; 135 | let namespace_format = match &record.get(2) { 136 | Some("f32") => NamespaceFormat::F32, 137 | Some("") => NamespaceFormat::Categorical, 138 | None => NamespaceFormat::Categorical, 139 | Some(unknown_type) => return Err(Box::new(IOError::new(ErrorKind::Other, format!("Unknown type used for the feature in vw_namespace_map.csv: \"{}\". Only \"f32\" is possible.", unknown_type)))) 140 | }; 141 | 142 | vw_source.entries.push(VwNamespaceMapEntry { 143 | namespace_vwname: vwname_str.to_string(), 144 | namespace_verbose: name_str.to_string(), 145 | namespace_index: i as u16, 146 | namespace_format, 147 | }); 148 | } 149 | 150 | VwNamespaceMap::new_from_source(vw_source) 151 | } 152 | } 153 | 154 | #[cfg(test)] 155 | mod tests { 156 | // Note this useful idiom: importing names from outer (for mod tests) scope. 157 | use super::*; 158 | 159 | #[test] 160 | fn test_simple() { 161 | let vw_map_string = r#" 162 | A,featureA 163 | B,featureB 164 | C,featureC 165 | "#; 166 | let vw = VwNamespaceMap::new(vw_map_string).unwrap(); 167 | assert_eq!(vw.vw_source.entries.len(), 3); 168 | assert_eq!(vw.vw_source.namespace_skip_prefix, 0); 169 | assert_eq!( 170 | vw.vw_source.entries[0], 171 | VwNamespaceMapEntry { 172 | namespace_vwname: "A".to_string(), 173 | namespace_verbose: "featureA".to_string(), 174 | namespace_index: 0, 175 | namespace_format: NamespaceFormat::Categorical 176 | } 177 | ); 178 | 179 | assert_eq!( 180 | vw.vw_source.entries[1], 181 | VwNamespaceMapEntry { 182 | namespace_vwname: "B".to_string(), 183 | namespace_verbose: "featureB".to_string(), 184 | namespace_index: 1, 185 | namespace_format: NamespaceFormat::Categorical 186 | } 187 | ); 188 | 189 | assert_eq!( 190 | vw.vw_source.entries[2], 191 | VwNamespaceMapEntry { 192 | namespace_vwname: "C".to_string(), 193 | namespace_verbose: "featureC".to_string(), 194 | namespace_index: 2, 195 | namespace_format: NamespaceFormat::Categorical 196 | } 197 | ); 198 | } 199 | 200 | #[test] 201 | fn test_f32() { 202 | { 203 | let vw_map_string = "A,featureA,f32\n_namespace_skip_prefix,2"; 204 | let vw = VwNamespaceMap::new(vw_map_string).unwrap(); 205 | assert_eq!( 206 | vw.vw_source.entries[0], 207 | VwNamespaceMapEntry { 208 | namespace_vwname: "A".to_string(), 209 | namespace_verbose: "featureA".to_string(), 210 | namespace_index: 0, 211 | namespace_format: NamespaceFormat::F32 212 | } 213 | ); 214 | assert_eq!(vw.vw_source.namespace_skip_prefix, 2); 215 | } 216 | { 217 | let vw_map_string = "A,featureA,blah\n"; 218 | let result = VwNamespaceMap::new(vw_map_string); 219 | assert!(result.is_err()); 220 | assert_eq!(format!("{:?}", result), "Err(Custom { kind: Other, error: \"Unknown type used for the feature in vw_namespace_map.csv: \\\"blah\\\". Only \\\"f32\\\" is possible.\" })"); 221 | } 222 | } 223 | } 224 | -------------------------------------------------------------------------------- /src/block_normalize.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | use std::error::Error; 3 | 4 | use crate::block_helpers; 5 | use crate::feature_buffer; 6 | use crate::feature_buffer::FeatureBuffer; 7 | use crate::graph; 8 | use crate::model_instance; 9 | use crate::port_buffer; 10 | use crate::port_buffer::PortBuffer; 11 | use crate::regressor; 12 | use crate::regressor::BlockCache; 13 | use regressor::BlockTrait; 14 | 15 | const EPS: f32 = 1e-2; 16 | 17 | pub struct BlockNormalize { 18 | pub num_inputs: usize, 19 | pub input_offset: usize, 20 | pub output_offset: usize, 21 | } 22 | 23 | // This is purely variance normalization as described in 24 | // https://arxiv.org/pdf/2006.12753.pdf 25 | // Early results show no improvements for normalization od neural layers 26 | 27 | pub fn new_normalize_layer_block( 28 | bg: &mut graph::BlockGraph, 29 | _mi: &model_instance::ModelInstance, 30 | input: graph::BlockPtrOutput, 31 | ) -> Result> { 32 | let num_inputs = bg.get_num_output_values(vec![&input]); 33 | assert_ne!(num_inputs, 0); 34 | let block = Box::new(BlockNormalize { 35 | output_offset: usize::MAX, 36 | input_offset: usize::MAX, 37 | num_inputs, 38 | }); 39 | let mut block_outputs = bg.add_node(block, vec![input])?; 40 | assert_eq!(block_outputs.len(), 1); 41 | Ok(block_outputs.pop().unwrap()) 42 | } 43 | 44 | impl BlockTrait for BlockNormalize { 45 | fn as_any(&mut self) -> &mut dyn Any { 46 | self 47 | } 48 | 49 | fn get_num_output_values(&self, output: graph::OutputSlot) -> usize { 50 | assert_eq!(output.get_output_index(), 0); 51 | return self.num_inputs; 52 | } 53 | 54 | fn set_input_offset(&mut self, input: graph::InputSlot, offset: usize) { 55 | assert_eq!(input.get_input_index(), 0); 56 | self.input_offset = offset; 57 | } 58 | 59 | fn set_output_offset(&mut self, output: graph::OutputSlot, offset: usize) { 60 | assert_eq!(output.get_output_index(), 0); 61 | self.output_offset = offset; 62 | } 63 | 64 | #[inline(always)] 65 | fn forward_backward( 66 | &mut self, 67 | further_blocks: &mut [Box], 68 | fb: &feature_buffer::FeatureBuffer, 69 | pb: &mut port_buffer::PortBuffer, 70 | update: bool, 71 | ) { 72 | debug_assert!(self.output_offset != usize::MAX); 73 | debug_assert!(self.input_offset != usize::MAX); 74 | debug_assert!(self.num_inputs > 0); 75 | 76 | unsafe { 77 | let mut mean: f32 = 0.0; 78 | for i in 0..self.num_inputs { 79 | mean += *pb.tape.get_unchecked_mut(self.input_offset + i); 80 | } 81 | mean /= self.num_inputs as f32; 82 | let meansq = mean * mean; 83 | let mut variance: f32 = 0.0; 84 | for i in 0..self.num_inputs { 85 | let w = meansq - *pb.tape.get_unchecked_mut(self.input_offset + i); 86 | variance += w * w; 87 | } 88 | variance += EPS; 89 | variance /= self.num_inputs as f32; 90 | variance = variance.sqrt(); 91 | 92 | let variance_inv = 1.0 / variance; 93 | 94 | for i in 0..self.num_inputs { 95 | *pb.tape.get_unchecked_mut(self.output_offset + i) = 96 | (*pb.tape.get_unchecked(self.input_offset + i) - mean) * variance_inv; 97 | } 98 | block_helpers::forward_backward(further_blocks, fb, pb, update); 99 | 100 | if update { 101 | for i in 0..self.num_inputs { 102 | *pb.tape.get_unchecked_mut(self.input_offset + i) = 103 | *pb.tape.get_unchecked_mut(self.output_offset + i) * variance_inv; 104 | } 105 | } 106 | } 107 | } 108 | 109 | fn forward( 110 | &self, 111 | further_blocks: &[Box], 112 | fb: &feature_buffer::FeatureBuffer, 113 | pb: &mut port_buffer::PortBuffer, 114 | ) { 115 | self.internal_forward(pb); 116 | block_helpers::forward(further_blocks, fb, pb); 117 | } 118 | 119 | fn forward_with_cache( 120 | &self, 121 | further_blocks: &[Box], 122 | fb: &FeatureBuffer, 123 | pb: &mut PortBuffer, 124 | caches: &[BlockCache], 125 | ) { 126 | self.internal_forward(pb); 127 | block_helpers::forward_with_cache(further_blocks, fb, pb, caches); 128 | } 129 | } 130 | 131 | impl BlockNormalize { 132 | #[inline(always)] 133 | fn internal_forward(&self, pb: &mut port_buffer::PortBuffer) -> f32 { 134 | debug_assert!(self.output_offset != usize::MAX); 135 | debug_assert!(self.input_offset != usize::MAX); 136 | debug_assert!(self.num_inputs > 0); 137 | 138 | unsafe { 139 | let mut mean: f32 = 0.0; 140 | for i in 0..self.num_inputs { 141 | mean += *pb.tape.get_unchecked_mut(self.input_offset + i); 142 | } 143 | mean /= self.num_inputs as f32; 144 | let meansq = mean * mean; 145 | let mut variance: f32 = 0.0; 146 | for i in 0..self.num_inputs { 147 | let w = meansq - *pb.tape.get_unchecked_mut(self.input_offset + i); 148 | variance += w * w; 149 | } 150 | variance += EPS; 151 | variance /= self.num_inputs as f32; 152 | variance = variance.sqrt(); 153 | 154 | let variance_inv = 1.0 / variance; 155 | 156 | for i in 0..self.num_inputs { 157 | *pb.tape.get_unchecked_mut(self.output_offset + i) = 158 | *pb.tape.get_unchecked(self.input_offset + i) * variance_inv; 159 | } 160 | 161 | variance_inv 162 | } 163 | } 164 | } 165 | 166 | pub struct BlockStopBackward { 167 | pub num_inputs: usize, 168 | pub input_offset: usize, 169 | pub output_offset: usize, 170 | } 171 | 172 | // This is purely variance normalization as described in 173 | // https://arxiv.org/pdf/2006.12753.pdf 174 | // Early results show no improvements for normalization od neural layers 175 | 176 | pub fn new_stop_block( 177 | bg: &mut graph::BlockGraph, 178 | _mi: &model_instance::ModelInstance, 179 | input: graph::BlockPtrOutput, 180 | ) -> Result> { 181 | let num_inputs = bg.get_num_output_values(vec![&input]); 182 | debug_assert!(num_inputs != 0); 183 | let block = Box::new(BlockStopBackward { 184 | output_offset: usize::MAX, 185 | input_offset: usize::MAX, 186 | num_inputs, 187 | }); 188 | let mut block_outputs = bg.add_node(block, vec![input])?; 189 | assert_eq!(block_outputs.len(), 1); 190 | Ok(block_outputs.pop().unwrap()) 191 | } 192 | 193 | impl BlockTrait for BlockStopBackward { 194 | fn as_any(&mut self) -> &mut dyn Any { 195 | self 196 | } 197 | 198 | fn allocate_and_init_weights(&mut self, _mi: &model_instance::ModelInstance) {} 199 | 200 | fn get_num_output_values(&self, output: graph::OutputSlot) -> usize { 201 | assert_eq!(output.get_output_index(), 0); 202 | return self.num_inputs; 203 | } 204 | 205 | fn set_input_offset(&mut self, input: graph::InputSlot, offset: usize) { 206 | assert_eq!(input.get_input_index(), 0); 207 | self.input_offset = offset; 208 | } 209 | 210 | fn set_output_offset(&mut self, output: graph::OutputSlot, offset: usize) { 211 | assert_eq!(output.get_output_index(), 0); 212 | self.output_offset = offset; 213 | } 214 | 215 | #[inline(always)] 216 | fn forward_backward( 217 | &mut self, 218 | further_blocks: &mut [Box], 219 | fb: &feature_buffer::FeatureBuffer, 220 | pb: &mut port_buffer::PortBuffer, 221 | update: bool, 222 | ) { 223 | self.internal_forward(pb); 224 | 225 | block_helpers::forward_backward(further_blocks, fb, pb, update); 226 | 227 | if update { 228 | pb.tape[self.input_offset..(self.input_offset + self.num_inputs)].fill(0.0); 229 | } 230 | } 231 | 232 | fn forward( 233 | &self, 234 | further_blocks: &[Box], 235 | fb: &feature_buffer::FeatureBuffer, 236 | pb: &mut port_buffer::PortBuffer, 237 | ) { 238 | self.internal_forward(pb); 239 | block_helpers::forward(further_blocks, fb, pb); 240 | } 241 | 242 | fn forward_with_cache( 243 | &self, 244 | further_blocks: &[Box], 245 | fb: &FeatureBuffer, 246 | pb: &mut PortBuffer, 247 | caches: &[BlockCache], 248 | ) { 249 | self.internal_forward(pb); 250 | block_helpers::forward_with_cache(further_blocks, fb, pb, caches); 251 | } 252 | } 253 | 254 | impl BlockStopBackward { 255 | #[inline(always)] 256 | fn internal_forward(&self, pb: &mut port_buffer::PortBuffer) { 257 | debug_assert!(self.output_offset != usize::MAX); 258 | debug_assert!(self.input_offset != usize::MAX); 259 | debug_assert!(self.num_inputs > 0); 260 | 261 | pb.tape.copy_within( 262 | self.input_offset..(self.input_offset + self.num_inputs), 263 | self.output_offset, 264 | ); 265 | } 266 | } 267 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod block_ffm; 2 | pub mod block_helpers; 3 | pub mod block_loss_functions; 4 | pub mod block_lr; 5 | pub mod block_misc; 6 | pub mod block_neural; 7 | pub mod block_normalize; 8 | pub mod block_relu; 9 | pub mod buffer_handler; 10 | pub mod cache; 11 | pub mod cmdline; 12 | pub mod feature_buffer; 13 | pub mod feature_transform_executor; 14 | pub mod feature_transform_implementations; 15 | pub mod feature_transform_parser; 16 | pub mod graph; 17 | pub mod hogwild; 18 | pub mod logging_layer; 19 | pub mod model_instance; 20 | pub mod multithread_helpers; 21 | pub mod optimizer; 22 | pub mod parser; 23 | pub mod persistence; 24 | pub mod port_buffer; 25 | pub mod quantization; 26 | pub mod radix_tree; 27 | pub mod regressor; 28 | pub mod serving; 29 | pub mod version; 30 | pub mod vwmap; 31 | 32 | extern crate blas; 33 | extern crate half; 34 | extern crate intel_mkl_src; 35 | 36 | use crate::feature_buffer::FeatureBufferTranslator; 37 | use crate::multithread_helpers::BoxedRegressorTrait; 38 | use crate::parser::VowpalParser; 39 | use crate::port_buffer::PortBuffer; 40 | use crate::regressor::BlockCache; 41 | use crate::vwmap::NamespaceType; 42 | use shellwords; 43 | use std::ffi::CStr; 44 | use std::io::Cursor; 45 | use std::os::raw::c_char; 46 | 47 | const EOF_ERROR_CODE: f32 = -1.0; 48 | const EXCEPTION_ERROR_CODE: f32 = -1.0; 49 | 50 | #[repr(C)] 51 | pub struct FfiPredictor { 52 | _marker: core::marker::PhantomData, 53 | } 54 | 55 | pub struct Predictor { 56 | feature_buffer_translator: FeatureBufferTranslator, 57 | vw_parser: VowpalParser, 58 | regressor: BoxedRegressorTrait, 59 | pb: PortBuffer, 60 | cache: PredictorCache, 61 | } 62 | 63 | pub struct PredictorCache { 64 | blocks: Vec, 65 | input_buffer_size: usize, 66 | } 67 | 68 | impl Predictor { 69 | unsafe fn predict(&mut self, input_buffer: &str) -> f32 { 70 | let mut buffered_input = Cursor::new(input_buffer); 71 | let reading_result = self.vw_parser.next_vowpal(&mut buffered_input); 72 | let buffer = match reading_result { 73 | Ok([]) => { 74 | log::error!("Reading result for prediction returns EOF"); 75 | return EOF_ERROR_CODE; 76 | } // EOF 77 | Ok(buffer2) => buffer2, 78 | Err(e) => { 79 | log::error!("Reading result for prediction returns error {}", e); 80 | return EXCEPTION_ERROR_CODE; 81 | } 82 | }; 83 | self.feature_buffer_translator.translate(buffer, 0); 84 | self.regressor 85 | .predict(&self.feature_buffer_translator.feature_buffer, &mut self.pb) 86 | } 87 | 88 | unsafe fn predict_with_cache(&mut self, input_buffer: &str) -> f32 { 89 | let mut buffered_input = Cursor::new(&input_buffer); 90 | let reading_result = self 91 | .vw_parser 92 | .next_vowpal_with_cache(&mut buffered_input, self.cache.input_buffer_size); 93 | 94 | let buffer = match reading_result { 95 | Ok([]) => { 96 | log::error!("Reading result for prediction with cache returns EOF"); 97 | return EOF_ERROR_CODE; 98 | } // EOF 99 | Ok(buffer2) => buffer2, 100 | Err(e) => { 101 | log::error!( 102 | "Reading result for prediction with cache returns error {}", 103 | e 104 | ); 105 | return EXCEPTION_ERROR_CODE; 106 | } 107 | }; 108 | 109 | self.feature_buffer_translator.translate(buffer, 0); 110 | self.regressor.predict_with_cache( 111 | &self.feature_buffer_translator.feature_buffer, 112 | &mut self.pb, 113 | self.cache.blocks.as_slice(), 114 | ) 115 | } 116 | 117 | unsafe fn setup_cache(&mut self, input_buffer: &str) -> f32 { 118 | let mut buffered_input = Cursor::new(input_buffer); 119 | let reading_result = self.vw_parser.next_vowpal_with_size(&mut buffered_input); 120 | let (buffer, input_buffer_size) = match reading_result { 121 | Ok(([], _)) => { 122 | log::error!("Reading result for prediction with cache returns EOF"); 123 | return EOF_ERROR_CODE; 124 | } // EOF 125 | Ok(buffer2) => buffer2, 126 | Err(e) => { 127 | log::error!( 128 | "Reading result for prediction with cache returns error {}", 129 | e 130 | ); 131 | return EXCEPTION_ERROR_CODE; 132 | } 133 | }; 134 | // ignore last newline byte 135 | self.cache.input_buffer_size = input_buffer_size; 136 | self.feature_buffer_translator.translate_and_filter( 137 | buffer, 138 | 0, 139 | Some(NamespaceType::Primitive), 140 | ); 141 | let is_empty = self.cache.blocks.is_empty(); 142 | self.regressor.setup_cache( 143 | &self.feature_buffer_translator.feature_buffer, 144 | &mut self.cache.blocks, 145 | is_empty, 146 | ); 147 | 0.0 148 | } 149 | } 150 | 151 | #[no_mangle] 152 | pub extern "C" fn new_fw_predictor_prototype(command: *const c_char) -> *mut FfiPredictor { 153 | // create a "prototype" predictor that loads the weights file. This predictor is expensive, and is intended 154 | // to only be created once. If additional predictors are needed (e.g. for concurrent work), please 155 | // use this "prototype" with the clone_lite function, which will create cheap copies 156 | logging_layer::initialize_logging_layer(); 157 | 158 | let str_command = c_char_to_str(command); 159 | let words = shellwords::split(str_command).unwrap(); 160 | let cmd_matches = cmdline::create_expected_args().get_matches_from(words); 161 | let weights_filename = match cmd_matches.value_of("initial_regressor") { 162 | Some(filename) => filename, 163 | None => panic!("Cannot resolve input weights file name"), 164 | }; 165 | let (model_instance, vw_namespace_map, regressor) = 166 | persistence::new_regressor_from_filename(weights_filename, true, Some(&cmd_matches)) 167 | .unwrap(); 168 | let feature_buffer_translator = FeatureBufferTranslator::new(&model_instance); 169 | let vw_parser = VowpalParser::new(&vw_namespace_map); 170 | let sharable_regressor = BoxedRegressorTrait::new(Box::new(regressor)); 171 | let pb = sharable_regressor.new_portbuffer(); 172 | let predictor = Predictor { 173 | feature_buffer_translator, 174 | vw_parser, 175 | regressor: sharable_regressor, 176 | pb, 177 | cache: PredictorCache { 178 | blocks: Vec::default(), 179 | input_buffer_size: 0, 180 | }, 181 | }; 182 | Box::into_raw(Box::new(predictor)).cast() 183 | } 184 | 185 | #[no_mangle] 186 | pub unsafe extern "C" fn clone_lite(prototype: *mut FfiPredictor) -> *mut FfiPredictor { 187 | // given an expensive "prototype" predictor, this function creates cheap copies of it 188 | // that can be used in different threads concurrently. Note that individually, these predictors 189 | // are not thread safe, but it is safe to use multiple threads, each accessing only one predictor. 190 | let prototype: &mut Predictor = from_ptr(prototype); 191 | let lite_predictor = Predictor { 192 | feature_buffer_translator: prototype.feature_buffer_translator.clone(), 193 | vw_parser: prototype.vw_parser.clone(), 194 | regressor: prototype.regressor.clone(), 195 | pb: prototype.pb.clone(), 196 | 197 | cache: PredictorCache { 198 | blocks: Vec::new(), 199 | input_buffer_size: 0, 200 | }, 201 | }; 202 | Box::into_raw(Box::new(lite_predictor)).cast() 203 | } 204 | 205 | #[no_mangle] 206 | pub unsafe extern "C" fn fw_predict(ptr: *mut FfiPredictor, input_buffer: *const c_char) -> f32 { 207 | let str_buffer = c_char_to_str(input_buffer); 208 | let predictor: &mut Predictor = from_ptr(ptr); 209 | predictor.predict(str_buffer) 210 | } 211 | 212 | #[no_mangle] 213 | pub unsafe extern "C" fn fw_predict_with_cache( 214 | ptr: *mut FfiPredictor, 215 | input_buffer: *const c_char, 216 | ) -> f32 { 217 | let str_buffer = c_char_to_str(input_buffer); 218 | let predictor: &mut Predictor = from_ptr(ptr); 219 | predictor.predict_with_cache(str_buffer) 220 | } 221 | 222 | #[no_mangle] 223 | pub unsafe extern "C" fn fw_setup_cache( 224 | ptr: *mut FfiPredictor, 225 | input_buffer: *const c_char, 226 | ) -> f32 { 227 | let str_buffer = c_char_to_str(input_buffer); 228 | let predictor: &mut Predictor = from_ptr(ptr); 229 | predictor.setup_cache(str_buffer) 230 | } 231 | 232 | #[no_mangle] 233 | pub unsafe extern "C" fn free_predictor(ptr: *mut FfiPredictor) { 234 | drop::>(Box::from_raw(from_ptr(ptr))); 235 | } 236 | 237 | unsafe fn from_ptr<'a>(ptr: *mut FfiPredictor) -> &'a mut Predictor { 238 | if ptr.is_null() { 239 | log::error!("Fatal error, got NULL `Context` pointer"); 240 | std::process::abort(); 241 | } 242 | &mut *(ptr.cast()) 243 | } 244 | 245 | fn c_char_to_str<'a>(input_buffer: *const c_char) -> &'a str { 246 | let c_str = unsafe { 247 | assert!(!input_buffer.is_null()); 248 | CStr::from_ptr(input_buffer) 249 | }; 250 | let str_buffer = c_str.to_str().unwrap(); 251 | str_buffer 252 | } 253 | -------------------------------------------------------------------------------- /src/cache.rs: -------------------------------------------------------------------------------- 1 | use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; 2 | use std::error::Error; 3 | use std::fs; 4 | use std::io; 5 | use std::io::Read; 6 | use std::io::Write; 7 | use std::path; 8 | use std::{mem, slice}; 9 | 10 | use crate::vwmap; 11 | 12 | const CACHE_HEADER_MAGIC_STRING: &[u8; 4] = b"FWCA"; // Fwumious Wabbit CAche 13 | const CACHE_HEADER_VERSION: u32 = 11; 14 | /* 15 | Version incompatibilites: 16 | 10->11: float namespaces cannot have a weight attached 17 | 9->10: enable binning 18 | 8->9: enabled multi-byte feature names in vw files 19 | 7->8: add example importance to the parsed buffer format 20 | */ 21 | 22 | // Cache layout: 23 | // 4 bytes: Magic bytes 24 | // u32: Version of the cache format 25 | // u_size + blob: json encoding of vw_source 26 | // ...cached examples 27 | 28 | const READBUF_LEN: usize = 1024 * 100; 29 | 30 | // This is super ugly hack around the fact that we need to call finish() before closing the lz4 stream 31 | // Effectively lz4 implementation we're using is kind of bad 32 | // More info (and where workaround comes from): https://github.com/bozaro/lz4-rs/issues/9 33 | struct Wrapper { 34 | s: Option>, 35 | } 36 | impl Write for Wrapper { 37 | fn write(&mut self, buffer: &[u8]) -> Result { 38 | self.s.as_mut().unwrap().write(buffer) 39 | } 40 | 41 | fn flush(&mut self) -> Result<(), std::io::Error> { 42 | self.s.as_mut().unwrap().flush() 43 | } 44 | } 45 | impl Drop for Wrapper { 46 | fn drop(&mut self) { 47 | if let Some(s) = self.s.take() { 48 | let _result = s.finish(); 49 | } 50 | } 51 | } 52 | 53 | pub struct RecordCache { 54 | output_bufwriter: Box, 55 | input_bufreader: Box, 56 | temporary_filename: String, 57 | final_filename: String, 58 | pub writing: bool, 59 | pub reading: bool, 60 | // pub output_buffer: Vec, 61 | pub byte_buffer: Vec, //[u8; READBUF_LEN], 62 | start_pointer: usize, 63 | end_pointer: usize, 64 | total_read: usize, 65 | } 66 | 67 | impl RecordCache { 68 | pub fn new(input_filename: &str, enabled: bool, vw_map: &vwmap::VwNamespaceMap) -> RecordCache { 69 | let temporary_filename: String = format!("{}.fwcache.writing", input_filename); 70 | let final_filename: String = format!("{}.fwcache", input_filename); 71 | let gz = input_filename.ends_with("gz"); 72 | 73 | let mut rc = RecordCache { 74 | output_bufwriter: Box::new(io::BufWriter::new(io::sink())), 75 | input_bufreader: Box::new(io::empty()), 76 | temporary_filename: temporary_filename.to_string(), 77 | final_filename: final_filename.to_string(), 78 | writing: false, 79 | reading: false, 80 | byte_buffer: Vec::new(), 81 | start_pointer: 0, 82 | end_pointer: 0, 83 | total_read: 0, 84 | }; 85 | 86 | if enabled { 87 | if path::Path::new(&final_filename).exists() { 88 | rc.reading = true; 89 | if !gz { 90 | // we buffer ourselves, otherwise i would be wise to use bufreader 91 | rc.input_bufreader = Box::new(fs::File::open(&final_filename).unwrap()); 92 | } else { 93 | rc.input_bufreader = Box::new( 94 | lz4::Decoder::new(fs::File::open(&final_filename).unwrap()).unwrap(), 95 | ); 96 | } 97 | log::info!("using cache_file = {}", final_filename); 98 | log::info!("ignoring text input in favor of cache input"); 99 | match rc.verify_header(vw_map) { 100 | Ok(()) => {} 101 | Err(e) => { 102 | log::error!("Couldn't use the existing cache file: {:?}", e); 103 | rc.reading = false; 104 | } 105 | } 106 | rc.byte_buffer.resize(READBUF_LEN, 0); 107 | } 108 | 109 | if !rc.reading { 110 | rc.writing = true; 111 | log::info!("creating cache file = {}", final_filename); 112 | if !gz { 113 | rc.output_bufwriter = Box::new(io::BufWriter::new( 114 | fs::File::create(temporary_filename).unwrap(), 115 | )); 116 | } else { 117 | let w = Wrapper { 118 | s: Some( 119 | lz4::EncoderBuilder::new() 120 | .level(3) 121 | .build(fs::File::create(temporary_filename).unwrap()) 122 | .unwrap(), 123 | ), 124 | }; 125 | rc.output_bufwriter = Box::new(io::BufWriter::new(w)); 126 | } 127 | rc.write_header(vw_map).unwrap(); 128 | } 129 | } 130 | rc 131 | } 132 | 133 | pub fn push_record(&mut self, record_buf: &[u32]) -> Result<(), Box> { 134 | if self.writing { 135 | let element_size = mem::size_of::(); 136 | unsafe { 137 | let vv: &[u8] = slice::from_raw_parts( 138 | record_buf.as_ptr() as *const u8, 139 | record_buf.len() * element_size, 140 | ); 141 | self.output_bufwriter.write_all(vv)?; 142 | } 143 | } 144 | Ok(()) 145 | } 146 | 147 | pub fn write_finish(&mut self) -> Result<(), Box> { 148 | if self.writing { 149 | self.output_bufwriter.flush()?; 150 | fs::rename(&self.temporary_filename, &self.final_filename)?; 151 | } 152 | Ok(()) 153 | } 154 | 155 | pub fn write_header(&mut self, vw_map: &vwmap::VwNamespaceMap) -> Result<(), Box> { 156 | self.output_bufwriter.write_all(CACHE_HEADER_MAGIC_STRING)?; 157 | self.output_bufwriter 158 | .write_u32::(CACHE_HEADER_VERSION)?; 159 | vw_map.save_to_buf(&mut self.output_bufwriter)?; 160 | Ok(()) 161 | } 162 | 163 | pub fn verify_header(&mut self, vwmap: &vwmap::VwNamespaceMap) -> Result<(), Box> { 164 | let mut magic_string: [u8; 4] = [0; 4]; 165 | self.input_bufreader.read(&mut magic_string)?; 166 | if &magic_string != CACHE_HEADER_MAGIC_STRING { 167 | return Err("Cache header does not begin with magic bytes FWFW")?; 168 | } 169 | 170 | let version = self.input_bufreader.read_u32::()?; 171 | if CACHE_HEADER_VERSION != version { 172 | return Err(format!( 173 | "Cache file version of this binary: {}, version of the cache file: {}", 174 | CACHE_HEADER_VERSION, version 175 | ))?; 176 | } 177 | 178 | // Compare vwmap in cache and the one we've been given. If they differ, rebuild cache 179 | let vwmap_from_cache = vwmap::VwNamespaceMap::new_from_buf(&mut self.input_bufreader)?; 180 | if vwmap_from_cache.vw_source != vwmap.vw_source { 181 | return Err("vw_namespace_map.csv and the one from cache file differ")?; 182 | } 183 | 184 | Ok(()) 185 | } 186 | 187 | pub fn get_next_record(&mut self) -> Result<&[u32], Box> { 188 | if !self.reading { 189 | return Err("next_recrod() called on reading cache, when not opened in reading mode")?; 190 | } 191 | unsafe { 192 | // We're going to cast another view over the data, so we can read it as u32 193 | // This requires that the allocator we're using gives us sufficiently-aligned bytes, 194 | // but that's not guaranteed, so blow up to avoid UB if the allocator uses that freedom. 195 | assert_eq!( 196 | self.byte_buffer.as_ptr() as usize % mem::align_of::(), 197 | 0 198 | ); 199 | let buf_view: &[u32] = 200 | slice::from_raw_parts(self.byte_buffer.as_ptr() as *const u32, READBUF_LEN / 4); 201 | loop { 202 | // Classical buffer strategy: 203 | // Return if you have full record in buffer, 204 | // Otherwise shift the buffer and backfill it 205 | if self.end_pointer - self.start_pointer >= 4 { 206 | let record_len = buf_view[self.start_pointer / 4] as usize; 207 | if self.start_pointer + record_len * 4 <= self.end_pointer { 208 | let ret_buf = 209 | &buf_view[self.start_pointer / 4..self.start_pointer / 4 + record_len]; 210 | self.start_pointer += record_len * 4; 211 | return Ok(ret_buf); 212 | } 213 | } 214 | self.byte_buffer 215 | .copy_within(self.start_pointer..self.end_pointer, 0); 216 | self.end_pointer -= self.start_pointer; 217 | self.start_pointer = 0; 218 | 219 | let read_len = match self 220 | .input_bufreader 221 | .read(&mut self.byte_buffer[self.end_pointer..READBUF_LEN]) 222 | { 223 | Ok(0) => return Ok(&[]), 224 | Ok(n) => n, 225 | Err(e) => Err(e)?, 226 | }; 227 | 228 | self.end_pointer += read_len; 229 | self.total_read += read_len; 230 | } 231 | } 232 | } 233 | } 234 | -------------------------------------------------------------------------------- /src/optimizer.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | pub trait OptimizerTrait: std::clone::Clone { 4 | type PerWeightStore: std::clone::Clone; 5 | fn new() -> Self; 6 | fn init(&mut self, learning_rate: f32, power_t: f32, initial_acc_gradient: f32); 7 | unsafe fn calculate_update(&self, gradient: f32, data: &mut Self::PerWeightStore) -> f32; 8 | fn initial_data(&self) -> Self::PerWeightStore; 9 | fn get_name() -> &'static str; 10 | } 11 | 12 | /******************* SGD **************************/ 13 | // This is non-adaptive fixed learning rate SGD, which is exactly the same as Vowpal when --power_t is 0.0 14 | #[derive(Clone)] 15 | pub struct OptimizerSGD { 16 | learning_rate: f32, 17 | } 18 | 19 | impl OptimizerTrait for OptimizerSGD { 20 | type PerWeightStore = PhantomData<()>; 21 | 22 | fn get_name() -> &'static str { 23 | "SGD" 24 | } 25 | 26 | fn new() -> Self { 27 | OptimizerSGD { learning_rate: 0.0 } 28 | } 29 | 30 | fn init(&mut self, learning_rate: f32, _power_t: f32, _initial_acc_gradient: f32) { 31 | self.learning_rate = learning_rate; 32 | } 33 | 34 | #[inline(always)] 35 | unsafe fn calculate_update(&self, gradient: f32, _data: &mut Self::PerWeightStore) -> f32 { 36 | gradient * self.learning_rate 37 | } 38 | 39 | fn initial_data(&self) -> Self::PerWeightStore { 40 | std::marker::PhantomData {} 41 | } 42 | } 43 | 44 | /******************* Adagrad with flexible power_t **************************/ 45 | /* Regular Adagrad always uses sqrt (power_t = 0.5) */ 46 | /* For power_t = 0.5, this is slower than simply using sqrt */ 47 | /* however we generally always use lookup table for adagrad, so this */ 48 | /* implementation is mainly used as a reference */ 49 | #[derive(Clone)] 50 | pub struct OptimizerAdagradFlex { 51 | learning_rate: f32, 52 | minus_power_t: f32, 53 | initial_acc_gradient: f32, 54 | } 55 | 56 | impl OptimizerTrait for OptimizerAdagradFlex { 57 | fn get_name() -> &'static str { 58 | "AdagradFlex" 59 | } 60 | type PerWeightStore = f32; 61 | 62 | fn new() -> Self { 63 | OptimizerAdagradFlex { 64 | learning_rate: 0.0, 65 | minus_power_t: 0.0, 66 | initial_acc_gradient: 0.0, 67 | } 68 | } 69 | 70 | fn init(&mut self, learning_rate: f32, power_t: f32, initial_acc_gradient: f32) { 71 | self.learning_rate = learning_rate; 72 | self.minus_power_t = -power_t; 73 | self.initial_acc_gradient = initial_acc_gradient; 74 | } 75 | 76 | #[inline(always)] 77 | unsafe fn calculate_update(&self, gradient: f32, data: &mut Self::PerWeightStore) -> f32 { 78 | let accumulated_gradient_squared = *data; 79 | let gradient_squared = gradient * gradient; 80 | let new_accumulated_gradient_squared = accumulated_gradient_squared + gradient_squared; 81 | *data = new_accumulated_gradient_squared; 82 | let update = gradient 83 | * self.learning_rate 84 | * (new_accumulated_gradient_squared).powf(self.minus_power_t); 85 | if update.is_nan() || update.is_infinite() { 86 | return 0.0; 87 | } 88 | update 89 | } 90 | 91 | fn initial_data(&self) -> Self::PerWeightStore { 92 | self.initial_acc_gradient 93 | } 94 | } 95 | 96 | /***************** Adagrad using Look Up Table ******************/ 97 | // The intuition about low precision is : sqrt/powf is changing less and less as the parameter 98 | // grows. This means as parameter grows we can use lesser precision while keeping the error small. 99 | // Floating point encoding with separated exponent and mantissa is ideal for such optimization. 100 | 101 | pub const FASTMATH_LR_LUT_BITS: u8 = 11; 102 | pub const FASTMATH_LR_LUT_SIZE: usize = 1 << FASTMATH_LR_LUT_BITS; 103 | 104 | #[derive(Clone, Copy)] 105 | pub struct OptimizerAdagradLUT { 106 | pub fastmath_lr_lut: [f32; FASTMATH_LR_LUT_SIZE], 107 | } 108 | 109 | impl OptimizerTrait for OptimizerAdagradLUT { 110 | fn get_name() -> &'static str { 111 | "AdagradLUT" 112 | } 113 | type PerWeightStore = f32; 114 | 115 | fn new() -> Self { 116 | OptimizerAdagradLUT { 117 | fastmath_lr_lut: [0.0; FASTMATH_LR_LUT_SIZE], 118 | } 119 | } 120 | 121 | fn init(&mut self, learning_rate: f32, power_t: f32, initial_acc_gradient: f32) { 122 | log::info!("Calculating look-up tables for Adagrad learning rate calculation"); 123 | let minus_power_t = -power_t; 124 | for x in 0..FASTMATH_LR_LUT_SIZE { 125 | // accumulated gradients are always positive floating points, sign is guaranteed to be zero 126 | // floating point: 1 bit of sign, 7 bits of signed exponent then floating point bits (mantissa) 127 | // we will take 7 bits of exponent + whatever most significant bits of mantissa remain 128 | // we take two consequtive such values, so we act as if it had rounding 129 | let float_x = 130 | (f32::from_bits((x as u32) << (31 - FASTMATH_LR_LUT_BITS))) + initial_acc_gradient; 131 | let float_x_plus_one = 132 | (f32::from_bits(((x + 1) as u32) << (31 - FASTMATH_LR_LUT_BITS))) 133 | + initial_acc_gradient; 134 | let mut val = learning_rate 135 | * ((float_x).powf(minus_power_t) + (float_x_plus_one).powf(minus_power_t)) 136 | * 0.5; 137 | // Safety measure 138 | if val.is_nan() || val.is_infinite() { 139 | val = learning_rate; 140 | } 141 | 142 | self.fastmath_lr_lut[x] = val; 143 | } 144 | } 145 | 146 | #[inline(always)] 147 | unsafe fn calculate_update(&self, gradient: f32, data: &mut Self::PerWeightStore) -> f32 { 148 | let accumulated_gradient_squared = *data; 149 | debug_assert!(accumulated_gradient_squared >= 0.0); 150 | let gradient_squared = gradient * gradient; 151 | let new_accumulated_gradient_squared = accumulated_gradient_squared + gradient_squared; 152 | *data = new_accumulated_gradient_squared; 153 | let key = new_accumulated_gradient_squared.to_bits() >> (31 - FASTMATH_LR_LUT_BITS); 154 | let update = gradient * *self.fastmath_lr_lut.get_unchecked(key as usize); 155 | update 156 | } 157 | 158 | fn initial_data(&self) -> Self::PerWeightStore { 159 | // We took it into account when calcualting lookup table, so look at init() 160 | 0.0 161 | } 162 | } 163 | 164 | #[cfg(test)] 165 | mod tests { 166 | // Note this useful idiom: importing names from outer (for mod tests) scope. 167 | use super::*; 168 | 169 | #[test] 170 | fn test_sgd() { 171 | let mut l = OptimizerSGD::new(); 172 | l.init(0.15, 0.4, 0.0); 173 | unsafe { 174 | let mut acc: PhantomData<()> = std::marker::PhantomData {}; 175 | let p = l.calculate_update(0.1, &mut acc); 176 | assert_eq!(p, 0.1 * 0.15); 177 | } 178 | } 179 | 180 | #[test] 181 | fn test_adagradflex() { 182 | let mut l = OptimizerAdagradFlex::new(); 183 | l.init(0.15, 0.4, 0.0); 184 | unsafe { 185 | let mut acc: f32; 186 | acc = 0.9; 187 | let p = l.calculate_update(0.1, &mut acc); 188 | assert_eq!(p, 0.015576674); 189 | assert_eq!(acc, 0.9 + 0.1 * 0.1); 190 | 191 | acc = 0.0; 192 | let p = l.calculate_update(0.1, &mut acc); 193 | assert_eq!(p, 0.09464361); 194 | assert_eq!(acc, 0.1 * 0.1); 195 | 196 | acc = 0.0; 197 | l.calculate_update(0.0, &mut acc); 198 | // Here we check that we get NaN back - this is not good, but it's correct 199 | // assert!(p.is_nan()); 200 | assert_eq!(acc, 0.0); 201 | } 202 | } 203 | 204 | #[test] 205 | fn test_adagradlut() { 206 | let mut l = OptimizerAdagradLUT::new(); 207 | l.init(0.15, 0.4, 0.0); 208 | unsafe { 209 | let mut acc: f32; 210 | acc = 0.9; 211 | let p = l.calculate_update(0.1, &mut acc); 212 | assert_eq!(p, 0.015607622); 213 | assert_eq!(acc, 0.9 + 0.1 * 0.1); 214 | 215 | acc = 0.0; 216 | let p = l.calculate_update(0.1, &mut acc); 217 | assert_eq!(p, 0.09375872); 218 | assert_eq!(acc, 0.1 * 0.1); 219 | 220 | acc = 0.0; 221 | let p = l.calculate_update(0.0, &mut acc); 222 | // Here we check that we don't get Inf back 223 | assert_eq!(p, 0.0); 224 | assert_eq!(acc, 0.0); 225 | } 226 | } 227 | 228 | #[test] 229 | fn test_adagradlut_comparison() { 230 | // Here we test that our implementation of LUT has small enough relative error 231 | let mut l_lut = OptimizerAdagradFlex::new(); 232 | let mut l_flex = OptimizerAdagradLUT::new(); 233 | l_lut.init(0.15, 0.4, 0.0); 234 | l_flex.init(0.15, 0.4, 0.0); 235 | let test_gradients = [-1.0, -0.9, -0.1, -0.00001, 0.0, 0.00001, 0.1, 0.5, 0.9, 1.0]; 236 | let test_accumulations = [ 237 | 0.0000000001, 238 | 0.00001, 239 | 0.1, 240 | 0.5, 241 | 1.1, 242 | 2.0, 243 | 20.0, 244 | 200.0, 245 | 2000.0, 246 | 200000.0, 247 | 2000000.0, 248 | ]; 249 | 250 | unsafe { 251 | for gradient in test_gradients.iter() { 252 | for accumulation in test_accumulations.iter() { 253 | let mut acc_flex: f32 = *accumulation; 254 | let p_flex = l_flex.calculate_update(*gradient, &mut acc_flex); 255 | let mut acc_lut: f32 = *accumulation; 256 | let p_lut = l_lut.calculate_update(*gradient, &mut acc_lut); 257 | let error = (p_flex - p_lut).abs(); 258 | let relative_error: f32; 259 | if p_flex != 0.0 { 260 | relative_error = error / p_flex.abs(); 261 | } else { 262 | relative_error = error; // happens when the update is 0.0 263 | } 264 | 265 | assert!(relative_error < 0.05); 266 | } 267 | } 268 | } 269 | } 270 | } 271 | -------------------------------------------------------------------------------- /src/block_lr.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | 3 | use crate::graph; 4 | use crate::model_instance; 5 | use crate::optimizer; 6 | use crate::regressor; 7 | use crate::{feature_buffer, parser}; 8 | 9 | use std::error::Error; 10 | use std::io; 11 | 12 | use crate::block_helpers; 13 | use crate::port_buffer; 14 | use crate::regressor::BlockCache; 15 | use block_helpers::WeightAndOptimizerData; 16 | use optimizer::OptimizerTrait; 17 | use regressor::BlockTrait; 18 | 19 | pub struct BlockLR { 20 | pub weights: Vec>, 21 | pub weights_len: u32, 22 | pub optimizer_lr: L, 23 | pub output_offset: usize, 24 | pub num_combos: u32, 25 | } 26 | 27 | impl BlockLR { 28 | fn internal_forward( 29 | &self, 30 | fb: &feature_buffer::FeatureBuffer, 31 | pb: &mut port_buffer::PortBuffer, 32 | ) { 33 | debug_assert!(self.output_offset != usize::MAX); 34 | 35 | unsafe { 36 | let myslice = 37 | &mut pb.tape[self.output_offset..(self.output_offset + self.num_combos as usize)]; 38 | myslice.fill(0.0); 39 | for feature in fb.lr_buffer.iter() { 40 | let feature_index = feature.hash as usize; 41 | let feature_value = feature.value; 42 | let combo_index = feature.combo_index as usize; 43 | *myslice.get_unchecked_mut(combo_index) += 44 | self.weights.get_unchecked(feature_index).weight * feature_value; 45 | } 46 | } 47 | } 48 | } 49 | 50 | fn new_lr_block_without_weights( 51 | mi: &model_instance::ModelInstance, 52 | ) -> Result, Box> { 53 | let mut num_combos = mi.feature_combo_descs.len() as u32; 54 | if mi.add_constant_feature { 55 | num_combos += 1; 56 | } 57 | let mut reg_lr = BlockLR:: { 58 | weights: Vec::new(), 59 | weights_len: 0, 60 | optimizer_lr: L::new(), 61 | output_offset: usize::MAX, 62 | num_combos, 63 | }; 64 | reg_lr 65 | .optimizer_lr 66 | .init(mi.learning_rate, mi.power_t, mi.init_acc_gradient); 67 | reg_lr.weights_len = 1 << mi.bit_precision; 68 | Ok(Box::new(reg_lr)) 69 | } 70 | 71 | pub fn new_lr_block( 72 | bg: &mut graph::BlockGraph, 73 | mi: &model_instance::ModelInstance, 74 | ) -> Result> { 75 | let block = match mi.optimizer { 76 | model_instance::Optimizer::AdagradLUT => { 77 | new_lr_block_without_weights::(mi) 78 | } 79 | model_instance::Optimizer::AdagradFlex => { 80 | new_lr_block_without_weights::(mi) 81 | } 82 | model_instance::Optimizer::SGD => { 83 | new_lr_block_without_weights::(mi) 84 | } 85 | } 86 | .unwrap(); 87 | let mut block_outputs = bg.add_node(block, vec![])?; 88 | assert_eq!(block_outputs.len(), 1); 89 | Ok(block_outputs.pop().unwrap()) 90 | } 91 | 92 | impl BlockTrait for BlockLR { 93 | fn as_any(&mut self) -> &mut dyn Any { 94 | self 95 | } 96 | 97 | fn allocate_and_init_weights(&mut self, _mi: &model_instance::ModelInstance) { 98 | self.weights = vec![ 99 | WeightAndOptimizerData:: { 100 | weight: 0.0, 101 | optimizer_data: self.optimizer_lr.initial_data() 102 | }; 103 | self.weights_len as usize 104 | ]; 105 | } 106 | 107 | fn get_num_output_values(&self, output: graph::OutputSlot) -> usize { 108 | assert_eq!(output.get_output_index(), 0); 109 | self.num_combos as usize 110 | } 111 | 112 | fn set_input_offset(&mut self, _input: graph::InputSlot, _offset: usize) { 113 | panic!("You cannot set_input_offset() for BlockLR"); 114 | } 115 | 116 | fn set_output_offset(&mut self, output: graph::OutputSlot, offset: usize) { 117 | assert_eq!(output.get_output_index(), 0); 118 | debug_assert!(self.output_offset == usize::MAX); // We only allow a single call 119 | self.output_offset = offset; 120 | } 121 | 122 | #[inline(always)] 123 | fn forward_backward( 124 | &mut self, 125 | further_blocks: &mut [Box], 126 | fb: &feature_buffer::FeatureBuffer, 127 | pb: &mut port_buffer::PortBuffer, 128 | update: bool, 129 | ) { 130 | unsafe { 131 | self.internal_forward(fb, pb); 132 | 133 | block_helpers::forward_backward(further_blocks, fb, pb, update); 134 | 135 | if update { 136 | let myslice = &mut pb.tape.get_unchecked( 137 | self.output_offset..(self.output_offset + self.num_combos as usize), 138 | ); 139 | 140 | for feature in fb.lr_buffer.iter() { 141 | let feature_index = feature.hash as usize; 142 | let feature_value = feature.value; 143 | let gradient = 144 | myslice.get_unchecked(feature.combo_index as usize) * feature_value; 145 | let update = self.optimizer_lr.calculate_update( 146 | gradient, 147 | &mut self.weights.get_unchecked_mut(feature_index).optimizer_data, 148 | ); 149 | self.weights.get_unchecked_mut(feature_index).weight -= update; 150 | } 151 | } 152 | } 153 | } 154 | 155 | fn forward( 156 | &self, 157 | further_blocks: &[Box], 158 | fb: &feature_buffer::FeatureBuffer, 159 | pb: &mut port_buffer::PortBuffer, 160 | ) { 161 | self.internal_forward(fb, pb); 162 | block_helpers::forward(further_blocks, fb, pb); 163 | } 164 | 165 | fn forward_with_cache( 166 | &self, 167 | further_blocks: &[Box], 168 | fb: &feature_buffer::FeatureBuffer, 169 | pb: &mut port_buffer::PortBuffer, 170 | caches: &[BlockCache], 171 | ) { 172 | let Some((next_cache, further_caches)) = caches.split_first() else { 173 | log::warn!("Expected BlockLRCache caches, but non available, executing forward pass without cache"); 174 | self.forward(further_blocks, fb, pb); 175 | return; 176 | }; 177 | 178 | let BlockCache::LR { lr, combo_indexes } = next_cache else { 179 | log::warn!( 180 | "Unable to downcast cache to BlockLRCache, executing forward pass without cache" 181 | ); 182 | self.forward(further_blocks, fb, pb); 183 | return; 184 | }; 185 | 186 | unsafe { 187 | let lr_slice = 188 | &mut pb.tape[self.output_offset..(self.output_offset + self.num_combos as usize)]; 189 | lr_slice.copy_from_slice(lr.as_slice()); 190 | 191 | for feature in fb.lr_buffer.iter() { 192 | let combo_index = feature.combo_index as usize; 193 | if *combo_indexes.get_unchecked(combo_index) { 194 | continue; 195 | } 196 | let feature_index = feature.hash as usize; 197 | let feature_value = feature.value; 198 | *lr_slice.get_unchecked_mut(combo_index) += 199 | self.weights.get_unchecked(feature_index).weight * feature_value; 200 | } 201 | } 202 | block_helpers::forward_with_cache(further_blocks, fb, pb, further_caches); 203 | } 204 | 205 | fn create_forward_cache( 206 | &mut self, 207 | further_blocks: &mut [Box], 208 | caches: &mut Vec, 209 | ) { 210 | caches.push(BlockCache::LR { 211 | lr: vec![0.0; self.num_combos as usize], 212 | combo_indexes: vec![false; self.num_combos as usize], 213 | }); 214 | block_helpers::create_forward_cache(further_blocks, caches); 215 | } 216 | 217 | fn prepare_forward_cache( 218 | &mut self, 219 | further_blocks: &mut [Box], 220 | fb: &feature_buffer::FeatureBuffer, 221 | caches: &mut [BlockCache], 222 | ) { 223 | let Some((next_cache, further_caches)) = caches.split_first_mut() else { 224 | log::warn!( 225 | "Expected BlockLRCache caches, but non available, skipping cache preparation" 226 | ); 227 | return; 228 | }; 229 | 230 | let BlockCache::LR { lr, combo_indexes } = next_cache else { 231 | log::warn!("Unable to downcast cache to BlockLRCache, skipping cache preparation"); 232 | return; 233 | }; 234 | 235 | unsafe { 236 | combo_indexes.fill(false); 237 | lr.fill(0.0); 238 | 239 | let lr_slice = lr.as_mut_slice(); 240 | 241 | for feature in fb.lr_buffer.iter() { 242 | if (feature.hash & parser::IS_NOT_SINGLE_MASK) == 0 { 243 | continue; 244 | } 245 | let feature_index = feature.hash as usize; 246 | let feature_value = feature.value; 247 | let combo_index = feature.combo_index as usize; 248 | *lr_slice.get_unchecked_mut(combo_index) += 249 | self.weights.get_unchecked(feature_index).weight * feature_value; 250 | *combo_indexes.get_unchecked_mut(combo_index) = true; 251 | } 252 | } 253 | 254 | block_helpers::prepare_forward_cache(further_blocks, fb, further_caches); 255 | } 256 | 257 | fn get_serialized_len(&self) -> usize { 258 | self.weights_len as usize 259 | } 260 | 261 | fn read_weights_from_buf( 262 | &mut self, 263 | input_bufreader: &mut dyn io::Read, 264 | _use_quantization: bool, 265 | ) -> Result<(), Box> { 266 | block_helpers::read_weights_from_buf(&mut self.weights, input_bufreader, false) 267 | } 268 | 269 | fn write_weights_to_buf( 270 | &self, 271 | output_bufwriter: &mut dyn io::Write, 272 | _use_quantization: bool, 273 | ) -> Result<(), Box> { 274 | block_helpers::write_weights_to_buf(&self.weights, output_bufwriter, false) 275 | } 276 | 277 | fn read_weights_from_buf_into_forward_only( 278 | &self, 279 | input_bufreader: &mut dyn io::Read, 280 | forward: &mut Box, 281 | _use_quantization: bool, 282 | ) -> Result<(), Box> { 283 | let forward = forward 284 | .as_any() 285 | .downcast_mut::>() 286 | .unwrap(); 287 | block_helpers::read_weights_only_from_buf2::( 288 | self.weights_len as usize, 289 | &mut forward.weights, 290 | input_bufreader, 291 | ) 292 | } 293 | } 294 | -------------------------------------------------------------------------------- /src/feature_transform_executor.rs: -------------------------------------------------------------------------------- 1 | use crate::parser; 2 | use crate::vwmap; 3 | use std::error::Error; 4 | use std::io::Error as IOError; 5 | use std::io::ErrorKind; 6 | 7 | use std::cell::RefCell; 8 | 9 | use dyn_clone::{clone_trait_object, DynClone}; 10 | use fasthash::murmur3; 11 | 12 | use crate::feature_transform_implementations::{ 13 | TransformerBinner, TransformerCombine, TransformerLogRatioBinner, TransformerWeight, 14 | }; 15 | use crate::feature_transform_parser; 16 | 17 | pub fn default_seeds(to_namespace_index: u32) -> [u32; 5] { 18 | let to_namespace_index = to_namespace_index ^ 1u32 << 31; // compatibility with earlier version 19 | [ 20 | // These are random numbers, i threw a dice! 21 | murmur3::hash32_with_seed(vec![214, 231, 1, 55], to_namespace_index), 22 | murmur3::hash32_with_seed(vec![255, 6, 14, 69], to_namespace_index), 23 | murmur3::hash32_with_seed(vec![50, 6, 71, 123], to_namespace_index), 24 | murmur3::hash32_with_seed(vec![10, 3, 0, 43], to_namespace_index), 25 | murmur3::hash32_with_seed(vec![0, 53, 10, 201], to_namespace_index), 26 | ] 27 | } 28 | 29 | #[derive(Clone, Copy)] 30 | pub enum SeedNumber { 31 | Default = 0, 32 | One = 1, 33 | Two = 2, 34 | Three = 3, 35 | } 36 | 37 | #[derive(Clone)] 38 | pub struct ExecutorToNamespace { 39 | pub namespace_descriptor: vwmap::NamespaceDescriptor, 40 | pub namespace_seeds: [u32; 5], // These are precomputed namespace seeds 41 | pub tmp_data: Vec<(u32, f32)>, 42 | } 43 | 44 | #[derive(Clone)] 45 | pub struct ExecutorFromNamespace { 46 | pub namespace_descriptor: vwmap::NamespaceDescriptor, 47 | } 48 | 49 | impl ExecutorToNamespace { 50 | // We use const generics here as an experiment to see if they would be useful elsewhere to specialize functions 51 | #[inline(always)] 52 | pub fn emit_i32(&mut self, to_data: i32, hash_value: f32) { 53 | let hash_index = murmur3::hash32_with_seed(to_data.to_le_bytes(), *unsafe { 54 | self.namespace_seeds.get_unchecked(SEED_ID) 55 | }) & parser::MASK31; 56 | self.tmp_data.push((hash_index, hash_value)); 57 | } 58 | 59 | #[inline(always)] 60 | pub fn emit_f32(&mut self, f: f32, hash_value: f32, interpolated: bool) { 61 | if !f.is_finite() { 62 | // these handle INF, -INF and NAN 63 | self.emit_i32::(f.to_bits() as i32, hash_value); 64 | } else if interpolated { 65 | let floor = f.floor(); 66 | let floor_int = floor as i32; 67 | let part = f - floor; 68 | if part != 0.0 { 69 | self.emit_i32::(floor_int + 1, hash_value * part); 70 | } 71 | let part = 1.0 - part; 72 | if part != 0.0 { 73 | self.emit_i32::(floor_int, hash_value * part); 74 | } 75 | } else { 76 | self.emit_i32::(f as i32, hash_value); 77 | } 78 | } 79 | 80 | #[inline(always)] 81 | pub fn emit_i32_i32( 82 | &mut self, 83 | to_data1: i32, 84 | to_data2: i32, 85 | hash_value: f32, 86 | ) { 87 | let hash_index = murmur3::hash32_with_seed(to_data1.to_le_bytes(), unsafe { 88 | *self.namespace_seeds.get_unchecked(SEED_ID) 89 | }); 90 | let hash_index = 91 | murmur3::hash32_with_seed(to_data2.to_le_bytes(), hash_index) & parser::MASK31; 92 | self.tmp_data.push((hash_index, hash_value)); 93 | } 94 | } 95 | 96 | #[derive(Clone)] 97 | pub struct TransformExecutor { 98 | pub namespace_to: RefCell, 99 | pub function_executor: Box, 100 | } 101 | 102 | impl TransformExecutor { 103 | pub fn from_namespace_transform( 104 | namespace_transform: &feature_transform_parser::NamespaceTransform, 105 | ) -> Result> { 106 | let namespace_to = ExecutorToNamespace { 107 | namespace_descriptor: namespace_transform.to_namespace.namespace_descriptor, 108 | namespace_seeds: default_seeds( 109 | namespace_transform 110 | .to_namespace 111 | .namespace_descriptor 112 | .namespace_index as u32, 113 | ), 114 | tmp_data: Vec::new(), 115 | }; 116 | 117 | let te = TransformExecutor { 118 | namespace_to: RefCell::new(namespace_to), 119 | function_executor: Self::create_executor( 120 | &namespace_transform.function_name, 121 | &namespace_transform.from_namespaces, 122 | &namespace_transform.function_parameters, 123 | )?, 124 | }; 125 | Ok(te) 126 | } 127 | 128 | pub fn create_executor( 129 | function_name: &str, 130 | namespaces_from: &Vec, 131 | function_params: &Vec, 132 | ) -> Result, Box> { 133 | /* let mut executor_namespaces_from: Vec = Vec::new(); 134 | for namespace in namespaces_from { 135 | executor_namespaces_from.push(ExecutorFromNamespace{namespace_descriptor: namespace.namespace_descriptor, 136 | }); 137 | }*/ 138 | if function_name == "BinnerSqrtPlain" { 139 | TransformerBinner::create_function( 140 | &(|x, resolution| x.sqrt() * resolution), 141 | function_name, 142 | namespaces_from, 143 | function_params, 144 | false, 145 | ) 146 | } else if function_name == "BinnerSqrt" { 147 | TransformerBinner::create_function( 148 | &(|x, resolution| x.sqrt() * resolution), 149 | function_name, 150 | namespaces_from, 151 | function_params, 152 | true, 153 | ) 154 | } else if function_name == "BinnerLogPlain" { 155 | TransformerBinner::create_function( 156 | &(|x, resolution| x.ln() * resolution), 157 | function_name, 158 | namespaces_from, 159 | function_params, 160 | false, 161 | ) 162 | } else if function_name == "BinnerLog" { 163 | TransformerBinner::create_function( 164 | &(|x, resolution| x.ln() * resolution), 165 | function_name, 166 | namespaces_from, 167 | function_params, 168 | true, 169 | ) 170 | } else if function_name == "BinnerLogRatioPlain" { 171 | TransformerLogRatioBinner::create_function( 172 | function_name, 173 | namespaces_from, 174 | function_params, 175 | false, 176 | ) 177 | } else if function_name == "BinnerLogRatio" { 178 | TransformerLogRatioBinner::create_function( 179 | function_name, 180 | namespaces_from, 181 | function_params, 182 | true, 183 | ) 184 | } else if function_name == "Combine" { 185 | TransformerCombine::create_function(function_name, namespaces_from, function_params) 186 | } else if function_name == "Weight" { 187 | TransformerWeight::create_function(function_name, namespaces_from, function_params) 188 | } else { 189 | return Err(Box::new(IOError::new( 190 | ErrorKind::Other, 191 | format!("Unknown transformer function: {}", function_name), 192 | ))); 193 | } 194 | } 195 | } 196 | 197 | #[derive(Clone)] 198 | pub struct TransformExecutors { 199 | pub executors: Vec, 200 | } 201 | 202 | impl TransformExecutors { 203 | pub fn from_namespace_transforms( 204 | namespace_transforms: &feature_transform_parser::NamespaceTransforms, 205 | ) -> TransformExecutors { 206 | let mut executors: Vec = Vec::new(); 207 | for transformed_namespace in &namespace_transforms.v { 208 | let transformed_namespace_executor = 209 | TransformExecutor::from_namespace_transform(transformed_namespace).unwrap(); 210 | executors.push(transformed_namespace_executor); 211 | } 212 | TransformExecutors { executors } 213 | } 214 | 215 | /* 216 | // We don't use this function as we have put it into feature_reader! macro 217 | #[inline(always)] 218 | pub fn get_transformations<'a>(&self, record_buffer: &[u32], feature_index_offset: u32) -> &TransformExecutor { 219 | let executor_index = feature_index_offset & !feature_transform_parser::TRANSFORM_NAMESPACE_MARK; // remove transform namespace mark 220 | let executor = unsafe {&self.executors.get_unchecked(executor_index as usize)}; 221 | 222 | // If we have a cyclic defintion (which is a bug), this will panic! 223 | let mut namespace_to = executor.namespace_to.borrow_mut(); 224 | namespace_to.tmp_data.truncate(0); 225 | 226 | executor.function_executor.execute_function(record_buffer, &mut namespace_to, &self); 227 | executor 228 | } 229 | */ 230 | } 231 | 232 | // Some black magic from: https://stackoverflow.com/questions/30353462/how-to-clone-a-struct-storing-a-boxed-trait-object 233 | // We need clone() because of serving. There is also an option of doing FeatureBufferTransform from scratch in each thread 234 | pub trait FunctionExecutorTrait: DynClone + Send { 235 | fn execute_function( 236 | &self, 237 | record_buffer: &[u32], 238 | to_namespace: &mut ExecutorToNamespace, 239 | transform_executors: &TransformExecutors, 240 | ); 241 | } 242 | clone_trait_object!(FunctionExecutorTrait); 243 | 244 | #[cfg(test)] 245 | mod tests { 246 | // Note this useful idiom: importing names from outer (for mod tests) scope. 247 | use super::*; 248 | use crate::feature_transform_executor::default_seeds; 249 | use crate::parser; 250 | 251 | fn ns_desc(i: u16) -> vwmap::NamespaceDescriptor { 252 | vwmap::NamespaceDescriptor { 253 | namespace_index: i, 254 | namespace_type: vwmap::NamespaceType::Primitive, 255 | namespace_format: vwmap::NamespaceFormat::Categorical, 256 | } 257 | } 258 | 259 | #[test] 260 | fn test_interpolation() { 261 | let to_namespace_empty = ExecutorToNamespace { 262 | namespace_descriptor: ns_desc(1), 263 | namespace_seeds: default_seeds(1), // These are precomputed namespace seeds 264 | tmp_data: Vec::new(), 265 | }; 266 | let mut to_namespace = to_namespace_empty; 267 | to_namespace.emit_f32::<{ SeedNumber::Default as usize }>(5.4, 20.0, true); 268 | let to_data_1: i32 = 6; 269 | let to_data_1_value = 20.0 * (5.4 - 5.0); 270 | let hash_index_1 = murmur3::hash32_with_seed( 271 | to_data_1.to_le_bytes(), 272 | to_namespace.namespace_seeds[SeedNumber::Default as usize], 273 | ) & parser::MASK31; 274 | let to_data_2: i32 = 5; 275 | let to_data_2_value = 20.0 * (6.0 - 5.4); 276 | let hash_index_2 = murmur3::hash32_with_seed( 277 | to_data_2.to_le_bytes(), 278 | to_namespace.namespace_seeds[SeedNumber::Default as usize], 279 | ) & parser::MASK31; 280 | assert_eq!( 281 | to_namespace.tmp_data, 282 | vec![ 283 | (hash_index_1, to_data_1_value), 284 | (hash_index_2, to_data_2_value) 285 | ] 286 | ); 287 | } 288 | } 289 | -------------------------------------------------------------------------------- /examples/ffm/run_fw_with_prediction_tests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | ##################################################################################### 5 | # A script tailored for fast inspection(s) of fw's output space and its properties. # 6 | # Complains if: # 7 | # 1. Predictions with different weight files don't match # 8 | # 2. Predictions look too random on simple data sets # 9 | # 3. Not high-enough balanced accuracy margin is observed w.r.t. random # 10 | ##################################################################################### 11 | 12 | function compute_main_metrics { 13 | # A method that takes err1/2 counts from context and overwrites main statistics 14 | 15 | PRECISION=$(bc <<<"scale=5 ; $TP / ($TP + $FP)") 16 | RECALL=$(bc <<<"scale=5 ; $TP / ($TP + $FN)") 17 | F1=$(bc <<<"scale=5 ; $TP / ($TP + 0.5 * ($FP + $FN))") 18 | SENSITIVITY=$(bc <<<"scale=5 ; $TP / $ALL_INSTANCES_POSITIVE") 19 | SPECIFICITY=$(bc <<<"scale=5 ; $TN / $ALL_INSTANCES_NEGATIVE") 20 | BALANCED_ACCURACY=$(bc <<<"scale=5 ; ($SENSITIVITY + $SPECIFICITY) / 2") 21 | LOGLOSS=$(cat $LOGLOSS_FRAME | awk 'BEGIN { 22 | totalLoss=0; 23 | allInstances=0; 24 | eps=10e-8; 25 | } 26 | { 27 | # Encode negative class appropriately 28 | if($2==-1) 29 | $2=0.0; 30 | 31 | # L_{\log}(y, p) = -(y log (p) + (1 - y) log (1 - p)) 32 | termFirst = $2 * log($1 + eps); 33 | termSecond = (1.0 - $2) * log(1.0 - $1 + eps); 34 | localLoss = -(termFirst + termSecond); 35 | totalLoss += localLoss; 36 | allInstances += 1; 37 | } 38 | END {print totalLoss / allInstances}') 39 | 40 | } 41 | 42 | SCRIPT=$(readlink -f "$0") 43 | DIR=$(dirname "$SCRIPT") 44 | echo "Generating input datasets" 45 | rm -rf datasets 46 | ( 47 | cd $DIR 48 | python3 generate.py --num_animals 300 --num_foods 200 --num_train_examples 30000 49 | ) 50 | 51 | # Probability threshold considered 52 | THRESHOLD=0.5 53 | 54 | # Training performance margin required to pass (Balanced acc.) 55 | MARGIN_OF_PERFORMANCE_BA=0.45 56 | MARGIN_OF_PERFORMANCE_HARD_TEST_BA=0.80 57 | 58 | # Project structure 59 | PROJECT_ROOT=$DIR/../../ 60 | FW=$PROJECT_ROOT/target/release/fw 61 | DATASET_FOLDER=$DIR/datasets 62 | TRAIN_DATA=$DATASET_FOLDER/data.vw 63 | 64 | # Nicer printing 65 | INFO_STRING="==============>" 66 | 67 | # Cleanup 68 | rm -rf models 69 | rm -rf predictions 70 | mkdir -p models 71 | mkdir -p predictions 72 | 73 | echo "Building FW" 74 | ( 75 | cd $PROJECT_ROOT 76 | cargo build --release 77 | ) 78 | 79 | # Change this to your preference if required - this is tailored for the toy example 80 | namespaces="--keep A --keep B --interactions AB --ffm_k 10 --ffm_field A --ffm_field B" 81 | rest="-l 0.1 -b 25 -c --sgd --loss_function logistic --link logistic --power_t 0.0 --l2 0.0 --hash all --noconstant" 82 | 83 | # Train on a given data set 84 | $FW $namespaces $rest --data $DATASET_FOLDER/train.vw -p $DIR/predictions/training.txt -f $DIR/models/full_weights.fw.model --save_resume 85 | 86 | # Create inference weights 87 | $FW $namespaces $rest -i models/full_weights.fw.model --convert_inference_regressor ./models/inference_weights.fw.model 88 | 89 | # Test full weights on a given data set 90 | $FW $namespaces $rest -i models/full_weights.fw.model --data $DATASET_FOLDER/train.vw -p ./predictions/eval_full_weight_space.txt -t 91 | 92 | # Test inference weights on a given data set 93 | $FW $namespaces $rest -i models/inference_weights.fw.model -d $DATASET_FOLDER/train.vw -t -p ./predictions/eval_inference_only.txt 94 | 95 | ########################################### 96 | # Test the predictions and their validity # 97 | ########################################### 98 | 99 | # Create ground truth labels first 100 | cat datasets/train.vw | awk '{print $1}' >predictions/ground_truth.txt 101 | 102 | # get last n predictions of training 103 | cat ./predictions/training.txt | tail -n $(cat predictions/ground_truth.txt | wc -l) >./predictions/training_eval_part_only.txt 104 | 105 | # check line counts first (same amount of eval instances) 106 | if [ $(cat predictions/ground_truth.txt | wc -l) = $(cat predictions/training_eval_part_only.txt | wc -l) ]; then 107 | echo "$INFO_STRING Matching prediction counts! The test can proceed .." 108 | else 109 | echo "$INFO_STRING Ground truth number different to eval number of training predictions, exiting .." 110 | exit 1 111 | fi 112 | 113 | ###################################################################################################################### 114 | # Create a single file for subsequent prediction analysis; columns are: # 115 | # training's predictions -- predictions only inference -- predictions using full weight space -- ground truth labels # 116 | ###################################################################################################################### 117 | paste predictions/training_eval_part_only.txt predictions/eval_inference_only.txt predictions/eval_full_weight_space.txt predictions/ground_truth.txt >./predictions/joint_prediction_space.txt 118 | 119 | # Generate a "dummy" prediction space 120 | yes "0.0" | head -n $(cat predictions/joint_prediction_space.txt | wc -l) >./predictions/all_negative.txt 121 | 122 | # Form the final dataframe 123 | paste predictions/joint_prediction_space.txt predictions/all_negative.txt >./tmp.txt 124 | mv ./tmp.txt ./predictions/joint_prediction_space.txt 125 | 126 | # All instances 127 | ALL_INSTANCES=$(cat predictions/joint_prediction_space.txt | wc -l) 128 | 129 | # Are inference weights' predictions the same? 130 | INFERENCE_SAME_COUNT=$(cat ./predictions/joint_prediction_space.txt | awk '$2==$3' | wc -l) 131 | 132 | if [ $ALL_INSTANCES = $INFERENCE_SAME_COUNT ]; then 133 | echo "$INFO_STRING All inferences' weights' predictions are the same .." 134 | else 135 | echo "$INFO_STRING inference weights produce different predictions to full weights!" 136 | exit 1 137 | fi 138 | 139 | NUM_UNIQUE_INFERENCE_ONLY_EVAL=$(cat predictions/eval_inference_only.txt | sort -u | wc -l) 140 | NUM_UNIQUE_FULL_WEIGHTS_EVAL=$(cat predictions/eval_full_weight_space.txt | sort -u | wc -l) 141 | NUM_UNIQUE_TRAINING_RUN_EVAL=$(cat predictions/training.txt | sort -u | wc -l) 142 | 143 | # Are all predictions for full weights the same? 144 | if [ $NUM_UNIQUE_FULL_WEIGHTS_EVAL = 1 ]; then 145 | echo "$INFO_STRING WARNING: all predictions are the same if using full weights file for inference only." 146 | exit 1 147 | fi 148 | 149 | # Are all inference weights-based predictions fine? 150 | if [ $NUM_UNIQUE_INFERENCE_ONLY_EVAL = 1 ]; then 151 | echo "$INFO_STRING WARNING: all predictions are the same if using inference weights file for inference only." 152 | exit 1 153 | fi 154 | 155 | # Are all training predictions same? 156 | if [ $NUM_UNIQUE_TRAINING_RUN_EVAL = 1 ]; then 157 | echo "$INFO_STRING WARNING: all predictions are the same during training." 158 | exit 1 159 | 160 | fi 161 | 162 | ####################################### 163 | # PART 1 - benchmarks of training set # 164 | ####################################### 165 | 166 | # Create a benchmark against random classifier 167 | echo -e "OUTPUT_TAG\tTHRESHOLD\tPRECISION\tRECALL\tF1\tBALANCED_ACCURACY\tLOGLOSS" 168 | ALL_INSTANCES=$(cat predictions/joint_prediction_space.txt | wc -l) 169 | ALL_INSTANCES_POSITIVE=$(cat predictions/joint_prediction_space.txt | awk '{print $4}' | grep -v '\-1' | wc -l) 170 | ALL_INSTANCES_NEGATIVE=$(cat predictions/joint_prediction_space.txt | awk '{print $4}' | grep '\-1' | wc -l) 171 | 172 | TP=$(cat predictions/joint_prediction_space.txt | awk -v THRESHOLD="$THRESHOLD" '($4=="1") && ($3>=THRESHOLD) {positiveMatch++} END {print positiveMatch}') 173 | 174 | TN=$(cat predictions/joint_prediction_space.txt | awk -v THRESHOLD="$THRESHOLD" '($4=="-1") && ($3=THRESHOLD) {positiveMatch++} END {print positiveMatch}') 177 | 178 | FN=$(cat predictions/joint_prediction_space.txt | awk -v THRESHOLD="$THRESHOLD" '($4=="1") && ($3./predictions/logloss_frame_training.txt 181 | LOGLOSS_FRAME="./predictions/logloss_frame_training.txt" 182 | 183 | # Account for corner cases 184 | if [ "$FP" = "" ]; then 185 | FP=0 186 | fi 187 | 188 | if [ "$FN" = "" ]; then 189 | FN=0 190 | fi 191 | 192 | compute_main_metrics 193 | echo -e "FW\t$THRESHOLD\t$PRECISION\t$RECALL\t$F1\t$BALANCED_ACCURACY\t$LOGLOSS" 194 | 195 | # Random baseline 196 | TP=$(cat predictions/joint_prediction_space.txt | awk -v THRESHOLD="$THRESHOLD" '($4=="1") && (rand()>=THRESHOLD) {positiveMatch++} END {print positiveMatch}') 197 | 198 | TN=$(cat predictions/joint_prediction_space.txt | awk -v THRESHOLD="$THRESHOLD" '($4=="-1") && (rand()=THRESHOLD) {positiveMatch++} END {print positiveMatch}') 201 | 202 | FN=$(cat predictions/joint_prediction_space.txt | awk -v THRESHOLD="$THRESHOLD" '($4=="1") && (rand()./predictions/logloss_frame_random.txt 205 | LOGLOSS_FRAME="./predictions/logloss_frame_random.txt" 206 | 207 | compute_main_metrics 208 | 209 | echo -e "RANDOM\t$THRESHOLD\t$PRECISION\t$RECALL\t$F1\t$BALANCED_ACCURACY\t$LOGLOSS" 210 | 211 | # Is the difference substantial (in BA) 212 | BA_DIFF=$(bc <<<"scale=5 ; $BALANCED_ACCURACY_FW - $BALANCED_ACCURACY") 213 | ZERO_VAR="0.0" 214 | 215 | # BA margin must be beyond specified threshold for this to pass 216 | if [ 1 -eq "$(echo "$BA_DIFF > $ZERO_VAR" | bc)" ]; then 217 | echo "$INFO_STRING FW learned much better than random (on training), exiting gracefully." 218 | fi 219 | 220 | ####################################### 221 | # PART 2 - benchmarks on the test set # 222 | ####################################### 223 | 224 | # Test inference weights on a given data set 225 | $FW $namespaces $rest -i models/inference_weights.fw.model -d $DATASET_FOLDER/test-hard.vw -t -p ./predictions/test_hard_predictions.txt 226 | 227 | cat ./datasets/test-hard.vw | awk '{print $1}' >./predictions/hard_ground_truth.txt 228 | paste predictions/test_hard_predictions.txt predictions/hard_ground_truth.txt >./predictions/joint_hard_predictions_and_ground.txt 229 | 230 | ALL_INSTANCES_POSITIVE=$(cat predictions/joint_hard_predictions_and_ground.txt | awk '{print $2}' | grep -v '\-1' | wc -l) 231 | ALL_INSTANCES_NEGATIVE=$(cat predictions/joint_hard_predictions_and_ground.txt | awk '{print $2}' | grep '\-1' | wc -l) 232 | 233 | # Random baseline 234 | TP=$(cat predictions/joint_hard_predictions_and_ground.txt | awk -v THRESHOLD="$THRESHOLD" '($2=="1") && ($1>=THRESHOLD) {positiveMatch++} END {print positiveMatch}') 235 | 236 | TN=$(cat predictions/joint_hard_predictions_and_ground.txt | awk -v THRESHOLD="$THRESHOLD" '($2=="-1") && ($1=THRESHOLD) {positiveMatch++} END {print positiveMatch}') 239 | 240 | FN=$(cat predictions/joint_hard_predictions_and_ground.txt | awk -v THRESHOLD="$THRESHOLD" '($2=="1") && ($1 0" | bc)" ]; then 248 | echo "$INFO_STRING FW learned much better than random (on hard test), exiting gracefully." 249 | else 250 | echo "$INFO_STRING FW did not learn to classify the hard problem well enough!" 251 | exit 1 252 | fi 253 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused_variables)] 2 | #![allow(unused_mut)] 3 | #![allow(non_snake_case)] 4 | #![allow(redundant_semicolons)] 5 | #![allow(dead_code, unused_imports)] 6 | 7 | use std::collections::VecDeque; 8 | use std::error::Error; 9 | use std::f32; 10 | use std::fs::File; 11 | use std::io; 12 | use std::io::BufRead; 13 | use std::io::BufWriter; 14 | use std::io::Write; 15 | use std::path::Path; 16 | use std::time::Instant; 17 | 18 | extern crate blas; 19 | extern crate half; 20 | extern crate intel_mkl_src; 21 | 22 | #[macro_use] 23 | extern crate nom; 24 | extern crate core; 25 | 26 | use fw::cache::RecordCache; 27 | use fw::feature_buffer::FeatureBufferTranslator; 28 | use fw::hogwild::HogwildTrainer; 29 | use fw::model_instance::{ModelInstance, Optimizer}; 30 | use fw::multithread_helpers::BoxedRegressorTrait; 31 | use fw::parser::VowpalParser; 32 | use fw::buffer_handler::create_buffered_input; 33 | use fw::persistence::{ 34 | new_regressor_from_filename, save_regressor_to_filename, save_sharable_regressor_to_filename, 35 | }; 36 | use fw::regressor::{get_regressor_with_weights, Regressor}; 37 | use fw::serving::Serving; 38 | use fw::vwmap::VwNamespaceMap; 39 | use fw::{cmdline, feature_buffer, logging_layer, regressor}; 40 | 41 | fn main() { 42 | logging_layer::initialize_logging_layer(); 43 | 44 | if let Err(e) = main_fw_loop() { 45 | log::error!("Global error: {:?}", e); 46 | std::process::exit(1) 47 | } 48 | } 49 | 50 | fn build_cache_without_training(cl: clap::ArgMatches) -> Result<(), Box> { 51 | /*! A method that enables creating the cache file without training the first model instance. 52 | This is done in order to reduce building time of the cache and running the first model instance multi threaded. */ 53 | // We'll parse once the command line into cl and then different objects will examine it 54 | let input_filename = cl.value_of("data").expect("--data expected"); 55 | let vw_namespace_map_filepath = Path::new(input_filename) 56 | .parent() 57 | .expect("Couldn't access path given by --data") 58 | .join("vw_namespace_map.csv"); 59 | 60 | let vw: VwNamespaceMap = VwNamespaceMap::new_from_csv_filepath(vw_namespace_map_filepath)?; 61 | let mut cache = RecordCache::new(input_filename, true, &vw); 62 | let input = File::open(input_filename)?; 63 | 64 | let mut bufferred_input = create_buffered_input(input_filename); 65 | let mut pa = VowpalParser::new(&vw); 66 | let mut example_num = 0; 67 | loop { 68 | let reading_result; 69 | let buffer: &[u32]; 70 | if !cache.reading { 71 | reading_result = pa.next_vowpal(&mut bufferred_input); 72 | buffer = match reading_result { 73 | Ok([]) => break, // EOF 74 | Ok(buffer2) => buffer2, 75 | Err(_e) => return Err(_e), 76 | }; 77 | if cache.writing { 78 | cache.push_record(buffer)?; 79 | } 80 | } else { 81 | reading_result = cache.get_next_record(); 82 | match reading_result { 83 | Ok([]) => break, // EOF 84 | Ok(buffer) => buffer, 85 | Err(_e) => return Err(_e), 86 | }; 87 | } 88 | example_num += 1; 89 | } 90 | 91 | log::info!("Built cache only, exiting."); 92 | cache.write_finish()?; 93 | Ok(()) 94 | } 95 | 96 | fn main_fw_loop() -> Result<(), Box> { 97 | // We'll parse once the command line into cl and then different objects will examine it 98 | let cl = cmdline::parse(); 99 | if cl.is_present("build_cache_without_training") { 100 | return build_cache_without_training(cl); 101 | } 102 | // Where will we be putting perdictions (if at all) 103 | let mut predictions_file = match cl.value_of("predictions") { 104 | Some(filename) => Some(BufWriter::new(File::create(filename)?)), 105 | None => None, 106 | }; 107 | 108 | let testonly = cl.is_present("testonly"); 109 | let quantize_weights = cl.is_present("weight_quantization"); 110 | let final_regressor_filename = cl.value_of("final_regressor"); 111 | let output_pred_sto: bool = cl.is_present("predictions_stdout"); 112 | if let Some(filename) = final_regressor_filename { 113 | if !cl.is_present("save_resume") { 114 | return Err("You need to use --save_resume with --final_regressor, for vowpal wabbit compatibility")?; 115 | } 116 | log::info!("final_regressor = {}", filename); 117 | }; 118 | 119 | let inference_regressor_filename = cl.value_of("convert_inference_regressor"); 120 | if let Some(filename) = inference_regressor_filename { 121 | log::info!("inference_regressor = {}", filename); 122 | }; 123 | 124 | /* setting up the pipeline, either from command line or from existing regressor */ 125 | // we want heal-allocated objects here 126 | 127 | if cl.is_present("daemon") { 128 | let filename = cl 129 | .value_of("initial_regressor") 130 | .expect("Daemon mode only supports serving from --initial regressor"); 131 | log::info!("initial_regressor = {}", filename); 132 | let (mi2, vw2, re_fixed) = new_regressor_from_filename(filename, true, Option::Some(&cl))?; 133 | 134 | let mut se = Serving::new(&cl, &vw2, Box::new(re_fixed), &mi2)?; 135 | se.serve()?; 136 | } else if cl.is_present("convert_inference_regressor") { 137 | let filename = cl 138 | .value_of("initial_regressor") 139 | .expect("Convert mode requires --initial regressor"); 140 | let (mut mi2, vw2, re_fixed) = 141 | new_regressor_from_filename(filename, true, Option::Some(&cl))?; 142 | mi2.optimizer = Optimizer::SGD; 143 | if cl.is_present("weight_quantization") { 144 | mi2.dequantize_weights = Some(true); 145 | } 146 | if let Some(filename1) = inference_regressor_filename { 147 | save_regressor_to_filename(filename1, &mi2, &vw2, re_fixed, quantize_weights).unwrap() 148 | } 149 | } else { 150 | let vw: VwNamespaceMap; 151 | let mut re: Regressor; 152 | let mut sharable_regressor: BoxedRegressorTrait; 153 | let mi: ModelInstance; 154 | 155 | if let Some(filename) = cl.value_of("initial_regressor") { 156 | log::info!("initial_regressor = {}", filename); 157 | (mi, vw, re) = new_regressor_from_filename(filename, testonly, Option::Some(&cl))?; 158 | sharable_regressor = BoxedRegressorTrait::new(Box::new(re)); 159 | } else { 160 | // We load vw_namespace_map.csv just so we know all the namespaces ahead of time 161 | // This is one of the major differences from vowpal 162 | 163 | let input_filename = cl.value_of("data").expect("--data expected"); 164 | let vw_namespace_map_filepath = Path::new(input_filename) 165 | .parent() 166 | .expect("Couldn't access path given by --data") 167 | .join("vw_namespace_map.csv"); 168 | vw = VwNamespaceMap::new_from_csv_filepath(vw_namespace_map_filepath)?; 169 | mi = ModelInstance::new_from_cmdline(&cl, &vw)?; 170 | re = get_regressor_with_weights(&mi); 171 | sharable_regressor = BoxedRegressorTrait::new(Box::new(re)); 172 | }; 173 | 174 | let input_filename = cl.value_of("data").expect("--data expected"); 175 | let mut cache = RecordCache::new(input_filename, cl.is_present("cache"), &vw); 176 | let mut fbt = FeatureBufferTranslator::new(&mi); 177 | let mut pb = sharable_regressor.new_portbuffer(); 178 | 179 | let predictions_after: u64 = match cl.value_of("predictions_after") { 180 | Some(examples) => examples.parse()?, 181 | None => 0, 182 | }; 183 | 184 | let holdout_after_option: Option = 185 | cl.value_of("holdout_after").map(|s| s.parse().unwrap()); 186 | 187 | let hogwild_training = cl.is_present("hogwild_training"); 188 | let mut hogwild_trainer = if hogwild_training { 189 | let hogwild_threads = match cl.value_of("hogwild_threads") { 190 | Some(hogwild_threads) => hogwild_threads 191 | .parse() 192 | .expect("hogwild_threads should be integer"), 193 | None => 16, 194 | }; 195 | HogwildTrainer::new(sharable_regressor.clone(), &mi, hogwild_threads) 196 | } else { 197 | HogwildTrainer::default() 198 | }; 199 | 200 | let prediction_model_delay: u64 = match cl.value_of("prediction_model_delay") { 201 | Some(delay) => delay.parse()?, 202 | None => 0, 203 | }; 204 | 205 | let mut delayed_learning_fbs: VecDeque = 206 | VecDeque::with_capacity(prediction_model_delay as usize); 207 | 208 | let mut bufferred_input = create_buffered_input(input_filename); 209 | let mut pa = VowpalParser::new(&vw); 210 | 211 | let now = Instant::now(); 212 | let mut example_num = 0; 213 | loop { 214 | let reading_result; 215 | let buffer: &[u32]; 216 | if !cache.reading { 217 | reading_result = pa.next_vowpal(&mut bufferred_input); 218 | buffer = match reading_result { 219 | Ok([]) => break, // EOF 220 | Ok(buffer2) => buffer2, 221 | Err(_e) => return Err(_e), 222 | }; 223 | if cache.writing { 224 | cache.push_record(buffer)?; 225 | } 226 | } else { 227 | reading_result = cache.get_next_record(); 228 | buffer = match reading_result { 229 | Ok([]) => break, // EOF 230 | Ok(buffer) => buffer, 231 | Err(_e) => return Err(_e), 232 | }; 233 | } 234 | example_num += 1; 235 | let mut prediction: f32 = 0.0; 236 | 237 | if prediction_model_delay == 0 { 238 | let update = match holdout_after_option { 239 | Some(holdout_after) => !testonly && example_num < holdout_after, 240 | None => !testonly, 241 | }; 242 | if hogwild_training && update { 243 | hogwild_trainer.digest_example(Vec::from(buffer)); 244 | } else { 245 | fbt.translate(buffer, example_num); 246 | prediction = sharable_regressor.learn(&fbt.feature_buffer, &mut pb, update); 247 | } 248 | } else { 249 | fbt.translate(buffer, example_num); 250 | if example_num > predictions_after { 251 | prediction = sharable_regressor.learn(&fbt.feature_buffer, &mut pb, false); 252 | } 253 | delayed_learning_fbs.push_back(fbt.feature_buffer.clone()); 254 | if (prediction_model_delay as usize) < delayed_learning_fbs.len() { 255 | let delayed_buffer = delayed_learning_fbs.pop_front().unwrap(); 256 | sharable_regressor.learn(&delayed_buffer, &mut pb, !testonly); 257 | } 258 | } 259 | 260 | if example_num > predictions_after { 261 | if output_pred_sto { 262 | println!("{:.6}", prediction); 263 | } 264 | 265 | match predictions_file.as_mut() { 266 | Some(file) => writeln!(file, "{:.6}", prediction)?, 267 | None => {} 268 | } 269 | } 270 | } 271 | cache.write_finish()?; 272 | 273 | if hogwild_training { 274 | hogwild_trainer.block_until_workers_finished(); 275 | } 276 | let elapsed = now.elapsed(); 277 | log::info!("Elapsed: {:.2?} rows: {}", elapsed, example_num); 278 | 279 | if let Some(filename) = final_regressor_filename { 280 | save_sharable_regressor_to_filename( 281 | filename, 282 | &mi, 283 | &vw, 284 | sharable_regressor, 285 | quantize_weights, 286 | ) 287 | .unwrap() 288 | } 289 | } 290 | 291 | Ok(()) 292 | } 293 | -------------------------------------------------------------------------------- /weight_patcher/src/main.rs: -------------------------------------------------------------------------------- 1 | use gzp::par::compress::{ParCompress, ParCompressBuilder}; 2 | use gzp::par::decompress::{ParDecompress, ParDecompressBuilder}; 3 | use gzp::{deflate, Compression, ZWriter}; 4 | use std::env::args; 5 | use std::fs::File; 6 | use std::io::{self, BufReader, BufWriter, Read, Seek, Write}; 7 | 8 | #[derive(Debug, PartialEq)] 9 | struct DiffEntry { 10 | // we use a relative index rather than absolute, so we can represent the values of this field with fewer bits than 64 11 | relative_index: u64, 12 | to: u8, 13 | } 14 | 15 | const CHUNK_SIZE: usize = 1024 * 64; 16 | 17 | fn main() -> io::Result<()> { 18 | env_logger::init(); 19 | let (action, file_a_path, file_b_path, output_path) = parse_args()?; 20 | 21 | match action.as_str() { 22 | "create_diff" => create_diff_file(&file_a_path, &file_b_path, &output_path), 23 | "recreate" => recreate_file(&file_a_path, &file_b_path, &output_path), 24 | _ => { 25 | log::error!("Invalid action: {}", action); 26 | std::process::exit(1); 27 | } 28 | } 29 | } 30 | 31 | fn parse_args() -> io::Result<(String, String, String, String)> { 32 | let args: Vec = args().collect(); 33 | if args.len() < 5 { 34 | log::error!( 35 | "Usage: {} ", 36 | args[0] 37 | ); 38 | log::error!(" action: 'create_diff' or 'recreate'"); 39 | std::process::exit(1); 40 | } 41 | 42 | Ok(( 43 | args[1].clone(), 44 | args[2].clone(), 45 | args[3].clone(), 46 | args[4].clone(), 47 | )) 48 | } 49 | 50 | /// Create a diff file between file_a & file_b, so file_b can be restored. Restoration is supported only for the second param file_b. Diff file 51 | /// is compressed using zlib. 52 | fn create_diff_file(file_a_path: &str, file_b_path: &str, diff_file_path: &str) -> io::Result<()> { 53 | let (file_a, file_b) = open_input_files(file_a_path, file_b_path)?; 54 | let diff_file = File::create(diff_file_path)?; 55 | let mut zlib_writer: ParCompress = ParCompressBuilder::new() 56 | .compression_level(Compression::fast()) 57 | .from_writer(diff_file); 58 | 59 | compare_files_and_write_diff(file_a, file_b, &mut zlib_writer)?; 60 | 61 | zlib_writer.finish().unwrap(); 62 | 63 | log::info!("Compressed diff file created: {}", diff_file_path); 64 | Ok(()) 65 | } 66 | 67 | /// This will create a diff file so that file_b can be recreated. file_a recreation is not supported. 68 | fn compare_files_and_write_diff( 69 | file_a: R, 70 | file_b: R, 71 | diff_writer: &mut W, 72 | ) -> io::Result<()> { 73 | let mut reader_a = BufReader::new(file_a); 74 | let mut reader_b = BufReader::new(file_b); 75 | let mut buf_a = [0u8; CHUNK_SIZE]; 76 | let mut buf_b = [0u8; CHUNK_SIZE]; 77 | let mut position: u64 = 0; 78 | let mut prev_index: u64 = 0; 79 | 80 | let mut diff_entries = Vec::with_capacity(CHUNK_SIZE); 81 | 82 | loop { 83 | // Read from fila_a and file_b into buffers buf_a and buf_b 84 | let (bytes_a, _) = ( 85 | reader_a.read(&mut buf_a).unwrap_or(0), 86 | reader_b.read(&mut buf_b).unwrap_or(0), 87 | ); 88 | 89 | // We're done with reading both files, so break loop. file_a and file_b are always the same size so we need to check only one of them 90 | if bytes_a == 0 { 91 | break; 92 | } 93 | 94 | for i in 0..bytes_a { 95 | let a_val = buf_a.get(i); 96 | let b_val = buf_b.get(i); 97 | 98 | // mismatch byte between file_a and file_b 99 | if a_val != b_val { 100 | let current_index = position + i as u64; 101 | let delta = current_index - prev_index; 102 | let diff_entry = DiffEntry { 103 | relative_index: delta, 104 | to: b_val.map(|v| *v).unwrap_or(0), 105 | }; 106 | prev_index = current_index; 107 | diff_entries.push(diff_entry); 108 | 109 | if diff_entries.len() == CHUNK_SIZE { 110 | // Write all accumulated diff entries and clear the buffer 111 | for diff_entry in &diff_entries { 112 | write_diff_entry(diff_writer, diff_entry)?; 113 | } 114 | diff_entries.clear(); 115 | } 116 | } 117 | } 118 | position += bytes_a as u64; 119 | } 120 | 121 | // Write any remaining diff entries and clear the buffer 122 | for diff_entry in &diff_entries { 123 | write_diff_entry(diff_writer, diff_entry)?; 124 | } 125 | 126 | // Flush buffered writer before returning 127 | diff_writer.flush()?; 128 | 129 | Ok(()) 130 | } 131 | 132 | fn recreate_file(file_a_path: &str, diff_file_path: &str, output_path: &str) -> io::Result<()> { 133 | let (mut file_a, diff_file) = open_input_files(file_a_path, diff_file_path)?; 134 | let diff_file = BufReader::new(diff_file); // Wrap the File in a BufReader 135 | let diff_file: ParDecompress = 136 | ParDecompressBuilder::new().from_reader(diff_file); 137 | 138 | let output_file = File::create(output_path)?; 139 | recreate_file_inner(&mut file_a, diff_file, output_file)?; 140 | log::info!( 141 | "Output file recreated from compressed diff file: {}", 142 | output_path 143 | ); 144 | Ok(()) 145 | } 146 | 147 | // Recreate file_b from file_a + diff_fill 148 | fn recreate_file_inner( 149 | file_a: &mut R, 150 | diff_file: G, 151 | mut output_file: W, 152 | ) -> io::Result<()> { 153 | let mut reader_a = BufReader::new(file_a); 154 | let mut diff_reader = BufReader::new(diff_file); 155 | let mut writer = BufWriter::new(&mut output_file); 156 | 157 | let mut buf_a = [0u8; CHUNK_SIZE]; 158 | let mut current_position: u64 = 0; 159 | let mut diff_entry = read_diff_entry(&mut diff_reader); 160 | 161 | let mut output_buffer: Vec = Vec::with_capacity(CHUNK_SIZE); 162 | loop { 163 | let bytes_a = reader_a.read(&mut buf_a).unwrap_or(0); 164 | 165 | // File_a content exhausted 166 | if bytes_a == 0 { 167 | break; 168 | } 169 | 170 | output_buffer.clear(); 171 | 172 | for i in 0..bytes_a { 173 | let mut next_entry = None; 174 | if let Some(ref mut entry) = diff_entry { 175 | if current_position as u64 == entry.relative_index { 176 | // Apply the diff entry 177 | output_buffer.push(entry.to); 178 | next_entry = read_diff_entry(&mut diff_reader); 179 | if let Some(ref mut next_e) = next_entry { 180 | next_e.relative_index += entry.relative_index; 181 | } 182 | } else { 183 | // Write the byte from file_a 184 | output_buffer.push(buf_a[i]); 185 | } 186 | } else { 187 | // Write the byte from file_a 188 | output_buffer.push(buf_a[i]); 189 | } 190 | 191 | current_position += 1; 192 | 193 | if let Some(_) = next_entry { 194 | diff_entry = next_entry; 195 | } 196 | } 197 | 198 | // Write the buffer to the output file 199 | writer.write_all(&output_buffer)?; 200 | } 201 | 202 | // Flush the buffered writer before returning 203 | writer.flush()?; 204 | 205 | Ok(()) 206 | } 207 | 208 | fn read_diff_entry(diff_reader: &mut R) -> Option { 209 | let index = read_varint(diff_reader).ok()?; 210 | let mut buf = [0u8; 1]; 211 | if diff_reader.read_exact(&mut buf).is_ok() { 212 | let to = buf[0]; 213 | Some(DiffEntry { 214 | relative_index: index, 215 | to, 216 | }) 217 | } else { 218 | None 219 | } 220 | } 221 | 222 | /// Reads a variable-length integer (varint) from the given reader and returns 223 | /// it as a u64 value. 224 | /// 225 | /// The varint encoding uses the least significant 7 bits of each byte to store 226 | /// the integer value, with the most significant bit used as a continuation flag. 227 | /// The continuation flag is set to 1 for all bytes except the last one, which 228 | /// signals the end of the varint. This encoding is efficient for small integer 229 | /// values, as it uses fewer bytes compared to a fixed-size integer. 230 | fn read_varint(reader: &mut R) -> io::Result { 231 | let mut value: u64 = 0; 232 | let mut shift: u64 = 0; 233 | let mut buf = [0u8; 1]; 234 | 235 | loop { 236 | reader.read_exact(&mut buf)?; 237 | let byte = buf[0]; 238 | 239 | value |= ((byte & 0x7F) as u64) << shift; 240 | if byte & 0x80 == 0 { 241 | break; 242 | } 243 | shift += 7; 244 | } 245 | Ok(value) 246 | } 247 | 248 | fn write_diff_entry(diff_file: &mut W, diff_entry: &DiffEntry) -> io::Result<()> { 249 | write_varint(diff_entry.relative_index, diff_file)?; 250 | diff_file.write_all(&[diff_entry.to]) 251 | } 252 | 253 | /// Writes a u64 value as a variable-length integer to the given writer. 254 | /// 255 | /// The varint encoding uses the least significant 7 bits of each byte to store 256 | /// the integer value, with the most significant bit used as a continuation flag. 257 | /// The continuation flag is set to 1 for all bytes except the last one, which 258 | /// signals the end of the varint. This encoding is efficient for small integer 259 | /// values, as it uses fewer bytes compared to a fixed-size integer. 260 | fn write_varint(mut value: u64, writer: &mut W) -> io::Result<()> { 261 | while value >= 0x80 { 262 | writer.write_all(&[(value & 0x7F) as u8 | 0x80])?; 263 | value >>= 7; 264 | } 265 | writer.write_all(&[value as u8]) 266 | } 267 | 268 | fn open_input_files(file_a_path: &str, file_b_path: &str) -> io::Result<(File, File)> { 269 | let file_a = File::open(file_a_path)?; 270 | let file_b = File::open(file_b_path)?; 271 | Ok((file_a, file_b)) 272 | } 273 | 274 | #[cfg(test)] 275 | mod tests { 276 | 277 | use super::*; 278 | use std::io::Cursor; 279 | 280 | fn create_diff(file_a_content: &[u8], file_b_content: &[u8]) -> io::Result> { 281 | let file_a = Cursor::new(file_a_content); 282 | let file_b = Cursor::new(file_b_content); 283 | let mut diff_file = Cursor::new(Vec::new()); 284 | 285 | compare_files_and_write_diff(file_a, file_b, &mut diff_file)?; 286 | Ok(diff_file.into_inner()) 287 | } 288 | 289 | fn test_recreation( 290 | file_a_content: &[u8], 291 | file_b_content: &[u8], 292 | diff_file_content: &[u8], 293 | ) -> io::Result<()> { 294 | let mut file_a = Cursor::new(file_a_content); 295 | let diff_file = Cursor::new(diff_file_content); 296 | let mut recreated_file_b = Cursor::new(Vec::new()); 297 | 298 | recreate_file_inner(&mut file_a, diff_file, &mut recreated_file_b)?; 299 | 300 | assert_eq!(recreated_file_b.into_inner(), file_b_content); 301 | Ok(()) 302 | } 303 | 304 | #[test] 305 | fn file_a_and_file_b_are_the_same() { 306 | let file_a_content = b"hello world"; 307 | let file_b_content = b"hello world"; 308 | let diff_file_content = create_diff(file_a_content, file_b_content).unwrap(); 309 | test_recreation(file_a_content, file_b_content, &diff_file_content).unwrap(); 310 | } 311 | 312 | #[test] 313 | fn file_a_and_file_b_are_different() { 314 | let file_a_content = b"hello"; 315 | let file_b_content = b"world"; 316 | let diff_file_content = create_diff(file_a_content, file_b_content).unwrap(); 317 | test_recreation(file_a_content, file_b_content, &diff_file_content).unwrap(); 318 | } 319 | 320 | #[test] 321 | fn test_write_varint() { 322 | let mut buffer = Vec::new(); 323 | let value: u64 = 12345; 324 | write_varint(value, &mut buffer).unwrap(); 325 | 326 | assert_eq!(buffer, vec![0xB9, 0x60]); 327 | } 328 | 329 | #[test] 330 | fn test_read_varint() { 331 | let mut buffer = Cursor::new(vec![0xB9, 0x60]); 332 | let value = read_varint(&mut buffer).unwrap(); 333 | 334 | assert_eq!(value, 12345); 335 | } 336 | 337 | #[test] 338 | fn test_read_write_varint() { 339 | let test_values = vec![0, 1, 127, 128, 16383, 16384, 2097151, 2097152, u64::MAX]; 340 | 341 | for value in test_values { 342 | let mut buffer = Vec::new(); 343 | write_varint(value, &mut buffer).unwrap(); 344 | 345 | let mut buffer = Cursor::new(buffer); 346 | let read_value = read_varint(&mut buffer).unwrap(); 347 | 348 | assert_eq!(value, read_value); 349 | } 350 | } 351 | } 352 | -------------------------------------------------------------------------------- /src/cmdline.rs: -------------------------------------------------------------------------------- 1 | use crate::version; 2 | use clap::{App, AppSettings, Arg}; 3 | 4 | pub fn parse<'a>() -> clap::ArgMatches<'a> { 5 | let matches = create_expected_args().get_matches(); 6 | matches 7 | } 8 | 9 | pub fn create_expected_args<'a>() -> App<'a, 'a> { 10 | App::new("fwumious wabbit") 11 | .version(version::LATEST) 12 | .author("Andraz Tori ") 13 | .about("Superfast Logistic Regression & Field Aware Factorization Machines") 14 | .setting(AppSettings::DeriveDisplayOrder) 15 | .arg(Arg::with_name("data") 16 | .long("data") 17 | .short("d") 18 | .value_name("filename") 19 | .help("File with input examples") 20 | .takes_value(true)) 21 | .arg(Arg::with_name("quiet") 22 | .long("quiet") 23 | .help("Quiet mode, does nothing currently (as we don't output diagnostic data anyway)") 24 | .takes_value(false)) 25 | .arg(Arg::with_name("predictions") 26 | .short("p") 27 | .value_name("output predictions file") 28 | .help("Output predictions file") 29 | .takes_value(true)) 30 | .arg(Arg::with_name("cache") 31 | .short("c") 32 | .long("cache") 33 | .help("Use cache file") 34 | .takes_value(false)) 35 | .arg(Arg::with_name("save_resume") 36 | .long("save_resume") 37 | .help("save extra state so learning can be resumed later with new data") 38 | .takes_value(false)) 39 | .arg(Arg::with_name("interactions") 40 | .long("interactions") 41 | .value_name("namespace_char,namespace_char[:value]") 42 | .help("Adds interactions") 43 | .multiple(true) 44 | .takes_value(true)) 45 | .arg(Arg::with_name("linear") 46 | .long("linear") 47 | .value_name("verbose_namespace,verbose_namespace[:value]") 48 | .help("Adds linear feature term with optional value") 49 | .multiple(true) 50 | .takes_value(true)) 51 | .arg(Arg::with_name("keep") 52 | .long("keep") 53 | .value_name("namespace") 54 | .help("Adds single features") 55 | .multiple(true) 56 | .takes_value(true)) 57 | .arg(Arg::with_name("build_cache_without_training") 58 | .long("build_cache_without_training") 59 | .value_name("arg") 60 | .help("Build cache file without training the first model instance") 61 | .takes_value(false)) 62 | 63 | .arg(Arg::with_name("learning_rate") 64 | .short("l") 65 | .long("learning_rate") 66 | .value_name("0.5") 67 | .help("Learning rate") 68 | .takes_value(true)) 69 | .arg(Arg::with_name("ffm_learning_rate") 70 | .long("ffm_learning_rate") 71 | .value_name("0.5") 72 | .help("Learning rate") 73 | .takes_value(true)) 74 | .arg(Arg::with_name("nn_learning_rate") 75 | .long("nn_learning_rate") 76 | .value_name("0.5") 77 | .help("Learning rate") 78 | .takes_value(true)) 79 | 80 | .arg(Arg::with_name("minimum_learning_rate") 81 | .long("minimum_learning_rate") 82 | .value_name("0.0") 83 | .help("Minimum learning rate (in adaptive algos)") 84 | .takes_value(true)) 85 | .arg(Arg::with_name("power_t") 86 | .long("power_t") 87 | .value_name("0.5") 88 | .help("How to apply Adagrad (0.5 = sqrt)") 89 | .takes_value(true)) 90 | .arg(Arg::with_name("ffm_power_t") 91 | .long("ffm_power_t") 92 | .value_name("0.5") 93 | .help("How to apply Adagrad (0.5 = sqrt)") 94 | .takes_value(true)) 95 | .arg(Arg::with_name("nn_power_t") 96 | .long("nn_power_t") 97 | .value_name("0.5") 98 | .help("How to apply Adagrad (0.5 = sqrt)") 99 | .takes_value(true)) 100 | .arg(Arg::with_name("l2") 101 | .long("l2") 102 | .value_name("0.0") 103 | .help("Regularization is not supported (only 0.0 will work)") 104 | .takes_value(true)) 105 | 106 | .arg(Arg::with_name("sgd") 107 | .long("sgd") 108 | .value_name("") 109 | .help("Disable the Adagrad, normalization and invariant updates") 110 | .takes_value(false)) 111 | .arg(Arg::with_name("adaptive") 112 | .long("adaptive") 113 | .value_name("") 114 | .help("Use Adagrad") 115 | .takes_value(false)) 116 | .arg(Arg::with_name("noconstant") 117 | .long("noconstant") 118 | .value_name("") 119 | .help("No intercept") 120 | .takes_value(false)) 121 | .arg(Arg::with_name("link") 122 | .long("link") 123 | .value_name("logistic") 124 | .help("What link function to use") 125 | .takes_value(true)) 126 | .arg(Arg::with_name("loss_function") 127 | .long("loss_function") 128 | .value_name("logistic") 129 | .help("What loss function to use") 130 | .takes_value(true)) 131 | .arg(Arg::with_name("bit_precision") 132 | .short("b") 133 | .long("bit_precision") 134 | .value_name("18") 135 | .help("Size of the hash space for feature weights") 136 | .takes_value(true)) 137 | .arg(Arg::with_name("hash") 138 | .long("hash") 139 | .value_name("all") 140 | .help("We do not support treating strings as already hashed numbers, so you have to use --hash all") 141 | .takes_value(true)) 142 | 143 | // Regressor 144 | .arg(Arg::with_name("final_regressor") 145 | .short("f") 146 | .long("final_regressor") 147 | .value_name("arg") 148 | .help("Final regressor to save (arg is filename)") 149 | .takes_value(true)) 150 | .arg(Arg::with_name("initial_regressor") 151 | .short("i") 152 | .long("initial_regressor") 153 | .value_name("arg") 154 | .help("Initial regressor(s) to load into memory (arg is filename)") 155 | .takes_value(true)) 156 | .arg(Arg::with_name("testonly") 157 | .short("t") 158 | .long("testonly") 159 | .help("Ignore label information and just test") 160 | .takes_value(false)) 161 | .arg(Arg::with_name("vwcompat") 162 | .long("vwcompat") 163 | .help("vowpal compatibility mode. Uses slow adagrad, emits warnings for non-compatible features") 164 | .multiple(false) 165 | .takes_value(false)) 166 | .arg(Arg::with_name("convert_inference_regressor") 167 | .long("convert_inference_regressor") 168 | .value_name("arg") 169 | .conflicts_with("adaptive") 170 | .help("Inference regressor to save (arg is filename)") 171 | .takes_value(true)) 172 | 173 | .arg(Arg::with_name("transform") 174 | .long("transform") 175 | .value_name("target_namespace=func(source_namespaces)(parameters)") 176 | .help("Create new namespace by transforming one or more other namespaces") 177 | .multiple(true) 178 | .takes_value(true)) 179 | 180 | .arg(Arg::with_name("ffm_field") 181 | .long("ffm_field") 182 | .value_name("namespace,namespace,...") 183 | .help("Define a FFM field by listing namespace letters") 184 | .multiple(true) 185 | .takes_value(true)) 186 | .arg(Arg::with_name("ffm_field_verbose") 187 | .long("ffm_field_verbose") 188 | .value_name("namespace_verbose,namespace_verbose,...") 189 | .help("Define a FFM field by listing verbose namespace names") 190 | .multiple(true) 191 | .takes_value(true)) 192 | .arg(Arg::with_name("ffm_k") 193 | .long("ffm_k") 194 | .value_name("k") 195 | .help("Lenght of a vector to use for FFM") 196 | .takes_value(true)) 197 | .arg(Arg::with_name("ffm_bit_precision") 198 | .long("ffm_bit_precision") 199 | .value_name("N") 200 | .help("Bits to use for ffm hash space") 201 | .takes_value(true)) 202 | .arg(Arg::with_name("ffm_k_threshold") 203 | .long("ffm_k_threshold") 204 | .help("A minum gradient on left and right side to increase k") 205 | .multiple(false) 206 | .takes_value(true)) 207 | .arg(Arg::with_name("ffm_init_center") 208 | .long("ffm_init_center") 209 | .help("Center of the initial weights distribution") 210 | .multiple(false) 211 | .takes_value(true)) 212 | .arg(Arg::with_name("ffm_init_width") 213 | .long("ffm_init_width") 214 | .help("Total width of the initial weights distribution") 215 | .multiple(false) 216 | .takes_value(true)) 217 | .arg(Arg::with_name("ffm_init_zero_band") 218 | .long("ffm_init_zero_band") 219 | .help("Percentage of ffm_init_width where init is zero") 220 | .multiple(false) 221 | .takes_value(true)) 222 | 223 | .arg(Arg::with_name("nn_init_acc_gradient") 224 | .long("nn_init_acc_gradient") 225 | .help("Adagrad initial accumulated gradient for nn") 226 | .multiple(false) 227 | .takes_value(true)) 228 | .arg(Arg::with_name("ffm_init_acc_gradient") 229 | .long("ffm_init_acc_gradient") 230 | .help("Adagrad initial accumulated gradient for ffm") 231 | .multiple(false) 232 | .takes_value(true)) 233 | .arg(Arg::with_name("init_acc_gradient") 234 | .long("init_acc_gradient") 235 | .help("Adagrad initial accumulated gradient for ") 236 | .multiple(false) 237 | .takes_value(true)) 238 | 239 | 240 | .arg(Arg::with_name("nn_layers") 241 | .long("nn_layers") 242 | .help("Enable deep neural network on top of LR+FFM") 243 | .multiple(false) 244 | .takes_value(true)) 245 | 246 | 247 | .arg(Arg::with_name("nn") 248 | .long("nn") 249 | .help("Parameters of layers, for example 1:activation:relu or 2:width:20") 250 | .multiple(true) 251 | .takes_value(true)) 252 | 253 | .arg(Arg::with_name("nn_topology") 254 | .long("nn_topology") 255 | .help("How should connections be organized - possiblities 'one' and 'two'") 256 | .multiple(false) 257 | .takes_value(true)) 258 | 259 | 260 | // Daemon parameterts 261 | .arg(Arg::with_name("daemon") 262 | .long("daemon") 263 | .help("read data from port 26542") 264 | .takes_value(false)) 265 | .arg(Arg::with_name("ffm_initialization_type") 266 | .long("ffm_initialization_type") 267 | .help("Which weight initialization to consider") 268 | .multiple(false) 269 | .takes_value(true)) 270 | .arg(Arg::with_name("port") 271 | .long("port") 272 | .value_name("arg") 273 | .help("port to listen on") 274 | .takes_value(true)) 275 | .arg(Arg::with_name("num_children") 276 | .long("num_children") 277 | .value_name("arg (=10") 278 | .help("number of children for persistent daemon mode") 279 | .takes_value(true)) 280 | .arg(Arg::with_name("foreground") 281 | .long("foreground") 282 | .help("in daemon mode, do not fork and run and run fw process in the foreground") 283 | .takes_value(false)) 284 | .arg(Arg::with_name("prediction_model_delay") 285 | .conflicts_with("test_only") 286 | .long("prediction_model_delay") 287 | .value_name("examples (0)") 288 | .help("Output predictions with a model that is delayed by a number of examples") 289 | .takes_value(true)) 290 | .arg(Arg::with_name("predictions_after") 291 | .long("predictions_after") 292 | .value_name("examples (=0)") 293 | .help("After how many examples start printing predictions") 294 | .takes_value(true)) 295 | .arg(Arg::with_name("holdout_after") 296 | .conflicts_with("testonly") 297 | .required(false) 298 | .long("holdout_after") 299 | .value_name("examples") 300 | .help("After how many examples stop updating weights") 301 | .takes_value(true)) 302 | .arg(Arg::with_name("hogwild_training") 303 | .long("hogwild_training") 304 | .required(false) 305 | .help("Use faster lock-free multithreading training") 306 | .takes_value(false)) 307 | .arg(Arg::with_name("hogwild_threads") 308 | .long("hogwild_threads") 309 | .value_name("num_threads") 310 | .help("Number of threads to use with hogwild training") 311 | .takes_value(true)) 312 | .arg(Arg::with_name("weight_quantization") 313 | .long("weight_quantization") 314 | .value_name("Whether to consider weight quantization when reading/writing weights.") 315 | .help("Half-float quantization trigger (inference only is the suggested use).") 316 | .takes_value(false)) 317 | .arg(Arg::with_name("predictions_stdout") 318 | .long("predictions_stdout") 319 | .value_name("Output predictions to stdout") 320 | .help("Output predictions file to stdout") 321 | .takes_value(false)) 322 | } 323 | --------------------------------------------------------------------------------