├── .github └── workflows │ └── ci.yml ├── .gitignore ├── CHANGELOG.md ├── Cargo.lock ├── Cargo.toml ├── Makefile ├── README.md ├── docs ├── adding-operators.md ├── debugging.md ├── performance.md ├── quantization.md ├── release.md └── rten-file-format.md ├── index.js ├── js-examples └── image-classification │ ├── .gitignore │ ├── README.md │ ├── classify-node.js │ ├── classify-web.js │ ├── image-classifier.js │ ├── imagenet-classes.js │ ├── index.html │ ├── package-lock.json │ └── package.json ├── package-lock.json ├── package.json ├── pytorch-ref-tests ├── README.md ├── common.py ├── rnn.json └── rnn.py ├── rten-bench ├── Cargo.toml ├── README.md └── src │ └── lib.rs ├── rten-cli ├── Cargo.toml ├── README.md └── src │ ├── dim_size.rs │ └── main.rs ├── rten-convert ├── LICENSE ├── Makefile ├── README.md ├── pyproject.toml ├── requirements.dev.txt └── rten_convert │ ├── __init__.py │ ├── attr_reader.py │ ├── converter.py │ ├── errors.py │ ├── graph.py │ ├── schema_generated.py │ ├── tensor_data.py │ └── util.py ├── rten-examples ├── Cargo.toml ├── README.md ├── data │ ├── README.md │ ├── coco.names │ ├── dump_mel_filters.py │ ├── imagenet22k_synsets.txt │ ├── imagenet_synset_to_lemma.txt │ ├── imagenet_synsets.txt │ ├── mel_filters.json │ └── rust-questions.txt └── src │ ├── bert_qa.rs │ ├── bert_qa_reference.py │ ├── clip.rs │ ├── clip_reference.py │ ├── deeplab.rs │ ├── deeplab_reference.py │ ├── depth_anything.rs │ ├── detr.rs │ ├── distilvit.rs │ ├── export-deeplab.py │ ├── gpt2.rs │ ├── gpt2_reference.py │ ├── imagenet.rs │ ├── jina_similarity.rs │ ├── jina_similarity_reference.py │ ├── modernbert.rs │ ├── modernbert_reference.py │ ├── nougat.rs │ ├── piper.rs │ ├── qwen2_chat.rs │ ├── rmbg.rs │ ├── segment_anything.rs │ ├── silero.rs │ ├── trocr.rs │ ├── wav2vec2.rs │ ├── whisper.rs │ └── yolo.rs ├── rten-generate ├── Cargo.toml ├── README.md └── src │ ├── filter.rs │ ├── generator.rs │ ├── lib.rs │ ├── metrics.rs │ ├── model.rs │ ├── sampler.rs │ └── text_decoder.rs ├── rten-imageio ├── Cargo.toml ├── README.md └── src │ └── lib.rs ├── rten-imageproc ├── Cargo.toml ├── README.md └── src │ ├── contours.rs │ ├── drawing.rs │ ├── lib.rs │ ├── math.rs │ ├── normalize.rs │ ├── poly_algos.rs │ └── shapes.rs ├── rten-simd ├── Cargo.toml ├── README.md └── src │ ├── README.md │ ├── arch.rs │ ├── arch │ ├── aarch64.rs │ ├── generic.rs │ ├── wasm32.rs │ ├── x86_64.rs │ └── x86_64 │ │ ├── avx2.rs │ │ └── avx512.rs │ ├── dispatch.rs │ ├── elem.rs │ ├── functional.rs │ ├── isa_detection.rs │ ├── iter.rs │ ├── lib.rs │ ├── ops.rs │ ├── simd.rs │ ├── span.rs │ └── writer.rs ├── rten-tensor ├── Cargo.toml ├── README.md └── src │ ├── assume_init.rs │ ├── copy.rs │ ├── errors.rs │ ├── impl_debug.rs │ ├── impl_serialize.rs │ ├── index_iterator.rs │ ├── iterators.rs │ ├── layout.rs │ ├── lib.rs │ ├── macros.rs │ ├── overlap.rs │ ├── rng.rs │ ├── slice_range.rs │ ├── storage.rs │ ├── tensor.rs │ ├── test_util.rs │ └── type_num.rs ├── rten-testing ├── Cargo.toml ├── README.md └── src │ └── lib.rs ├── rten-text ├── Cargo.toml ├── README.md ├── src │ ├── lib.rs │ ├── models.rs │ ├── models │ │ ├── bpe.rs │ │ └── wordpiece.rs │ ├── normalizers.rs │ ├── pre_tokenizers.rs │ ├── split.rs │ ├── tokenizer.rs │ └── tokenizer │ │ └── json.rs ├── test-data │ ├── reftests │ │ ├── Metal_umlaut-bert-base-uncased.json │ │ ├── Metal_umlaut.txt │ │ ├── README.md │ │ ├── Rust_(programming_language)-bert-base-cased.json │ │ ├── Rust_(programming_language)-bert-base-uncased.json │ │ ├── Rust_(programming_language).txt │ │ ├── models │ │ │ ├── bert-base-cased │ │ │ │ └── vocab.txt │ │ │ ├── bert-base-uncased │ │ │ │ └── vocab.txt │ │ │ └── gpt2 │ │ │ │ ├── README.md │ │ │ │ ├── merges.txt │ │ │ │ ├── tokenizer.json │ │ │ │ └── vocab.json │ │ ├── monty-python-credits-bert-base-uncased.json │ │ ├── monty-python-credits-gpt2.json │ │ └── monty-python-credits.txt │ └── tokenizer-json │ │ ├── wordpiece-lower.json │ │ └── wordpiece.json ├── tests │ └── reftest.rs └── tools │ ├── fetch_wikipedia.py │ ├── reference_tokenize.py │ └── requirements.txt ├── rten-vecmath ├── Cargo.toml ├── README.md └── src │ ├── erf.rs │ ├── exp.rs │ ├── extend_init.rs │ ├── lib.rs │ ├── min_max.rs │ ├── normalize.rs │ ├── quantize.rs │ ├── softmax.rs │ ├── sum.rs │ ├── tanh.rs │ ├── testing.rs │ └── ulp.rs ├── src ├── constant_storage.rs ├── ctc.rs ├── downcast.rs ├── env.rs ├── gemm.rs ├── gemm │ ├── errors.rs │ ├── im2col.rs │ ├── kernels.rs │ ├── kernels │ │ ├── aarch64.rs │ │ ├── generic.rs │ │ ├── simd_generic.rs │ │ ├── wasm.rs │ │ └── x86_64.rs │ ├── packing.rs │ ├── packing │ │ └── int8.rs │ ├── prepack.rs │ ├── reduced_range_rng.rs │ └── tiles.rs ├── graph.rs ├── graph │ ├── builder.rs │ ├── capture_env.rs │ ├── node.rs │ ├── node_id.rs │ ├── noop_hash.rs │ └── planner.rs ├── header.rs ├── iter_util.rs ├── lib.rs ├── model.rs ├── model_builder.rs ├── model_metadata.rs ├── number.rs ├── op_registry.rs ├── ops │ ├── binary_elementwise.rs │ ├── concat.rs │ ├── control_flow.rs │ ├── conv.rs │ ├── conv │ │ ├── depthwise.rs │ │ └── im2col.rs │ ├── convert.rs │ ├── einsum.rs │ ├── gather.rs │ ├── generate.rs │ ├── identity.rs │ ├── layout.rs │ ├── matmul.rs │ ├── mod.rs │ ├── non_max_suppression.rs │ ├── norm.rs │ ├── operators.rs │ ├── pad.rs │ ├── pooling.rs │ ├── quantize.rs │ ├── random.rs │ ├── reduce.rs │ ├── resize.rs │ ├── rnn.rs │ ├── slice.rs │ ├── split.rs │ ├── transform_inputs.rs │ ├── trilu.rs │ ├── unary_elementwise.rs │ └── variadic_elementwise.rs ├── optimize.rs ├── optimize │ └── pattern_matcher.rs ├── schema.fbs ├── schema_generated.rs ├── shift_cast.rs ├── slice_cast.rs ├── slice_reductions.rs ├── tensor_pool.rs ├── threading.rs ├── timing.rs ├── wasm_api.rs └── weight_cache.rs └── tools ├── __init__.py ├── add-node-outputs-to-model.py ├── benchmarks └── wasm-gemm.js ├── compare-tensors.py ├── debug_utils.py ├── export-timm-model.py ├── generate-coverage.sh ├── optimize-wasm.sh ├── ort-infer.py ├── ort-profile-summarize.py ├── ort-quantize.py ├── requirements.txt ├── test-images ├── README.md └── horses.jpeg └── update-onnx-model.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: 3 | pull_request: 4 | push: 5 | branches: 6 | - 'main' 7 | - 'ci/**' 8 | jobs: 9 | ci: 10 | strategy: 11 | matrix: 12 | # The macOS build tests Arm and macOS-specific code paths. 13 | # The Linux build tests everything else (x64, wasm, Python ...) 14 | os: [ubuntu-latest, macos-14, ubuntu-24.04-arm] 15 | runs-on: ${{ matrix.os }} 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v4 19 | - name: Install Rust WASM targets 20 | run: | 21 | rustup target add wasm32-unknown-unknown 22 | rustup target add wasm32-wasip1 23 | if: ${{ matrix.os == 'ubuntu-latest' }} 24 | - name: Install wasmtime 25 | run: | 26 | mkdir -p ~/.wasmtime 27 | curl -L https://github.com/bytecodealliance/wasmtime/releases/download/v29.0.1/wasmtime-v29.0.1-x86_64-linux.tar.xz | tar xf - --xz -C ~/.wasmtime --strip-components=1 28 | echo "$HOME/.wasmtime" >> "$GITHUB_PATH" 29 | if: ${{ matrix.os == 'ubuntu-latest' }} 30 | - name: Install Rust nightly toolchain 31 | run: rustup toolchain install nightly 32 | if: ${{ matrix.os == 'ubuntu-latest' }} 33 | - name: Install Rust x86_64-apple-darwin target 34 | run: rustup target add x86_64-apple-darwin 35 | if: ${{ matrix.os == 'macos-14' }} 36 | - name: Query Rust version 37 | run: | 38 | rustc --version 39 | cargo --version 40 | - name: Cache 41 | uses: actions/cache@v3 42 | with: 43 | path: | 44 | ~/.cargo/ 45 | target/ 46 | key: ${{ matrix.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} 47 | - name: Install wasm-bindgen 48 | # nb. wasm-bindgen-cli version must match `wasm-bindgen` version in Cargo.lock 49 | run: cargo install wasm-bindgen-cli --version 0.2.100 50 | if: ${{ matrix.os == 'ubuntu-latest' }} 51 | - name: Build 52 | run: cargo build 53 | - name: Test 54 | run: make test 55 | # We compile AVX-512 in CI but don't run tests as GitHub Actions' default 56 | # runners don't support it yet (https://github.com/actions/runner/issues/1069). 57 | - name: Build (AVX-512) 58 | run: cargo +nightly check -p rten --features avx512 59 | if: ${{ matrix.os == 'ubuntu-latest' }} 60 | - name: Build (WASM) 61 | run: make wasm 62 | if: ${{ matrix.os == 'ubuntu-latest' }} 63 | - name: Test (WASM) 64 | run: | 65 | make wasm-test wasm-test-simd 66 | if: ${{ matrix.os == 'ubuntu-latest' }} 67 | - name: Build (Intel macOS) 68 | run: cargo check --workspace --target x86_64-apple-darwin 69 | if: ${{ matrix.os == 'macos-14' }} 70 | - name: Lint 71 | run: | 72 | make checkformatting 73 | make lint 74 | - name: Docs 75 | run: | 76 | make docs 77 | - name: Setup Python 78 | run: | 79 | python -m venv .venv 80 | .venv/bin/pip install --upgrade pip 81 | if: ${{ matrix.os == 'ubuntu-latest' }} 82 | - name: Python Lint 83 | run: | 84 | source .venv/bin/activate 85 | cd rten-convert 86 | pip install -e . 87 | pip install -r requirements.dev.txt 88 | make check 89 | if: ${{ matrix.os == 'ubuntu-latest' }} 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Rus 2 | target/ 3 | 4 | # JS 5 | node_modules/ 6 | 7 | # Python 8 | __pycache__ 9 | *.egg-info/ 10 | dist/ 11 | 12 | # Converted rten ML models 13 | *.rten 14 | /models/ 15 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | ".", 4 | "rten-cli", 5 | "rten-generate", 6 | "rten-imageio", 7 | "rten-imageproc", 8 | "rten-simd", 9 | "rten-tensor", 10 | "rten-text", 11 | "rten-vecmath", 12 | 13 | # Example crates. These are not published. 14 | "rten-examples", 15 | 16 | # Development crates. These are not published. 17 | "rten-bench", 18 | "rten-testing", 19 | ] 20 | default-members = [ 21 | ".", 22 | "rten-imageproc", 23 | "rten-tensor", 24 | "rten-text" 25 | ] 26 | 27 | [workspace.dependencies] 28 | image = { version = "0.25.1", default-features = false, features = ["png", "jpeg", "webp"] } 29 | serde = { version = "1.0.202" } 30 | serde_json = { version = "1.0.117" } 31 | 32 | [package] 33 | name = "rten" 34 | version = "0.18.0" 35 | edition = "2021" 36 | authors = ["Robert Knight"] 37 | description = "Machine learning runtime" 38 | license = "MIT OR Apache-2.0" 39 | homepage = "https://github.com/robertknight/rten" 40 | repository = "https://github.com/robertknight/rten" 41 | resolver = "2" 42 | include = ["/src", "/CHANGELOG.md", "/README.md"] 43 | 44 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 45 | 46 | [dependencies] 47 | flatbuffers = "24.3.25" 48 | rayon = "1.7.0" 49 | smallvec = { version = "1.10.0", features = ["union", "const_generics", "const_new"] } 50 | rten-tensor = { path = "./rten-tensor", version = "0.18.0" } 51 | rten-vecmath = { path = "./rten-vecmath", version = "0.18.0" } 52 | rten-simd = { path = "./rten-simd", version = "0.18.0" } 53 | fastrand = { version = "2.0.2", optional = true } 54 | fastrand-contrib = { version = "0.1.0", optional = true } 55 | rustc-hash = "2.0.0" 56 | memmap2 = { version = "0.9.4", optional = true } 57 | num_cpus = "1.16.0" 58 | 59 | [dev-dependencies] 60 | libm = "0.2.6" 61 | rten-bench = { path = "./rten-bench" } 62 | rten-testing = { path = "./rten-testing" } 63 | serde_json = { workspace = true } 64 | 65 | [lib] 66 | crate-type = ["lib", "cdylib"] 67 | 68 | [features] 69 | # Use AVX-512 instructions if available. Requires nightly Rust for AVX-512 intrinsics. 70 | avx512 = ["rten-simd/avx512", "rten-vecmath/avx512"] 71 | # Enable loading models using memory mapping 72 | mmap = ["dep:memmap2"] 73 | # Generate WebAssembly API using wasm-bindgen. 74 | wasm_api = [] 75 | # Enable operators that generate random numbers. 76 | random = ["dep:fastrand", "dep:fastrand-contrib"] 77 | 78 | [target.'cfg(target_arch = "wasm32")'.dependencies] 79 | wasm-bindgen = "0.2.100" 80 | 81 | [lints.clippy] 82 | # `assert!(const)` effectively used as a static assert, which compiler will 83 | # optimize away. 84 | assertions_on_constants = "allow" 85 | # Clippy frequently suggests to replace for loops with const bounds (often used 86 | # in performance-critical loops) with iterators, which is more verbose and 87 | # potentially less efficient. 88 | needless_range_loop = "allow" 89 | too_many_arguments = "allow" 90 | manual_repeat_n = "allow" # TODO - Address existing failures 91 | 92 | [package.metadata.docs.rs] 93 | # These features should match the features enabled by `make docs`. 94 | features = [ 95 | "mmap", 96 | "random", 97 | ] 98 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: all 2 | all: 3 | cargo build --workspace 4 | 5 | .PHONY: schema 6 | schema: src/schema_generated.rs rten-convert/rten_convert/schema_generated.py 7 | 8 | .PHONY: clean 9 | clean: 10 | rm -rf dist/* 11 | rm -rf target/ 12 | 13 | .PHONY: check 14 | check: checkformatting test lint 15 | 16 | .PHONY: checkformatting 17 | checkformatting: 18 | cargo fmt --check 19 | 20 | .PHONY: docs 21 | docs: 22 | RUSTDOCFLAGS='-D warnings' cargo doc -p rten --features mmap,random 23 | 24 | .PHONY: lint 25 | lint: 26 | cargo clippy --workspace 27 | 28 | .PHONY: miri 29 | miri: 30 | # - Only the tensor lib is currently tested. Testing the main crate will 31 | # require changes to prevent tests taking too long to run. 32 | cargo +nightly miri test -p rten-tensor 33 | 34 | # Run tests for all crates with all features enabled that do not require 35 | # nightly Rust. 36 | .PHONY: test 37 | test: 38 | cargo test --no-fail-fast --workspace --features mmap,random,text-decoder,serde 39 | 40 | .PHONY: wasm 41 | wasm: 42 | RUSTFLAGS="-C target-feature=+simd128" cargo build --features=wasm_api --release --target wasm32-unknown-unknown 43 | wasm-bindgen target/wasm32-unknown-unknown/release/rten.wasm --out-dir dist/ --target web --weak-refs 44 | # This makes the binary smaller but also removes all symbols. Comment this 45 | # out to get a release WASM build with symbols. 46 | tools/optimize-wasm.sh dist/rten_bg.wasm 47 | 48 | .PHONY: wasm-relaxedsimd 49 | wasm-relaxedsimd: 50 | RUSTFLAGS="-C target-feature=+simd128,+relaxed-simd" cargo build --features=wasm_api --release --target wasm32-unknown-unknown 51 | wasm-bindgen target/wasm32-unknown-unknown/release/rten.wasm --out-dir dist/ --target web --weak-refs 52 | 53 | .PHONY: wasm-nosimd 54 | wasm-nosimd: 55 | cargo build --release --target wasm32-unknown-unknown 56 | wasm-bindgen target/wasm32-unknown-unknown/release/rten.wasm --out-dir dist/ --out-name rten-nosimd --target web --weak-refs 57 | tools/optimize-wasm.sh dist/rten-nosimd_bg.wasm 58 | 59 | .PHONY: wasm-all 60 | wasm-all: wasm wasm-nosimd 61 | 62 | # WASM tests run with `--nocapture` as otherwise assertion failure panic messages 63 | # are not printed if a test assert fails. 64 | .PHONY: wasm-tests 65 | wasm-test: 66 | rm -f target/wasm32-wasi/debug/deps/rten-*.wasm 67 | RUSTFLAGS="-C target-feature=+simd128" cargo build --target wasm32-wasip1 --tests -p rten 68 | wasmtime --dir . target/wasm32-wasip1/debug/deps/rten-*.wasm --nocapture 69 | 70 | .PHONY: wasm-tests 71 | wasm-test-simd: 72 | rm -f target/wasm32-wasi/debug/deps/rten_simd-*.wasm 73 | RUSTFLAGS="-C target-feature=+simd128" cargo build --target wasm32-wasip1 --tests -p rten-simd 74 | wasmtime --dir . target/wasm32-wasip1/debug/deps/rten_simd-*.wasm --nocapture 75 | 76 | .PHONY: wasm-bench-gemm 77 | wasm-bench-gemm: 78 | rm -f target/wasm32-wasi/release/deps/rten-*.wasm 79 | RUSTFLAGS="-C target-feature=+simd128" cargo build --target wasm32-wasip1 --tests -p rten -r 80 | wasmtime --dir . target/wasm32-wasip1/release/deps/rten-*.wasm --nocapture --ignored bench_gemm_mix 81 | 82 | src/schema_generated.rs: src/schema.fbs 83 | flatc -o src/ --rust src/schema.fbs 84 | cargo fmt 85 | (echo "#![allow(clippy::all)]" && cat src/schema_generated.rs) > src/schema_generated.rs.tmp 86 | mv src/schema_generated.rs.tmp src/schema_generated.rs 87 | 88 | rten-convert/rten_convert/schema_generated.py: src/schema.fbs 89 | flatc -o rten-convert/rten_convert --gen-onefile --gen-object-api --python src/schema.fbs 90 | 91 | 92 | .PHONY: gen-pytorch-references 93 | gen-pytorch-references: 94 | python -m pytorch-ref-tests.rnn 95 | -------------------------------------------------------------------------------- /docs/adding-operators.md: -------------------------------------------------------------------------------- 1 | # Adding new operators 2 | 3 | Adding support for a new operator involves several steps: 4 | 5 | 1. Reading the ONNX operator specification to understand how the operator 6 | works. See https://onnx.ai/onnx/operators/. 7 | 2. Defining the implementation of the new operator in Rust code 8 | 3. Adding the new operator to the FlatBuffers model schema, and implementing 9 | support for reading it on the Rust side, and writing it in the Python 10 | script that converts ONNX models to this library's format. 11 | 4. Adding tests for the new operator's implementation and deserialization 12 | 13 | In detail, the process is: 14 | 15 | 1. Add the new operator to the end of the `OperatorType` enum in schema.fbs. 16 | 2. If the new operator requires attributes, add a new table in schema.fbs and 17 | add the table to the end of the `OperatorAttrs` union. If the new operator 18 | uses the same attributes as an existing operator, it can re-use the 19 | attributes from that operator. 20 | 3. Run `make` to generate updated Rust and Python code to read the updated 21 | FlatBuffers schema 22 | 4. If the new operator has attributes, edit `rten-convert/rten_convert/converter.py` and reinstall rten-convert to read 23 | the attributes from ONNX and convert to this library's model format 24 | 5. Define the implementation of the new operator in Rust. This is a struct 25 | that implements the `Operator` trait. 26 | 6. Add tests for the new operator at the bottom of the module where the 27 | operator is defined. 28 | 7. Export the operator from the `ops/mod.rs` module 29 | 8. Modify `read_operator` in `model.rs` to read the new operator from model 30 | files. 31 | 9. Add support for the new operator in `model_builder.rs` 32 | 10. Update the `test_all_op_types` test at the bottom of model.rs to run the 33 | new operator with test input. 34 | 35 | ## Adding partial support for a new operator 36 | 37 | It is OK to add a new operator without support for all features, but ONNX model 38 | conversion and the operator implementation must report an error if an 39 | unsupported capability is used by a model, rather than silently producing 40 | incorrect results. 41 | 42 | The Python ONNX conversion script will check that all attributes of an operator 43 | in the ONNX model are read. Unsupported attributes can be ignored if they have 44 | a value which is equal to the default. 45 | 46 | ## FlatBuffers binary compatibility 47 | 48 | Additions to the FlatBuffers schema for models should preserve binary 49 | compatibility with existing model files. This is achieved for enums, unions and 50 | tables by making additions at the end of the item. 51 | -------------------------------------------------------------------------------- /docs/debugging.md: -------------------------------------------------------------------------------- 1 | # Debugging 2 | 3 | This document provides strategies for debugging incorrect/different outputs in 4 | RTen compared to other runtimes. 5 | 6 | ## Inspecting models 7 | 8 | [Netron](https://netron.app) is an online tool, also available as an Electron 9 | app for visualizing ONNX models. 10 | 11 | ## Comparing against ONNX Runtime 12 | 13 | [ONNX Runtime](https://onnxruntime.ai) ("ORT") is the most mature implementation 14 | of the ONNX specification, and is often used as a reference for correctness and 15 | performance testing. 16 | 17 | The general steps to use ORT to compare and debug unexpected output from RTen 18 | are: 19 | 20 | 1. Create a Python script to execute the model with ORT, and a 21 | corresponding Rust binary to execute the model with RTen. 22 | 23 | 2. Verify that the model produces the expected results with ORT. 24 | 25 | 3. Verify that the inputs to the model, after all preprocessing and conversion 26 | to tensors, are the same in RTen and ORT. 27 | 28 | 4. Verify that there are significant differences in the RTen vs ORT output. 29 | 30 | 5. Compare the values of intermediate outputs in the graph to find where 31 | significant differences begin to arise. For small tensors, the values can 32 | simply be printed and inspected by eye. Most tensors will be larger however 33 | and so it is useful to get statistics of a comparison. A typical approach is: 34 | 35 | 1. Run the model specifying an intermediate node as an output. 36 | 37 | RTen allows any node in the graph to be specified as an output. ORT on the 38 | other hand only allows nodes specified in the model's output list to be 39 | fetched as an output from a model run. The 40 | `tools/add-node-outputs-to-model.py` script works around this limitation of 41 | ORT by producing a modified ONNX model that lists every node in the graph 42 | in the output list. 43 | 44 | 2. Write out the resulting intermediate tensors. The `Tensor::write` method 45 | can be used for this in RTen and the `write_tensor` function in 46 | `tools/debug_utils.py` in Python. 47 | 48 | 3. Compare the resulting tensors. `tools/compare-tensors.py` compares tensor 49 | shapes and reports statistics on the absolute difference between 50 | corresponding values. 51 | 52 | Repeat steps 1-3 until you have identified where in the model discrepancies 53 | begin to arise. Note that very small differences for individual values, 54 | eg. on the order of 5 or 6 places after the decimal point, are normal for 55 | certain operations due to implementation differences. 56 | -------------------------------------------------------------------------------- /docs/release.md: -------------------------------------------------------------------------------- 1 | # Release process 2 | 3 | The release process uses 4 | [cargo-release](https://github.com/crate-ci/cargo-release) to simplify releasing 5 | multiple crates from a single repository. 6 | 7 | 1. Update `CHANGELOG.md` in root of repo. To find PRs merged since the last 8 | release you can use `git log --oneline v{prev_version}..main --merges`. 9 | 2. Run `cargo release changes` to find out which crates changed since the 10 | previous release 11 | 3. Run `cargo release --workspace ` to do a dry run. Alternatively 12 | use `cargo release -p crate1 -p crate2 ` to do this just for 13 | crates with changes in the new release. You can also use the `--exclude` flag 14 | to exclude packages that haven't changed. 15 | 4. If the dry run looks good, run step 2 again with the `--execute` flag 16 | 5. Bump the package version of `rten-convert` in its `pyproject.toml` file 17 | 6. Publish `rten-convert` to PyPI by running `make release` in the directory 18 | containing rten-convert's pyproject.toml. 19 | -------------------------------------------------------------------------------- /docs/rten-file-format.md: -------------------------------------------------------------------------------- 1 | # RTen model format 2 | 3 | RTen model files (`.rten`) contain the computation graph for a machine learning 4 | model, model metadata and weights. The format is designed to be efficient to 5 | load and to minimize additional memory required, beyond the size of the file 6 | itself. 7 | 8 | RTen files are produced by exporting models from a machine learning framework 9 | such as PyTorch or Keras into [ONNX](https://onnx.ai) format, and converting the 10 | ONNX model to `.rten` using the 11 | [rten-convert](https://pypi.org/project/rten-convert/) tool. 12 | 13 | ## Compatibility 14 | 15 | The `rten-convert` tool and `rten` Rust crate have version numbers that are 16 | aligned. A `.rten` model produced by version X of `rten-convert` can be read by 17 | version X of the `rten` crate or newer. Models produced by version X of 18 | `rten-convert` _may_ work with earlier versions of `rten` as long as the model 19 | does not rely on operators or attributes that were added in version X. 20 | 21 | ## History 22 | 23 | There are two versions of the RTen model format. The second version added 24 | support for models larger than 2GB. RTen can load models in either format. The 25 | `rten-convert` tool generates the V2 format by default, and will generate the V1 26 | format if the `--v1` flag is passed. 27 | 28 | ## V2 format 29 | 30 | ### Overall structure 31 | 32 | The overall structure of a `.rten` file is: 33 | 34 | ``` 35 | [header] … [model_data] … [tensor_data] 36 | ``` 37 | 38 | ### Header 39 | 40 | The header identifies the file type, the major version of the format and 41 | contains the offsets of the other sections. The structure of the header is: 42 | 43 | ``` 44 | [magic:u8x4] [version:u32] [model_data_offset:u64] [model_data_len:u64] [tensor_data_offset:u64] 45 | ``` 46 | 47 | All numbers are encoded in little-endian order. 48 | 49 | - `magic` - The ASCII bytes `RTEN` 50 | - `version` - Currently 2 51 | - `model_data_offset` - Offset of the data describing the model 52 | - `model_data_len` - Length of the data describing the model 53 | - `tensor_data_offset` - Offset of the start of tensor data. Tensor references in 54 | the model data are relative to this. 55 | 56 | ### Model data 57 | 58 | The model data is a [FlatBuffers](https://flatbuffers.dev) buffer which 59 | describes the computation graph for the model. It also contains metadata about 60 | the model. 61 | 62 | The computation graph consists of three kinds of nodes: constants (weights, 63 | biases etc.), values (inputs or outputs from computation steps) and operators 64 | (computation steps such as matrix multiplication). The operators correspond 65 | closely to operators in the [ONNX 66 | specification](https://onnx.ai/onnx/operators/). Constant nodes describe the 67 | data type and shape of tensors. The data for a tensor can either be stored 68 | inline in the model or externally in the tensor data section. 69 | 70 | The FlatBuffers schema can be found in `src/schema.fbs`. 71 | 72 | ### Tensor data 73 | 74 | The tensor data section is a block of bytes referenced by the model data. The 75 | shape of tensors, type of elements and other metadata is contained in the model 76 | data. 77 | 78 | ## V1 format 79 | 80 | The first version of the `.rten` model format consisted of just the model 81 | data without the header or tensor data sections. The FlatBuffers schema used by 82 | V1 is the same as V2. 83 | 84 | This was changed due to FlatBuffers having a 2GB file 85 | size limit, and also to enable more control over the alignment of tensor data. 86 | -------------------------------------------------------------------------------- /index.js: -------------------------------------------------------------------------------- 1 | export { 2 | default as init, 3 | initSync, 4 | Model, 5 | Tensor, 6 | } from "./dist/rten.js"; 7 | 8 | /** 9 | * Return true if the current JS environment supports the SIMD extension for 10 | * WebAssembly. 11 | */ 12 | function simdSupported() { 13 | // Tiny WebAssembly file generated from the following source using `wat2wasm`: 14 | // 15 | // (module 16 | // (func (result v128) 17 | // i32.const 0 18 | // i8x16.splat 19 | // i8x16.popcnt 20 | // ) 21 | // ) 22 | const simdTest = Uint8Array.from([ 23 | 0, 97, 115, 109, 1, 0, 0, 0, 1, 5, 1, 96, 0, 1, 123, 3, 2, 1, 0, 10, 10, 1, 24 | 8, 0, 65, 0, 253, 15, 253, 98, 11, 25 | ]); 26 | return WebAssembly.validate(simdTest); 27 | } 28 | 29 | /** 30 | * Return the filename of the preferred RTen binary for the current 31 | * environment. 32 | */ 33 | export function binaryName() { 34 | if (simdSupported()) { 35 | return "rten_bg.wasm"; 36 | } else { 37 | return "rten-nosimd_bg.wasm"; 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /js-examples/image-classification/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules/ 2 | *.model 3 | *.onnx 4 | -------------------------------------------------------------------------------- /js-examples/image-classification/README.md: -------------------------------------------------------------------------------- 1 | # MobileNet image classification 2 | 3 | This demo classifies the content of images using MobileNet v2. The known 4 | categories are listed in `imagenet-classes.js`. The model works best when 5 | there is an obvious central subject in the image, and the subject is a common 6 | kind of object. 7 | 8 | ## Setup 9 | 10 | 1. Build the main RTen project for WebAssembly. See the README.md file at the 11 | root of the repository. 12 | 2. In this directory, run `npm install` 13 | 3. Download the ONNX MobileNet model from the ONNX Model Zoo and convert it 14 | to `.rten` format: 15 | 16 | ```sh 17 | curl -L https://github.com/onnx/models/raw/main/Computer_Vision/mobilenetv2_110d_Opset18_timm/mobilenetv2_110d_Opset18.onnx -o mobilenet.onnx 18 | 19 | rten-convert mobilenet.onnx mobilenet.rten 20 | ``` 21 | 4. Follow either of the subsections below to run the example in Node or the 22 | browser 23 | 24 | ## Running in Node 25 | 26 | ```sh 27 | $ node classify-node.js espresso.png 28 | 29 | # Example output 30 | Most likely categories: 31 | - espresso 32 | - chocolate sauce, chocolate syrup 33 | - cup 34 | - ice cream, icecream 35 | - plate 36 | ``` 37 | 38 | ## Running in a browser 39 | 40 | 1. Start a web server: 41 | 42 | ``` 43 | python -m http.server 3010 44 | ``` 45 | 46 | 2. Open http://localhost:3010/ 47 | 3. Click "Choose file" and select a photo or image to classify 48 | -------------------------------------------------------------------------------- /js-examples/image-classification/classify-node.js: -------------------------------------------------------------------------------- 1 | import { readFileSync } from "fs"; 2 | 3 | import sharp from "sharp"; 4 | import { initSync, binaryName } from "rten"; 5 | 6 | import { ImageClassifier } from "./image-classifier.js"; 7 | import { IMAGENET_CLASSES } from "./imagenet-classes.js"; 8 | 9 | /** 10 | * Load a JPEG or PNG image from `path`, resize it to `width`x`height` and 11 | * return the RGB image data as an `ImageData`-like object. 12 | */ 13 | async function loadImage(path, width, height) { 14 | const image = await sharp(path) 15 | .removeAlpha() 16 | .resize(width, height, { fit: "fill" }); 17 | return { 18 | data: new Uint8Array(await image.raw().toBuffer()), 19 | width, 20 | height, 21 | }; 22 | } 23 | 24 | const path = process.argv[2]; 25 | const modelPath = process.argv[3] ?? "./mobilenet.rten"; 26 | 27 | // Initialize RTen. 28 | const rtenBinary = readFileSync("node_modules/rten/dist/" + binaryName()); 29 | initSync(rtenBinary); 30 | 31 | // Load the MobileNet classification model. 32 | const modelData = new Uint8Array(readFileSync(modelPath)); 33 | const classifier = new ImageClassifier(modelData); 34 | const { width, height } = classifier.inputSize(); 35 | const image = await loadImage(path, width, height); 36 | 37 | const classifyStart = Date.now(); 38 | const top5 = classifier.classify(image); 39 | const classifyEnd = Date.now(); 40 | 41 | const topCategories = top5.map( 42 | ([classIndex, score]) => IMAGENET_CLASSES[classIndex] 43 | ); 44 | 45 | console.log( 46 | `Analyzed image in ${classifyEnd - classifyStart}ms. Most likely categories:` 47 | ); 48 | for (let category of topCategories) { 49 | console.log(" - " + category); 50 | } 51 | -------------------------------------------------------------------------------- /js-examples/image-classification/classify-web.js: -------------------------------------------------------------------------------- 1 | import { init as initRTen, binaryName } from "./node_modules/rten/index.js"; 2 | 3 | import { ImageClassifier } from "./image-classifier.js"; 4 | import { IMAGENET_CLASSES } from "./imagenet-classes.js"; 5 | 6 | /** 7 | * Fetch a binary file from `url`. 8 | * 9 | * @param {string} url 10 | * @return {Promise} 11 | */ 12 | async function fetchBinary(url) { 13 | const response = await fetch(url); 14 | if (!response.ok) { 15 | throw new Error(`Failed to fetch ${url}`); 16 | } 17 | const buffer = await response.arrayBuffer(); 18 | return new Uint8Array(buffer); 19 | } 20 | 21 | /** 22 | * Extract the pixel data from an ImageBitmap. 23 | * 24 | * @param {ImageBitmap} bitmap 25 | * @return {ImageData} 26 | */ 27 | function imageDataFromBitmap(bitmap) { 28 | let canvas; 29 | if (typeof OffscreenCanvas !== "undefined") { 30 | canvas = new OffscreenCanvas(bitmap.width, bitmap.height); 31 | } else if (typeof HTMLCanvasElement !== "undefined") { 32 | const canvasEl = document.createElement("canvas"); 33 | canvasEl.width = bitmap.width; 34 | canvasEl.height = bitmap.height; 35 | canvas = canvasEl; 36 | } else { 37 | throw new Error("No canvas implementation available"); 38 | } 39 | 40 | const context = canvas.getContext("2d"); 41 | context.drawImage(bitmap, 0, 0, bitmap.width, bitmap.height); 42 | return context.getImageData(0, 0, bitmap.width, bitmap.height); 43 | } 44 | 45 | /** 46 | * Initialize an image classifier using the RTen engine and MobileNet v2 47 | * model. 48 | */ 49 | async function createClassifier() { 50 | // Fetch the RTen engine and MobileNet model in parallel. 51 | const [, modelData] = await Promise.all([ 52 | fetch("./node_modules/rten/dist/" + binaryName()).then(initRTen), 53 | fetchBinary("./mobilenet.rten"), 54 | ]); 55 | 56 | // Initialize the classifier. This must be done after RTen is initialized. 57 | return new ImageClassifier(modelData); 58 | } 59 | 60 | async function init() { 61 | // Start to initialize the classifier pre-emptively, before an image is 62 | // selected. This reduces the delay for the user after the initial selection. 63 | const classifierPromise = createClassifier(); 64 | 65 | const fileInput = document.querySelector("#file"); 66 | const resultList = document.querySelector("#result-list"); 67 | const statusInfo = document.querySelector("#status"); 68 | 69 | fileInput.onchange = async () => { 70 | statusInfo.textContent = "Downloading model..."; 71 | const classifier = await classifierPromise; 72 | const { width, height } = classifier.inputSize(); 73 | 74 | const bitmap = await createImageBitmap(fileInput.files[0], { 75 | // Resize image to input dimensions expected by model. 76 | resizeWidth: width, 77 | resizeHeight: height, 78 | }); 79 | 80 | statusInfo.textContent = "Thinking..."; 81 | const imageData = imageDataFromBitmap(bitmap); 82 | const classes = classifier.classify(imageData); 83 | 84 | statusInfo.textContent = "Things that may be in this image:"; 85 | 86 | resultList.innerHTML = ""; 87 | const listItems = classes.map(([classIndex, score]) => { 88 | const item = document.createElement("li"); 89 | item.textContent = IMAGENET_CLASSES[classIndex]; 90 | return item; 91 | }); 92 | resultList.append(...listItems); 93 | }; 94 | } 95 | 96 | init().catch((err) => { 97 | console.error("Error initializing classifier:", err); 98 | }); 99 | -------------------------------------------------------------------------------- /js-examples/image-classification/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | RTen MobileNet image classification demo 5 | 6 | 7 | 8 |

RTen image classification demo

9 | 10 |

11 | This demo guesses what is in an image. It uses the popular 12 | MobileNet v2 13 | model. 14 |

15 | 16 | 17 | 18 | 19 |

Results:

20 |

21 |
    22 |
23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /js-examples/image-classification/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "rten-mobilenet-demo", 3 | "version": "1.0.0", 4 | "description": "RTen MobileNet image classification demo", 5 | "main": "mobilenet.js", 6 | "type": "module", 7 | "scripts": { 8 | "test": "echo \"Error: no test specified\" && exit 1" 9 | }, 10 | "repository": { 11 | "type": "git", 12 | "url": "git+https://github.com/robertknight/rten.git" 13 | }, 14 | "author": "Robert Knight ", 15 | "license": "BSD-2-Clause", 16 | "bugs": { 17 | "url": "https://github.com/robertknight/rten/issues" 18 | }, 19 | "homepage": "https://github.com/robertknight/rten#readme", 20 | "dependencies": { 21 | "rten": "file:../../", 22 | "sharp": "^0.33.1" 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "rten", 3 | "version": "1.0.0", 4 | "type": "module", 5 | "description": "", 6 | "main": "index.js", 7 | "scripts": { 8 | "test": "echo \"Error: no test specified\" && exit 1" 9 | }, 10 | "repository": { 11 | "type": "git", 12 | "url": "git+https://github.com/robertknight/rten.git" 13 | }, 14 | "keywords": [], 15 | "author": "Robert Knight ", 16 | "license": "BSD-2-Clause", 17 | "bugs": { 18 | "url": "https://github.com/robertknight/rten/issues" 19 | }, 20 | "homepage": "https://github.com/robertknight/rten#readme", 21 | "devDependencies": { 22 | "sharp": "^0.33.1" 23 | }, 24 | "files": [ 25 | "dist/*", 26 | "index.js" 27 | ] 28 | } 29 | -------------------------------------------------------------------------------- /pytorch-ref-tests/README.md: -------------------------------------------------------------------------------- 1 | This directory contains scripts which use PyTorch to generate a set of test 2 | cases with reference outputs in a JSON format that Rust tests can easily load. 3 | -------------------------------------------------------------------------------- /pytorch-ref-tests/common.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | import torch.nn as nn 3 | 4 | def tensor_json(x: Tensor): 5 | """ 6 | Convert a tensor to a JSON-serializable representation. 7 | """ 8 | return [list(x.shape), x.flatten().tolist()] 9 | 10 | 11 | def params_json(m: nn.Module): 12 | """ 13 | Convert a PyTorch module's parameters to a JSON-serializable dict. 14 | """ 15 | params_dict = {} 16 | for name, tensor in m.state_dict().items(): 17 | params_dict[name] = tensor_json(tensor) 18 | return params_dict 19 | 20 | 21 | -------------------------------------------------------------------------------- /pytorch-ref-tests/rnn.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from torch import Tensor 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .common import tensor_json, params_json 9 | 10 | 11 | def gen_lstm_test_case( 12 | module: nn.Module, inputs: Tensor, initial: tuple[Tensor, Tensor] | None = None 13 | ) -> dict: 14 | output, (last_hidden, last_cell) = module(x, initial) 15 | case = { 16 | "input": tensor_json(x), 17 | "output": tensor_json(output), 18 | "params": params_json(module), 19 | } 20 | if initial: 21 | case["initial_hidden"] = tensor_json(initial[0]) 22 | case["initial_cell"] = tensor_json(initial[1]) 23 | return case 24 | 25 | 26 | def gen_gru_test_case( 27 | module: nn.Module, inputs: Tensor, initial: Tensor | None = None 28 | ) -> dict: 29 | ( 30 | output, 31 | last_hidden, 32 | ) = module(x, initial) 33 | case = { 34 | "input": tensor_json(x), 35 | "output": tensor_json(output), 36 | "params": params_json(module), 37 | } 38 | if initial is not None: 39 | case["initial_hidden"] = tensor_json(initial) 40 | return case 41 | 42 | 43 | # Ensure we get the same output on every run. 44 | torch.manual_seed(1234) 45 | 46 | input_features = 10 47 | hidden_size = 5 48 | seq_len = 7 49 | 50 | x = torch.rand((seq_len, input_features)) 51 | initial_hidden = torch.rand((1, hidden_size)) 52 | initial_cell = torch.rand((1, hidden_size)) 53 | 54 | lstm = nn.LSTM(input_size=input_features, hidden_size=hidden_size) 55 | lstm_bidirectional = nn.LSTM( 56 | input_size=input_features, hidden_size=hidden_size, bidirectional=True 57 | ) 58 | 59 | gru = nn.GRU(input_size=input_features, hidden_size=hidden_size) 60 | gru_bidirectional = nn.GRU( 61 | input_size=input_features, hidden_size=hidden_size, bidirectional=True 62 | ) 63 | 64 | test_cases = { 65 | "__comment__": f"Generated with {os.path.basename(__file__)}", 66 | "lstm_forwards": gen_lstm_test_case(lstm, x), 67 | "lstm_bidirectional": gen_lstm_test_case(lstm_bidirectional, x), 68 | "lstm_initial": gen_lstm_test_case(lstm, x, (initial_hidden, initial_cell)), 69 | "gru_forwards": gen_gru_test_case(gru, x), 70 | "gru_bidirectional": gen_gru_test_case(gru_bidirectional, x), 71 | "gru_initial": gen_gru_test_case(gru, x, initial_hidden), 72 | } 73 | 74 | script_dir = os.path.dirname(__file__) 75 | with open(f"{script_dir}/rnn.json", "w") as f: 76 | json.dump(test_cases, f, indent=2) 77 | -------------------------------------------------------------------------------- /rten-bench/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rten-bench" 3 | version = "0.1.0" 4 | edition = "2021" 5 | authors = ["Robert Knight"] 6 | description = "Benchmarking utilities for use in RTen development" 7 | license = "MIT OR Apache-2.0" 8 | homepage = "https://github.com/robertknight/rten" 9 | repository = "https://github.com/robertknight/rten" 10 | 11 | [package.metadata.release] 12 | release = false 13 | 14 | [lib] 15 | crate-type = ["lib"] 16 | -------------------------------------------------------------------------------- /rten-bench/README.md: -------------------------------------------------------------------------------- 1 | Internal benchmarking utilities used in RTen development. 2 | -------------------------------------------------------------------------------- /rten-bench/src/lib.rs: -------------------------------------------------------------------------------- 1 | use std::time::Instant; 2 | 3 | /// Statistics from a benchmark run. All fields are durations in milliseconds. 4 | #[derive(Default)] 5 | pub struct BenchStats { 6 | /// Duration of longest run. 7 | pub max: f32, 8 | 9 | /// Mean duration. 10 | pub mean: f32, 11 | 12 | /// Median duration. 13 | pub median: f32, 14 | 15 | /// Minimum duration. 16 | pub min: f32, 17 | 18 | /// Variance of durations. 19 | pub var: f32, 20 | } 21 | 22 | /// Run a benchmark function `f` for `trials` iterations and returns statistics. 23 | /// 24 | /// Prints the statistics itself if `description` is provided. 25 | pub fn run_bench(trials: usize, description: Option<&str>, mut f: F) -> BenchStats { 26 | if trials == 0 { 27 | return BenchStats::default(); 28 | } 29 | 30 | let mut times = Vec::with_capacity(trials); 31 | for _ in 0..trials { 32 | let start = Instant::now(); 33 | 34 | f(); 35 | 36 | let duration_ms = start.elapsed().as_secs_f64() * 1000.0; 37 | times.push(duration_ms as f32); 38 | } 39 | 40 | times.sort_by(|a, b| a.total_cmp(b)); 41 | let min = times.first().copied().unwrap(); 42 | let max = times.last().copied().unwrap(); 43 | 44 | let mid = times.len() / 2; 45 | let median = if times.len() % 2 == 1 { 46 | times[mid] 47 | } else { 48 | (times[mid] + times[mid + 1]) / 2. 49 | }; 50 | let mean = times.iter().sum::() / times.len() as f32; 51 | let var = times.iter().map(|x| (x - mean).abs()).sum::() / times.len() as f32; 52 | 53 | if let Some(description) = description { 54 | println!( 55 | "{}. mean {:.3}ms median {:.3} var {:.3} min {:.3} max {:.3}", 56 | description, mean, median, var, min, max 57 | ); 58 | } 59 | 60 | BenchStats { 61 | max, 62 | mean, 63 | median, 64 | min, 65 | var, 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /rten-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rten-cli" 3 | version = "0.18.0" 4 | edition = "2021" 5 | authors = ["Robert Knight"] 6 | description = "CLI tool for inspecting and running RTen models" 7 | license = "MIT OR Apache-2.0" 8 | homepage = "https://github.com/robertknight/rten" 9 | repository = "https://github.com/robertknight/rten" 10 | include = ["/src", "/README.md"] 11 | 12 | [dependencies] 13 | fastrand = "2.0.2" 14 | rten = { path = "../", version = "0.18.0", features=["mmap", "random"] } 15 | rten-tensor = { path = "../rten-tensor", version = "0.18.0" } 16 | lexopt = "0.3.0" 17 | 18 | [dev-dependencies] 19 | rten-testing = { path = "../rten-testing" } 20 | 21 | [features] 22 | # Use AVX-512 instructions if available. Requires nightly Rust for AVX-512 intrinsics. 23 | avx512 = ["rten/avx512"] 24 | 25 | [[bin]] 26 | name = "rten" 27 | path = "src/main.rs" 28 | -------------------------------------------------------------------------------- /rten-cli/README.md: -------------------------------------------------------------------------------- 1 | # rten-cli 2 | 3 | rten-cli is a CLI tool for inspecting RTen models and running them with 4 | randomly generated inputs. 5 | -------------------------------------------------------------------------------- /rten-convert/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 RTen project contributors. 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /rten-convert/Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: check 2 | check: checkformat lint typecheck 3 | 4 | .PHONY: checkformat 5 | checkformat: 6 | ruff format --check rten_convert 7 | 8 | .PHONY: format 9 | format: 10 | ruff format rten_convert 11 | 12 | .PHONY: lint 13 | lint: 14 | ruff check rten_convert 15 | 16 | .PHONY: typecheck 17 | typecheck: 18 | mypy rten_convert 19 | 20 | # See https://packaging.python.org/en/latest/tutorials/packaging-projects/#generating-distribution-archives 21 | .PHONY: release 22 | release: check 23 | rm -rf dist/ 24 | python -m build 25 | python -m twine upload dist/* 26 | -------------------------------------------------------------------------------- /rten-convert/README.md: -------------------------------------------------------------------------------- 1 | # rten-convert 2 | 3 | rten-convert converts ONNX models to `.rten` format, for use with the 4 | [RTen](https://github.com/robertknight/rten) machine learning runtime. 5 | 6 | ## Installation 7 | 8 | The conversion tool requires Python >= 3.10. To install the tool, run: 9 | 10 | ```sh 11 | pip install rten-convert 12 | ``` 13 | 14 | ## Usage 15 | 16 | ```sh 17 | rten-convert your-model.onnx your-model.rten 18 | ``` 19 | 20 | The second argument is optional. If omitted the output filename will be the 21 | input filename with the `.onnx` extension replaced with `.rten`. 22 | 23 | ## Versioning 24 | 25 | The `rten-convert` tool and `rten` library use common version numbering. A 26 | model produced by `rten-convert` version X can be executed by `rten` version X 27 | or newer. 28 | 29 | ## Development 30 | 31 | To install this tool from a checkout of the Git repository, run: 32 | 33 | ```sh 34 | pip install -e . 35 | ``` 36 | 37 | After making changes, run the QA checks. First, install the development 38 | dependencies: 39 | 40 | ``` 41 | pip install -r requirements.dev.txt 42 | ``` 43 | 44 | Then run: 45 | 46 | ``` 47 | make check 48 | ``` 49 | -------------------------------------------------------------------------------- /rten-convert/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "rten-convert" 3 | description = "Convert ONNX models to .rten format" 4 | requires-python = ">=3.10" 5 | version = "0.17.0" 6 | dependencies = ["flatbuffers", "onnx", "numpy"] 7 | readme = "README.md" 8 | classifiers = [ 9 | "License :: OSI Approved :: MIT License", 10 | ] 11 | 12 | [project.scripts] 13 | rten-convert = "rten_convert.converter:main" 14 | 15 | [project.urls] 16 | Homepage = "https://github.com/robertknight/rten" 17 | Issues = "https://github.com/robertknight/rten/issues" 18 | 19 | [build-system] 20 | requires = ["setuptools>=61.0"] 21 | build-backend = "setuptools.build_meta" 22 | 23 | [tool.ruff] 24 | exclude = ["rten_convert/schema_generated.py"] 25 | 26 | [[tool.mypy.overrides]] 27 | module = "rten_convert.schema_generated" 28 | disable_error_code = [ 29 | "annotation-unchecked", 30 | "import-untyped" # for flatbuffers 31 | ] 32 | 33 | [[tool.mypy.overrides]] 34 | module = "rten_convert.converter" 35 | disable_error_code = [ 36 | "import-untyped", # for flatbuffers 37 | ] 38 | -------------------------------------------------------------------------------- /rten-convert/requirements.dev.txt: -------------------------------------------------------------------------------- 1 | mypy==1.10.0 2 | ruff==0.4.3 3 | -------------------------------------------------------------------------------- /rten-convert/rten_convert/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertknight/rten/c646a5ac1bed69d0717fc37eb53e3ba290b89233/rten-convert/rten_convert/__init__.py -------------------------------------------------------------------------------- /rten-convert/rten_convert/errors.py: -------------------------------------------------------------------------------- 1 | """Errors reported during model conversion.""" 2 | 3 | 4 | class ConversionError(Exception): 5 | """Errors when converting ONNX models to .rten format.""" 6 | 7 | def __init__(self, message: str): 8 | super().__init__(message) 9 | 10 | 11 | class UnsupportedOperatorError(ConversionError): 12 | """Conversion failed because an operator is unsupported.""" 13 | 14 | op_type: str 15 | """The name of the unsupported operator, eg. `Conv`""" 16 | 17 | def __init__(self, op_type: str): 18 | self.op_type = op_type 19 | super().__init__(f'Unsupported operator "{op_type}"') 20 | -------------------------------------------------------------------------------- /rten-convert/rten_convert/graph.py: -------------------------------------------------------------------------------- 1 | """Types used in intermediate representation of parsed ONNX graphs. 2 | 3 | This module defines graph and node types produced by parsing ONNX files. These 4 | are in a format that is convenient for serialization into the FlatBuffers 5 | format used by .rten models. 6 | """ 7 | 8 | from typing import Any 9 | 10 | import numpy as np 11 | 12 | from rten_convert.errors import ConversionError 13 | 14 | 15 | class Node: 16 | """Base class for all graph nodes (constants, values, operators).""" 17 | 18 | def __init__(self, name: str): 19 | self.name = name 20 | 21 | 22 | class ConstantNode(Node): 23 | """ 24 | Data for a constant graph node. 25 | 26 | These are used for model weights, biases etc. 27 | """ 28 | 29 | shape: list[int] 30 | data: np.ndarray 31 | 32 | def __init__(self, name: str, shape: list[int], data: np.ndarray): 33 | super().__init__(name) 34 | self.shape = shape 35 | self.data = data 36 | 37 | shape_numel = np.prod(shape) 38 | if shape_numel != data.size: 39 | raise ConversionError( 40 | f'Shape {shape} product {shape_numel} does not match data length {data.size} in node "{name}"' 41 | ) 42 | 43 | # Verify that this is a data type that we'll be able to serialize later. 44 | match data.dtype: 45 | case np.float32 | np.int32 | np.int8 | np.uint8: 46 | pass 47 | case _: 48 | dtype_name: str = data.dtype.name # type:ignore[union-attr] 49 | raise ConversionError( 50 | f'Tried to construct ConstantNode "{name}" with unsupported data type {dtype_name}' 51 | ) 52 | 53 | def get_scalar(self): 54 | if self.shape != []: 55 | return None 56 | return self.data[0] 57 | 58 | 59 | class OperatorNode(Node): 60 | """ 61 | Data for an operator graph node. 62 | """ 63 | 64 | # RTen operator name. This should match the operator name in the FlatBuffers 65 | # schema. 66 | op_type: str 67 | 68 | attrs: Any 69 | """ 70 | Attributes object or None. 71 | 72 | This should be the operator-specific attributes object generated by flatc. 73 | eg. `sg.AveragePoolAttrsT` for the AveragePool op. 74 | """ 75 | 76 | inputs: list[int | None] 77 | outputs: list[int | None] 78 | 79 | def __init__( 80 | self, 81 | name: str, 82 | op_type: str, 83 | attrs: Any, 84 | inputs: list[int | None], 85 | outputs: list[int | None], 86 | ): 87 | super().__init__(name) 88 | self.op_type = op_type 89 | self.attrs = attrs 90 | self.inputs = inputs 91 | self.outputs = outputs 92 | 93 | 94 | class ValueNode(Node): 95 | """ 96 | Data for a value placeholder graph node. 97 | 98 | These are used for operator inputs and outputs. 99 | 100 | The shape can be missing, or a mix of fixed and symbolic (unknown at model 101 | export time) sizes. 102 | """ 103 | 104 | def __init__(self, name: str, shape: list[int | str] | None, dtype: int | None): 105 | """ 106 | Initialize a value node. 107 | 108 | :param name: Unique name of the value 109 | :param shape: Expected shape of tensor at runtime 110 | :param dtype: Expected data type of tensor at runtime. Value from `sg.DataType`. 111 | """ 112 | super().__init__(name) 113 | 114 | self.shape = shape 115 | self.dtype = dtype 116 | 117 | 118 | class Graph: 119 | nodes: list[Node] 120 | 121 | inputs: list[int] 122 | """Indices of nodes in `nodes` that are model inputs.""" 123 | 124 | outputs: list[int] 125 | """Indices of nodes in `nodes` that are model outputs.""" 126 | 127 | captures: list[int] | None 128 | """Indices of nodes in `nodes` that are captured from parent scopes at runtime.""" 129 | 130 | def __init__( 131 | self, 132 | nodes: list[Node], 133 | inputs: list[int], 134 | outputs: list[int], 135 | captures: list[int] | None = None, 136 | ): 137 | self.nodes = nodes 138 | self.inputs = inputs 139 | self.outputs = outputs 140 | self.captures = captures 141 | -------------------------------------------------------------------------------- /rten-convert/rten_convert/tensor_data.py: -------------------------------------------------------------------------------- 1 | from typing import BinaryIO 2 | 3 | import numpy as np 4 | 5 | from rten_convert.util import round_up, write_padding 6 | 7 | 8 | class TensorDataBuilder: 9 | offset: int 10 | """End offset of written data from start of tensor data.""" 11 | 12 | tensors: list[np.ndarray] 13 | """List of tensors to write.""" 14 | 15 | align: int 16 | """Alignment of each tensor's data, relative to the start of the tensor data.""" 17 | 18 | def __init__(self): 19 | self.offset = 0 20 | self.tensors = [] 21 | self.tensor_offsets = [] 22 | self.tensor_lengths = [] 23 | self.align = 64 24 | 25 | def add_tensor(self, array: np.ndarray, dtype=None) -> int: 26 | """ 27 | Add a tensor to be written to the tensor data segment. 28 | 29 | Returns the offset that the data will be stored at, relative to the 30 | start of the tensor data segment. 31 | """ 32 | self.tensors.append(array) 33 | 34 | match array.dtype: 35 | case np.float32 | np.int32: 36 | element_size = 4 37 | case np.int8 | np.uint8: 38 | element_size = 1 39 | case _: 40 | raise ValueError("Unsupported NumPy array type {}".format(array.dtype)) 41 | 42 | prev_offset = self.offset 43 | padding = round_up(prev_offset, self.align) - prev_offset 44 | tensor_len = array.size * element_size 45 | 46 | self.offset += padding 47 | self.tensor_offsets.append(self.offset) 48 | self.tensor_lengths.append(tensor_len) 49 | self.offset += tensor_len 50 | 51 | return self.tensor_offsets[-1] 52 | 53 | def write(self, fp: BinaryIO): 54 | """ 55 | Write out tensor data to a file. 56 | """ 57 | 58 | offset = 0 59 | 60 | for i, tensor in enumerate(self.tensors): 61 | expected_offset = self.tensor_offsets[i] 62 | padding = round_up(offset, self.align) - offset 63 | 64 | assert ( 65 | expected_offset == offset + padding 66 | ), f"actual offset {offset} of tensor {i} does not match expected offset {expected_offset}" 67 | 68 | write_padding(fp, padding) 69 | offset += padding 70 | 71 | tensor_data = tensor.tobytes() 72 | assert ( 73 | len(tensor_data) == self.tensor_lengths[i] 74 | ), f"actual length {len(tensor_data)} of tensor {i} does not match expected length {self.tensor_lengths[i]}" 75 | 76 | fp.write(tensor_data) 77 | offset += len(tensor_data) 78 | -------------------------------------------------------------------------------- /rten-convert/rten_convert/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import BinaryIO 3 | import sys 4 | 5 | 6 | def round_up(value: int, base: int) -> int: 7 | """Round up `value` to the next multiple of `base`.""" 8 | return base * math.ceil(value / base) 9 | 10 | 11 | def write_padding(fp: BinaryIO, n: int, max_padding=1024): 12 | """ 13 | Write `n` bytes of zero padding at the end of a file. 14 | 15 | :param max_padding: 16 | Maximum value for `n`. This is a sanity check to catch unexpectedly 17 | large padding sizes. 18 | """ 19 | 20 | if n < 0 or n >= max_padding: 21 | raise ValueError(f"Padding size {n} is out of range") 22 | 23 | if n == 0: 24 | return 25 | fp.write(b"\x00" * n) 26 | 27 | 28 | EMITTED_WARNINGS: set[str] = set() 29 | 30 | 31 | def warn_once(msg: str): 32 | """ 33 | Emit a warning if not already emitted. 34 | 35 | This is used to reduce output noise if the same problem arises many times 36 | when converting a model. 37 | """ 38 | if msg in EMITTED_WARNINGS: 39 | return 40 | EMITTED_WARNINGS.add(msg) 41 | print(f"WARNING: {msg}", file=sys.stderr) 42 | -------------------------------------------------------------------------------- /rten-examples/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rten-examples" 3 | version = "0.3.0" 4 | edition = "2021" 5 | authors = ["Robert Knight"] 6 | description = "Examples for using the rten library" 7 | license = "MIT OR Apache-2.0" 8 | homepage = "https://github.com/robertknight/rten" 9 | repository = "https://github.com/robertknight/rten" 10 | 11 | [dependencies] 12 | fastrand = "2.0.2" 13 | hound = "3.5.1" 14 | image = { workspace = true } 15 | lexopt = "0.3.0" 16 | microfft = { version = "0.6.0", default-features = false, features = ["size-512"] } 17 | png = "0.17.6" 18 | serde = { workspace = true, features = ["derive"] } 19 | serde_json = { workspace = true } 20 | rten = { path = "../", features = ["mmap", "random"] } 21 | rten-generate = { path = "../rten-generate", features=["text-decoder"] } 22 | rten-imageio = { path = "../rten-imageio" } 23 | rten-imageproc = { path = "../rten-imageproc" } 24 | rten-tensor = { path = "../rten-tensor", features=["serde"] } 25 | rten-text = { path = "../rten-text" } 26 | smallvec = "1.13.2" 27 | 28 | [features] 29 | # Use AVX-512 instructions if available. Requires nightly Rust for AVX-512 intrinsics. 30 | avx512 = ["rten/avx512"] 31 | 32 | [lints.clippy] 33 | # Allows use of `..Default::default()` for future compatibility even when not 34 | # currently needed. 35 | needless_update = "allow" 36 | manual_repeat_n = "allow" # TODO - Address existing failures 37 | 38 | [package.metadata.release] 39 | release = false 40 | 41 | # Vision 42 | [[bin]] 43 | name = "clip" 44 | path = "src/clip.rs" 45 | 46 | [[bin]] 47 | name = "deeplab" 48 | path = "src/deeplab.rs" 49 | test = false 50 | 51 | [[bin]] 52 | name = "detr" 53 | path = "src/detr.rs" 54 | test = false 55 | 56 | [[bin]] 57 | name = "distilvit" 58 | path = "src/distilvit.rs" 59 | test = false 60 | 61 | [[bin]] 62 | name = "imagenet" 63 | path = "src/imagenet.rs" 64 | test = false 65 | 66 | [[bin]] 67 | name = "nougat" 68 | path = "src/nougat.rs" 69 | test = false 70 | 71 | [[bin]] 72 | name = "yolo" 73 | path = "src/yolo.rs" 74 | test = false 75 | 76 | [[bin]] 77 | name = "depth_anything" 78 | path = "src/depth_anything.rs" 79 | test = false 80 | 81 | [[bin]] 82 | name = "rmbg" 83 | path = "src/rmbg.rs" 84 | test = false 85 | 86 | [[bin]] 87 | name = "segment_anything" 88 | path = "src/segment_anything.rs" 89 | test = false 90 | 91 | [[bin]] 92 | name = "trocr" 93 | path = "src/trocr.rs" 94 | test = false 95 | 96 | # Text 97 | [[bin]] 98 | name = "bert_qa" 99 | path = "src/bert_qa.rs" 100 | test = false 101 | 102 | [[bin]] 103 | name = "gpt2" 104 | path = "src/gpt2.rs" 105 | test = false 106 | 107 | [[bin]] 108 | name = "modernbert" 109 | path = "src/modernbert.rs" 110 | test = false 111 | 112 | [[bin]] 113 | name = "jina_similarity" 114 | path = "src/jina_similarity.rs" 115 | test = false 116 | 117 | [[bin]] 118 | name = "qwen2_chat" 119 | path = "src/qwen2_chat.rs" 120 | test = false 121 | 122 | # Audio 123 | [[bin]] 124 | name = "piper" 125 | path = "src/piper.rs" 126 | test = false 127 | 128 | [[bin]] 129 | name = "silero" 130 | path = "src/silero.rs" 131 | test = false 132 | 133 | [[bin]] 134 | name = "wav2vec2" 135 | path = "src/wav2vec2.rs" 136 | test = false 137 | 138 | [[bin]] 139 | name = "whisper" 140 | path = "src/whisper.rs" 141 | test = false 142 | -------------------------------------------------------------------------------- /rten-examples/data/README.md: -------------------------------------------------------------------------------- 1 | # Example support data 2 | 3 | This directory contains support files for examples, such as class label 4 | mappings. See sections below for notes on the data source, licenses etc. 5 | 6 | ## ImageNet labels 7 | 8 | The `imagenet*.txt` files were obtained from 9 | [pytorch-image-models](https://github.com/huggingface/pytorch-image-models/tree/main/timm/data/_info). 10 | 11 | ## COCO labels 12 | 13 | COLO labels for YOLO were obtained from 14 | https://github.com/pjreddie/darknet/blob/master/data/coco.names. 15 | 16 | -------------------------------------------------------------------------------- /rten-examples/data/coco.names: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /rten-examples/data/dump_mel_filters.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | import librosa.filters 5 | 6 | 7 | def ndarray_to_dict(array): 8 | """ 9 | Return a JSON-serializable representation of an ndarray. 10 | 11 | This representation is compatible with rten-tensor's serde deserialization. 12 | """ 13 | return { 14 | "shape": array.shape, 15 | "data": array.flatten().tolist(), 16 | } 17 | 18 | 19 | # Generate mel filter matrices using the same method as Whisper's original 20 | # preprocessing code and export them to JSON. 21 | # 22 | # See https://github.com/openai/whisper/blob/25639fc17ddc013d56c594bfbf7644f2185fad84/whisper/audio.py#L92 23 | # 24 | # Most of the entries are zero so the output could be made smaller by 25 | # representing it as an (index, value) array. However, the non-sparse 26 | # representation is pretty small when compressed. 27 | mel_80 = librosa.filters.mel(sr=16_000, n_fft=400, n_mels=80) 28 | mel_128 = librosa.filters.mel(sr=16_000, n_fft=400, n_mels=128) 29 | data = { 30 | "_note": "Generated with dump_mel_filters.py", 31 | "mel_80": ndarray_to_dict(mel_80), 32 | "mel_128": ndarray_to_dict(mel_128), 33 | } 34 | with open("mel_filters.json", "w") as fp: 35 | json.dump(data, fp) 36 | -------------------------------------------------------------------------------- /rten-examples/data/rust-questions.txt: -------------------------------------------------------------------------------- 1 | Why can't I store a value and a reference to that value in the same struct? 2 | Why is it discouraged to accept a reference &String, &Vec, or &Box as a function argument? 3 | Is there any way to return a reference to a variable created in a function? 4 | How do I create a global, mutable singleton? 5 | What are non-lexical lifetimes? 6 | What are Rust's exact auto-dereferencing rules? 7 | How can I pass a reference to a stack variable to a thread? 8 | Return local String as a slice (&str) 9 | What are the differences between Rust's `String` and `str`? 10 | How do I implement a trait I don't own for a type I don't own? 11 | Returning a reference from a HashMap or Vec causes a borrow to last beyond the scope it's in? 12 | Why does my string not match when reading user input from stdin? 13 | What is the correct way to return an Iterator (or any other trait)? 14 | Why can I return a reference to a local literal but not a variable? 15 | What is the difference between iter and into_iter? 16 | How to get mutable references to two array elements at the same time? 17 | How do I write an iterator that returns references to itself? 18 | How to lookup from and insert into a HashMap efficiently? 19 | How do I create a heterogeneous collection of objects? 20 | Why is it legal to borrow a temporary? 21 | What's the difference between placing "mut" before a variable name and after the ":"? 22 | Why doesn't Rust support trait object upcasting? 23 | How to get a reference to a concrete type from a trait object? 24 | "Expected type parameter" error in the constructor of a generic struct 25 | How can I swap in a new value for a field in a mutable reference to a structure? 26 | Why is a trait not implemented for a type that clearly has it implemented? 27 | Why do try!() and ? not compile when used in a function that doesn't return Option or Result? 28 | The compiler suggests I add a 'static lifetime because the parameter type may not live long enough, but I don't think that's what I want 29 | How do I print in Rust the type of a variable? 30 | How does "for<>" syntax differ from a regular lifetime bound? 31 | How do I return a reference to something inside a RefCell without breaking encapsulation? 32 | Cannot obtain a mutable reference when iterating a recursive structure: cannot borrow as mutable more than once at a time 33 | Why are explicit lifetimes needed in Rust? 34 | Cannot move out of borrowed content / cannot move out of behind a shared reference 35 | How can I create my own data structure with an iterator that returns mutable references? 36 | How can you make a safe static singleton in Rust? 37 | Why is this match pattern unreachable when using non-literal patterns? 38 | Cannot move out of value which is behind a shared reference when unwrapping 39 | Is it possible to control the size of an array using the type parameter of a generic? 40 | How can I include a module from another file from the same project? 41 | How to clone a struct storing a boxed trait object? 42 | Should trait bounds be duplicated in struct and impl? 43 | Conditionally iterate over one of several possible iterators 44 | How do I import from a sibling module? 45 | How do I synchronously return a value calculated in an asynchronous Future? 46 | How is there a conflicting implementation of `From` when using a generic type? 47 | Do mutable references have move semantics? 48 | How do I require a generic type implement an operation like Add, Sub, Mul, or Div in a generic function? 49 | What does "Sized is not implemented" mean? 50 | How do I express mutually recursive data structures in safe Rust? 51 | -------------------------------------------------------------------------------- /rten-examples/src/bert_qa_reference.py: -------------------------------------------------------------------------------- 1 | # Reference implementation of BERT extractive question answering using 2 | # Hugging Face Transformers. 3 | 4 | from argparse import ArgumentParser 5 | import time 6 | 7 | import torch 8 | import torch.nn as nn 9 | from transformers import AutoTokenizer, AutoModelForQuestionAnswering 10 | from transformers import pipeline 11 | 12 | 13 | # Run inference on a question answering model using a Hugging Face transformers 14 | # pipeline. 15 | # 16 | # See https://huggingface.co/docs/transformers/v4.35.2/en/tasks/question_answering#inference. 17 | def eval_qa_model(model_name: str, context: str, question: str): 18 | """ 19 | :param model_name: Name of Hugging Face model trained for question answering 20 | :param context: Context to search for answer to question 21 | :param question: Question to answer 22 | """ 23 | 24 | tokenizer = AutoTokenizer.from_pretrained(model_name) 25 | model = AutoModelForQuestionAnswering.from_pretrained(model_name) 26 | oracle = pipeline(task="question-answering", model=model, tokenizer=tokenizer) 27 | 28 | start = time.perf_counter() 29 | result = oracle(question=question, context=context) 30 | end = time.perf_counter() 31 | print(f"Result from `tokenizers` pipeline in {end-start:.2f}s:", result) 32 | 33 | 34 | parser = ArgumentParser( 35 | description=""" 36 | Perform extractive question answering using BERT. 37 | 38 | This is a reference implementation using Hugging Face Transformers. 39 | """ 40 | ) 41 | parser.add_argument("context", help="Path to text file containing context") 42 | parser.add_argument("question", help="Question to answer") 43 | parser.add_argument( 44 | "--model", 45 | help="Name of the Hugging Face model", 46 | # For more models, search for "bert squad" on HF: 47 | # https://huggingface.co/models?pipeline_tag=question-answering&sort=downloads&search=bert+squad 48 | default="deepset/bert-base-cased-squad2", 49 | ) 50 | args = parser.parse_args() 51 | 52 | with open(args.context) as context_fp: 53 | context = context_fp.read() 54 | 55 | eval_qa_model(args.model, context, args.question) 56 | -------------------------------------------------------------------------------- /rten-examples/src/clip_reference.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from PIL import Image 4 | from transformers import CLIPProcessor, CLIPModel 5 | 6 | parser = ArgumentParser(description="Reference implementation for the CLIP example.") 7 | parser.add_argument("-i", "--image", type=str, action="append", help="Path to image") 8 | parser.add_argument("-c", "--caption", type=str, action="append", help="Text caption") 9 | parser.add_argument("-t", "--tokens", action="store_true", help="Print text token IDs") 10 | args = parser.parse_args() 11 | 12 | model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 13 | processor = CLIPProcessor.from_pretrained( 14 | "openai/clip-vit-base-patch32", clean_up_tokenization_spaces=True 15 | ) 16 | 17 | images = [Image.open(img_path) for img_path in args.image] 18 | 19 | inputs = processor(text=args.caption, images=images, return_tensors="pt", padding=True) 20 | if args.tokens: 21 | print("Tokens", inputs["input_ids"]) 22 | 23 | outputs = model(**inputs) 24 | logits_per_image = outputs.logits_per_image # this is the image-text similarity score 25 | probs = logits_per_image.softmax(dim=1) 26 | 27 | for img_idx, img_path in enumerate(args.image): 28 | for cap_idx, caption in enumerate(args.caption): 29 | prob = probs[img_idx, cap_idx] 30 | print(f'image "{img_path}" caption "{caption}" probability {prob:.2f}') 31 | -------------------------------------------------------------------------------- /rten-examples/src/deeplab_reference.py: -------------------------------------------------------------------------------- 1 | # Reference inference for DeepLab example using ONNX Runtime. 2 | # 3 | # To use this, first export the DeepLab model then run inference: 4 | # 5 | # ``` 6 | # python export-deeplab.py 7 | # python deeplab_reference.py deeplab.onnx path/to/test_image.jpeg 8 | # ``` 9 | # 10 | # This will produce an `out_reference.png` image containing the segmentation map. 11 | from argparse import ArgumentParser 12 | 13 | from PIL import Image 14 | import numpy as np 15 | import onnxruntime 16 | 17 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 18 | IMAGENET_STD_DEV = [0.229, 0.224, 0.225] 19 | 20 | # Labels and colors for the different categories of object that DeepLabv3 can 21 | # detect. 22 | # 23 | # For the labels, see https://github.com/NVIDIA/DIGITS/blob/master/examples/semantic-segmentation/pascal-voc-classes.txt. 24 | PASCAL_VOC_LABELS = [ 25 | ("background", (0.0, 0.0, 0.0)), # Black 26 | ("aeroplane", (0.0, 1.0, 0.0)), # Green 27 | ("bicycle", (0.0, 0.0, 1.0)), # Blue 28 | ("bird", (1.0, 1.0, 0.0)), # Yellow 29 | ("boat", (1.0, 0.0, 1.0)), # Magenta 30 | ("bottle", (0.0, 1.0, 1.0)), # Cyan 31 | ("bus", (0.5, 0.0, 0.0)), # Dark Red 32 | ("car", (0.0, 0.5, 0.0)), # Dark Green 33 | ("cat", (0.0, 0.0, 0.5)), # Dark Blue 34 | ("chair", (0.5, 0.5, 0.0)), # Olive 35 | ("cow", (0.5, 0.0, 0.5)), # Purple 36 | ("diningtable", (0.0, 0.5, 0.5)), # Teal 37 | ("dog", (0.75, 0.75, 0.75)), # Light Gray 38 | ("horse", (0.5, 0.5, 0.5)), # Gray 39 | ("motorbike", (0.25, 0.25, 0.25)), # Dark Gray 40 | ("person", (1.0, 0.5, 0.0)), # Orange 41 | ("pottedplant", (0.5, 1.0, 0.5)), # Pastel Green 42 | ("sheep", (0.5, 0.5, 1.0)), # Pastel Blue 43 | ("sofa", (1.0, 0.75, 0.8)), # Pink 44 | ("train", (0.64, 0.16, 0.16)), # Brown 45 | ("tvmonitor", (1.0, 1.0, 1.0)), # White 46 | ] 47 | 48 | parser = ArgumentParser() 49 | parser.add_argument("model", help="Path to DeepLab ONNX model") 50 | parser.add_argument("image", help="Image to segment") 51 | args = parser.parse_args() 52 | 53 | session = onnxruntime.InferenceSession(args.model) 54 | 55 | # Input image size expected by model 56 | input_width = 693 57 | input_height = 520 58 | 59 | # Load image, normalize and convert to NHWC layout 60 | image = Image.open(args.image) 61 | image = image.resize([input_width, input_height]) 62 | image = np.asarray(image).astype("float32") / 255.0 63 | image = np.transpose(image, (2, 0, 1)) # HWC => CHW 64 | 65 | norm_mean = np.array(IMAGENET_MEAN, dtype="float32").reshape(-1, 1, 1) 66 | norm_std_dev = np.array(IMAGENET_STD_DEV, dtype="float32").reshape(-1, 1, 1) 67 | image = (image - norm_mean) / norm_std_dev 68 | image = np.expand_dims(image, axis=0) # Insert batch dim 69 | 70 | # Segment image, producing an HW tensor containing the class index for each pixel. 71 | seg_classes = session.run(["output"], {"input": image})[0] 72 | seg_classes = np.transpose(seg_classes, (0, 2, 3, 1)) # (N,class,H,W) => (N,H,W,class) 73 | seg_classes = np.argmax(seg_classes[0], axis=-1) 74 | 75 | # Produce a segmentation map with pixels colored based on predicted class for 76 | # each pixel. 77 | out_height, out_width = seg_classes.shape 78 | seg_map = np.zeros((out_height, out_width, 3), dtype="float32") 79 | for cls_id, cls_info in enumerate(PASCAL_VOC_LABELS): 80 | cls_name, cls_color = cls_info 81 | cls_mask = seg_classes == cls_id 82 | for chan in range(3): 83 | seg_map[cls_mask, chan] = cls_color[chan] 84 | 85 | out_im = Image.fromarray(np.uint8(seg_map * 255)) 86 | out_im.save("out_reference.png") 87 | -------------------------------------------------------------------------------- /rten-examples/src/depth_anything.rs: -------------------------------------------------------------------------------- 1 | use std::collections::VecDeque; 2 | use std::error::Error; 3 | 4 | use rten::{FloatOperators, Model, Operators}; 5 | use rten_imageio::{read_image, write_image}; 6 | use rten_imageproc::{normalize_image, IMAGENET_MEAN, IMAGENET_STD_DEV}; 7 | use rten_tensor::prelude::*; 8 | use rten_tensor::Tensor; 9 | 10 | struct Args { 11 | model: String, 12 | image: String, 13 | output: String, 14 | } 15 | 16 | fn parse_args() -> Result { 17 | use lexopt::prelude::*; 18 | 19 | let mut values = VecDeque::new(); 20 | let mut parser = lexopt::Parser::from_env(); 21 | 22 | while let Some(arg) = parser.next()? { 23 | match arg { 24 | Value(val) => values.push_back(val.string()?), 25 | Long("help") => { 26 | println!( 27 | "Perform monocular depth estimation on an image. 28 | 29 | Usage: {bin_name} [] 30 | 31 | Args: 32 | 33 | - Input Depth Anything model 34 | - Image to process 35 | - Path to save depth image to. Defaults to \"depth-map.png\". 36 | ", 37 | bin_name = parser.bin_name().unwrap_or("depth_anything") 38 | ); 39 | std::process::exit(0); 40 | } 41 | _ => return Err(arg.unexpected()), 42 | } 43 | } 44 | 45 | let model = values.pop_front().ok_or("missing `model` arg")?; 46 | let image = values.pop_front().ok_or("missing `image` arg")?; 47 | let output = values.pop_front().unwrap_or("depth-map.png".into()); 48 | 49 | let args = Args { 50 | image, 51 | model, 52 | output, 53 | }; 54 | 55 | Ok(args) 56 | } 57 | 58 | /// Perform monocular depth estimation using [Depth Anything][depth_anything]. 59 | /// 60 | /// The ONNX models can be obtained from 61 | /// https://github.com/fabio-sim/Depth-Anything-ONNX. See the 62 | /// [releases](https://github.com/fabio-sim/Depth-Anything-ONNX/releases) page 63 | /// for pre-trained model links. The small ("vits") model is recommended for 64 | /// CPU inference. 65 | /// 66 | /// After downloading the model, it can be run on an image using: 67 | /// 68 | /// ``` 69 | /// rten-convert depth_anything.onnx 70 | /// cargo run --release --bin depth_anything depth_anything.rten image.jpg 71 | /// ``` 72 | /// 73 | /// This will generate a depth map as `depth-map.png`. 74 | /// 75 | /// [depth_anything]: 76 | fn main() -> Result<(), Box> { 77 | let args = parse_args()?; 78 | let model = Model::load_file(args.model)?; 79 | 80 | let mut image: Tensor = read_image(&args.image)?.into(); 81 | let [_, orig_height, orig_width] = image.shape().try_into()?; 82 | normalize_image(image.nd_view_mut(), IMAGENET_MEAN, IMAGENET_STD_DEV); 83 | image.insert_axis(0); // Add batch dim 84 | 85 | // Input size taken from README in https://github.com/fabio-sim/Depth-Anything-ONNX. 86 | let [input_h, input_w] = [518, 518]; 87 | let image = image.resize_image([input_h, input_w])?; 88 | 89 | // Run model to estimate depth for each pixel. 90 | // 91 | // Depending on the model variant used, the output will be either 92 | // 3D (batch, height, width) or 4D (batch, 1, height, width). 93 | let mut output: Tensor = model.run_one(image.view().into(), None)?.try_into()?; 94 | if output.ndim() == 3 { 95 | output.insert_axis(1); // Add channel dim 96 | } 97 | 98 | // Normalize depth values to be in the range [0, 1]. 99 | let min = output 100 | .reduce_min(None, false /* keep_dims */)? 101 | .item() 102 | .copied() 103 | .unwrap(); 104 | let max = output 105 | .reduce_max(None, false /* keep_dims */)? 106 | .item() 107 | .copied() 108 | .unwrap(); 109 | output.apply(|x| (x - min) / (max - min)); 110 | 111 | // Resize output map back to original input size and write to file. 112 | let resized = output.resize_image([orig_height, orig_width])?; 113 | let resized = resized.nd_view::<4>().slice(0); 114 | write_image(&args.output, resized)?; 115 | 116 | Ok(()) 117 | } 118 | -------------------------------------------------------------------------------- /rten-examples/src/distilvit.rs: -------------------------------------------------------------------------------- 1 | use std::collections::VecDeque; 2 | use std::error::Error; 3 | use std::io::prelude::*; 4 | 5 | use rten::{FloatOperators, Model}; 6 | use rten_generate::{Generator, GeneratorUtils}; 7 | use rten_imageio::read_image; 8 | use rten_tensor::prelude::*; 9 | use rten_tensor::NdTensor; 10 | use rten_text::Tokenizer; 11 | 12 | struct Args { 13 | encoder_model: String, 14 | decoder_model: String, 15 | tokenizer_config: String, 16 | image_path: String, 17 | } 18 | 19 | fn parse_args() -> Result { 20 | use lexopt::prelude::*; 21 | 22 | let mut values = VecDeque::new(); 23 | let mut parser = lexopt::Parser::from_env(); 24 | 25 | while let Some(arg) = parser.next()? { 26 | match arg { 27 | Value(val) => values.push_back(val.string()?), 28 | Long("help") => { 29 | println!( 30 | "Generate a caption for an image. 31 | 32 | Usage: {bin_name} [options] 33 | 34 | Args: 35 | 36 | - Image encoder model 37 | - Text decoder model 38 | - `tokenizer.json` file 39 | - Image path 40 | ", 41 | bin_name = parser.bin_name().unwrap_or("distilvit") 42 | ); 43 | std::process::exit(0); 44 | } 45 | _ => return Err(arg.unexpected()), 46 | } 47 | } 48 | 49 | let encoder_model = values.pop_front().ok_or("missing `encoder_model` arg")?; 50 | let decoder_model = values.pop_front().ok_or("missing `decoder_model` arg")?; 51 | let tokenizer_config = values.pop_front().ok_or("missing `tokenizer` arg")?; 52 | let image_path = values.pop_front().ok_or("missing `image_path` arg")?; 53 | 54 | let args = Args { 55 | encoder_model, 56 | decoder_model, 57 | tokenizer_config, 58 | image_path, 59 | }; 60 | 61 | Ok(args) 62 | } 63 | 64 | /// Generates captions for an image using Mozilla's DistilViT. 65 | /// 66 | /// 1. Download the `onnx/encoder.onnx` and `onnx/decoder_with_past.onnx` ONNX 67 | /// models from https://huggingface.co/Mozilla/distilvit/tree/main, as well 68 | /// as the `tokenizer.json` file. 69 | /// 2. Convert the models 70 | /// 71 | /// ```sh 72 | /// rten-convert encoder_model.onnx 73 | /// rten-convert decoder_model_with_past.onnx 74 | /// ``` 75 | /// 76 | /// 3. Run the converted model, specifying the image to caption: 77 | /// 78 | /// ```sh 79 | /// cargo run --release --bin distilvit encoder_model.rten decoder_model.rten tokenizer.json 80 | /// ``` 81 | fn main() -> Result<(), Box> { 82 | let args = parse_args()?; 83 | let encoder_model = Model::load_file(args.encoder_model)?; 84 | let decoder_model = Model::load_file(args.decoder_model)?; 85 | let tokenizer = Tokenizer::from_file(&args.tokenizer_config)?; 86 | let mut image = read_image(args.image_path)?.into_dyn(); 87 | image.insert_axis(0); // Add batch dim 88 | let image = image.resize_image([224, 224])?; 89 | 90 | let encoded_image: NdTensor = encoder_model 91 | .run_one(image.view().into(), None)? 92 | .try_into()?; 93 | 94 | let encoder_hidden_states_id = decoder_model.node_id("encoder_hidden_states")?; 95 | 96 | // `decoder_start_token_id` value from 97 | // https://huggingface.co/Mozilla/distilvit/blob/main/config.json. 98 | let bos_token = 50256; 99 | let eos_token = bos_token; 100 | 101 | // Taken from https://github.com/mozilla/distilvit/blob/9c301fd5ba1f62ab407ca0a342642666a1ec13c5/distilvit/infere.py#L45 102 | let max_tokens = 40; 103 | 104 | let prompt = vec![bos_token]; 105 | let generator = Generator::from_model(&decoder_model)? 106 | .with_prompt(&prompt) 107 | .with_constant_input(encoder_hidden_states_id, encoded_image.view().into()) 108 | .stop_on_tokens([eos_token]) 109 | .take(max_tokens) 110 | .decode(&tokenizer); 111 | 112 | for token in generator { 113 | let token = token?; 114 | 115 | print!("{}", token); 116 | let _ = std::io::stdout().flush(); 117 | } 118 | 119 | Ok(()) 120 | } 121 | -------------------------------------------------------------------------------- /rten-examples/src/export-deeplab.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from torchvision.models.segmentation import ( 5 | deeplabv3_mobilenet_v3_large, 6 | DeepLabV3_MobileNet_V3_Large_Weights, 7 | ) 8 | 9 | # Load model 10 | weights = DeepLabV3_MobileNet_V3_Large_Weights.DEFAULT 11 | model = deeplabv3_mobilenet_v3_large(weights=weights) 12 | model.eval() 13 | 14 | # Load transforms. These will resize the input to what the model expects. 15 | preprocess = weights.transforms() 16 | 17 | # Generate a random input and resize it. 18 | img = torch.rand((3, 480, 640)) 19 | batch = preprocess(img).unsqueeze(0) 20 | 21 | parser = ArgumentParser() 22 | parser.add_argument("-f", "--filename", default="deeplab.onnx") 23 | parser.add_argument( 24 | "--dynamo", action="store_true", help="Use TorchDynamo-based exporter" 25 | ) 26 | args = parser.parse_args() 27 | 28 | if args.dynamo: 29 | print("Exporting model using TorchDynamo...") 30 | onnx_prog = torch.onnx.export( 31 | model, 32 | args=(batch), 33 | verbose=False, 34 | input_names=["input"], 35 | output_names=["output"], 36 | dynamo=True, 37 | ) 38 | onnx_prog.optimize() 39 | onnx_prog.save(args.filename) 40 | else: 41 | print("Exporting model using TorchScript...") 42 | torch.onnx.export( 43 | model, 44 | args=(batch), 45 | f=args.filename, 46 | verbose=False, 47 | input_names=["input"], 48 | output_names=["output"], 49 | ) 50 | -------------------------------------------------------------------------------- /rten-examples/src/gpt2.rs: -------------------------------------------------------------------------------- 1 | use std::collections::VecDeque; 2 | use std::error::Error; 3 | use std::io::prelude::*; 4 | 5 | use rten::Model; 6 | use rten_generate::metrics::Metrics; 7 | use rten_generate::sampler::TopKSampler; 8 | use rten_generate::{Generator, GeneratorUtils}; 9 | use rten_text::Tokenizer; 10 | 11 | struct Args { 12 | model: String, 13 | tokenizer_config: String, 14 | prompt: String, 15 | output_length: usize, 16 | top_k: usize, 17 | } 18 | 19 | fn parse_args() -> Result { 20 | use lexopt::prelude::*; 21 | 22 | let mut values = VecDeque::new(); 23 | let mut parser = lexopt::Parser::from_env(); 24 | let mut output_length = 30; 25 | let mut top_k = 50; 26 | 27 | while let Some(arg) = parser.next()? { 28 | match arg { 29 | Short('l') | Long("length") => { 30 | output_length = parser.value()?.parse()?; 31 | } 32 | Short('k') | Long("top-k") => { 33 | top_k = parser.value()?.parse()?; 34 | } 35 | Value(val) => values.push_back(val.string()?), 36 | Long("help") => { 37 | println!( 38 | "Generate text using a prompt. 39 | 40 | Usage: {bin_name} [options] 41 | 42 | Args: 43 | 44 | - Input GPT-2 model 45 | - `tokenizer.json` file 46 | - Text generation prompt 47 | 48 | Options: 49 | 50 | -l, --length N - Set max output length (in tokens) 51 | 52 | -k, --top-k K - Sample from top `K` tokens at each step. 53 | ", 54 | bin_name = parser.bin_name().unwrap_or("gpt2") 55 | ); 56 | std::process::exit(0); 57 | } 58 | _ => return Err(arg.unexpected()), 59 | } 60 | } 61 | 62 | let model = values.pop_front().ok_or("missing `model` arg")?; 63 | let tokenizer_config = values.pop_front().ok_or("missing `tokenizer` arg")?; 64 | let prompt = values.make_contiguous().join(" "); 65 | 66 | let args = Args { 67 | model, 68 | tokenizer_config, 69 | prompt, 70 | output_length, 71 | top_k, 72 | }; 73 | 74 | Ok(args) 75 | } 76 | 77 | /// Generates text using GPT-2 [1] and a prompt. 78 | /// 79 | /// To obtain the model from Hugging Face, use Optimum [2], then convert it: 80 | /// 81 | /// ```sh 82 | /// optimum-cli export onnx --model gpt2 gpt2_onnx/ 83 | /// rten-convert gpt2_onnx/model.onnx 84 | /// ``` 85 | /// 86 | /// Run the converted model with a prompt: 87 | /// 88 | /// ```sh 89 | /// cargo run --release --bin gpt2 gpt2_onnx/model.rten gp2_onnx/tokenizer.json 90 | /// ``` 91 | /// 92 | /// Where `` is the start of a sentence that the model should complete. 93 | /// 94 | /// [1] https://openai.com/research/better-language-models 95 | /// [2] https://huggingface.co/docs/optimum/index 96 | fn main() -> Result<(), Box> { 97 | let args = parse_args()?; 98 | let model = Model::load_file(args.model)?; 99 | let tokenizer = Tokenizer::from_file(&args.tokenizer_config)?; 100 | 101 | let prompt = args.prompt.as_str(); 102 | let encoded_prompt = tokenizer.encode(prompt, None)?; 103 | 104 | // The output starts with the user's prompt. 105 | print!("{}", prompt); 106 | 107 | let mut metrics = Metrics::new(); 108 | let temperature = 1.0; 109 | let generator = Generator::from_model(&model)? 110 | .with_prompt(encoded_prompt.token_ids()) 111 | .with_sampler(TopKSampler::new(args.top_k, temperature)) 112 | .take(args.output_length) 113 | .profile(&mut metrics) 114 | .decode(&tokenizer); 115 | 116 | for token in generator { 117 | let token = token?; 118 | print!("{}", token); 119 | let _ = std::io::stdout().flush(); 120 | } 121 | println!(); 122 | 123 | println!( 124 | "Metrics: {:.2}s total, {:.2}s warmup, {:.2} tokens/sec, {:.2} ms/token.", 125 | metrics.total_duration().as_secs_f32(), 126 | metrics 127 | .warmup_duration() 128 | .map(|dur| dur.as_secs_f32()) 129 | .unwrap_or(0.), 130 | metrics.tokens_per_second(), 131 | metrics.mean_duration() 132 | ); 133 | 134 | Ok(()) 135 | } 136 | -------------------------------------------------------------------------------- /rten-examples/src/gpt2_reference.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from transformers import pipeline, set_seed 3 | 4 | 5 | def main(): 6 | parser = ArgumentParser(description="Generate text using GPT-2 and a prompt") 7 | parser.add_argument("prompt", nargs="*") 8 | parser.add_argument("--seed", type=int, help="Random seed") 9 | args = parser.parse_args() 10 | 11 | prompt = " ".join(args.prompt) 12 | if args.seed is not None: 13 | set_seed(args.seed) 14 | 15 | print(f'prompt: "{prompt}"') 16 | generator = pipeline("text-generation", model="gpt2") 17 | 18 | sequences = generator(prompt, max_length=30, num_return_sequences=1, do_sample=False) 19 | for seq in sequences: 20 | print(seq) 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /rten-examples/src/jina_similarity_reference.py: -------------------------------------------------------------------------------- 1 | # Reference implementation of sentence similarity estimation using a BERT 2 | # embedding model. Adapted from example on 3 | # https://huggingface.co/jinaai/jina-embeddings-v2-small-en 4 | 5 | from argparse import ArgumentParser 6 | 7 | from numpy.linalg import norm 8 | import torch 9 | import torch.nn.functional as F 10 | from transformers import AutoModel, AutoTokenizer 11 | 12 | 13 | def mean_pooling(model_output, attention_mask): 14 | token_embeddings = model_output[0] 15 | input_mask_expanded = ( 16 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 17 | ) 18 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( 19 | input_mask_expanded.sum(1), min=1e-9 20 | ) 21 | 22 | 23 | def main(): 24 | parser = ArgumentParser() 25 | parser.add_argument("first_sentence") 26 | parser.add_argument("second_sentence") 27 | args = parser.parse_args() 28 | 29 | cos_sim = lambda a, b: a.dot(b) / (norm(a) * norm(b)) 30 | model_name = "jinaai/jina-embeddings-v2-small-en" 31 | tokenizer = AutoTokenizer.from_pretrained(model_name) 32 | model = AutoModel.from_pretrained( 33 | model_name, trust_remote_code=True 34 | ) # trust_remote_code is needed to use the encode method 35 | 36 | sentences = [args.first_sentence, args.second_sentence] 37 | encoded_input = tokenizer( 38 | sentences, padding=True, truncation=True, return_tensors="pt" 39 | ) 40 | 41 | with torch.no_grad(): 42 | model_output = model(**encoded_input) 43 | 44 | embeddings = mean_pooling(model_output, encoded_input["attention_mask"]) 45 | 46 | # FIXME - Is this needed 47 | embeddings = F.normalize(embeddings, p=2, dim=1) 48 | 49 | similarity = cos_sim(embeddings[0], embeddings[1]) 50 | 51 | print(f'First sentence: "{args.first_sentence}"') 52 | print(f'Second sentence: "{args.second_sentence}"') 53 | print(f"Similarity: {similarity}") 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /rten-examples/src/modernbert_reference.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from transformers import AutoTokenizer, AutoModelForMaskedLM 4 | 5 | parser = ArgumentParser( 6 | description="Replace [MASK] tokens in the input text with model predictions." 7 | ) 8 | parser.add_argument( 9 | "-m", 10 | "--model", 11 | type=str, 12 | help="Name of the model to use", 13 | 14 | # nb. If you get an error that this model is not supported, see 15 | # https://huggingface.co/answerdotai/ModernBERT-base/discussions/3. 16 | # 17 | # You can also use an older BERT model such as "bert-base-uncased". 18 | default="answerdotai/ModernBERT-base", 19 | ) 20 | parser.add_argument("text", type=str, help="Input text containing [MASK] tokens") 21 | args = parser.parse_args() 22 | 23 | tokenizer = AutoTokenizer.from_pretrained(args.model) 24 | model = AutoModelForMaskedLM.from_pretrained(args.model) 25 | inputs = tokenizer(args.text, return_tensors="pt") 26 | 27 | # Print input and output token IDs to enable comparison against RTen's 28 | # tokenization and model output. 29 | input_ids = inputs["input_ids"][0].tolist() 30 | print("Input IDs:", input_ids) 31 | 32 | outputs = model(**inputs) 33 | 34 | raw_output_ids = outputs.logits[0].argmax(axis=-1) 35 | print("Output IDs:", raw_output_ids.tolist()) 36 | 37 | # Keep only the output IDs for positions where the input contained a mask token. 38 | output_ids = input_ids.copy() 39 | for pos in range(len(output_ids)): 40 | if output_ids[pos] == tokenizer.mask_token_id: 41 | output_ids[pos] = raw_output_ids[pos] 42 | 43 | predicted_text = tokenizer.decode(output_ids, skip_special_tokens=True) 44 | print(predicted_text) 45 | -------------------------------------------------------------------------------- /rten-examples/src/rmbg.rs: -------------------------------------------------------------------------------- 1 | use std::collections::VecDeque; 2 | use std::error::Error; 3 | 4 | use rten::{FloatOperators, Model}; 5 | use rten_imageio::{read_image, write_image}; 6 | use rten_imageproc::normalize_image; 7 | use rten_tensor::prelude::*; 8 | use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut}; 9 | 10 | struct Args { 11 | model: String, 12 | image: String, 13 | output: String, 14 | } 15 | 16 | fn parse_args() -> Result { 17 | use lexopt::prelude::*; 18 | 19 | let mut values = VecDeque::new(); 20 | let mut parser = lexopt::Parser::from_env(); 21 | 22 | while let Some(arg) = parser.next()? { 23 | match arg { 24 | Value(val) => values.push_back(val.string()?), 25 | Long("help") => { 26 | println!( 27 | "Perform background removal on an image. 28 | 29 | Usage: {bin_name} [] 30 | 31 | Args: 32 | 33 | - Background removal model 34 | - Image to process 35 | - Path to save image to. Defaults to \"output.png\". 36 | ", 37 | bin_name = parser.bin_name().unwrap_or("rmbg") 38 | ); 39 | std::process::exit(0); 40 | } 41 | _ => return Err(arg.unexpected()), 42 | } 43 | } 44 | 45 | let model = values.pop_front().ok_or("missing `model` arg")?; 46 | let image = values.pop_front().ok_or("missing `image` arg")?; 47 | let output = values.pop_front().unwrap_or("output.png".into()); 48 | 49 | let args = Args { 50 | image, 51 | model, 52 | output, 53 | }; 54 | 55 | Ok(args) 56 | } 57 | 58 | /// Fill a CHW image with `color` using a mask. 59 | fn fill_mask(mut image: NdTensorViewMut, mask: NdTensorView, color: [f32; 3]) { 60 | let [chans, rows, cols] = image.shape(); 61 | assert_eq!(chans, 3); 62 | for y in 0..rows { 63 | for x in 0..cols { 64 | if mask[[y, x]] { 65 | image.set_array([0, y, x], 0, color); 66 | } 67 | } 68 | } 69 | } 70 | 71 | /// Remove the background of an image using [BRIA Background 72 | /// Removal](https://huggingface.co/briaai/RMBG-1.4). 73 | /// 74 | /// The ONNX models can be obtained from https://huggingface.co/briaai/RMBG-1.4. 75 | /// See the "Files and Versions" page. 76 | /// 77 | /// Assuming the model has been downloaded and named as "rmbg.onnx", it can be 78 | /// run on an image using: 79 | /// 80 | /// ``` 81 | /// rten-convert rmbg.onnx 82 | /// cargo run --release --bin rmbg rmbg.rten image.jpg 83 | /// ``` 84 | /// 85 | /// This will generate `output.png`, a copy of the input image with the 86 | /// background replaced with a fixed color. 87 | fn main() -> Result<(), Box> { 88 | let args = parse_args()?; 89 | let model = Model::load_file(args.model)?; 90 | 91 | let mut image: NdTensor = read_image(&args.image)?; 92 | 93 | let mut normalized_image = image.clone(); 94 | let mean = [0.5, 0.5, 0.5]; 95 | let std_dev = [1.0, 1.0, 1.0]; 96 | normalize_image(normalized_image.view_mut(), mean, std_dev); 97 | 98 | let [_, orig_height, orig_width] = image.shape(); 99 | 100 | let mut normalized_image = normalized_image.into_dyn(); 101 | normalized_image.insert_axis(0); // Add batch dim 102 | 103 | let [input_h, input_w] = [1024, 1024]; 104 | let resized_image = normalized_image.resize_image([input_h, input_w])?; 105 | 106 | // Run the model to get the probability of each pixel being part of the 107 | // image's foreground, then apply a threshold to get a mask indicating which 108 | // pixels are part of the image's background. 109 | let foreground_prob: NdTensor = model 110 | .run_one(resized_image.view().into(), None)? 111 | .try_into()?; 112 | let foreground_prob: NdTensor = foreground_prob 113 | .resize_image([orig_height, orig_width])? 114 | .try_into()?; 115 | let foreground_threshold = 0.5; 116 | let background_mask = foreground_prob.map(|x| *x < foreground_threshold); 117 | 118 | // Replace background with a fixed color. 119 | let bg_color = [0., 1., 0.]; // RGB 120 | fill_mask( 121 | image.view_mut(), 122 | background_mask.slice([0, 0]), // Extract first mask and channel 123 | bg_color, 124 | ); 125 | 126 | write_image(&args.output, image.nd_view())?; 127 | 128 | Ok(()) 129 | } 130 | -------------------------------------------------------------------------------- /rten-generate/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rten-generate" 3 | version = "0.18.0" 4 | edition = "2021" 5 | authors = ["Robert Knight"] 6 | description = "Utilities to simplify running auto-regressive models with RTen" 7 | license = "MIT OR Apache-2.0" 8 | homepage = "https://github.com/robertknight/rten" 9 | repository = "https://github.com/robertknight/rten" 10 | include = ["/src", "/README.md"] 11 | 12 | [dependencies] 13 | fastrand = { version = "2.0.2" } 14 | rten = { path = "../", version = "0.18.0" } 15 | rten-text = { path = "../rten-text", version = "0.18.0", optional = true } 16 | rten-tensor = { path = "../rten-tensor", version = "0.18.0" } 17 | 18 | [dev-dependencies] 19 | rten-testing = { path = "../rten-testing" } 20 | 21 | [features] 22 | # Enable text decoding using tokenizers from rten-text 23 | text-decoder = ["dep:rten-text"] 24 | 25 | [package.metadata.docs.rs] 26 | features = ["text-decoder"] 27 | -------------------------------------------------------------------------------- /rten-generate/README.md: -------------------------------------------------------------------------------- 1 | rten-generate is a layer on top of RTen which handles the generation loop for 2 | auto-regressive transformer models (aka. "transformer decoders" or "generative 3 | AI"). This includes managing the KV cache, sampling and post-processing logits 4 | etc. 5 | -------------------------------------------------------------------------------- /rten-generate/src/filter.rs: -------------------------------------------------------------------------------- 1 | //! Filters for processing model outputs prior to sampling. 2 | //! 3 | //! This module defines the [`LogitsFilter`] trait implemented by all filters, 4 | //! plus convenience functions to simplify implementing filters. 5 | 6 | use rten_tensor::prelude::*; 7 | use rten_tensor::{NdTensor, NdTensorView}; 8 | 9 | use crate::generator::TokenId; 10 | 11 | /// Filter which modifies the output logits from a model. 12 | /// 13 | /// The filter is applied to the model outputs before a token is sampled. 14 | pub trait LogitsFilter { 15 | /// Filter the model's output and return the modified logits. 16 | /// 17 | /// If this method returns `None`, the input logits are passed unmodified 18 | /// to the sampler. `prev_tokens` contains the previously sampled tokens, 19 | /// including the prompt. 20 | fn filter( 21 | &self, 22 | logits: NdTensorView, 23 | prev_tokens: &[TokenId], 24 | ) -> Option>; 25 | } 26 | 27 | struct TokenIdFilter bool> { 28 | predicate: F, 29 | } 30 | 31 | impl bool> LogitsFilter for TokenIdFilter { 32 | fn filter( 33 | &self, 34 | logits: NdTensorView, 35 | _prev_tokens: &[TokenId], 36 | ) -> Option> { 37 | Some(NdTensor::from_fn(logits.shape(), |[i]| { 38 | let token_id = i as TokenId; 39 | if (self.predicate)(token_id) { 40 | logits[[i]] 41 | } else { 42 | f32::NEG_INFINITY 43 | } 44 | })) 45 | } 46 | } 47 | 48 | /// Create a filter which suppresses all tokens that do not match a predicate by 49 | /// setting the value to `f32::NEG_INFINITY`. 50 | pub fn token_id_filter bool>(predicate: F) -> impl LogitsFilter { 51 | TokenIdFilter { predicate } 52 | } 53 | 54 | #[cfg(test)] 55 | mod tests { 56 | use rten_tensor::prelude::*; 57 | use rten_tensor::NdTensor; 58 | 59 | use super::{token_id_filter, LogitsFilter}; 60 | 61 | #[test] 62 | fn test_token_id_filter() { 63 | let logits = NdTensor::from([0., 1., 2., 3., 4.]); 64 | let filter = token_id_filter(|id| id % 2 == 0); 65 | let output = filter.filter(logits.view(), &[]); 66 | assert_eq!( 67 | output, 68 | Some(NdTensor::from([ 69 | 0., 70 | f32::NEG_INFINITY, 71 | 2., 72 | f32::NEG_INFINITY, 73 | 4. 74 | ])) 75 | ); 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /rten-generate/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Utilities to simplify running auto-regressive [RTen][rten] models such 2 | //! as transformer decoders. 3 | //! 4 | //! For working examples, see the examples in the [rten-examples][rten-examples] 5 | //! crate which import `rten_generate`. 6 | //! 7 | //! [rten]: https://github.com/robertknight/rten 8 | //! [rten-examples]: https://github.com/robertknight/rten/tree/main/rten-examples 9 | 10 | pub mod filter; 11 | pub mod generator; 12 | pub mod metrics; 13 | pub mod model; 14 | pub mod sampler; 15 | 16 | #[cfg(feature = "text-decoder")] 17 | pub mod text_decoder; 18 | 19 | pub use generator::{ 20 | Generator, GeneratorConfig, GeneratorError, GeneratorUtils, ModelInputsConfig, 21 | }; 22 | -------------------------------------------------------------------------------- /rten-generate/src/model.rs: -------------------------------------------------------------------------------- 1 | //! Abstraction over [`rten::Model`] for querying and executing ML models. 2 | 3 | use std::error::Error; 4 | 5 | use rten::{Dimension, NodeId, RunOptions, Value, ValueOrView}; 6 | 7 | /// Describes the name and shape of a model input or output. 8 | /// 9 | /// This is similar to [`rten::NodeInfo`] but the name and shape are required. 10 | #[derive(Clone)] 11 | pub struct NodeInfo { 12 | name: String, 13 | shape: Vec, 14 | } 15 | 16 | impl NodeInfo { 17 | pub fn name(&self) -> &str { 18 | &self.name 19 | } 20 | 21 | pub fn shape(&self) -> &[Dimension] { 22 | &self.shape 23 | } 24 | 25 | pub fn from_name_shape(name: &str, shape: &[Dimension]) -> NodeInfo { 26 | NodeInfo { 27 | name: name.to_string(), 28 | shape: shape.to_vec(), 29 | } 30 | } 31 | } 32 | 33 | /// Abstraction over [`rten::Model`] used by [`Generator`](crate::Generator) to 34 | /// query and execute a machine learning model. 35 | /// 36 | /// This is implemented by [`rten::Model`] and the trait's methods correspond 37 | /// to methods of the same name in that type. 38 | pub trait Model { 39 | /// Get the ID of an input or output node. 40 | fn find_node(&self, name: &str) -> Option; 41 | 42 | /// Get the name and shape of an input or output node. 43 | /// 44 | /// Returns `None` if the node does not exist, or name or shape information 45 | /// is not available. 46 | fn node_info(&self, id: NodeId) -> Option; 47 | 48 | /// Return the node IDs of the model's inputs. 49 | fn input_ids(&self) -> &[NodeId]; 50 | 51 | /// Run the model with the provided inputs and return the results. 52 | fn run( 53 | &self, 54 | inputs: Vec<(NodeId, ValueOrView)>, 55 | outputs: &[NodeId], 56 | opts: Option, 57 | ) -> Result, Box>; 58 | 59 | /// Run as much of the model as possible given the provided inputs and 60 | /// return the leaves of the evaluation where execution stopped. 61 | fn partial_run( 62 | &self, 63 | inputs: Vec<(NodeId, ValueOrView)>, 64 | outputs: &[NodeId], 65 | opts: Option, 66 | ) -> Result, Box>; 67 | } 68 | 69 | impl Model for rten::Model { 70 | fn find_node(&self, name: &str) -> Option { 71 | self.find_node(name) 72 | } 73 | 74 | fn node_info(&self, id: NodeId) -> Option { 75 | self.node_info(id).and_then(|info| { 76 | let name = info.name()?; 77 | let dims = info.shape()?; 78 | 79 | Some(NodeInfo { 80 | name: name.to_string(), 81 | shape: dims, 82 | }) 83 | }) 84 | } 85 | 86 | fn input_ids(&self) -> &[NodeId] { 87 | self.input_ids() 88 | } 89 | 90 | fn run( 91 | &self, 92 | inputs: Vec<(NodeId, ValueOrView)>, 93 | outputs: &[NodeId], 94 | opts: Option, 95 | ) -> Result, Box> { 96 | self.run(inputs, outputs, opts).map_err(|e| e.into()) 97 | } 98 | 99 | fn partial_run( 100 | &self, 101 | inputs: Vec<(NodeId, ValueOrView)>, 102 | outputs: &[NodeId], 103 | opts: Option, 104 | ) -> Result, Box> { 105 | self.partial_run(inputs, outputs, opts) 106 | .map_err(|e| e.into()) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /rten-imageio/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rten-imageio" 3 | version = "0.18.0" 4 | edition = "2021" 5 | authors = ["Robert Knight"] 6 | description = "Utilities for loading images for use with RTen" 7 | license = "MIT OR Apache-2.0" 8 | homepage = "https://github.com/robertknight/rten" 9 | repository = "https://github.com/robertknight/rten" 10 | include = ["/src", "/README.md"] 11 | 12 | [dependencies] 13 | png = "0.17.6" 14 | rten-tensor = { path = "../rten-tensor", version = "0.18.0" } 15 | image = { workspace = true } 16 | -------------------------------------------------------------------------------- /rten-imageio/README.md: -------------------------------------------------------------------------------- 1 | # rten-imageio 2 | 3 | Crate for loading images and converting them to tensors that can be used as 4 | inputs for RTen models. 5 | -------------------------------------------------------------------------------- /rten-imageio/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Provides utilities for loading, saving and preprocessing images for use with 2 | //! [RTen](https://github.com/robertknight/rten). 3 | //! 4 | //! The APIs are limited to keep them simple for the most common use cases. 5 | //! If you need more flexibility from a function, copy and adjust the 6 | //! implementation. 7 | 8 | use std::error::Error; 9 | use std::path::Path; 10 | 11 | use rten_tensor::errors::FromDataError; 12 | use rten_tensor::prelude::*; 13 | use rten_tensor::{NdTensor, NdTensorView}; 14 | 15 | /// Errors reported when creating a tensor from an image. 16 | #[derive(Debug)] 17 | pub enum ReadImageError { 18 | /// The image could not be loaded. 19 | ImageError(image::ImageError), 20 | /// The loaded image could not be converted to a tensor. 21 | ConvertError(FromDataError), 22 | } 23 | 24 | impl std::fmt::Display for ReadImageError { 25 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 26 | match self { 27 | ReadImageError::ImageError(e) => write!(f, "failed to read image: {}", e), 28 | ReadImageError::ConvertError(e) => write!(f, "failed to create tensor: {}", e), 29 | } 30 | } 31 | } 32 | 33 | impl Error for ReadImageError {} 34 | 35 | /// Convert an image into a CHW tensor with 3 channels and values in the range 36 | /// [0, 1]. 37 | pub fn image_to_tensor(image: image::DynamicImage) -> Result, ReadImageError> { 38 | let image = image.into_rgb8(); 39 | let (width, height) = image.dimensions(); 40 | let layout = image.sample_layout(); 41 | 42 | let chw_tensor = NdTensorView::from_data_with_strides( 43 | [height as usize, width as usize, 3], 44 | image.as_raw().as_slice(), 45 | [ 46 | layout.height_stride, 47 | layout.width_stride, 48 | layout.channel_stride, 49 | ], 50 | ) 51 | .map_err(ReadImageError::ConvertError)? 52 | .permuted([2, 0, 1]) // HWC => CHW 53 | .map(|x| *x as f32 / 255.); // Rescale from [0, 255] to [0, 1] 54 | 55 | Ok(chw_tensor) 56 | } 57 | 58 | /// Read an image from a file into a CHW tensor. 59 | /// 60 | /// To load an image from a byte buffer or other source, use [`image::open`] 61 | /// and pass the result to [`image_to_tensor`]. 62 | pub fn read_image>(path: P) -> Result, ReadImageError> { 63 | image::open(path) 64 | .map_err(ReadImageError::ImageError) 65 | .and_then(image_to_tensor) 66 | } 67 | 68 | /// Errors returned when writing a tensor to an image. 69 | #[derive(Debug)] 70 | pub enum WriteImageError { 71 | /// The number of channels in the image tensor is unsupported. 72 | UnsupportedChannelCount, 73 | /// The image could not be written. 74 | ImageError(image::ImageError), 75 | } 76 | 77 | impl std::fmt::Display for WriteImageError { 78 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 79 | match self { 80 | Self::ImageError(e) => write!(f, "failed to write image: {}", e), 81 | Self::UnsupportedChannelCount => write!(f, "image has unsupported number of channels"), 82 | } 83 | } 84 | } 85 | 86 | impl Error for WriteImageError {} 87 | 88 | /// Convert a CHW tensor to an image and write it to a PNG file. 89 | pub fn write_image(path: &str, img: NdTensorView) -> Result<(), WriteImageError> { 90 | let [channels, height, width] = img.shape(); 91 | let color_type = match channels { 92 | 1 => image::ColorType::L8, 93 | 3 => image::ColorType::Rgb8, 94 | 4 => image::ColorType::Rgba8, 95 | _ => return Err(WriteImageError::UnsupportedChannelCount), 96 | }; 97 | 98 | let hwc_img = img 99 | .permuted([1, 2, 0]) // CHW => HWC 100 | .map(|x| (x.clamp(0., 1.) * 255.0) as u8); 101 | 102 | image::save_buffer( 103 | path, 104 | hwc_img.data().unwrap(), 105 | width as u32, 106 | height as u32, 107 | color_type, 108 | ) 109 | .map_err(WriteImageError::ImageError)?; 110 | 111 | Ok(()) 112 | } 113 | -------------------------------------------------------------------------------- /rten-imageproc/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rten-imageproc" 3 | version = "0.18.0" 4 | edition = "2021" 5 | authors = ["Robert Knight"] 6 | description = "Image tensor processing and geometry functions" 7 | license = "MIT OR Apache-2.0" 8 | homepage = "https://github.com/robertknight/rten" 9 | repository = "https://github.com/robertknight/rten" 10 | include = ["/src", "/README.md"] 11 | 12 | [dependencies] 13 | rten-tensor = { path = "../rten-tensor", version = "0.18.0" } 14 | serde = { workspace = true, features = ["derive"], optional = true } 15 | 16 | [dev-dependencies] 17 | rten-bench = { path = "../rten-bench" } 18 | rten-testing = { path = "../rten-testing" } 19 | 20 | [lib] 21 | crate-type = ["lib"] 22 | 23 | [features] 24 | # Implement serde Serialize and Deserialize traits on items where it makes sense 25 | serde_traits = ["serde"] 26 | -------------------------------------------------------------------------------- /rten-imageproc/README.md: -------------------------------------------------------------------------------- 1 | # rten-imageproc 2 | 3 | Library for pre and post-processing image data stored in matrices. It includes 4 | functionality for: 5 | 6 | - Finding contours of objects in segmentation masks 7 | - Working with axis-aligned and oriented bounding boxes / rectangles 8 | - Simplifying polygons 9 | - Simple drawing of shapes 10 | 11 | The genesis of this library was a need in the ocrs OCR engine for a Rust 12 | implementation of a subset of the geometry and image processing functionality 13 | provided by libraries like OpenCV and Shapely in Python. 14 | -------------------------------------------------------------------------------- /rten-imageproc/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Provides 2D geometry and image processing functions. 2 | //! 3 | //! This includes: 4 | //! 5 | //! - 2D vectors and related math 6 | //! - 2D shapes and associated algorithms: [Point], [Line], [Rect], 7 | //! [RotatedRect], [Polygon] 8 | //! - Rudimentary drawing functions 9 | //! - Algorithms for finding the contours of connected components in an image 10 | //! ([find_contours]) 11 | //! - Algorithms for simplifying polygons and finding various kinds of shape 12 | //! that contain a polygon: [simplify_polygon], [min_area_rect], [convex_hull] 13 | 14 | mod contours; 15 | mod drawing; 16 | mod math; 17 | mod normalize; 18 | mod poly_algos; 19 | mod shapes; 20 | 21 | pub use contours::{find_contours, RetrievalMode}; 22 | pub use drawing::{draw_line, draw_polygon, fill_rect, stroke_rect, FillIter, Painter, Rgb}; 23 | pub use math::Vec2; 24 | pub use normalize::{normalize_image, IMAGENET_MEAN, IMAGENET_STD_DEV}; 25 | pub use poly_algos::{convex_hull, min_area_rect, simplify_polygon, simplify_polyline}; 26 | pub use shapes::{ 27 | bounding_rect, BoundingRect, Coord, Line, LineF, Point, PointF, Polygon, PolygonF, Polygons, 28 | Rect, RectF, RotatedRect, 29 | }; 30 | 31 | #[cfg(test)] 32 | mod tests { 33 | use std::fmt::Display; 34 | 35 | use rten_tensor::{MatrixLayout, NdTensorView, NdTensorViewMut}; 36 | 37 | use super::{Coord, Point, Rect}; 38 | 39 | /// Return a list of the points on the border of `rect`, in counter-clockwise 40 | /// order starting from the top-left corner. 41 | /// 42 | /// If `omit_corners` is true, the corner points of the rect are not 43 | /// included. 44 | pub fn border_points(rect: Rect, omit_corners: bool) -> Vec { 45 | let mut points = Vec::new(); 46 | 47 | let left_range = if omit_corners { 48 | rect.top() + 1..rect.bottom() - 1 49 | } else { 50 | rect.top()..rect.bottom() 51 | }; 52 | 53 | // Left edge 54 | for y in left_range.clone() { 55 | points.push(Point::from_yx(y, rect.left())); 56 | } 57 | 58 | // Bottom edge 59 | for x in rect.left() + 1..rect.right() - 1 { 60 | points.push(Point::from_yx(rect.bottom() - 1, x)); 61 | } 62 | 63 | // Right edge 64 | for y in left_range.rev() { 65 | points.push(Point::from_yx(y, rect.right() - 1)); 66 | } 67 | 68 | // Top edge 69 | for x in (rect.left() + 1..rect.right() - 1).rev() { 70 | points.push(Point::from_yx(rect.top(), x)); 71 | } 72 | 73 | points 74 | } 75 | 76 | /// Set the elements of a grid listed in `points` to `value`. 77 | #[allow(dead_code)] 78 | pub fn plot_points(mut grid: NdTensorViewMut, points: &[Point], value: T) { 79 | for point in points { 80 | grid[point.coord()] = value; 81 | } 82 | } 83 | 84 | /// Plot the 1-based indices of points in `points` on a grid. `step` is the 85 | /// increment value for each plotted point. 86 | #[allow(dead_code)] 87 | fn plot_point_indices( 88 | mut grid: NdTensorViewMut, 89 | points: &[Point], 90 | step: T, 91 | ) { 92 | let mut value = T::default(); 93 | value += step; 94 | for point in points { 95 | grid[point.coord()] = value; 96 | value += step; 97 | } 98 | } 99 | 100 | /// Convert a slice of `[y, x]` coordinates to `Point`s 101 | pub fn points_from_coords(coords: &[[T; 2]]) -> Vec> { 102 | coords.iter().map(|[y, x]| Point::from_yx(*y, *x)).collect() 103 | } 104 | 105 | /// Convery an array of `[y, x]` coordinates to `Point`s 106 | pub fn points_from_n_coords(coords: [[T; 2]; N]) -> [Point; N] { 107 | coords.map(|[y, x]| Point::from_yx(y, x)) 108 | } 109 | 110 | /// Print out elements of a 2D grid for debugging. 111 | #[allow(dead_code)] 112 | pub fn print_grid(grid: NdTensorView) { 113 | for y in 0..grid.rows() { 114 | for x in 0..grid.cols() { 115 | print!("{:2} ", grid[[y, x]]); 116 | } 117 | println!(); 118 | } 119 | println!(); 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /rten-imageproc/src/math.rs: -------------------------------------------------------------------------------- 1 | use crate::Point; 2 | 3 | #[derive(Copy, Clone, Debug, PartialEq)] 4 | #[cfg_attr(feature = "serde_traits", derive(serde::Serialize, serde::Deserialize))] 5 | pub struct Vec2 { 6 | pub x: f32, 7 | pub y: f32, 8 | } 9 | 10 | impl Vec2 { 11 | pub fn from_yx(y: f32, x: f32) -> Vec2 { 12 | Vec2 { y, x } 13 | } 14 | 15 | pub fn from_xy(x: f32, y: f32) -> Vec2 { 16 | Vec2 { x, y } 17 | } 18 | 19 | /// Return the vector from `start` to `end`. 20 | pub fn from_points(start: Point, end: Point) -> Vec2 { 21 | let dx = end.x - start.x; 22 | let dy = end.y - start.y; 23 | Vec2::from_yx(dy as f32, dx as f32) 24 | } 25 | 26 | pub fn length(&self) -> f32 { 27 | (self.x * self.x + self.y * self.y).sqrt() 28 | } 29 | 30 | /// Return the magnitude of the cross product of this vector with `other`. 31 | pub fn cross_product_norm(&self, other: Vec2) -> f32 { 32 | self.x * other.y - self.y * other.x 33 | } 34 | 35 | /// Return the dot product of this vector with `other`. 36 | pub fn dot(&self, other: Vec2) -> f32 { 37 | self.x * other.x + self.y * other.y 38 | } 39 | 40 | /// Return a copy of this vector scaled such that the length is 1. 41 | pub fn normalized(&self) -> Vec2 { 42 | let inv_len = 1. / self.length(); 43 | Vec2::from_yx(self.y * inv_len, self.x * inv_len) 44 | } 45 | 46 | /// Return a vector perpendicular to this vector. 47 | pub fn perpendicular(&self) -> Vec2 { 48 | Vec2 { 49 | y: -self.x, 50 | x: self.y, 51 | } 52 | } 53 | } 54 | 55 | impl std::ops::Add for Vec2 { 56 | type Output = Vec2; 57 | 58 | fn add(self, rhs: Vec2) -> Vec2 { 59 | Vec2 { 60 | y: self.y + rhs.y, 61 | x: self.x + rhs.x, 62 | } 63 | } 64 | } 65 | 66 | impl std::ops::Neg for Vec2 { 67 | type Output = Vec2; 68 | 69 | fn neg(self) -> Vec2 { 70 | Vec2 { 71 | y: -self.y, 72 | x: -self.x, 73 | } 74 | } 75 | } 76 | 77 | impl std::ops::Mul for Vec2 { 78 | type Output = Vec2; 79 | 80 | fn mul(self, rhs: f32) -> Vec2 { 81 | Vec2 { 82 | y: self.y * rhs, 83 | x: self.x * rhs, 84 | } 85 | } 86 | } 87 | 88 | impl std::ops::Sub for Vec2 { 89 | type Output = Vec2; 90 | 91 | fn sub(self, rhs: f32) -> Vec2 { 92 | Vec2 { 93 | y: self.y - rhs, 94 | x: self.x - rhs, 95 | } 96 | } 97 | } 98 | 99 | impl std::ops::Sub for Vec2 { 100 | type Output = Vec2; 101 | 102 | fn sub(self, rhs: Vec2) -> Vec2 { 103 | Vec2 { 104 | y: self.y - rhs.y, 105 | x: self.x - rhs.x, 106 | } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /rten-imageproc/src/normalize.rs: -------------------------------------------------------------------------------- 1 | use rten_tensor::prelude::*; 2 | use rten_tensor::NdTensorViewMut; 3 | 4 | /// Standard ImageNet normalization mean values, for use with 5 | /// [`normalize_image`]. 6 | pub const IMAGENET_MEAN: [f32; 3] = [0.485, 0.456, 0.406]; 7 | 8 | /// Standard ImageNet normalization standard deviation values, for use with 9 | /// [`normalize_image`]. 10 | pub const IMAGENET_STD_DEV: [f32; 3] = [0.229, 0.224, 0.225]; 11 | 12 | /// Normalize the mean and standard deviation of all pixels in an image. 13 | /// 14 | /// `img` should be a CHW tensor with `C` channels. For each channel `c`, the 15 | /// output pixel values are computed as `y = (x - mean[c]) / std_dev[c]`. 16 | /// 17 | /// This is a common preprocessing step for inputs to machine learning models. 18 | /// Many models use standard "ImageNet" constants ([`IMAGENET_MEAN`], 19 | /// [`IMAGENET_STD_DEV`]), but check the expected values for the model you are 20 | /// using. 21 | pub fn normalize_image( 22 | mut img: NdTensorViewMut, 23 | mean: [f32; C], 24 | std_dev: [f32; C], 25 | ) { 26 | let n_chans = img.size(0); 27 | assert_eq!( 28 | n_chans, C, 29 | "expected image to have {} channels but found {}", 30 | C, n_chans 31 | ); 32 | 33 | for chan in 0..n_chans { 34 | let inv_std_dev = 1. / std_dev[chan]; 35 | img.slice_mut(chan) 36 | .apply(|x| (x - mean[chan]) * inv_std_dev); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /rten-simd/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rten-simd" 3 | version = "0.18.0" 4 | edition = "2021" 5 | authors = ["Robert Knight"] 6 | description = "Portable SIMD for stable Rust" 7 | license = "MIT OR Apache-2.0" 8 | homepage = "https://github.com/robertknight/rten" 9 | repository = "https://github.com/robertknight/rten" 10 | include = ["/src", "/README.md"] 11 | 12 | [lib] 13 | crate-type = ["lib"] 14 | 15 | [lints.clippy] 16 | # See comments about `needless_range_loop` in root Cargo.toml. 17 | needless_range_loop = "allow" 18 | manual_memcpy = "allow" 19 | 20 | [features] 21 | avx512 = [] 22 | -------------------------------------------------------------------------------- /rten-simd/README.md: -------------------------------------------------------------------------------- 1 | # rten-simd 2 | 3 | Portable SIMD library for stable Rust. 4 | 5 | rten-simd is a library for defining operations that are accelerated using 6 | [SIMD](https://en.wikipedia.org/wiki/Single_instruction,_multiple_data) 7 | instruction sets such as AVX2, Arm Neon or WebAssembly SIMD. Operations are 8 | defined once using safe, portable APIs, then _dispatched_ at runtime to 9 | evaluate the operation using the best available SIMD instruction set (ISA) 10 | on the current CPU. 11 | 12 | The design is inspired by Google's 13 | [Highway](https://github.com/google/highway) library for C++ and the 14 | [pulp](https://docs.rs/pulp/latest/pulp/) crate. 15 | -------------------------------------------------------------------------------- /rten-simd/src/README.md: -------------------------------------------------------------------------------- 1 | This crate contains RTen's portable SIMD library. It is used to implement 2 | vectorized kernels for neural network operators. 3 | -------------------------------------------------------------------------------- /rten-simd/src/arch.rs: -------------------------------------------------------------------------------- 1 | #[cfg(target_arch = "aarch64")] 2 | pub mod aarch64; 3 | 4 | #[cfg(target_arch = "x86_64")] 5 | pub mod x86_64; 6 | 7 | #[cfg(target_arch = "wasm32")] 8 | #[cfg(target_feature = "simd128")] 9 | pub mod wasm32; 10 | 11 | pub mod generic; 12 | 13 | use crate::simd::Simd; 14 | 15 | /// Return the number of lanes in a SIMD vector with compile-time known size. 16 | const fn lanes() -> usize { 17 | size_of::() / size_of::() 18 | } 19 | 20 | /// Create a wrapper type for a platform-specific intrinsic type. 21 | #[allow(unused_macros)] // Not used on some platforms 22 | macro_rules! simd_type { 23 | ($type:ident, $inner:ty, $elem:ty, $mask:ty, $isa:ty) => { 24 | // The platform intrinsic is exposed as a public field so that 25 | // downstream crates can implement custom SIMD operations. It might be 26 | // better to support an `Into` conversion from the wrapper to the 27 | // platform type instead? 28 | 29 | #[derive(Copy, Clone, Debug)] 30 | #[repr(transparent)] 31 | pub struct $type(pub $inner); 32 | 33 | impl From<$inner> for $type { 34 | fn from(val: $inner) -> Self { 35 | Self(val) 36 | } 37 | } 38 | 39 | impl Simd for $type { 40 | type Elem = $elem; 41 | type Mask = $mask; 42 | type Array = [Self::Elem; size_of::() / size_of::<$elem>()]; 43 | type Isa = $isa; 44 | 45 | #[inline] 46 | fn to_bits(self) -> ::Bits { 47 | #[allow(clippy::useless_transmute)] 48 | unsafe { 49 | transmute::::Bits>(self) 50 | } 51 | } 52 | 53 | #[inline] 54 | fn from_bits(bits: ::Bits) -> Self { 55 | #[allow(clippy::useless_transmute)] 56 | unsafe { 57 | transmute::<::Bits, Self>(bits) 58 | } 59 | } 60 | 61 | #[inline] 62 | fn to_array(self) -> Self::Array { 63 | unsafe { transmute::(self) } 64 | } 65 | } 66 | }; 67 | } 68 | 69 | #[allow(unused_imports)] // Not used on some platforms 70 | use simd_type; 71 | -------------------------------------------------------------------------------- /rten-simd/src/arch/x86_64.rs: -------------------------------------------------------------------------------- 1 | mod avx2; 2 | pub use avx2::Avx2Isa; 3 | 4 | #[cfg(feature = "avx512")] 5 | mod avx512; 6 | 7 | #[cfg(feature = "avx512")] 8 | pub use avx512::Avx512Isa; 9 | -------------------------------------------------------------------------------- /rten-simd/src/elem.rs: -------------------------------------------------------------------------------- 1 | //! Traits for elements of SIMD vectors. 2 | 3 | /// Types used as elements (or _lanes_) of SIMD vectors. 4 | pub trait Elem: Copy + Default + WrappingAdd { 5 | /// Return the 1 value of this type. 6 | fn one() -> Self; 7 | } 8 | 9 | impl Elem for f32 { 10 | fn one() -> Self { 11 | 1. 12 | } 13 | } 14 | 15 | macro_rules! impl_elem_for_int { 16 | ($int:ty) => { 17 | impl Elem for $int { 18 | fn one() -> Self { 19 | 1 20 | } 21 | } 22 | }; 23 | } 24 | 25 | impl_elem_for_int!(i32); 26 | impl_elem_for_int!(i16); 27 | impl_elem_for_int!(i8); 28 | impl_elem_for_int!(u8); 29 | impl_elem_for_int!(u16); 30 | 31 | /// Wrapping addition of numbers. 32 | /// 33 | /// For float types, this is the same as [`std::ops::Add`]. For integer types, 34 | /// this is the same as the type's inherent `wrapping_add` method. 35 | pub trait WrappingAdd: Sized { 36 | type Output; 37 | 38 | fn wrapping_add(self, x: Self) -> Self; 39 | } 40 | 41 | macro_rules! impl_wrapping_add { 42 | ($type:ty) => { 43 | impl WrappingAdd for $type { 44 | type Output = Self; 45 | 46 | fn wrapping_add(self, x: Self) -> Self { 47 | Self::wrapping_add(self, x) 48 | } 49 | } 50 | }; 51 | } 52 | 53 | impl_wrapping_add!(i32); 54 | impl_wrapping_add!(i16); 55 | impl_wrapping_add!(i8); 56 | impl_wrapping_add!(u8); 57 | impl_wrapping_add!(u16); 58 | 59 | impl WrappingAdd for f32 { 60 | type Output = Self; 61 | 62 | fn wrapping_add(self, x: f32) -> f32 { 63 | self + x 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /rten-simd/src/isa_detection.rs: -------------------------------------------------------------------------------- 1 | //! Functions for testing the availability of instruction sets at runtime. 2 | 3 | /// Functions for reading system info on macOS. 4 | #[cfg(target_os = "macos")] 5 | #[allow(unused)] 6 | pub mod macos { 7 | /// Detect availability of AVX-512 on macOS, where `is_x86_feature_detected` 8 | /// can return false even if AVX-512 is available. 9 | /// 10 | /// See https://github.com/golang/go/issues/43089. Go chose to use the 11 | /// `commpage` to get the info. We use `sysctlbyname` instead since it is 12 | /// a documented API. 13 | pub(super) fn test_for_avx512_on_macos() -> bool { 14 | use std::sync::OnceLock; 15 | 16 | static AVX512_AVAILABLE: OnceLock = OnceLock::new(); 17 | 18 | *AVX512_AVAILABLE.get_or_init(|| { 19 | // We test for the minimum AVX-512 extensions we require, but not 20 | // avx512f, as that is implied if the extensions are supported. 21 | sysctl_bool(c"hw.optional.avx512vl").unwrap_or(false) 22 | && sysctl_bool(c"hw.optional.avx512bw").unwrap_or(false) 23 | && sysctl_bool(c"hw.optional.avx512dq").unwrap_or(false) 24 | }) 25 | } 26 | 27 | /// Error code returned by `sysctlbyname`. 28 | #[derive(Copy, Clone, Debug, PartialEq)] 29 | pub struct SysctlError(i32); 30 | 31 | /// Read system info on macOS via the `sysctlbyname` API. 32 | /// 33 | /// See the output of `sysctl -a` on the command line for available settings. 34 | /// 35 | /// See also [Apple's documentation](https://developer.apple.com/documentation/kernel/1387446-sysctlbyname). 36 | pub fn sysctl_int(name: &std::ffi::CStr) -> Result { 37 | use std::os::raw::{c_char, c_int, c_void}; 38 | 39 | #[link(name = "c")] 40 | extern "C" { 41 | /// See https://developer.apple.com/documentation/kernel/1387446-sysctlbyname. 42 | fn sysctlbyname( 43 | name: *const c_char, 44 | oldp: *mut c_void, 45 | oldlenp: *mut usize, 46 | newp: *const c_void, 47 | newlen: usize, 48 | ) -> c_int; 49 | } 50 | 51 | // Use i64 for the return type per example in Apple's docs. 52 | let mut result = 0i64; 53 | let mut size = std::mem::size_of::(); 54 | 55 | let sysctl_ret = unsafe { 56 | sysctlbyname( 57 | name.as_ptr(), 58 | &mut result as *mut i64 as *mut c_void, 59 | &mut size, 60 | std::ptr::null(), 61 | 0, 62 | ) 63 | }; 64 | 65 | if sysctl_ret != 0 { 66 | return Err(SysctlError(sysctl_ret)); 67 | } 68 | Ok(result) 69 | } 70 | 71 | /// Read a system configuration integer value and interpret it as a boolean. 72 | pub fn sysctl_bool(name: &std::ffi::CStr) -> Result { 73 | sysctl_int(name).map(|val| val == 1) 74 | } 75 | } 76 | 77 | /// Test if the current system has basic AVX-512 support. 78 | /// 79 | /// "Basic support" is defined as: 80 | /// - AVX512F 81 | /// - AVX512VL 82 | /// - AVX512BW 83 | /// - AVX512DQ 84 | /// 85 | /// These features are available on Skylake (2016) and later. 86 | /// See . 87 | /// 88 | /// This is unfortunately not as simple as using `is_x86_feature_detected` 89 | /// because that can return incorrect results on macOS. 90 | #[cfg(target_arch = "x86_64")] 91 | pub fn is_avx512_supported() -> bool { 92 | if is_x86_feature_detected!("avx512f") 93 | && is_x86_feature_detected!("avx512vl") 94 | && is_x86_feature_detected!("avx512bw") 95 | && is_x86_feature_detected!("avx512dq") 96 | { 97 | true 98 | } else { 99 | #[cfg(target_os = "macos")] 100 | { 101 | macos::test_for_avx512_on_macos() 102 | } 103 | #[cfg(not(target_os = "macos"))] 104 | { 105 | false 106 | } 107 | } 108 | } 109 | 110 | #[cfg(test)] 111 | mod tests { 112 | #[cfg(target_arch = "x86_64")] 113 | use super::is_avx512_supported; 114 | 115 | #[cfg(target_arch = "x86_64")] 116 | #[test] 117 | fn test_is_avx512_supported() { 118 | // Just test that the function runs. 119 | is_avx512_supported(); 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /rten-simd/src/simd.rs: -------------------------------------------------------------------------------- 1 | //! Traits for SIMD vectors and masks. 2 | 3 | use std::fmt::Debug; 4 | 5 | use crate::elem::Elem; 6 | use crate::ops::Isa; 7 | 8 | /// Masks used or returned by SIMD operations. 9 | /// 10 | /// Most operations on masks are available via the 11 | /// [`MaskOps`](crate::ops::MaskOps) trait. Implementations are obtained via 12 | /// [`NumOps::mask_ops`](crate::ops::NumOps::mask_ops). 13 | pub trait Mask: Copy + Debug { 14 | type Array: AsRef<[bool]> 15 | + Copy 16 | + Debug 17 | + IntoIterator 18 | + PartialEq 19 | + std::ops::Index; 20 | 21 | /// Convert this mask to a bool array. 22 | fn to_array(self) -> Self::Array; 23 | 24 | /// Return true if all lanes in the mask are one. 25 | fn all_true(self) -> bool { 26 | self.to_array().as_ref().iter().all(|&x| x) 27 | } 28 | 29 | /// Return true if all lanes in the mask are false. 30 | fn all_false(self) -> bool { 31 | self.to_array().as_ref().iter().all(|&x| !x) 32 | } 33 | } 34 | 35 | /// SIMD vector type. 36 | #[allow(clippy::len_without_is_empty)] 37 | pub trait Simd: Copy + Debug { 38 | /// Representation of this vector as a `[Self::Elem; N]` array. 39 | type Array: AsRef<[Self::Elem]> 40 | + Copy 41 | + Debug 42 | + IntoIterator 43 | + PartialEq 44 | + std::ops::Index 45 | + std::ops::IndexMut; 46 | 47 | /// Type of data held in each SIMD lane. 48 | type Elem: Elem; 49 | 50 | /// Mask with the same number of elements as this vector. 51 | type Mask: Mask; 52 | 53 | /// The ISA associated with this SIMD vector. 54 | type Isa: Isa; 55 | 56 | /// Convert this SIMD vector to the common "bits" type used by all vectors 57 | /// in this family. 58 | fn to_bits(self) -> ::Bits; 59 | 60 | /// Convert this SIMD vector from the common "bits" type used by all vectors 61 | /// in this family. 62 | fn from_bits(bits: ::Bits) -> Self; 63 | 64 | /// Reinterpret the bits of this vector as another vector from the same 65 | /// family. 66 | fn reinterpret_cast(self) -> T 67 | where 68 | T: Simd, 69 | { 70 | T::from_bits(self.to_bits()) 71 | } 72 | 73 | /// Cast this vector to another with the same ISA and element type. 74 | /// 75 | /// This cast is a no-op which doesn't generate any code. It is needed in 76 | /// some cases to downcast a `Simd` type to one of an `Isa`s associated 77 | /// types, or vice-versa. 78 | fn same_cast(self) -> T 79 | where 80 | T: Simd, 81 | { 82 | T::from_bits(self.to_bits()) 83 | } 84 | 85 | /// Convert `self` to a SIMD array. 86 | /// 87 | /// This is a cheap transmute in most cases, since SIMD vectors usually 88 | /// have the same layout as `[S::Elem; N]` but a greater alignment. 89 | fn to_array(self) -> Self::Array; 90 | } 91 | -------------------------------------------------------------------------------- /rten-simd/src/span.rs: -------------------------------------------------------------------------------- 1 | //! Slice-like types used as inputs and outputs for vectorized operations. 2 | 3 | use std::mem::{transmute, MaybeUninit}; 4 | 5 | enum SrcDestInner<'src, 'dst, T> { 6 | InOut(&'src [T], &'dst mut [MaybeUninit]), 7 | InMut(&'dst mut [T]), 8 | } 9 | 10 | /// Input-output buffer for vectorized operations. 11 | /// 12 | /// This can either be a single mutable buffer for operations that execute 13 | /// in-place (`&mut [T]`) or a pair of input and output buffers where the 14 | /// output is uninitialized (`([T], &mut [MaybeUninit])`) and both buffers 15 | /// must have the same length. 16 | pub struct SrcDest<'src, 'dst, T: Copy> { 17 | inner: SrcDestInner<'src, 'dst, T>, 18 | } 19 | 20 | impl<'dst, T: Copy> SrcDest<'_, 'dst, T> { 21 | /// Return the source slice. 22 | pub fn src(&self) -> &[T] { 23 | match &self.inner { 24 | SrcDestInner::InOut(src, _dest) => src, 25 | SrcDestInner::InMut(src_mut) => src_mut, 26 | } 27 | } 28 | 29 | /// Return the length of the input and output slices. 30 | pub fn len(&self) -> usize { 31 | self.src().len() 32 | } 33 | 34 | /// Return true if the input and output slices are empty. 35 | pub fn is_empty(&self) -> bool { 36 | self.src().is_empty() 37 | } 38 | 39 | /// Return source and destination slice pointers and the length. 40 | /// 41 | /// The source and destination will either alias, or the destination will 42 | /// be a non-aliasing, uninitialized slice. 43 | pub fn src_dest_ptr(&mut self) -> (*const T, *mut MaybeUninit, usize) { 44 | match &mut self.inner { 45 | SrcDestInner::InOut(src, dest) => (src.as_ptr(), dest.as_mut_ptr(), src.len()), 46 | SrcDestInner::InMut(src) => ( 47 | src.as_ptr(), 48 | src.as_mut_ptr() as *mut MaybeUninit, 49 | src.len(), 50 | ), 51 | } 52 | } 53 | 54 | /// Return the initialized destination slice. 55 | /// 56 | /// # Safety 57 | /// 58 | /// If this instance was constructed with an uninitialized destination 59 | /// buffer, all elements must have been initialized before this is called. 60 | pub unsafe fn dest_assume_init(self) -> &'dst mut [T] { 61 | match self.inner { 62 | SrcDestInner::InOut(_src, dest) => transmute::<&mut [MaybeUninit], &mut [T]>(dest), 63 | SrcDestInner::InMut(src) => src, 64 | } 65 | } 66 | } 67 | 68 | impl<'src, 'dst, T: Copy> From<(&'src [T], &'dst mut [MaybeUninit])> for SrcDest<'src, 'dst, T> { 69 | fn from(val: (&'src [T], &'dst mut [MaybeUninit])) -> Self { 70 | let (src, dest) = val; 71 | assert_eq!( 72 | src.len(), 73 | dest.len(), 74 | "src len {} != dest len {}", 75 | src.len(), 76 | dest.len(), 77 | ); 78 | SrcDest { 79 | inner: SrcDestInner::InOut(src, dest), 80 | } 81 | } 82 | } 83 | 84 | impl<'dst, T: Copy> From<&'dst mut [T]> for SrcDest<'dst, 'dst, T> { 85 | fn from(val: &'dst mut [T]) -> Self { 86 | SrcDest { 87 | inner: SrcDestInner::InMut(val), 88 | } 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /rten-simd/src/writer.rs: -------------------------------------------------------------------------------- 1 | use std::mem::{transmute, MaybeUninit}; 2 | 3 | use crate::ops::NumOps; 4 | use crate::Elem; 5 | 6 | /// Utility for incrementally filling an uninitialized slice, one SIMD vector 7 | /// at a time. 8 | pub struct SliceWriter<'a, T> { 9 | buf: &'a mut [MaybeUninit], 10 | n_init: usize, 11 | } 12 | 13 | impl<'a, T: Elem> SliceWriter<'a, T> { 14 | /// Create a writer which initializes elements of `buf`. 15 | pub fn new(buf: &'a mut [MaybeUninit]) -> Self { 16 | SliceWriter { buf, n_init: 0 } 17 | } 18 | 19 | /// Initialize the next `ops.len()` elements of the slice from the contents 20 | /// of SIMD vector `xs`. 21 | /// 22 | /// Panics if the slice does not have space for `ops.len()` elements. 23 | pub fn write_vec>(&mut self, ops: O, xs: O::Simd) { 24 | let written = ops.store_uninit(xs, &mut self.buf[self.n_init..]); 25 | self.n_init += written.len(); 26 | } 27 | 28 | /// Initialize the next element of the slice from `x`. 29 | /// 30 | /// Panics if the slice does not have space for writing any more elements. 31 | pub fn write_scalar(&mut self, x: T) { 32 | self.buf[self.n_init].write(x); 33 | self.n_init += 1; 34 | } 35 | 36 | /// Finish writing the slice and return the initialized portion. 37 | pub fn into_mut_slice(self) -> &'a mut [T] { 38 | let init = &mut self.buf[0..self.n_init]; 39 | 40 | // Safety: All elements in `init` have been initialized. 41 | unsafe { transmute::<&mut [MaybeUninit], &mut [T]>(init) } 42 | } 43 | } 44 | 45 | #[cfg(test)] 46 | mod tests { 47 | use std::mem::MaybeUninit; 48 | 49 | use crate::ops::NumOps; 50 | use crate::{Isa, SimdOp, SliceWriter}; 51 | 52 | #[test] 53 | fn test_slice_writer() { 54 | struct MemCopy<'src, 'dest> { 55 | src: &'src [f32], 56 | dest: &'dest mut [MaybeUninit], 57 | } 58 | 59 | impl<'src, 'dest> SimdOp for MemCopy<'src, 'dest> { 60 | type Output = &'dest mut [f32]; 61 | 62 | fn eval(self, isa: I) -> &'dest mut [f32] { 63 | let ops = isa.f32(); 64 | 65 | let mut src_chunks = self.src.chunks_exact(ops.len()); 66 | let mut dest_writer = SliceWriter::new(self.dest); 67 | 68 | for chunk in src_chunks.by_ref() { 69 | let xs = ops.load(chunk); 70 | dest_writer.write_vec(ops, xs); 71 | } 72 | 73 | for x in src_chunks.remainder() { 74 | dest_writer.write_scalar(*x); 75 | } 76 | 77 | dest_writer.into_mut_slice() 78 | } 79 | } 80 | 81 | // Length which should cover the vectorized body and tail cases for 82 | // every ISA. 83 | let len = 17; 84 | let src: Vec<_> = (0..len).map(|x| x as f32).collect(); 85 | let mut dest = Vec::with_capacity(src.len()); 86 | 87 | let copied = MemCopy { 88 | src: &src, 89 | dest: dest.spare_capacity_mut(), 90 | } 91 | .dispatch(); 92 | assert_eq!(copied, src); 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /rten-tensor/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rten-tensor" 3 | version = "0.18.0" 4 | edition = "2021" 5 | authors = ["Robert Knight"] 6 | description = "Tensor library for the RTen machine learning runtime" 7 | license = "MIT OR Apache-2.0" 8 | homepage = "https://github.com/robertknight/rten" 9 | repository = "https://github.com/robertknight/rten" 10 | include = ["/src", "/README.md"] 11 | 12 | [dependencies] 13 | serde = { workspace = true, optional = true } 14 | smallvec = { version = "1.10.0", features=["union", "const_generics", "const_new"] } 15 | 16 | [dev-dependencies] 17 | rten-testing = { path = "../rten-testing" } 18 | rten-bench = { path ="../rten-bench" } 19 | serde_json = { workspace = true } 20 | 21 | [lib] 22 | crate-type = ["lib"] 23 | 24 | [lints.clippy] 25 | # See comments about `needless_range_loop` in root Cargo.toml. 26 | needless_range_loop = "allow" 27 | manual_memcpy = "allow" 28 | manual_repeat_n = "allow" # TODO - Address existing failures 29 | 30 | [features] 31 | serde = ["dep:serde"] 32 | -------------------------------------------------------------------------------- /rten-tensor/README.md: -------------------------------------------------------------------------------- 1 | # rten-tensor 2 | 3 | rten-tensor is the foundational library that provides multi-dimensional arrays 4 | used by [RTen](https://github.com/robertknight/rten). It is similar to 5 | [ndarray](https://github.com/rust-ndarray/ndarray) but tailored for use in the 6 | RTen library. 7 | -------------------------------------------------------------------------------- /rten-tensor/src/assume_init.rs: -------------------------------------------------------------------------------- 1 | use std::mem::MaybeUninit; 2 | 3 | /// Trait for converting collections of uninitialized (`MaybeUninit`) values 4 | /// to collections of corresponding initializes values (`T`). 5 | /// 6 | /// ## Example 7 | /// 8 | /// ``` 9 | /// use std::mem::MaybeUninit; 10 | /// use rten_tensor::AssumeInit; 11 | /// 12 | /// fn scale_values<'a>(dst: &'a mut [MaybeUninit], src: &[f32], scale: f32) -> &'a mut [f32] { 13 | /// for (y, x) in dst.into_iter().zip(src) { 14 | /// y.write(x * scale); 15 | /// } 16 | /// // Safety: All elements have been initialized. 17 | /// unsafe { dst.assume_init() } 18 | /// } 19 | /// 20 | /// let src = [1., 2., 3.]; 21 | /// let mut dst = [MaybeUninit::uninit(); 3]; 22 | /// let scaled = scale_values(&mut dst, &src, 2.); 23 | /// assert_eq!(scaled, [2., 4., 6.]); 24 | /// ``` 25 | pub trait AssumeInit { 26 | /// The type of the initialized storage. 27 | type Output; 28 | 29 | /// Cast `self` to a collection of initialized values. 30 | /// 31 | /// # Safety 32 | /// 33 | /// The caller must guarantee that all elements have been initialized. 34 | unsafe fn assume_init(self) -> Self::Output; 35 | } 36 | 37 | impl AssumeInit for Vec> { 38 | type Output = Vec; 39 | 40 | unsafe fn assume_init(mut self) -> Self::Output { 41 | let (ptr, len, capacity) = (self.as_mut_ptr(), self.len(), self.capacity()); 42 | 43 | // Don't drop self, as that would deallocate. 44 | std::mem::forget(self); 45 | 46 | // Safety: We're re-constructing a `Vec` with the same length and 47 | // capacity and an element type that has the same size and alignment, 48 | // just cast from uninitialized to initialized. 49 | unsafe { Vec::from_raw_parts(ptr as *mut T, len, capacity) } 50 | } 51 | } 52 | 53 | impl<'a, T> AssumeInit for &'a [MaybeUninit] { 54 | type Output = &'a [T]; 55 | 56 | unsafe fn assume_init(self) -> Self::Output { 57 | std::mem::transmute(self) 58 | } 59 | } 60 | 61 | impl<'a, T> AssumeInit for &'a mut [MaybeUninit] { 62 | type Output = &'a mut [T]; 63 | 64 | unsafe fn assume_init(self) -> Self::Output { 65 | std::mem::transmute(self) 66 | } 67 | } 68 | 69 | #[cfg(test)] 70 | mod tests { 71 | use std::mem::MaybeUninit; 72 | 73 | use super::AssumeInit; 74 | 75 | #[test] 76 | fn test_assume_init_vec() { 77 | let mut vec = vec![MaybeUninit::uninit(); 3]; 78 | vec.reserve(4); 79 | 80 | for x in &mut vec { 81 | x.write(2.); 82 | } 83 | 84 | let vec = unsafe { vec.assume_init() }; 85 | assert_eq!(vec.len(), 3); 86 | assert_eq!(vec.capacity(), 7); 87 | assert_eq!(vec, &[2., 2., 2.]); 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /rten-tensor/src/rng.rs: -------------------------------------------------------------------------------- 1 | use crate::RandomSource; 2 | 3 | /// Simple, non-cryptographically secure random number generator. 4 | /// 5 | /// See . 6 | pub struct XorShiftRng { 7 | state: u64, 8 | } 9 | 10 | impl XorShiftRng { 11 | pub fn new(seed: u64) -> XorShiftRng { 12 | XorShiftRng { state: seed } 13 | } 14 | 15 | /// Return a random value in the range [0, 2^64] 16 | pub fn next_u64(&mut self) -> u64 { 17 | let mut tmp = self.state; 18 | tmp ^= tmp << 13; 19 | tmp ^= tmp >> 7; 20 | tmp ^= tmp << 17; 21 | self.state = tmp; 22 | tmp 23 | } 24 | 25 | /// Return a random value in the range [0, 1] 26 | pub fn next_f32(&mut self) -> f32 { 27 | // Number of most significant bits to use 28 | let n_bits = 40; 29 | let scale = 1.0 / (1u64 << n_bits) as f32; 30 | let val = self.next_u64() >> (64 - n_bits); 31 | (val as f32) * scale 32 | } 33 | 34 | /// Return an infinite iterator that yields random values of type `T`. 35 | pub fn iter(&mut self) -> impl Iterator + '_ 36 | where 37 | Self: RandomSource, 38 | { 39 | std::iter::from_fn(|| Some(self.next())) 40 | } 41 | } 42 | 43 | impl RandomSource for XorShiftRng { 44 | fn next(&mut self) -> f32 { 45 | self.next_f32() 46 | } 47 | } 48 | 49 | macro_rules! impl_random_source { 50 | ($ty:ty) => { 51 | impl RandomSource<$ty> for XorShiftRng { 52 | fn next(&mut self) -> $ty { 53 | // Take the least significant bits of the 64bit value as the 54 | // result. 55 | self.next_u64() as $ty 56 | } 57 | } 58 | }; 59 | } 60 | 61 | impl_random_source!(u8); 62 | impl_random_source!(i8); 63 | impl_random_source!(i16); 64 | impl_random_source!(u16); 65 | impl_random_source!(i32); 66 | impl_random_source!(u32); 67 | 68 | #[cfg(test)] 69 | mod tests { 70 | use super::XorShiftRng; 71 | 72 | #[test] 73 | fn test_f32() { 74 | let mut rng = XorShiftRng::new(1234); 75 | let x: Vec = rng.iter().take(10).collect(); 76 | assert_eq!( 77 | x, 78 | &[ 79 | 7.2381226e-8, 80 | 0.12971127, 81 | 0.44675463, 82 | 6.69676e-5, 83 | 0.44387037, 84 | 0.24518594, 85 | 0.84056354, 86 | 0.9960614, 87 | 0.32433507, 88 | 0.9239961 89 | ] 90 | ); 91 | } 92 | 93 | #[test] 94 | fn test_i8() { 95 | let mut rng = XorShiftRng::new(1234); 96 | let x: Vec = rng.iter().take(10).collect(); 97 | assert_eq!(x, &[91, 123, 3, -73, 8, -102, -19, 118, 88, 58]); 98 | } 99 | 100 | #[test] 101 | fn test_u8() { 102 | let mut rng = XorShiftRng::new(1234); 103 | let x: Vec = rng.iter().take(10).collect(); 104 | assert_eq!(x, &[91, 123, 3, 183, 8, 154, 237, 118, 88, 58]); 105 | } 106 | 107 | #[test] 108 | fn test_i32() { 109 | let mut rng = XorShiftRng::new(1234); 110 | let x: Vec = rng.iter().take(10).collect(); 111 | assert_eq!( 112 | x, 113 | &[ 114 | -533893029, 115 | -1874043781, 116 | -2014135805, 117 | -1501708361, 118 | 330844424, 119 | 1872264090, 120 | -1812926995, 121 | -306325642, 122 | 692957528, 123 | -1439925190 124 | ] 125 | ); 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /rten-testing/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rten-testing" 3 | version = "0.1.0" 4 | edition = "2021" 5 | authors = ["Robert Knight"] 6 | description = "Testing utilities for use in RTen development" 7 | license = "MIT OR Apache-2.0" 8 | homepage = "https://github.com/robertknight/rten" 9 | repository = "https://github.com/robertknight/rten" 10 | 11 | [package.metadata.release] 12 | release = false 13 | 14 | [lib] 15 | crate-type = ["lib"] 16 | -------------------------------------------------------------------------------- /rten-testing/README.md: -------------------------------------------------------------------------------- 1 | Internal testing utilities used in RTen development. 2 | -------------------------------------------------------------------------------- /rten-text/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rten-text" 3 | version = "0.18.0" 4 | edition = "2021" 5 | authors = ["Robert Knight"] 6 | description = "Text tokenization and other ML pre/post-processing functions" 7 | license = "MIT OR Apache-2.0" 8 | homepage = "https://github.com/robertknight/rten" 9 | repository = "https://github.com/robertknight/rten" 10 | include = ["/src", "/README.md"] 11 | 12 | [lib] 13 | crate-type = ["lib"] 14 | 15 | [dependencies] 16 | fancy-regex = { version = "0.14.0", default-features = false, features = ["std", "unicode"] } 17 | unicode_categories = "0.1.1" 18 | unicode-normalization = "0.1.22" 19 | serde = { workspace = true, features = ["derive"] } 20 | serde_json = { workspace = true } 21 | 22 | [dev-dependencies] 23 | rten-testing = { path = "../rten-testing" } 24 | 25 | [lints.clippy] 26 | manual_repeat_n = "allow" # TODO - Address existing failures 27 | -------------------------------------------------------------------------------- /rten-text/README.md: -------------------------------------------------------------------------------- 1 | # rten-text 2 | 3 | Library containing text tokenization and related functionality, for preparing 4 | inputs and decoding outputs for text models (eg. BERT). 5 | 6 | The functionality is a subset of that found in [Hugging Face 7 | Tokenizers](https://github.com/huggingface/tokenizers). It has less 8 | functionality, but also fewer dependencies, and none that require C/C++. 9 | -------------------------------------------------------------------------------- /rten-text/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This crate provides tokenizers for encoding text into token IDs 2 | //! for model inputs and decoding output token IDs back into text. 3 | //! 4 | //! The tokenization process follows the 5 | //! [pipeline](https://huggingface.co/docs/tokenizers/en/pipeline) used by the 6 | //! Hugging Face [Tokenizers](https://huggingface.co/docs/tokenizers/en/) 7 | //! library. Tokenizers can either be constructed manually or loaded from 8 | //! Hugging Face `tokenizer.json` files. 9 | //! 10 | //! ## Comparison to _tokenizers_ crate 11 | //! 12 | //! The canonical implementation of this tokenization pipeline is the 13 | //! [`tokenizers`](https://github.com/huggingface/tokenizers) crate. The main 14 | //! differences compared to that crate are: 15 | //! 16 | //! - rten-text focuses on inference only and does not support training 17 | //! tokenizers. 18 | //! - rten-text is a pure Rust library with no dependencies written in C/C++. 19 | //! This means it is easy to build for WebAssembly and other targets where 20 | //! non-Rust dependencies may cause difficulties. 21 | //! - rten-text is integrated with the 22 | //! [rten-generate](https://docs.rs/rten-generate/) library which handles 23 | //! running the complete inference loop for auto-regressive transformer 24 | //! models. Note that you can use rten-generate's outputs with other tokenizer 25 | //! libraries if rten-text is not suitable. 26 | //! - Not all tokenizer features are currently implemented in rten-text. Please 27 | //! file an issue if you find that rten-text is missing a feature needed for a 28 | //! particular model's tokenizer. 29 | //! 30 | //! ## Loading a pre-trained tokenizer 31 | //! 32 | //! The main entry point is the [`Tokenizer`] type. Use [`Tokenizer::from_file`] 33 | //! or [`Tokenizer::from_json`] to construct a tokenizer from a `tokenizer.json` 34 | //! file. 35 | //! 36 | //! ## Encoding text 37 | //! 38 | //! The [`Tokenizer::encode`] method is used to encode text into token IDs. 39 | //! This can be used for example to encode a model's prompt: 40 | //! 41 | //! ```no_run 42 | //! use rten_text::Tokenizer; 43 | //! 44 | //! let tokenizer = Tokenizer::from_file("gpt2/tokenizer.json")?; 45 | //! let encoded = tokenizer.encode("some text to tokenize", None)?; 46 | //! let token_ids = encoded.token_ids(); // Sequence of token IDs 47 | //! # Ok::<_, Box>(()) 48 | //! ``` 49 | //! 50 | //! ## Decoding text 51 | //! 52 | //! Given token IDs generated by a model, you can decode them back into text 53 | //! using the [`Tokenizer::decode`] method: 54 | //! 55 | //! ```no_run 56 | //! use rten_text::Tokenizer; 57 | //! 58 | //! let tokenizer = Tokenizer::from_file("gpt2/tokenizer.json")?; 59 | //! // Run model and get token IDs from outputs... 60 | //! let token_ids = [101, 4256, 300]; 61 | //! let text = tokenizer.decode(&token_ids)?; 62 | //! # Ok::<_, Box>(()) 63 | //! ``` 64 | //! 65 | //! ## More examples 66 | //! 67 | //! See the 68 | //! [rten-examples](https://github.com/robertknight/rten/tree/main/rten-examples) 69 | //! crate for various examples showing how to use this crate as part of an 70 | //! end-to-end pipeline. 71 | 72 | pub mod models; 73 | pub mod normalizers; 74 | pub mod pre_tokenizers; 75 | pub mod tokenizer; 76 | 77 | mod split; 78 | 79 | pub use tokenizer::{TokenId, Tokenizer, TokenizerError}; 80 | -------------------------------------------------------------------------------- /rten-text/test-data/reftests/README.md: -------------------------------------------------------------------------------- 1 | This directory contains text files from various sources (e.g. Wikipedia) and 2 | reference tokenizations created with existing Python libraries, by scripts in 3 | the tools/ directory. 4 | -------------------------------------------------------------------------------- /rten-text/test-data/reftests/models/gpt2/README.md: -------------------------------------------------------------------------------- 1 | # GPT 2 tokenizer 2 | 3 | This directory contains the Byte Pair Encoding (BPE) merge list and vocabulary 4 | mapping retrieved from https://huggingface.co/gpt2. 5 | 6 | These should match the data files for the 124M model from the original GPT-2 7 | repository. See https://github.com/openai/gpt-2/blob/master/DEVELOPERS.md. 8 | -------------------------------------------------------------------------------- /rten-text/test-data/reftests/monty-python-credits.txt: -------------------------------------------------------------------------------- 1 | Mønti Pythøn lk den Hølie Grailen 2 | 3 | Røtern nik Akten Di 4 | 5 | Wik 6 | 7 | Alsø wik 8 | 9 | Alsø alsø wik 10 | 11 | Wi nøt trei a høliday in Sweden this yer? 12 | 13 | See the løveli lakes 14 | 15 | The wonderful telephøne system 16 | 17 | And mani interesting furry animals 18 | 19 | The Producers would like to thank The Forestry Commission 20 | Doune Admissions Ltd, Keir and Cowdor Estates, Stirling 21 | University, and the people of Doune for their help in the 22 | making of this film. 23 | The Characters and incidents portrayed and the names used 24 | are fictitious and any similarity to the names, characters, 25 | or history of any person is entirely accidental and 26 | unintentional. 27 | Signed RICHARD M. NIXON 28 | 29 | Including the majestic møøse 30 | 31 | A Møøse once bit my sister ... 32 | 33 | No realli! She was Karving her initials on the møøse 34 | with the sharpened end of an interspace tøøthbrush given 35 | her by Svenge - her brother-in-law - an Oslo dentist and 36 | star of many Norwegian møvies: "The Høt Hands of an Oslo 37 | Dentist", "Fillings of Passion", "The Huge Mølars of Horst 38 | Nordfink". 39 | 40 | We apologise for the fault in the 41 | subtitles. Those responsible have been 42 | sacked. 43 | 44 | Mynd you, møøse bites Kan be pretty nasti... 45 | 46 | We apologise again for the fault in the subtitles. Those 47 | responsible for sacking the people who have just been sacked 48 | have been sacked. 49 | 50 | Møøse trained by TUTTE HERMSGERVORDENBROTBORDA 51 | 52 | Special Møøse Effects OLAF PROT 53 | Møøse Costumes SIGGI CHURCHILL 54 | Møøse Choreographed by HORST PROT III 55 | Miss Taylor's Møøses by HENGST DOUGLAS-HOME 56 | Møøse trained to mix 57 | concrete and sign com- 58 | plicated insurance 59 | forms by JURGEN WIGG 60 | Møøses' noses wiped by BJORN IRKESTOM-SLATER WALKER 61 | 62 | Large møøse on the left 63 | half side of the screen 64 | in the third scene from 65 | the end, given a thorough 66 | grounding in Latin, 67 | French and "O" Level 68 | Geography by BO BENN 69 | 70 | Suggestive poses for the 71 | Møøse suggested by VIC ROTTER 72 | Antler-care by LIV THATCHER 73 | 74 | The directors of the firm hired to 75 | continue the credits after the other 76 | people had been sacked, wish it to 77 | be known that they have just been 78 | sacked. 79 | 80 | The credits have been completed 81 | in an entirely different style at 82 | great expense and at the last 83 | minute. 84 | 85 | Executive Producer 86 | JOHN GOLDSTONE & "RALPH" The Wonder Llama 87 | 88 | The Producers would like to thank The Forestry Commission 89 | Doune Admissions Ltd, Keir and Cowdor Estates, Stirling 90 | University, and the people of Doune for their help in the 91 | making of this film. 92 | The Characters and incidents portrayed and the names used 93 | are fictitious and any similarity to the names, characters, 94 | or history of any person is entirely accidental and 95 | unintentional. 96 | Signed RICHARD M. NIXON 97 | JOHN GOLDSTONE & "RALPH" The Wonder Llama 98 | EARL J. LLAMA 99 | MIKE Q. LLAMA III 100 | SY LLAMA 101 | MERLE Z. LLAMA IX 102 | Directed By 103 | 40 SPECIALLY TRAINED 104 | ECUADORIAN MOUNTAIN LLAMAS 105 | 6 VENEZUELAN RED LLAMAS 106 | 142 MEXICAN WHOOPING LLAMAS 107 | 14 NORTH CHILEAN GUANACOS 108 | (CLOSELY RELATED TO THE LLAMA) 109 | REG LLAMA OF BRIXTON 110 | 76000 BATTERY LLAMAS 111 | FROM "LLAMA-FRESH" FARMS NEARE PARAGUAY 112 | and 113 | TERRY GILLIAM AND TERRY JONES 114 | -------------------------------------------------------------------------------- /rten-text/test-data/tokenizer-json/wordpiece-lower.json: -------------------------------------------------------------------------------- 1 | { 2 | "tokenizer": { 3 | "normalizer": { 4 | "type": "BertNormalizer", 5 | "lowercase": true, 6 | "strip_accents": null 7 | }, 8 | "model": { 9 | "type": "WordPiece", 10 | "vocab": { 11 | "foo": 1, 12 | "##bar": 2, 13 | "[CLS]": 3, 14 | "[SEP]": 4 15 | } 16 | } 17 | }, 18 | "cases": [ 19 | { 20 | "text": "foobar", 21 | "token_ids": [3, 1, 2, 4] 22 | } 23 | ] 24 | } 25 | -------------------------------------------------------------------------------- /rten-text/test-data/tokenizer-json/wordpiece.json: -------------------------------------------------------------------------------- 1 | { 2 | "tokenizer": { 3 | "model": { 4 | "type": "WordPiece", 5 | "vocab": { 6 | "foo": 1, 7 | "##bar": 2, 8 | "[CLS]": 3, 9 | "[SEP]": 4 10 | } 11 | } 12 | }, 13 | "cases": [ 14 | { 15 | "text": "foobar", 16 | "token_ids": [3, 1, 2, 4] 17 | } 18 | ] 19 | } 20 | -------------------------------------------------------------------------------- /rten-text/tools/fetch_wikipedia.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import wikipediaapi as wiki 4 | 5 | 6 | def main(): 7 | parser = ArgumentParser(description="Fetch text of Wikipedia pages") 8 | parser.add_argument("page_name", help="Name of page to fetch") 9 | parser.add_argument("-o", "--output", help="Output filename") 10 | args = parser.parse_args() 11 | 12 | page_name = args.page_name.strip().replace(" ", "_") 13 | output_file = args.output or f"{page_name}.txt" 14 | 15 | wiki_wiki = wiki.Wikipedia("rten-text (robertknight@gmail.com)", "en") 16 | page = wiki_wiki.page(page_name) 17 | 18 | with open(output_file, "w") as output: 19 | output.write(page.text) 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /rten-text/tools/reference_tokenize.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from os.path import splitext 3 | import json 4 | 5 | from tokenizers import Tokenizer 6 | 7 | 8 | def main(): 9 | parser = ArgumentParser( 10 | description=""" 11 | Create a reference tokenization of text using the `tokenizers` package. 12 | """ 13 | ) 14 | parser.add_argument("model_name", help="Name of pretrained model from Hugging Face") 15 | parser.add_argument("text_file", help="Text to tokenize") 16 | args = parser.parse_args() 17 | 18 | tokenizer = Tokenizer.from_pretrained(args.model_name) 19 | 20 | with open(args.text_file) as text_fp: 21 | text = text_fp.read() 22 | 23 | encoded = tokenizer.encode(text) 24 | 25 | output = { 26 | "input_file": args.text_file, 27 | "model_name": args.model_name, 28 | "token_ids": encoded.ids, 29 | "tokens": encoded.tokens, 30 | } 31 | json_output = json.dumps(output, indent=2) 32 | 33 | text_file_base, _ = splitext(args.text_file) 34 | model_fname = args.model_name.replace("/", "_") 35 | output_fname = f"{text_file_base}-{model_fname}.json" 36 | with open(output_fname, "w") as output: 37 | output.write(json_output) 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /rten-text/tools/requirements.txt: -------------------------------------------------------------------------------- 1 | tokenizers==0.15.0 2 | Wikipedia-API==0.6.0 3 | -------------------------------------------------------------------------------- /rten-vecmath/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rten-vecmath" 3 | version = "0.18.0" 4 | edition = "2021" 5 | authors = ["Robert Knight"] 6 | description = "SIMD vectorized implementations of various math functions used in ML models" 7 | license = "MIT OR Apache-2.0" 8 | homepage = "https://github.com/robertknight/rten" 9 | repository = "https://github.com/robertknight/rten" 10 | include = ["/src", "/README.md"] 11 | 12 | [dependencies] 13 | rten-simd = { path = "../rten-simd", version = "0.18.0" } 14 | 15 | [dev-dependencies] 16 | fastrand = "2.0.2" 17 | libm = "0.2.6" 18 | 19 | [lib] 20 | crate-type = ["lib"] 21 | 22 | [lints.clippy] 23 | # See comments about `needless_range_loop` in root Cargo.toml. 24 | needless_range_loop = "allow" 25 | manual_memcpy = "allow" 26 | 27 | [features] 28 | avx512 = ["rten-simd/avx512"] 29 | -------------------------------------------------------------------------------- /rten-vecmath/README.md: -------------------------------------------------------------------------------- 1 | # rten-vecmath 2 | 3 | This crate contains SIMD-vectorized kernels ("vectorized math") for various 4 | operations used in machine learning models. This includes: 5 | 6 | - Math functions such as exp, erf, tanh 7 | - Activation function such as gelu 8 | - Normalization functions such as softmax and mean-variance normalization 9 | - Reduction functions such as sums and sum-of-square 10 | 11 | SIMD operations are implemented using portable SIMD types from the rten-simd 12 | crate. 13 | -------------------------------------------------------------------------------- /rten-vecmath/src/extend_init.rs: -------------------------------------------------------------------------------- 1 | use std::mem::MaybeUninit; 2 | 3 | /// Extend a buffer by incrementally initializing spare capacity. 4 | /// 5 | /// This is implemented for [`Vec`], where it provides a safe API to 6 | /// initialize the spare capacity returned by 7 | /// [`spare_capacity_mut`](Vec::spare_capacity_mut). 8 | pub trait ExtendInit { 9 | /// Element type in the buffer. 10 | type Elem; 11 | 12 | /// Extend the buffer by initializing a portion of the buffer's spare 13 | /// capacity. 14 | /// 15 | /// The function `f` is passed the uninitialized portion of the buffer and 16 | /// should return the portion that it has initialized. `extend_init` can 17 | /// be called many times, until the entire buffer has been initialized. 18 | /// 19 | /// # Panics 20 | /// 21 | /// Panics if `f` returns a slice that is not a prefix of the slice that 22 | /// was passed to it. 23 | fn extend_init]) -> &[Self::Elem]>(&mut self, f: F); 24 | } 25 | 26 | impl ExtendInit for Vec { 27 | type Elem = T; 28 | 29 | fn extend_init]) -> &[Self::Elem]>(&mut self, f: F) { 30 | let cap = self.spare_capacity_mut(); 31 | let cap_ptr = cap.as_ptr(); 32 | let cap_len = cap.len(); 33 | 34 | let initialized = f(cap); 35 | assert_eq!( 36 | initialized.as_ptr(), 37 | cap_ptr as *const T, 38 | "returned slice must be a prefix of the input" 39 | ); 40 | assert!( 41 | initialized.len() <= cap_len, 42 | "initialized slice length {} is longer than input {}", 43 | initialized.len(), 44 | cap_len 45 | ); 46 | let n_init = initialized.len(); 47 | 48 | // Safety: `n_init` elements from the spare capacity have been initialized. 49 | unsafe { self.set_len(self.len() + n_init) } 50 | } 51 | } 52 | 53 | #[cfg(test)] 54 | mod tests { 55 | use std::mem::MaybeUninit; 56 | 57 | use super::ExtendInit; 58 | 59 | // Implementation of `MaybeUninit::fill` from nightly Rust. 60 | fn fill(xs: &mut [MaybeUninit], value: T) -> &mut [T] { 61 | for x in xs.iter_mut() { 62 | x.write(value); 63 | } 64 | unsafe { std::mem::transmute::<&mut [MaybeUninit], &mut [T]>(xs) } 65 | } 66 | 67 | #[test] 68 | fn test_extend_init() { 69 | let mut vec = Vec::with_capacity(7); 70 | 71 | vec.extend_init(|uninit| { 72 | assert_eq!(uninit.len(), 7); 73 | fill(&mut uninit[..3], 1.) 74 | }); 75 | assert_eq!(vec.len(), 3); 76 | assert_eq!(vec, &[1., 1., 1.]); 77 | 78 | vec.extend_init(|uninit| { 79 | assert_eq!(uninit.len(), 4); 80 | fill(uninit, 2.) 81 | }); 82 | assert_eq!(vec.len(), 7); 83 | assert_eq!(vec, &[1., 1., 1., 2., 2., 2., 2.]); 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /rten-vecmath/src/min_max.rs: -------------------------------------------------------------------------------- 1 | use rten_simd::ops::NumOps; 2 | use rten_simd::{Isa, Simd, SimdIterable, SimdOp}; 3 | 4 | /// Compute the minimum and maximum values in a slice of floats. 5 | pub struct MinMax<'a> { 6 | input: &'a [f32], 7 | } 8 | 9 | impl<'a> MinMax<'a> { 10 | pub fn new(input: &'a [f32]) -> Self { 11 | MinMax { input } 12 | } 13 | } 14 | 15 | impl SimdOp for MinMax<'_> { 16 | type Output = (f32, f32); 17 | 18 | #[inline(always)] 19 | fn eval(self, isa: I) -> Self::Output { 20 | let ops = isa.f32(); 21 | let [vec_min, vec_max] = self.input.simd_iter(ops).fold_n_unroll::<2, 4>( 22 | [ops.splat(f32::MAX), ops.splat(f32::MIN)], 23 | #[inline(always)] 24 | |[min, max], x| [ops.min(x, min), ops.max(x, max)], 25 | #[inline(always)] 26 | |[min_a, max_a], [min_b, max_b]| [ops.min(min_a, min_b), ops.max(max_a, max_b)], 27 | ); 28 | let min = vec_min 29 | .to_array() 30 | .as_ref() 31 | .iter() 32 | .fold(f32::MAX, |min, x| x.min(min)); 33 | let max = vec_max 34 | .to_array() 35 | .as_ref() 36 | .iter() 37 | .fold(f32::MIN, |max, x| x.max(max)); 38 | (min, max) 39 | } 40 | } 41 | 42 | #[cfg(test)] 43 | mod tests { 44 | use super::MinMax; 45 | use rten_simd::SimdOp; 46 | 47 | // Chosen to not be a multiple of vector size, so that tail handling is 48 | // exercised. 49 | const LEN: usize = 100; 50 | 51 | fn reference_min_max(xs: &[f32]) -> (f32, f32) { 52 | let min = xs.iter().fold(f32::MAX, |min, x| x.min(min)); 53 | let max = xs.iter().fold(f32::MIN, |max, x| x.max(max)); 54 | (min, max) 55 | } 56 | 57 | #[test] 58 | fn test_min_max() { 59 | let xs: Vec = (0..LEN).map(|i| i as f32 * 0.1).collect(); 60 | let expected = reference_min_max(&xs); 61 | let min_max = MinMax::new(&xs).dispatch(); 62 | assert_eq!(min_max, expected); 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /rten-vecmath/src/quantize.rs: -------------------------------------------------------------------------------- 1 | use std::mem::MaybeUninit; 2 | 3 | use rten_simd::ops::{FloatOps, NarrowSaturate, NumOps}; 4 | use rten_simd::{Isa, SimdOp, SliceWriter}; 5 | 6 | /// Quantize a slice of `f32` elements to 8-bit integers using the formula: 7 | /// 8 | /// ```text 9 | /// y = saturate(round(x * inv_scale) + zero_point) 10 | /// ``` 11 | /// 12 | /// Where `round` rounds to the nearest `i32` value with ties to even and 13 | /// `saturate` converts `i32` to the small integer type `To` with saturation. 14 | pub struct Quantize<'s, 'd, To> { 15 | src: &'s [f32], 16 | dest: &'d mut [MaybeUninit], 17 | inv_scale: f32, 18 | zero_point: To, 19 | } 20 | 21 | impl<'s, 'd, To> Quantize<'s, 'd, To> { 22 | pub fn new( 23 | src: &'s [f32], 24 | dest: &'d mut [MaybeUninit], 25 | inv_scale: f32, 26 | zero_point: To, 27 | ) -> Self { 28 | assert_eq!(src.len(), dest.len()); 29 | Quantize { 30 | src, 31 | dest, 32 | inv_scale, 33 | zero_point, 34 | } 35 | } 36 | } 37 | 38 | impl<'d> SimdOp for Quantize<'_, 'd, u8> { 39 | type Output = &'d mut [u8]; 40 | 41 | #[inline(always)] 42 | fn eval(self, isa: I) -> Self::Output { 43 | let src_ops = isa.f32(); 44 | let i32_ops = isa.i32(); 45 | 46 | let zp_vec = i32_ops.splat(self.zero_point as i32); 47 | let scale_vec = src_ops.splat(self.inv_scale); 48 | let f32_v_len = src_ops.len(); 49 | 50 | // Generate one vector of u8 elements in each iteration by quantizing 51 | // 4 vectors of f32 elements. 52 | let mut src_chunks = self.src.chunks_exact(f32_v_len * 4); 53 | let mut dest_writer = SliceWriter::new(self.dest); 54 | 55 | for src_chunk in src_chunks.by_ref() { 56 | let src = src_ops.load_many::<4>(src_chunk); 57 | let quant_i32 = src.map(|x| { 58 | let y = src_ops.mul(x, scale_vec); 59 | let y = src_ops.to_int_round(y); 60 | i32_ops.add(y, zp_vec) 61 | }); 62 | let quant_i16_low = i32_ops.narrow_saturate(quant_i32[0], quant_i32[1]); 63 | let quant_i16_high = i32_ops.narrow_saturate(quant_i32[2], quant_i32[3]); 64 | let quant_u8 = isa.i16().narrow_saturate(quant_i16_low, quant_i16_high); 65 | dest_writer.write_vec(isa.u8(), quant_u8); 66 | } 67 | 68 | // Quantize tail elements. 69 | for src in src_chunks.remainder() { 70 | let y = (src * self.inv_scale).round_ties_even() as i32; 71 | let y = (y + self.zero_point as i32).clamp(0, u8::MAX as i32); 72 | dest_writer.write_scalar(y as u8); 73 | } 74 | 75 | dest_writer.into_mut_slice() 76 | } 77 | } 78 | 79 | #[cfg(test)] 80 | mod tests { 81 | use rten_simd::ops::NumOps; 82 | use rten_simd::{Isa, SimdOp}; 83 | 84 | use super::Quantize; 85 | 86 | fn reference_quantize(src: &[f32], inv_scale: f32, zero_point: u8) -> Vec { 87 | src.iter() 88 | .map(|x| { 89 | let tmp = (x * inv_scale).round_ties_even() + zero_point as f32; 90 | tmp as u8 // Saturating cast 91 | }) 92 | .collect() 93 | } 94 | 95 | /// Return number of u8 lanes supported in a SIMD vector. 96 | fn u8_vec_len() -> usize { 97 | struct U8VecLen {} 98 | impl SimdOp for U8VecLen { 99 | type Output = usize; 100 | fn eval(self, isa: I) -> usize { 101 | isa.u8().len() 102 | } 103 | } 104 | U8VecLen {}.dispatch() 105 | } 106 | 107 | #[test] 108 | fn test_quantize() { 109 | let mut rng = fastrand::Rng::with_seed(1234); 110 | 111 | // Larger than max u8 SIMD vector length, and not an exact multiple, so 112 | // we have a tail. 113 | let len = u8_vec_len() + 1; 114 | let src: Vec = std::iter::from_fn(|| Some(rng.f32())).take(len).collect(); 115 | let inv_scale = 5.2; 116 | let zero_point = 10; 117 | let expected = reference_quantize(&src, inv_scale, zero_point); 118 | 119 | let mut buf = Vec::with_capacity(src.len()); 120 | let actual = &mut buf.spare_capacity_mut(); 121 | let actual = Quantize::new(&src, actual, inv_scale, zero_point).dispatch(); 122 | 123 | assert_eq!(actual, expected); 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /rten-vecmath/src/tanh.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::excessive_precision)] 2 | 3 | use rten_simd::ops::{FloatOps, NumOps}; 4 | use rten_simd::{Isa, Simd, SimdUnaryOp}; 5 | 6 | use crate::Exp; 7 | 8 | /// Vectorized tanh implementation. 9 | #[derive(Default)] 10 | pub struct Tanh {} 11 | 12 | impl SimdUnaryOp for Tanh { 13 | #[inline(always)] 14 | fn eval>(&self, isa: I, x: S) -> S { 15 | let ops = isa.f32(); 16 | let x = x.same_cast(); 17 | 18 | let x_negative = ops.le(x, ops.zero()); 19 | let abs_x = ops.abs(x); 20 | 21 | // Cutoff beyond which `f32::tanh(x)` saturates at +/- 1.0. 22 | let x_cutoff = ops.ge(abs_x, ops.splat(9.02)); 23 | 24 | // tanh(x) ~ x when |x| is very small. 25 | let x_tiny = ops.le(abs_x, ops.splat(0.0004)); 26 | 27 | // Threshold below which `tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)` method 28 | // produces errors >= 2 ULP. 29 | let x_small = ops.le(abs_x, ops.splat(0.55)); 30 | 31 | // For small x, use polynomial approximation. Computed using Sollya with 32 | // `P = fpminimax(f, [|1, 3, 5, 7, 9|], [|SG...|], [0, 0.6])`. 33 | const P1: f32 = 0.999999940395355224609375; 34 | const P3: f32 = -0.33332359790802001953125; 35 | const P5: f32 = 0.13310669362545013427734375; 36 | const P7: f32 = -5.21197654306888580322265625e-2; 37 | const P9: f32 = 1.5497927553951740264892578125e-2; 38 | 39 | let p1 = ops.splat(P1); 40 | let p3 = ops.splat(P3); 41 | let p5 = ops.splat(P5); 42 | let p7 = ops.splat(P7); 43 | let p9 = ops.splat(P9); 44 | 45 | let x_sqr = ops.mul(x, x); 46 | let y_small = ops.mul_add(p9, x_sqr, p7); 47 | let y_small = ops.mul_add(y_small, x_sqr, p5); 48 | let y_small = ops.mul_add(y_small, x_sqr, p3); 49 | let y_small = ops.mul_add(y_small, x_sqr, p1); 50 | let y_small = ops.mul(y_small, abs_x); 51 | 52 | // For medium x, compute `tanh(x) = (exp(2x) - 1) / (exp(2x) + 1)`. 53 | let x2 = ops.mul(abs_x, ops.splat(2.0)); 54 | let exp_2x = Exp::apply(isa, x2); 55 | let exp_2x_m1 = ops.sub(exp_2x, ops.one()); 56 | let exp_2x_p1 = ops.add(exp_2x, ops.one()); 57 | let y_medium = ops.div(exp_2x_m1, exp_2x_p1); 58 | 59 | // Select output to use depending on |x|. 60 | let y = ops.select(ops.one(), y_medium, x_cutoff); 61 | let y = ops.select(y_small, y, x_small); 62 | let y = ops.select(abs_x, y, x_tiny); 63 | 64 | // Flip sign if input was negative. 65 | ops.select(ops.neg(y), y, x_negative).same_cast() 66 | } 67 | } 68 | 69 | #[cfg(test)] 70 | mod tests { 71 | use rten_simd::SimdUnaryOp; 72 | 73 | use crate::testing::{ 74 | arange, benchmark_op, check_f32s_are_equal_ulps, check_with_all_f32s, AsUninit, 75 | }; 76 | use crate::Tanh; 77 | 78 | // Maximum error of `vec_tanh` compared to `f32::tanh`. 79 | const MAX_TANH_ERROR_ULPS: f32 = 3.0; 80 | 81 | #[test] 82 | #[ignore] // Ignored by default due to long runtime 83 | fn test_tanh_exhaustive() { 84 | check_with_all_f32s( 85 | |x| { 86 | let mut y = [0.; 1]; 87 | Tanh {}.map(&[x], y.as_mut().as_uninit()); 88 | (y[0], x.tanh()) 89 | }, 90 | MAX_TANH_ERROR_ULPS, 91 | "testing vec_tanh", 92 | ); 93 | } 94 | 95 | #[test] 96 | fn test_tanh() { 97 | let cases: Vec = arange(-8., 8., 0.001f32).collect(); 98 | let expected: Vec<_> = cases.iter().copied().map(|x| x.tanh()).collect(); 99 | let mut actual = cases.clone(); 100 | Tanh {}.map(&cases, actual.as_mut_slice().as_uninit()); 101 | 102 | let results = cases 103 | .iter() 104 | .zip(actual.iter().zip(expected.iter())) 105 | .map(|(x, (actual, expected))| (*x, *actual, *expected)); 106 | check_f32s_are_equal_ulps(results, MAX_TANH_ERROR_ULPS); 107 | } 108 | 109 | #[test] 110 | #[ignore] 111 | fn bench_tanh() { 112 | benchmark_op( 113 | |xs, ys| { 114 | xs.iter() 115 | .zip(ys.iter_mut()) 116 | .for_each(|(x, y)| *y = x.tanh()) 117 | }, 118 | |xs, ys| Tanh {}.map(xs, ys), 119 | ); 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /rten-vecmath/src/ulp.rs: -------------------------------------------------------------------------------- 1 | /// Trait for obtaining the size of the Unit in the Last Place (ULP) for floats. 2 | pub trait Ulp { 3 | /// Return the size of the ULP for a given value. 4 | fn ulp(self) -> Self; 5 | 6 | /// Return the difference between this value and `other` in units of 7 | /// `other.ulp()`. 8 | fn diff_ulps(self, other: Self) -> Self; 9 | } 10 | 11 | impl Ulp for f32 { 12 | /// Return the size of the Unit in the Last Place for a given value. 13 | /// 14 | /// Handling of special cases (NaN, infinity, zero, min/max) follows `Math.ulp` 15 | /// in Java [1]. 16 | /// 17 | /// [1] https://docs.oracle.com/en/java/javase/21/docs/api/java.base/java/lang/Math.html#ulp(float) 18 | fn ulp(self: f32) -> f32 { 19 | if self.is_nan() { 20 | self 21 | } else if self.is_infinite() { 22 | f32::INFINITY 23 | } else if self == 0. { 24 | f32::MIN 25 | } else if self == f32::MIN || self == f32::MAX { 26 | f32::from_bits((127 + 104) << 23) // 2^104 27 | } else { 28 | let bits = self.to_bits(); 29 | let next_up = f32::from_bits(bits + 1); 30 | (next_up - self).abs() 31 | } 32 | } 33 | 34 | fn diff_ulps(self: f32, other: f32) -> f32 { 35 | (self - other).abs() / other.ulp() 36 | } 37 | } 38 | 39 | /// Assert that the difference between two values is less than or equal to a 40 | /// given number of [ULPs](Ulp). 41 | macro_rules! assert_ulp_diff_le { 42 | ($actual:expr, $expected:expr, $max_diff:expr) => {{ 43 | use crate::ulp::Ulp; 44 | 45 | let ulp_diff = ($actual).diff_ulps($expected); 46 | assert!( 47 | ulp_diff <= $max_diff, 48 | "difference between {} and {} is {} ULPs which exceeds {}", 49 | $actual, 50 | $expected, 51 | ulp_diff, 52 | $max_diff 53 | ); 54 | }}; 55 | } 56 | 57 | pub(crate) use assert_ulp_diff_le; 58 | 59 | #[cfg(test)] 60 | mod tests { 61 | use super::Ulp; 62 | 63 | #[test] 64 | fn test_f32_ulp() { 65 | assert_eq!((1.0f32).ulp(), f32::EPSILON); 66 | 67 | // Special cases. See the Java `Math.ulp` docs. 68 | assert!(f32::NAN.ulp().is_nan()); 69 | assert_eq!(f32::INFINITY.ulp(), f32::INFINITY); 70 | assert_eq!(f32::NEG_INFINITY.ulp(), f32::INFINITY); 71 | assert_eq!((0.0f32).ulp(), f32::MIN); 72 | assert_eq!((-0.0f32).ulp(), f32::MIN); 73 | assert_eq!(f32::MAX.ulp(), (104f32).exp2()); 74 | assert_eq!(f32::MIN.ulp(), (104f32).exp2()); 75 | } 76 | 77 | #[test] 78 | fn test_f32_diff_ulps() { 79 | let x = 1.0f32; 80 | let y = 1.001f32; 81 | 82 | let diff_bits = y.to_bits() - x.to_bits(); 83 | assert_eq!(y.diff_ulps(x), diff_bits as f32); 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /src/downcast.rs: -------------------------------------------------------------------------------- 1 | use std::any::Any; 2 | 3 | /// Allows downcasting a trait object to a concrete type. 4 | /// 5 | /// This can be implemented for `dyn SomeTrait`, where `SomeTrait` is a trait 6 | /// that has `Any` as a supertrait, by using [`impl_downcastdyn`]. 7 | /// 8 | /// When the `trait_upcasting` feature is stabilized, this can be removed 9 | /// as callers can upcast the trait object to `dyn Any` and then downcast. 10 | pub(crate) trait DowncastDyn { 11 | fn is(&self) -> bool; 12 | fn downcast_ref(&self) -> Option<&T>; 13 | } 14 | 15 | /// Implement [`DowncastDyn`] for a trait. The trait must have `Any` as a 16 | /// supertrait. 17 | macro_rules! impl_downcastdyn { 18 | ($trait:ident) => { 19 | // Trigger compile error if `Any` is not a supertrait of `$trait`. 20 | // 21 | // Credit: https://stackoverflow.com/a/64826111/434243 22 | fn _assert_any_supertrait() { 23 | fn has_any_supertrait() {} 24 | let _ = has_any_supertrait::; 25 | } 26 | 27 | // The implementation approach was taken from how `downcast_ref` is implemented 28 | // for `dyn Error` in the standard library. 29 | impl $crate::downcast::DowncastDyn for dyn $trait { 30 | fn is(&self) -> bool { 31 | // nb. If `$trait` does not have an `Any` supertrait, this code 32 | // still compiles, but `type_id` will return a different value. 33 | // Hence the `_assert_any_supertrait` check above. 34 | std::any::TypeId::of::() == ::type_id(self) 35 | } 36 | 37 | fn downcast_ref(&self) -> Option<&T> { 38 | if self.is::() { 39 | // SAFETY: `is` ensures the cast is correct. 40 | Some(unsafe { &*(self as *const dyn $trait as *const T) }) 41 | } else { 42 | None 43 | } 44 | } 45 | } 46 | }; 47 | } 48 | 49 | pub(crate) use impl_downcastdyn; 50 | 51 | #[cfg(test)] 52 | mod tests { 53 | use std::any::Any; 54 | 55 | use super::{impl_downcastdyn, DowncastDyn}; 56 | 57 | trait Foo: Any {} 58 | impl_downcastdyn!(Foo); 59 | 60 | struct TypeA {} 61 | impl Foo for TypeA {} 62 | 63 | struct TypeB {} 64 | impl Foo for TypeB {} 65 | 66 | #[test] 67 | fn test_downcast_ref() { 68 | let type_a = TypeA {}; 69 | let type_b = TypeB {}; 70 | 71 | let type_a_dyn: &dyn Foo = &type_a; 72 | let type_b_dyn: &dyn Foo = &type_b; 73 | 74 | assert!(type_a_dyn.is::()); 75 | assert!(!type_a_dyn.is::()); 76 | assert!(std::ptr::eq( 77 | type_a_dyn.downcast_ref::().unwrap(), 78 | &type_a 79 | )); 80 | assert!(type_a_dyn.downcast_ref::().is_none()); 81 | 82 | assert!(type_b_dyn.is::()); 83 | assert!(!type_b_dyn.is::()); 84 | assert!(std::ptr::eq( 85 | type_b_dyn.downcast_ref::().unwrap(), 86 | &type_b 87 | )); 88 | assert!(type_b_dyn.downcast_ref::().is_none()); 89 | } 90 | } 91 | -------------------------------------------------------------------------------- /src/env.rs: -------------------------------------------------------------------------------- 1 | /// Interpret a string value such as "1" or "no" as a boolean. 2 | pub fn str_as_bool(s: &str) -> bool { 3 | match s { 4 | "1" | "true" | "t" | "yes" | "y" => true, 5 | "0" | "false" | "f" | "no" | "n" => false, 6 | _ => { 7 | eprintln!("Unrecognized boolean value \"{}\"", s); 8 | false 9 | } 10 | } 11 | } 12 | 13 | /// Return whether a feature flag controlled by an environment variable is 14 | /// enabled. 15 | pub fn env_flag(name: &str, default: bool) -> bool { 16 | std::env::var(name) 17 | .as_ref() 18 | .map(|s| str_as_bool(s)) 19 | .unwrap_or(default) 20 | } 21 | -------------------------------------------------------------------------------- /src/gemm/errors.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | use std::fmt; 3 | use std::fmt::Display; 4 | 5 | /// Errors with matrix multiplication inputs. 6 | #[derive(Clone, Debug, PartialEq)] 7 | pub enum GemmError { 8 | /// Number of columns in LHS does not match rows of RHS. 9 | KSizeMismatch, 10 | /// Bias vector length does not match the corresponding output matrix size. 11 | WrongBiasSize, 12 | /// Quantization parameter size does not match corresponding input size. 13 | WrongQuantParamSize, 14 | /// The buffer provided for the output is too short. 15 | OutputNotLargeEnough, 16 | /// The data was packed with a kernel that uses a different layout than 17 | /// the current kernel. 18 | PackedDataKernelMismatch, 19 | /// The data was packed with different cache blocking parameters than are 20 | /// currently being used. 21 | PackedDataBlockingMismatch, 22 | } 23 | 24 | impl Display for GemmError { 25 | fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { 26 | match self { 27 | Self::KSizeMismatch => { 28 | write!(fmt, "columns of matrix `a` must match rows of matrix `b`") 29 | } 30 | Self::WrongBiasSize => write!(fmt, "bias vector length is incorrect"), 31 | Self::WrongQuantParamSize => { 32 | write!(fmt, "quantization parameter size does not match input") 33 | } 34 | Self::OutputNotLargeEnough => write!(fmt, "output buffer is too small"), 35 | Self::PackedDataKernelMismatch => { 36 | write!(fmt, "matrix was packed with a different kernel") 37 | } 38 | Self::PackedDataBlockingMismatch => { 39 | write!(fmt, "matrix was packed with a different blocking size") 40 | } 41 | } 42 | } 43 | } 44 | 45 | impl Error for GemmError {} 46 | -------------------------------------------------------------------------------- /src/gemm/reduced_range_rng.rs: -------------------------------------------------------------------------------- 1 | use rten_tensor::rng::XorShiftRng; 2 | use rten_tensor::RandomSource; 3 | 4 | /// Random number generator which produces values with an optionally reduced 5 | /// range. 6 | /// 7 | /// This works around an issue under AVX2 where the `vpmaddubsw` instruction 8 | /// can encounter saturation when adding two signed 16-bit values into a 9 | /// 16-bit result. Each of the two 16-bit inputs are the result of a `u8 x 10 | /// i8` multiplication. By limiting the range of either the u8 or i8 input, 11 | /// saturation is avoided. This issue does not affect the VNNI instruction 12 | /// used on newer x64 systems. It also does not affect Arm. 13 | /// 14 | /// To match the behavior in ONNX Runtime's quantizer when 15 | /// `reduce_range=True` is enabled, the range of whichever input are the 16 | /// weights (usually the RHS) should be limited. 17 | /// 18 | /// To avoid saturation we require `i16::MIN >= u8_val * i8_val * 2 <= 19 | /// i16::MAX`. A suitable choice is to use i7/u7 values with ranges [-64, 20 | /// 63] and [0, 127]. 21 | /// 22 | /// See also https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html. 23 | pub struct ReducedRangeRng { 24 | reduce_range: bool, 25 | rng: XorShiftRng, 26 | } 27 | 28 | impl ReducedRangeRng { 29 | pub fn new(reduce_range: bool, seed: u64) -> Self { 30 | Self { 31 | rng: XorShiftRng::new(seed), 32 | reduce_range, 33 | } 34 | } 35 | } 36 | 37 | impl RandomSource for ReducedRangeRng { 38 | /// Return a random value in `[-64, 63]` (the i7 range). 39 | fn next(&mut self) -> i8 { 40 | if self.reduce_range { 41 | ((self.rng.next_u64() % 128) as i16 - 64i16) as i8 42 | } else { 43 | self.rng.next_u64() as i8 44 | } 45 | } 46 | } 47 | 48 | impl RandomSource for ReducedRangeRng { 49 | /// Return a random value in `[0, 127]` (the u7 range). 50 | fn next(&mut self) -> u8 { 51 | if self.reduce_range { 52 | (self.rng.next_u64() % 128) as u8 53 | } else { 54 | self.rng.next_u64() as u8 55 | } 56 | } 57 | } 58 | 59 | #[cfg(test)] 60 | mod tests { 61 | use rten_tensor::RandomSource; 62 | 63 | use super::ReducedRangeRng; 64 | 65 | #[test] 66 | fn test_reduced_range_rng() { 67 | let mut rng = ReducedRangeRng::new(true, 1234); 68 | for _ in 0..100 { 69 | let x: i8 = rng.next(); 70 | assert!(x >= -64 && x <= 63); 71 | 72 | let x: u8 = rng.next(); 73 | assert!(x <= 127); 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/gemm/tiles.rs: -------------------------------------------------------------------------------- 1 | use std::marker::PhantomData; 2 | 3 | use rten_tensor::prelude::*; 4 | use rten_tensor::{MatrixLayout, MatrixMut, StorageMut}; 5 | 6 | /// Wrapper around the GEMM output matrix which divides it into a grid of tiles. 7 | /// This can be shared across threads, but each individual tile must only be 8 | /// operated on by one thread at a time. 9 | pub struct OutputTiles<'a, T> { 10 | data: *mut T, 11 | 12 | // Size and stride of the output matrix. 13 | rows: usize, 14 | cols: usize, 15 | row_stride: usize, 16 | 17 | // Maximum size of each tile. 18 | tile_rows: usize, 19 | tile_cols: usize, 20 | 21 | // Precomputed number of tiles along each axis. 22 | n_row_tiles: usize, 23 | n_col_tiles: usize, 24 | 25 | _marker: PhantomData<&'a mut [T]>, 26 | } 27 | 28 | /// Safety: Caller must ensure they do not operate on overlapping tiles 29 | /// concurrently. 30 | unsafe impl Sync for OutputTiles<'_, T> {} 31 | 32 | impl<'a, T> OutputTiles<'a, T> { 33 | /// Expose `data` as a grid of tiles, each with a maximum size of 34 | /// `tile_rows` * `tile_cols`. 35 | pub fn new( 36 | mut data: MatrixMut<'a, T>, 37 | tile_rows: usize, 38 | tile_cols: usize, 39 | ) -> OutputTiles<'a, T> { 40 | OutputTiles { 41 | data: data.storage_mut().as_mut_ptr(), 42 | rows: data.rows(), 43 | cols: data.cols(), 44 | row_stride: data.stride(0), 45 | tile_rows, 46 | tile_cols, 47 | n_row_tiles: data.rows().div_ceil(tile_rows), 48 | n_col_tiles: data.cols().div_ceil(tile_cols), 49 | _marker: PhantomData, 50 | } 51 | } 52 | 53 | /// Return the output tile with the given coordinates in the grid of 54 | /// output tiles. 55 | /// 56 | /// Safety: The caller must guarantee that every tile is operated on by 57 | /// only a single thread at a time. 58 | pub unsafe fn tile(&self, row: usize, col: usize) -> OutputTile { 59 | assert!(row < self.n_row_tiles && col < self.n_col_tiles); 60 | 61 | let start_row = row * self.tile_rows; 62 | let start_col = col * self.tile_cols; 63 | 64 | OutputTile { 65 | ptr: self.data.add(start_row * self.row_stride + start_col), 66 | row_stride: self.row_stride, 67 | used_rows: (self.rows - start_row).min(self.tile_rows), 68 | used_cols: (self.cols - start_col).min(self.tile_cols), 69 | _marker: PhantomData, 70 | } 71 | } 72 | } 73 | 74 | /// A single tile of the output matrix. 75 | pub struct OutputTile<'a, T> { 76 | /// Pointer to first element in this tile. 77 | pub ptr: *mut T, 78 | 79 | /// Stride between rows of this tile. Note the column stride is always 1. 80 | pub row_stride: usize, 81 | 82 | /// Number of rows in this tile. Will be <= the [`Kernel`]'s `MR` constant. 83 | pub used_rows: usize, 84 | 85 | /// Number of columns in this tile. Will be <= the [`Kernel`]'s `NR` constant. 86 | pub used_cols: usize, 87 | 88 | _marker: PhantomData<&'a mut [T]>, 89 | } 90 | -------------------------------------------------------------------------------- /src/graph/node_id.rs: -------------------------------------------------------------------------------- 1 | use std::num::NonZero; 2 | 3 | /// ID of a node in a [`Model`](crate::Model) graph. 4 | /// 5 | /// This is used to identify input and output values as well as internal nodes. 6 | /// 7 | /// Node IDs are u32 values <= `i32::MAX`. 8 | #[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] 9 | pub struct NodeId(NonZero); 10 | 11 | impl NodeId { 12 | /// Return the underlying u32 value of the ID. 13 | pub fn as_u32(self) -> u32 { 14 | self.0.get() - 1 15 | } 16 | 17 | /// Return the underlying ID value as a usize, for slice indexing. 18 | pub fn as_usize(self) -> usize { 19 | self.as_u32() as usize 20 | } 21 | 22 | /// Construct a node ID from a u32 value. 23 | /// 24 | /// Panics if the value exceeds `i32::MAX`. 25 | pub fn from_u32(value: u32) -> NodeId { 26 | // Node IDs are limited to `i32::MAX` because the `OperatorNode` type 27 | // in the FlatBuffers schema represents operator input and output IDs 28 | // as `i32`. Negative values are used as a niche to represent missing 29 | // optional inputs. 30 | assert!(value <= i32::MAX as u32); 31 | 32 | // Valid node IDs are in the range `[0, i32::MAX]`, so we store them as 33 | // values in `[1, i32::MAX + 1]` internally and reserve 0 as a niche to 34 | // make `Option` the same size as `NodeId`. 35 | NodeId(unsafe { 36 | // Safety: `value + 1` cannot be zero 37 | NonZero::new_unchecked(value + 1) 38 | }) 39 | } 40 | } 41 | 42 | impl std::fmt::Display for NodeId { 43 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 44 | self.as_u32().fmt(f) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/graph/noop_hash.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::hash::{BuildHasherDefault, Hasher}; 3 | 4 | /// A hash map which uses keys directly as hash values. 5 | /// 6 | /// This is intended for use with u32 keys that come from a sequence, so the 7 | /// keys can be used directly as hash values without encountering too many 8 | /// collisions. 9 | pub type NoopHashMap = HashMap>; 10 | 11 | #[derive(Default)] 12 | pub struct NoopHasher { 13 | hash: u64, 14 | } 15 | 16 | impl Hasher for NoopHasher { 17 | fn finish(&self) -> u64 { 18 | self.hash 19 | } 20 | 21 | /// Hash bytes such that hashing `unsigned_int.to_ne_bytes()` sets the 22 | /// hash value to `unsigned_int as u64` (on a little-endian arch as least). 23 | fn write(&mut self, bytes: &[u8]) { 24 | let mut new_hash = 0; 25 | for (i, b) in bytes.iter().enumerate() { 26 | new_hash |= (*b as u64) << (i * 8); 27 | } 28 | self.hash = new_hash; 29 | } 30 | 31 | // Implement u32 hashing directly, since `NodeId`s are u32 values. 32 | fn write_u32(&mut self, i: u32) { 33 | self.hash = i as u64; 34 | } 35 | } 36 | 37 | #[cfg(test)] 38 | mod tests { 39 | use std::hash::Hasher; 40 | 41 | use super::NoopHasher; 42 | 43 | #[test] 44 | fn test_noop_hasher() { 45 | let mut hasher = NoopHasher::default(); 46 | hasher.write_u32(1234); 47 | assert_eq!(hasher.finish(), 1234); 48 | 49 | let mut hasher = NoopHasher::default(); 50 | hasher.write(&(4567u64).to_ne_bytes()); 51 | assert_eq!(hasher.finish(), 4567); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/ops/control_flow.rs: -------------------------------------------------------------------------------- 1 | use rten_tensor::TensorView; 2 | use smallvec::SmallVec; 3 | 4 | use crate::graph::{CaptureEnv, Graph, RunError, RunOptions}; 5 | use crate::ops::{OpError, OpRunContext, Operator, OutputList, Value}; 6 | use crate::timing::Profiler; 7 | use crate::weight_cache::WeightCache; 8 | 9 | fn output_list_from_vec(xs: Vec) -> OutputList { 10 | xs.into_iter().collect() 11 | } 12 | 13 | pub struct If { 14 | pub then_branch: Graph, 15 | pub else_branch: Graph, 16 | } 17 | 18 | impl std::fmt::Debug for If { 19 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> { 20 | write!(f, "If {{ ... }}") 21 | } 22 | } 23 | 24 | impl Operator for If { 25 | fn name(&self) -> &str { 26 | "If" 27 | } 28 | 29 | fn run(&self, _ctx: &OpRunContext) -> Result { 30 | Err(OpError::InvalidValue( 31 | "operator must be run with `run_subgraph`", 32 | )) 33 | } 34 | 35 | fn subgraphs(&self) -> SmallVec<[&Graph; 2]> { 36 | [&self.then_branch, &self.else_branch].into() 37 | } 38 | 39 | fn run_subgraph<'a>( 40 | &'a self, 41 | ctx: &OpRunContext, 42 | captures: CaptureEnv, 43 | weight_caches: Option<&[WeightCache]>, 44 | profiler: Option<&mut Profiler<'a>>, 45 | run_opts: Option, 46 | ) -> Result { 47 | let node_name = ctx.name().unwrap_or_default(); 48 | let cond: TensorView = ctx 49 | .inputs() 50 | .require_as(0) 51 | .map_err(|e| RunError::op_error(node_name, e, ctx))?; 52 | let Some(cond_bool) = cond.item().copied() else { 53 | return Err(RunError::op_error( 54 | node_name, 55 | OpError::InvalidValue("cond must be a single value"), 56 | ctx, 57 | )); 58 | }; 59 | 60 | if cond_bool != 0 { 61 | self.then_branch 62 | .run_subgraph( 63 | Vec::new(), 64 | self.then_branch.output_ids(), 65 | captures, 66 | Some(ctx.pool()), 67 | weight_caches.map(|wcs| &wcs[0]), 68 | profiler, 69 | run_opts, 70 | ) 71 | .map(output_list_from_vec) 72 | } else { 73 | self.else_branch 74 | .run_subgraph( 75 | Vec::new(), 76 | self.else_branch.output_ids(), 77 | captures, 78 | Some(ctx.pool()), 79 | weight_caches.map(|wcs| &wcs[1]), 80 | profiler, 81 | run_opts, 82 | ) 83 | .map(output_list_from_vec) 84 | } 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/ops/identity.rs: -------------------------------------------------------------------------------- 1 | use rten_tensor::prelude::*; 2 | use rten_tensor::{Tensor, TensorView}; 3 | 4 | use crate::ops::{ 5 | map_value_view, IntoOpResult, OpError, OpRunContext, Operator, OutputList, Value, ValueView, 6 | }; 7 | use crate::tensor_pool::TensorPool; 8 | 9 | fn identity(pool: &TensorPool, src: TensorView) -> Tensor { 10 | src.to_tensor_in(pool) 11 | } 12 | 13 | #[derive(Debug)] 14 | pub struct Identity {} 15 | 16 | impl Operator for Identity { 17 | fn name(&self) -> &str { 18 | "Identity" 19 | } 20 | 21 | fn run(&self, ctx: &OpRunContext) -> Result { 22 | let input = ctx.inputs().require(0)?; 23 | map_value_view!(input, x, { identity(ctx.pool(), x).into_op_result() }) 24 | } 25 | 26 | fn can_run_in_place(&self) -> bool { 27 | true 28 | } 29 | 30 | fn run_in_place(&self, input: Value, _ctx: &OpRunContext) -> Result { 31 | Ok(input) 32 | } 33 | } 34 | 35 | #[cfg(test)] 36 | mod tests { 37 | use std::error::Error; 38 | 39 | use rten_tensor::test_util::expect_equal; 40 | use rten_tensor::Tensor; 41 | 42 | use crate::ops::{Identity, OperatorExt}; 43 | 44 | #[test] 45 | fn test_identity() -> Result<(), Box> { 46 | let id_op = Identity {}; 47 | 48 | let int_input = Tensor::from([1, 2, 3]); 49 | let result: Tensor = id_op.run_simple(&int_input).unwrap(); 50 | assert_eq!(result, int_input); 51 | 52 | let float_input = Tensor::from([1.0, 2.0, 3.0]); 53 | let result: Tensor = id_op.run_simple(&float_input).unwrap(); 54 | expect_equal(&result, &float_input)?; 55 | 56 | Ok(()) 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /src/shift_cast.rs: -------------------------------------------------------------------------------- 1 | use rten_tensor::prelude::*; 2 | use rten_tensor::{Alloc, CowTensor, TensorView}; 3 | 4 | /// Conversion from one numeric type to another that preserves the value's 5 | /// offset from the minimum value. 6 | /// 7 | /// This trait is also implemented for some collection types, which can 8 | /// avoid allocating a new collection if the source and target types are the 9 | /// same. 10 | /// 11 | /// Converting `0i8` to `u8` via a normal cast returns 0, but a shift cast 12 | /// returns `128u8`, since `0i8 - i8::MIN = 128` and `128u8 - u8::MIN = 128`. 13 | pub trait ShiftCast { 14 | /// Return a value of type T that has the same difference from `T::MIN` 15 | /// as `self` has from `Self::MIN`. 16 | fn shift_cast(self) -> T; 17 | 18 | /// Variant of [`shift_cast`](ShiftCast::shift_cast) that takes an allocator. 19 | fn shift_cast_in(self, _alloc: impl Alloc) -> T 20 | where 21 | Self: Sized, 22 | { 23 | self.shift_cast() 24 | } 25 | } 26 | 27 | macro_rules! impl_noop_cast { 28 | ($type:ty) => { 29 | impl ShiftCast<$type> for $type { 30 | fn shift_cast(self) -> Self { 31 | self 32 | } 33 | } 34 | }; 35 | } 36 | 37 | impl_noop_cast!(i8); 38 | impl ShiftCast for i8 { 39 | fn shift_cast(self) -> u8 { 40 | (self as u8) ^ 0x80 41 | } 42 | } 43 | 44 | impl_noop_cast!(u8); 45 | impl ShiftCast for u8 { 46 | fn shift_cast(self) -> i8 { 47 | (self ^ 0x80) as i8 48 | } 49 | } 50 | 51 | impl<'a, T> ShiftCast> for TensorView<'a, T> { 52 | fn shift_cast(self) -> CowTensor<'a, T> { 53 | self.as_cow() 54 | } 55 | } 56 | 57 | impl<'a> ShiftCast> for TensorView<'a, i8> { 58 | fn shift_cast(self) -> CowTensor<'a, u8> { 59 | self.map(|&x| x.shift_cast()).into_cow() 60 | } 61 | 62 | fn shift_cast_in(self, alloc: impl Alloc) -> CowTensor<'a, u8> { 63 | self.map_in(alloc, |&x| x.shift_cast()).into_cow() 64 | } 65 | } 66 | 67 | impl<'a> ShiftCast> for TensorView<'a, u8> { 68 | fn shift_cast(self) -> CowTensor<'a, i8> { 69 | self.map(|&x| x.shift_cast()).into_cow() 70 | } 71 | 72 | fn shift_cast_in(self, alloc: impl Alloc) -> CowTensor<'a, i8> { 73 | self.map_in(alloc, |&x| x.shift_cast()).into_cow() 74 | } 75 | } 76 | 77 | impl ShiftCast> for Vec 78 | where 79 | T: ShiftCast, 80 | { 81 | fn shift_cast(self) -> Vec { 82 | self.into_iter().map(|x| x.shift_cast()).collect() 83 | } 84 | } 85 | 86 | #[cfg(test)] 87 | mod tests { 88 | use rten_tensor::prelude::*; 89 | use rten_tensor::Tensor; 90 | 91 | use super::{CowTensor, ShiftCast}; 92 | 93 | #[test] 94 | fn test_shift_cast_scalar() { 95 | const LEN: usize = 5; 96 | 97 | let input = [-128i8, -1, 0, 1, 127]; 98 | let expected = [0u8, 127, 128, 129, 255]; 99 | 100 | let actual: [u8; LEN] = input.map(|x| x.shift_cast()); 101 | assert_eq!(actual, expected); 102 | 103 | let actual_noop: [u8; LEN] = actual.map(|x| x.shift_cast()); 104 | assert_eq!(actual_noop, expected); 105 | 106 | let actual_inverse: [i8; LEN] = expected.map(|x| x.shift_cast()); 107 | assert_eq!(actual_inverse, input); 108 | 109 | let actual_inverse_noop: [i8; LEN] = input.map(|x| x.shift_cast()); 110 | assert_eq!(actual_inverse_noop, input); 111 | } 112 | 113 | #[test] 114 | fn test_shift_cast_tensor() { 115 | let input = Tensor::from([-128i8, -1, 0, 1, 127]); 116 | let expected = Tensor::from([0u8, 127, 128, 129, 255]); 117 | 118 | let actual: CowTensor = input.view().shift_cast(); 119 | assert_eq!(actual, expected); 120 | 121 | let noop_cast: CowTensor = actual.view().shift_cast(); 122 | assert_eq!(noop_cast, actual); 123 | 124 | let actual_inverse: CowTensor = expected.view().shift_cast(); 125 | assert_eq!(actual_inverse, input); 126 | } 127 | 128 | #[test] 129 | fn test_shift_cast_vec() { 130 | let input: Vec<_> = [-128i8, -1, 0, 1, 127].into(); 131 | let expected = [0u8, 127, 128, 129, 255]; 132 | let actual: Vec = input.shift_cast(); 133 | assert_eq!(actual, expected); 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /src/slice_reductions.rs: -------------------------------------------------------------------------------- 1 | //! Optimized reductions of slices and iterators of numbers. 2 | //! 3 | //! Library APIs like `std::iter::Sum` reduce elements in-order. For float 4 | //! values this is not optimal for performance as each step has a dependency on 5 | //! the previous step, inhibiting Instruction Level Parallelism and 6 | //! autovectorization. The functions in this module re-order operations to 7 | //! enable better performance. 8 | //! 9 | //! Related reading: 10 | //! 11 | //! - 12 | 13 | use crate::number::MinMax; 14 | 15 | /// Return the sum of a slice of numbers. 16 | pub fn slice_max(xs: &[T]) -> T { 17 | const CHUNK_SIZE: usize = 8; 18 | xs.chunks(CHUNK_SIZE) 19 | .map(|chunk| { 20 | if chunk.len() == CHUNK_SIZE { 21 | // Writing the code this way encourages better autovectorization. 22 | let a0 = chunk[0].max(chunk[1]); 23 | let a1 = chunk[2].max(chunk[3]); 24 | let a2 = chunk[4].max(chunk[5]); 25 | let a3 = chunk[6].max(chunk[7]); 26 | 27 | let b0 = a0.max(a1); 28 | let b1 = a2.max(a3); 29 | 30 | b0.max(b1) 31 | } else { 32 | chunk.iter().copied().fold(T::min_val(), |x, y| x.max(y)) 33 | } 34 | }) 35 | .fold(T::min_val(), |x, y| x.max(y)) 36 | } 37 | 38 | /// Return the sum of a slice of numbers. 39 | pub fn slice_sum>(xs: &[T]) -> T { 40 | slice_map_sum(xs, |x| x) 41 | } 42 | 43 | /// Apply `map` to each element in `xs` and sum the results. 44 | pub fn slice_map_sum, M: Fn(T) -> T>( 45 | xs: &[T], 46 | map: M, 47 | ) -> T { 48 | const CHUNK_SIZE: usize = 8; 49 | xs.chunks(CHUNK_SIZE) 50 | .map(|chunk| { 51 | if chunk.len() == CHUNK_SIZE { 52 | // Writing the code this way encourages better autovectorization. 53 | let x = [chunk[0], chunk[1], chunk[2], chunk[3]].map(&map); 54 | let y = [chunk[4], chunk[5], chunk[6], chunk[7]].map(&map); 55 | let z = [x[0] + y[0], x[1] + y[1], x[2] + y[2], x[3] + y[3]]; 56 | z[0] + z[1] + z[2] + z[3] 57 | } else { 58 | chunk 59 | .iter() 60 | .copied() 61 | .fold(T::default(), |acc, x| acc + map(x)) 62 | } 63 | }) 64 | .fold(T::default(), |acc, x| acc + x) 65 | } 66 | 67 | #[cfg(test)] 68 | mod tests { 69 | use rten_tensor::rng::XorShiftRng; 70 | use rten_tensor::test_util::ApproxEq; 71 | 72 | use super::{slice_max, slice_sum}; 73 | 74 | #[test] 75 | fn test_slice_max() { 76 | let mut rng = XorShiftRng::new(1234); 77 | let xs: Vec<_> = std::iter::from_fn(|| Some(rng.next_f32())) 78 | .take(256) 79 | .collect(); 80 | let expected = xs.iter().fold(f32::NEG_INFINITY, |x, y| x.max(*y)); 81 | let actual = slice_max(&xs); 82 | assert_eq!(actual, expected); 83 | } 84 | 85 | #[test] 86 | fn test_slice_sum() { 87 | let mut rng = XorShiftRng::new(1234); 88 | let xs: Vec<_> = std::iter::from_fn(|| Some(rng.next_f32())) 89 | .take(256) 90 | .collect(); 91 | assert!(xs.iter().sum::().approx_eq(&slice_sum(&xs))); 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /src/threading.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::sync::OnceLock; 3 | 4 | /// A wrapper around the Rayon thread pool used to run models. 5 | /// 6 | /// On platforms where threads are not supported (eg. WebAssembly) this runs 7 | /// operations directly on the main thread. 8 | pub struct ThreadPool { 9 | /// The wrapped thread pool, or None if we failed to construct one. 10 | pool: Option, 11 | } 12 | 13 | impl ThreadPool { 14 | /// Run a function in the thread pool. 15 | /// 16 | /// This corresponds to [`rayon::ThreadPool::install`], except on platforms 17 | /// where threading is not supported, where it just runs `op` directly. 18 | pub fn run R + Send>(&self, op: Op) -> R { 19 | if let Some(pool) = self.pool.as_ref() { 20 | pool.install(op) 21 | } else { 22 | op() 23 | } 24 | } 25 | 26 | /// Create a thread pool with a given number of threads. 27 | pub fn with_num_threads(num_threads: usize) -> ThreadPool { 28 | let pool = rayon::ThreadPoolBuilder::new() 29 | .num_threads(num_threads) 30 | .thread_name(|index| format!("rten-{}", index)) 31 | .build(); 32 | 33 | ThreadPool { pool: pool.ok() } 34 | } 35 | } 36 | 37 | /// Return the optimal number of cores to use for maximum performance. 38 | /// 39 | /// This may be less than the total number of cores on systems with heterogenous 40 | /// cores (eg. a mix of performance and efficiency). 41 | fn optimal_core_count() -> u32 { 42 | #[allow(unused_mut)] 43 | let mut core_count = num_cpus::get_physical().max(1) as u32; 44 | 45 | #[cfg(target_os = "macos")] 46 | { 47 | use rten_simd::isa_detection::macos::sysctl_int; 48 | if let Ok(perf_core_count) = sysctl_int(c"hw.perflevel0.physicalcpu") { 49 | core_count = core_count.clamp(1, perf_core_count as u32); 50 | } 51 | } 52 | 53 | core_count 54 | } 55 | 56 | /// Return the [Rayon][rayon] thread pool which is used to execute RTen models. 57 | /// 58 | /// This differs from Rayon's default global thread pool in that it is tuned for 59 | /// CPU rather than IO-bound work by choosing a thread count based on the number 60 | /// of physical rather than logical cores. 61 | /// 62 | /// The thread count can be overridden at the process level by setting the 63 | /// `RTEN_NUM_THREADS` environment variable, whose value must be a number 64 | /// between 1 and the logical core count. 65 | /// 66 | /// The thread count can be overridden for each model run by configuring a 67 | /// custom thread pool in [`RunOptions`](crate::RunOptions). 68 | /// 69 | /// To run your own tasks in this thread pool, you can use 70 | /// [`ThreadPool::run`]. 71 | /// 72 | /// [rayon]: https://github.com/rayon-rs/rayon 73 | pub fn thread_pool() -> &'static ThreadPool { 74 | static THREAD_POOL: OnceLock = OnceLock::new(); 75 | THREAD_POOL.get_or_init(|| { 76 | let physical_cpus = optimal_core_count(); 77 | 78 | let num_threads = if let Some(threads_var) = env::var_os("RTEN_NUM_THREADS") { 79 | let requested_threads: Result = threads_var.to_string_lossy().parse(); 80 | match requested_threads { 81 | Ok(n_threads) => n_threads.clamp(1, num_cpus::get() as u32), 82 | Err(_) => physical_cpus, 83 | } 84 | } else { 85 | physical_cpus 86 | }; 87 | 88 | ThreadPool::with_num_threads(num_threads as usize) 89 | }) 90 | } 91 | 92 | #[cfg(test)] 93 | mod tests { 94 | use super::optimal_core_count; 95 | 96 | #[test] 97 | fn test_optimal_core_count() { 98 | let max_cores = num_cpus::get_physical() as u32; 99 | let opt_cores = optimal_core_count(); 100 | assert!(opt_cores >= 1 && opt_cores <= max_cores); 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /src/weight_cache.rs: -------------------------------------------------------------------------------- 1 | use rustc_hash::FxHashMap; 2 | 3 | use crate::graph::NodeId; 4 | use crate::ops::PrepackedInput; 5 | 6 | /// A cache of prepacked weights for graph operators. 7 | /// 8 | /// The weight cache has a hierarchical structure which mirrors the model 9 | /// graph. At the top level is the root graph. For each operator with a 10 | /// subgraph (eg. control flow operators) there are separate sub-caches. 11 | pub struct WeightCache { 12 | /// Map of constant node ID to prepacked weights. 13 | cache: FxHashMap, 14 | 15 | /// Map of operator ID to caches for the operator's subgraphs. 16 | subgraph_caches: FxHashMap>, 17 | } 18 | 19 | impl WeightCache { 20 | /// Create an empty cache. 21 | pub fn new() -> WeightCache { 22 | WeightCache { 23 | cache: FxHashMap::default(), 24 | subgraph_caches: FxHashMap::default(), 25 | } 26 | } 27 | 28 | /// Check if a pre-packed weight exists for a given constant node ID. 29 | pub fn contains(&self, node: NodeId) -> bool { 30 | self.cache.contains_key(&node) 31 | } 32 | 33 | /// Add a prepacked weight to the cache. 34 | pub fn insert(&mut self, node: NodeId, packed: PrepackedInput) { 35 | self.cache.insert(node, packed); 36 | } 37 | 38 | /// Look up weight in the cache. 39 | pub fn get(&self, node: NodeId) -> Option<&PrepackedInput> { 40 | self.cache.get(&node) 41 | } 42 | 43 | /// Add caches for subgraphs belonging to an operator. 44 | pub fn insert_subgraph_caches(&mut self, operator_id: NodeId, caches: Vec) { 45 | self.subgraph_caches.insert(operator_id, caches); 46 | } 47 | 48 | /// Look up caches for an operator's subgraphs. 49 | pub fn get_subgraph_caches(&self, operator_id: NodeId) -> Option<&[WeightCache]> { 50 | self.subgraph_caches 51 | .get(&operator_id) 52 | .map(|wcs| wcs.as_slice()) 53 | } 54 | 55 | /// Return the total number of cached weights, including in subgraphs. 56 | pub fn len(&self) -> usize { 57 | self.cache.len() 58 | + self 59 | .subgraph_caches 60 | .values() 61 | .flat_map(|caches| caches.iter()) 62 | .map(|cache| cache.len()) 63 | .sum::() 64 | } 65 | } 66 | 67 | impl Default for WeightCache { 68 | fn default() -> Self { 69 | WeightCache::new() 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertknight/rten/c646a5ac1bed69d0717fc37eb53e3ba290b89233/tools/__init__.py -------------------------------------------------------------------------------- /tools/add-node-outputs-to-model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import onnx 4 | 5 | 6 | # This script adds all of the operator output value nodes to the model's list 7 | # of outputs. 8 | # 9 | # This enables examining intermediate outputs from the model when run using 10 | # ONNX Runtime. See https://github.com/microsoft/onnxruntime/issues/1455#issuecomment-979901463. 11 | # 12 | # For RTen this step is not necessary since any value node in the graph can 13 | # be specfied as the output for a model execution. 14 | def main(): 15 | parser = ArgumentParser( 16 | description="Add intermediate outputs in an ONNX graph to the model's outputs" 17 | ) 18 | parser.add_argument("onnx_model") 19 | parser.add_argument("out_model") 20 | args = parser.parse_args() 21 | 22 | model = onnx.load(args.onnx_model) 23 | initial_outputs = [val.name for val in model.graph.output] 24 | 25 | for node in model.graph.node: 26 | for output in node.output: 27 | if output not in initial_outputs: 28 | model.graph.output.extend([onnx.ValueInfoProto(name=output)]) 29 | 30 | onnx.save(model, args.out_model) 31 | 32 | 33 | if __name__ == "__main__": 34 | main() 35 | -------------------------------------------------------------------------------- /tools/benchmarks/wasm-gemm.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | // Matrix multiplication performance test script for WebAssembly. 4 | // 5 | // The matmul implementation in the Rust crate has a similar set of tests for 6 | // the native environment. 7 | 8 | import { readFileSync } from "node:fs"; 9 | 10 | // TensorFlow.js dependencies. These will need to be installed separately via 11 | // npm before you can run this script. 12 | import * as tf from "@tensorflow/tfjs"; 13 | import "@tensorflow/tfjs-backend-wasm"; 14 | 15 | import { Tensor, initSync } from "../../index.js"; 16 | 17 | // Init ML libs 18 | await tf.setBackend("wasm"); 19 | 20 | const wasmPath = new URL("../../dist/rten_bg.wasm", import.meta.url); 21 | const wasmBin = readFileSync(wasmPath); 22 | initSync(wasmBin); 23 | 24 | // Run tests 25 | const cases = [ 26 | { m: 512, n: 512, k: 512 }, // Square 27 | { m: 128, n: 2048, k: 512 }, // Wide 28 | { m: 2048, n: 128, k: 512 }, // Tall 29 | { m: 1, n: 4096, k: 512 }, // Vector 30 | ]; 31 | 32 | function logResult(engine, elapsedMs, m, n, k, iters) { 33 | const elapsedSecs = elapsedMs / 1000.0; 34 | const flops = (2 * m * n * k * iters) / elapsedSecs; 35 | const gflops = flops / 10 ** 9; 36 | 37 | const round = (f) => f.toFixed(2); 38 | 39 | console.log( 40 | `engine ${engine} m ${m} n ${n} k ${k} iters ${iters}. Duration ${round( 41 | elapsedMs 42 | )}ms (${round(elapsedMs / iters)} ms/iter). GFLOPS ${round(gflops)}` 43 | ); 44 | } 45 | 46 | function timeIt(iters, callback) { 47 | const start = performance.now(); 48 | for (let i = 0; i < iters; i++) { 49 | callback(); 50 | } 51 | const end = performance.now(); 52 | return end - start; 53 | } 54 | 55 | /** 56 | * Run a benchmark of `iters` iterations of matrix multiplication using random 57 | * inputs of size `[M, K]` and `[K, N]`. 58 | */ 59 | function testRTenMatmul(m, n, k, iters) { 60 | const seedA = 1234n; 61 | const seedB = 4567n; 62 | const a = Tensor.rand([m, k], seedA); 63 | const b = Tensor.rand([k, n], seedB); 64 | 65 | const elapsed = timeIt(iters, () => { 66 | const c = a.matmul(b); 67 | 68 | // Free the output immediately so the memory can be re-used in the next 69 | // iteration. 70 | c.free(); 71 | }); 72 | 73 | logResult("RTen", elapsed, m, n, k, iters); 74 | } 75 | 76 | /** 77 | * Run a benchmark of `iters` iterations of matrix multiplication using random 78 | * inputs of size `[M, K]` and `[K, N]`. 79 | */ 80 | function testTensorflowMatMul(m, n, k, iters) { 81 | const a = tf.randomUniform([m, k]); 82 | const b = tf.randomUniform([k, n]); 83 | 84 | const elapsed = timeIt(iters, () => { 85 | const c = tf.matMul(a, b); 86 | 87 | // Free the output immediately so the memory can be re-used in the next 88 | // iteration. 89 | c.dispose(); 90 | }); 91 | 92 | logResult("TF.js", elapsed, m, n, k, iters); 93 | } 94 | 95 | for (const { m, n, k } of cases) { 96 | const iters = 100; 97 | 98 | testRTenMatmul(m, n, k, iters); 99 | testTensorflowMatMul(m, n, k, iters); 100 | } 101 | -------------------------------------------------------------------------------- /tools/compare-tensors.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import json 3 | import sys 4 | 5 | import numpy as np 6 | 7 | from debug_utils import read_tensor 8 | 9 | def read_json_tensor(path: str): 10 | """ 11 | Load a tensor from a JSON file. 12 | 13 | The JSON data format is `{ "data": [elements...], "shape": [dims...] }`. 14 | This matches rten-tensor's serde serialization for the `Tensor` type. 15 | """ 16 | with open(path) as tensor_fp: 17 | tensor_json = json.load(tensor_fp) 18 | return np.array(tensor_json["data"]).reshape(tensor_json["shape"]) 19 | 20 | 21 | def main(): 22 | parser = ArgumentParser(description="Compare two binary tensors") 23 | parser.add_argument('tensor_a', help="File containing first tensor") 24 | parser.add_argument('tensor_b', help="File containing second_tensor") 25 | args = parser.parse_args() 26 | 27 | if args.tensor_a.endswith(".json"): 28 | x = read_json_tensor(args.tensor_a) 29 | else: 30 | x = read_tensor(args.tensor_a) 31 | 32 | if args.tensor_b.endswith(".json"): 33 | y = read_json_tensor(args.tensor_b) 34 | else: 35 | y = read_tensor(args.tensor_b) 36 | 37 | print(f"X shape {x.shape} Y shape {y.shape}") 38 | 39 | if x.shape != y.shape: 40 | print("Tensor shapes do not match") 41 | sys.exit(1) 42 | 43 | abs_diff = np.absolute(x - y) 44 | print(f"Average diff {abs_diff.sum() / x.size}") 45 | print(f"Max diff {abs_diff.max()}") 46 | print(f"Total diff {abs_diff.sum()}") 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /tools/debug_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import struct 3 | 4 | import numpy as np 5 | 6 | def read_tensor(path: str) -> np.ndarray: 7 | """ 8 | Read a tensor from a file. 9 | 10 | The file is expected to contain the tensor data in the little-endian 11 | binary format: 12 | 13 | [rank:u32][dims:u32 * rank][data:f32 * product(dims)] 14 | """ 15 | with open(path, 'rb') as file: 16 | ndim, = struct.unpack('" 10 | exit 1 11 | fi 12 | 13 | if [ -z "$WASMOPT_BIN" ]; then 14 | echo 'Skipping post-compilation optimization because `wasm-opt` binary was not found.' 15 | exit 16 | fi 17 | 18 | if [ -n "${SKIP_WASM_OPT:-}" ]; then 19 | echo "Skipping post-compilation optimization because SKIP_WASM_OPT is set" 20 | exit 21 | fi 22 | 23 | wasm-opt --enable-simd --enable-reference-types -O2 "$BIN_PATH" -o "$BIN_PATH".optimized 24 | mv "$BIN_PATH.optimized" "$BIN_PATH" 25 | -------------------------------------------------------------------------------- /tools/ort-quantize.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import onnx 4 | from onnxruntime.quantization import quantize_dynamic 5 | 6 | parser = ArgumentParser(description="Quantize ONNX models using dynamic quantization.") 7 | parser.add_argument("input") 8 | parser.add_argument("output", nargs="?") 9 | parser.add_argument( 10 | "--quantize-conv", 11 | action="store_true", 12 | help=""" 13 | Enable quantization of `Conv` operators. 14 | 15 | This is disabled by default to avoid producing models that don't work 16 | in ONNX Runtime. See https://github.com/microsoft/onnxruntime/issues/15888. 17 | """, 18 | ) 19 | args = parser.parse_args() 20 | 21 | output = args.output or args.input.replace(".onnx", ".quant.onnx") 22 | 23 | # Quantized operation types we support. 24 | # 25 | # See https://github.com/microsoft/onnxruntime/blob/1fc9c4823d7c2e8f0d07a09315a0755dd7c58ef8/onnxruntime/python/tools/quantization/quantize.py#L828 for the default list that ORT uses. 26 | # 27 | # See https://github.com/microsoft/onnxruntime/blob/1fc9c4823d7c2e8f0d07a09315a0755dd7c58ef8/onnxruntime/python/tools/quantization/registry.py#L66 for registries of different ops that 28 | # will be quantized depending on the quantization type. 29 | op_types_to_quantize = [ 30 | # Supported ops from `CommonOpsRegistry`. These support int8 types directly. 31 | # 32 | # There are other operators which support int8 types that we could list 33 | # here but don't because `quantize_dynamic` doesn't attempt to quantize them. 34 | "Gather", 35 | "Transpose", 36 | # Supported ops from `IntegerOpsRegistry`. These get replaced during quantization. 37 | "MatMul", # Replaced by MatMulInteger 38 | ] 39 | 40 | if args.quantize_conv: 41 | op_types_to_quantize.append("Conv") # Replaced by ConvInteger 42 | 43 | quantize_dynamic( 44 | args.input, 45 | output, 46 | op_types_to_quantize=op_types_to_quantize, 47 | # Avoid a saturation issue on x86-64 systems that don't support VNNI by 48 | # reducing the range of quantized values from 8 to 7 bits. 49 | # 50 | # Specifically the VPMADDUBSW instruction used in int8 matmul operations 51 | # can saturate when adding pairs of signed i16 values. 52 | # 53 | # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#when-to-use-reduce-range-and-per-channel-quantization. 54 | reduce_range=True, 55 | # Use per-channel rather than per-tensor quantization. 56 | # 57 | # The effect of this is that separate zero points and scales are used per 58 | # row or column of an input matrix in quantized matmuls (`MatMulInteger` 59 | # ops). 60 | # 61 | # Turning this on increases compute slightly, but allows tolerating a 62 | # wider range of weight values in a tensor. Since transformer models are 63 | # prone to having outlier weights, this seems like a good idea. Also 64 | # RTen internally broadcasts scalar zero points to vectors anyway. 65 | per_channel=True, 66 | extra_options={ 67 | # Enable quantization of models with control flow operators. This 68 | # includes Hugging Face "merged" transformer decoder models, which is 69 | # what various RTen examples use. 70 | "EnableSubgraph": True, 71 | }, 72 | ) 73 | -------------------------------------------------------------------------------- /tools/requirements.txt: -------------------------------------------------------------------------------- 1 | # General dependencies 2 | Pillow 3 | numpy 4 | onnx 5 | 6 | # For scripts that read/write this library's models 7 | flatbuffers 8 | 9 | # For scripts that use PyTorch to export models or produce reference outputs. 10 | torch 11 | 12 | # For scripts that use ONNX Runtime for transforming the model or to produce 13 | # reference outputs. 14 | onnxruntime 15 | 16 | # For fetching and exporting various image models 17 | timm 18 | -------------------------------------------------------------------------------- /tools/test-images/README.md: -------------------------------------------------------------------------------- 1 | # Test images 2 | 3 | Reference inputs used for quick sanity checks of models. 4 | -------------------------------------------------------------------------------- /tools/test-images/horses.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robertknight/rten/c646a5ac1bed69d0717fc37eb53e3ba290b89233/tools/test-images/horses.jpeg -------------------------------------------------------------------------------- /tools/update-onnx-model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import onnx 4 | from onnx import version_converter 5 | 6 | 7 | def main(): 8 | parser = ArgumentParser() 9 | parser.add_argument("input_model", help="Input ONNX model") 10 | parser.add_argument("output_model", help="Output ONNX model") 11 | parser.add_argument( 12 | "--opset_version", type=int, default=11, help="ONNX opset version to upgrade to" 13 | ) 14 | args = parser.parse_args() 15 | 16 | original_model = onnx.load(args.input_model) 17 | 18 | # A full list of supported adapters can be found here: 19 | # https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21 20 | converted_model = version_converter.convert_version( 21 | original_model, args.opset_version 22 | ) 23 | 24 | onnx.save(converted_model, args.output_model) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | --------------------------------------------------------------------------------