├── .cargo └── config.toml ├── .config └── nextest.toml ├── .github ├── actions │ └── setup │ │ └── action.yml ├── ratchet.png ├── trufflehog.yaml └── workflows │ ├── ratbot.yml │ └── rust.yml ├── .gitignore ├── ARCHITECTURE.md ├── CONTRIBUTING.md ├── Cargo.toml ├── LICENSE ├── README.md ├── config ├── webdriver-linux.json ├── webdriver-macos.json └── webdriver-win.json ├── crates ├── ratchet-cli │ ├── Cargo.toml │ └── src │ │ ├── bin │ │ └── cli.rs │ │ └── lib.rs ├── ratchet-core │ ├── Cargo.toml │ ├── src │ │ ├── compiled_op.rs │ │ ├── cpu │ │ │ ├── binary.rs │ │ │ ├── gemm.rs │ │ │ ├── mod.rs │ │ │ ├── norm.rs │ │ │ ├── reindex.rs │ │ │ ├── rope.rs │ │ │ ├── softmax.rs │ │ │ ├── unary.rs │ │ │ └── utils.rs │ │ ├── device.rs │ │ ├── dtype │ │ │ ├── blocks.rs │ │ │ └── mod.rs │ │ ├── enforcer.rs │ │ ├── executable.rs │ │ ├── gpu │ │ │ ├── align.rs │ │ │ ├── buffer_allocator │ │ │ │ ├── allocator.rs │ │ │ │ ├── mod.rs │ │ │ │ └── tensor_usage_record.rs │ │ │ ├── device.rs │ │ │ ├── mod.rs │ │ │ ├── pools │ │ │ │ ├── bind_group_layout_pool.rs │ │ │ │ ├── bind_group_pool.rs │ │ │ │ ├── buffer_pool.rs │ │ │ │ ├── dynamic_resource_pool.rs │ │ │ │ ├── kernel_module_pool.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── pipeline_layout_pool.rs │ │ │ │ ├── pipeline_pool.rs │ │ │ │ └── static_resource_pool.rs │ │ │ ├── profiler.rs │ │ │ ├── uniform.rs │ │ │ ├── wgsl │ │ │ │ ├── access_granularity.rs │ │ │ │ ├── dtype.rs │ │ │ │ ├── kernel.rs │ │ │ │ ├── kernel_binding.rs │ │ │ │ ├── kernel_builder.rs │ │ │ │ └── mod.rs │ │ │ └── workload.rs │ │ ├── lib.rs │ │ ├── ndarray_ext.rs │ │ ├── op.rs │ │ ├── ops │ │ │ ├── binary.rs │ │ │ ├── cache.rs │ │ │ ├── cast.rs │ │ │ ├── concat.rs │ │ │ ├── conv.rs │ │ │ ├── index_write.rs │ │ │ ├── matmul │ │ │ │ ├── gemm.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── quantized.rs │ │ │ │ ├── subgroup_gemv.rs │ │ │ │ └── workgroup_gemv.rs │ │ │ ├── mod.rs │ │ │ ├── norm │ │ │ │ ├── groupnorm.rs │ │ │ │ └── mod.rs │ │ │ ├── reindex │ │ │ │ ├── broadcast.rs │ │ │ │ ├── mod.rs │ │ │ │ ├── permute.rs │ │ │ │ └── slice.rs │ │ │ ├── rope.rs │ │ │ ├── select.rs │ │ │ ├── softmax.rs │ │ │ ├── unary.rs │ │ │ └── view.rs │ │ ├── plot.rs │ │ ├── quant.rs │ │ ├── shape.rs │ │ ├── storage │ │ │ ├── cpu_buffer.rs │ │ │ ├── gpu_buffer.rs │ │ │ └── mod.rs │ │ ├── strides.rs │ │ ├── tensor.rs │ │ └── tensor_id.rs │ └── tests │ │ └── attn_tests.rs ├── ratchet-hub │ ├── Cargo.toml │ └── src │ │ ├── lib.rs │ │ └── util.rs ├── ratchet-loader │ ├── Cargo.toml │ ├── src │ │ ├── error.rs │ │ ├── gguf │ │ │ ├── dtype.rs │ │ │ ├── gguf.rs │ │ │ ├── mod.rs │ │ │ └── utils.rs │ │ ├── k_quants.rs │ │ └── lib.rs │ └── test-data │ │ └── nano-llama-q4k.gguf ├── ratchet-macros │ ├── Cargo.toml │ └── src │ │ ├── lib.rs │ │ └── wgsl_metadata.rs ├── ratchet-models │ ├── Cargo.toml │ ├── src │ │ ├── lib.rs │ │ ├── moondream │ │ │ ├── generate.rs │ │ │ ├── mlp.rs │ │ │ ├── mod.rs │ │ │ ├── model.rs │ │ │ ├── text_model.rs │ │ │ └── vision_encoder.rs │ │ ├── phi2 │ │ │ ├── attn.rs │ │ │ ├── generate.rs │ │ │ ├── mlp.rs │ │ │ ├── mod.rs │ │ │ └── model.rs │ │ ├── phi3 │ │ │ ├── attn.rs │ │ │ ├── generate.rs │ │ │ ├── mlp.rs │ │ │ ├── mod.rs │ │ │ └── model.rs │ │ ├── registry.rs │ │ ├── token_stream.rs │ │ └── whisper │ │ │ ├── config.rs │ │ │ ├── decoder.rs │ │ │ ├── encoder.rs │ │ │ ├── logit_mutators │ │ │ ├── mod.rs │ │ │ └── timestamp_rules.rs │ │ │ ├── mha.rs │ │ │ ├── mlp.rs │ │ │ ├── mod.rs │ │ │ ├── model.rs │ │ │ ├── options.rs │ │ │ ├── residual_block.rs │ │ │ ├── samplers │ │ │ ├── greedy.rs │ │ │ └── mod.rs │ │ │ ├── spectrogram.rs │ │ │ ├── task.rs │ │ │ ├── tokenizer.rs │ │ │ ├── transcribe.rs │ │ │ └── transcript.rs │ ├── tests │ │ └── whisper.rs │ └── webdriver.json ├── ratchet-nn │ ├── Cargo.toml │ └── src │ │ ├── embedding.rs │ │ ├── groupnorm.rs │ │ ├── kv_cache.rs │ │ ├── lib.rs │ │ ├── linear.rs │ │ ├── norm.rs │ │ └── rope.rs └── ratchet-web │ ├── .gitignore │ ├── Cargo.toml │ ├── README.md │ └── src │ ├── db.rs │ ├── lib.rs │ └── model.rs ├── examples ├── ratchet-moondream │ ├── .gitignore │ ├── README.md │ ├── package-lock.json │ ├── package.json │ ├── public │ │ ├── index.html │ │ ├── manifest.json │ │ └── robots.txt │ └── src │ │ ├── App.css │ │ ├── App.js │ │ ├── index.css │ │ └── index.js ├── ratchet-phi │ ├── .gitignore │ ├── README.md │ ├── next.config.mjs │ ├── package.json │ ├── postcss.config.js │ ├── src │ │ └── app │ │ │ ├── components │ │ │ ├── WebGPUModal.tsx │ │ │ ├── progressBar.tsx │ │ │ └── warningModal.tsx │ │ │ ├── favicon.ico │ │ │ ├── globals.css │ │ │ ├── layout.tsx │ │ │ ├── page.module.css │ │ │ └── page.tsx │ ├── tailwind.config.js │ └── tsconfig.json └── ratchet-whisper │ ├── .gitignore │ ├── README.md │ ├── next.config.mjs │ ├── package.json │ ├── postcss.config.js │ ├── src │ └── app │ │ ├── audio.ts │ │ ├── components │ │ ├── WebGPUModal.tsx │ │ ├── configModal.tsx │ │ ├── languageDropdown.tsx │ │ ├── micButton.tsx │ │ ├── modelSelector.tsx │ │ ├── progressBar.tsx │ │ ├── suppressSelector.tsx │ │ └── taskSelector.tsx │ │ ├── favicon.ico │ │ ├── globals.css │ │ ├── layout.tsx │ │ ├── page.module.css │ │ └── page.tsx │ ├── tailwind.config.js │ └── tsconfig.json ├── justfile ├── package.json ├── pnpm-lock.yaml ├── pnpm-workspace.yaml ├── requirements.txt ├── rust-toolchain.toml └── scripts ├── phi3.py └── understanding_matmul.py /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | # Needed for WASM unstable features 2 | [build] 3 | rustflags = [ "--cfg=web_sys_unstable_apis" ] 4 | rustdocflags = [ "--cfg=web_sys_unstable_apis" ] 5 | #target = "wasm32-unknown-unknown" 6 | -------------------------------------------------------------------------------- /.config/nextest.toml: -------------------------------------------------------------------------------- 1 | [profile.default] 2 | slow-timeout = { period = "60s", terminate-after = 2 } 3 | -------------------------------------------------------------------------------- /.github/ratchet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/ratchet/136da4d5216910bfd015b27a17b837c21f17163a/.github/ratchet.png -------------------------------------------------------------------------------- /.github/trufflehog.yaml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | 4 | name: Secret Leaks 5 | 6 | permissions: 7 | contents: read 8 | 9 | jobs: 10 | trufflehog: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v4 15 | with: 16 | fetch-depth: 0 17 | - name: Secret Scanning 18 | uses: trufflesecurity/trufflehog@main 19 | -------------------------------------------------------------------------------- /.github/workflows/ratbot.yml: -------------------------------------------------------------------------------- 1 | name: Ratbot 2 | on: 3 | pull_request_target 4 | 5 | jobs: 6 | comment: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout code 10 | uses: actions/checkout@v4 11 | 12 | - name: Install Rust and Cargo 13 | run: | 14 | curl -sSf https://sh.rustup.rs | sh -s -- -y 15 | source $HOME/.cargo/env 16 | 17 | - name: Install Tokei 18 | run: cargo install tokei 19 | 20 | - name: Run Tokei and get the lines of code 21 | run: tokei crates/ratchet-core > tokei_output.txt 22 | 23 | - name: Comment or Update PR 24 | uses: actions/github-script@v6 25 | with: 26 | script: | 27 | const fs = require('fs'); 28 | const tokeiOutput = fs.readFileSync('tokei_output.txt', 'utf8'); 29 | const uniqueIdentifier = 'Code Metrics Report'; 30 | const codeReport = ` 31 |
32 | ${uniqueIdentifier} 33 |
34 |               ${tokeiOutput}
35 |               
36 |
37 | `; 38 | 39 | const issue_number = context.issue.number; 40 | const { owner, repo } = context.repo; 41 | 42 | const comments = await github.rest.issues.listComments({ 43 | issue_number, 44 | owner, 45 | repo 46 | }); 47 | 48 | const existingComment = comments.data.find(comment => comment.body.includes(uniqueIdentifier)); 49 | 50 | if (existingComment) { 51 | await github.rest.issues.updateComment({ 52 | owner, 53 | repo, 54 | comment_id: existingComment.id, 55 | body: codeReport 56 | }); 57 | } else { 58 | await github.rest.issues.createComment({ 59 | issue_number, 60 | owner, 61 | repo, 62 | body: codeReport 63 | }); 64 | } 65 | 66 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: ["master"] 6 | pull_request: 7 | branches: ["master"] 8 | 9 | env: 10 | RUST_LOG: "info" 11 | MESA_VERSION: "23.3.1" # Sourced from https://archive.mesa3d.org/ 12 | WARP_VERSION: "1.0.8" # Sourced from https://www.nuget.org/packages/Microsoft.Direct3D.WARP 13 | VULKAN_SDK_VERSION: "1.3.268" # Sourced from https://vulkan.lunarg.com/sdk/home#linux 14 | CARGO_TERM_COLOR: always 15 | WGPU_DX12_COMPILER: dxc 16 | RUSTFLAGS: --cfg=web_sys_unstable_apis 17 | RUST_BACKTRACE: 1 18 | DXC_RELEASE: "v1.7.2308" 19 | DXC_FILENAME: "dxc_2023_08_14.zip" 20 | WASM_BINDGEN_TEST_TIMEOUT: 300 # 5 minutes 21 | CI_BINARY_BUILD: "build18" # Corresponds to https://github.com/gfx-rs/ci-build/releases 22 | RATCHET_FORCE_F32: 1 23 | 24 | jobs: 25 | check: 26 | name: Check 27 | runs-on: ${{ matrix.os }} 28 | strategy: 29 | matrix: 30 | include: 31 | - name: Windows x86_64 32 | os: windows-2022 33 | target: x86_64-pc-windows-msvc 34 | platform: win64 35 | 36 | - name: Linux x86_64 37 | os: ubuntu-22.04 38 | target: x86_64-unknown-linux-gnu 39 | platform: linux64 40 | 41 | - name: MacOS aarch64 42 | os: macos-14 43 | target: aarch64-apple-darwin 44 | platform: mac-arm64 45 | 46 | steps: 47 | - uses: actions/checkout@v4 48 | 49 | - uses: ./.github/actions/setup 50 | 51 | - name: run tests 52 | shell: bash 53 | run: | 54 | set -e 55 | cargo nextest run -j 1 --no-fail-fast --features=ci,pyo3 56 | 57 | - name: Set up WebDriver for Ubuntu 58 | if: matrix.os == 'ubuntu-22.04' 59 | run: cp config/webdriver-linux.json crates/ratchet-models/webdriver.json 60 | 61 | - name: Set up WebDriver for Windows 62 | if: matrix.os == 'windows-2022' 63 | run: cp config/webdriver-win.json crates/ratchet-models/webdriver.json 64 | 65 | - name: Set up WebDriver for macOS 66 | if: matrix.os == 'macos-14' 67 | run: cp config/webdriver-macos.json crates/ratchet-models/webdriver.json 68 | 69 | - name: Run wasm-bindgen-test integration tests 70 | run: | 71 | just wasm-test ratchet-models chrome 72 | just wasm-test ratchet-hub chrome 73 | just wasm-test ratchet-web chrome 74 | 75 | build: 76 | name: Build & Publish Web 77 | runs-on: ${{ matrix.os }} 78 | strategy: 79 | matrix: 80 | include: 81 | - name: MacOS aarch64 82 | os: macos-14 83 | target: aarch64-apple-darwin 84 | platform: mac-arm64 85 | 86 | steps: 87 | - uses: actions/checkout@v4 88 | 89 | - uses: ./.github/actions/setup 90 | 91 | - name: Build ratchet-web 92 | shell: bash 93 | run: just wasm ratchet-web 94 | 95 | - name: Publish ratchet-web 96 | shell: bash 97 | run: just wasm-publish-pr ratchet-web 98 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | .python-version 16 | 17 | crates/ratchet-core/kernel-generated/** 18 | **/*.svg 19 | fixtures/ 20 | **/.DS_Store 21 | 22 | /node_modules/ 23 | **/*.bin 24 | ffmpeg-wasm 25 | models/** 26 | 27 | # Python local env 28 | venv/ 29 | .venv/ 30 | 31 | # proptest regression tests 32 | proptest-regressions/ 33 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Running Tests for the Ratchet Rust Package 2 | 3 | This guide outlines the steps necessary to set up and run tests for the Ratchet Rust package. Please follow these steps carefully to ensure a smooth testing process. 4 | 5 | ## Setup Instructions 6 | 7 | ### Clone the Repository 8 | 9 | First, ensure you have Git installed. Clone of the Ratchet repository from GitHub and navigate into the project directory: 10 | 11 | ```sh 12 | git clone https://github.com/FL33TW00D/ratchet.git 13 | cd ratchet/ 14 | ``` 15 | 16 | ### Setup Rust and Cargo 17 | 18 | Ensure you have Rust and Cargo installed. If not, please refer to the Rust installation guide to set up Rust and Cargo. 19 | 20 | ### Setup `just` 21 | 22 | Ensure you have `just`, a command runner that simplifies running project-specific commands, installed. If `just` is not already installed on your system, you can install it using Cargo, Rust's package manager: 23 | 24 | ```sh 25 | cargo install just 26 | ``` 27 | 28 | ### Setup Python 29 | 30 | There are two ways to setup Python for the project: using `pyenv` or using `conda`. 31 | 32 | #### Option 1: Using pyenv 33 | 34 | ##### Step 1: Install `pyenv` 35 | 36 | First, make sure to install [pyenv](https://github.com/pyenv/pyenv#getting-pyenv). `pyenv` lets you manage multiple versions of Python. Please make sure you follow the install guide and source the correct environment variables. 37 | 38 | ##### Step 2: Install python 3.10.6 39 | 40 | Use `just` to install `python3.10.6` and enable it as the local python version for the project. 41 | 42 | > **NOTE** : `PyO3`\*\* needs Python to be built with `enable-shared` flag. 43 | 44 | ```sh 45 | just install-pyo3 46 | ``` 47 | 48 | ##### Step 3: Create virtual environment (Optional) 49 | 50 | This step is optional but _highly_ recommended. You should create and source a virtual environment using your favorite tool (`uv`, `venv`, `virtualenv`...). We'll use the built-in `venv` module: 51 | 52 | ```sh 53 | python -m venv venv 54 | source venv/bin/activate 55 | ``` 56 | 57 | ##### Step 4: Install python dependencies 58 | 59 | Install the Python dependencies recursively: 60 | 61 | ```sh 62 | python -m pip install -r requirements.txt 63 | ``` 64 | 65 | ##### Step 5: Configure Python Environment for PyO3 66 | 67 | PyO3 uses a build script to determine the Python version and set the correct linker arguments. To override the Python interpreter to the virtual environment, run the following: 68 | 69 | ```sh 70 | export PYO3_PYTHON=$(which python) 71 | echo $PYO3_PYTHON 72 | ``` 73 | 74 | #### Option 2: Using conda 75 | 76 | ##### Step 1: Create a new conda environment 77 | 78 | ``` 79 | conda create -n ratchet python=3.10 80 | ``` 81 | 82 | ##### Step 2: Install dependencies 83 | 84 | ``` 85 | pip install -r requirements.txt 86 | ``` 87 | 88 | ##### Step 3: Configure Cargo 89 | 90 | Edit `/.cargo/config.toml` to add the linker config: 91 | 92 | ``` 93 | # .cargo/config.toml 94 | [build] 95 | rustflags = [ 96 | "--cfg=web_sys_unstable_apis", 97 | # Add these two lines and replace PATH_TO_CONDA with your conda directory: 98 | "-C", 99 | "link-args=-Wl,-rpath,/envs/ratchet/lib/", 100 | ] 101 | ``` 102 | 103 | ### Setup Node.js 104 | 105 | Ensure you have Node.js v18 or later installed. If not, please refer to the Node.js installation guide to set up Node.js. 106 | 107 | After installing Node.js, run `corepack enable` to enable the Node.js [corepack](https://github.com/nodejs/corepack) feature. 108 | 109 | Then run `pnpm install` to install the Node.js dependencies. 110 | 111 | ## Test config 112 | 113 | We'll first verify that your pyo3 config is correctly setup: 114 | 115 | ``` 116 | PYO3_PRINT_CONFIG=1 cargo build 117 | ``` 118 | 119 | Building the project will throw an error(!) and print the config: 120 | 121 | ``` 122 | (exit status: 101) 123 | --- stdout 124 | cargo:rerun-if-env-changed=PYO3_PRINT_CONFIG 125 | 126 | -- PYO3_PRINT_CONFIG=1 is set, printing configuration and halting compile -- 127 | implementation=CPython 128 | version=3.10 129 | shared=true 130 | abi3=false 131 | lib_name=python3.10 132 | lib_dir= 133 | executable= 134 | pointer_width=64 135 | build_flags= 136 | suppress_build_script_link_lines=false 137 | ``` 138 | 139 | If that looks like this, you are good to go 🎉 140 | 141 | ## Run Tests 142 | 143 | Finally, run the tests for the package using Cargo: 144 | 145 | ```sh 146 | cargo test 147 | ``` 148 | 149 | To run the `PyO3` tests, add the `pyo3` flag: 150 | 151 | ```sh 152 | cargo test --features pyo3 153 | ``` 154 | 155 | ## Run WASM Tests 156 | 157 | To run WASM tests (e.g., the whisper test) run: 158 | 159 | ```sh 160 | just wasm-test ratchet-models chrome 161 | ``` 162 | 163 | And check the result in: 164 | 165 | ``` 166 | http://localhost:8000 167 | ``` 168 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | "crates/ratchet-hub", 4 | "crates/ratchet-core", 5 | "crates/ratchet-web", 6 | "crates/ratchet-loader", 7 | "crates/ratchet-models", 8 | "crates/ratchet-nn", 9 | "crates/ratchet-hub", 10 | "crates/ratchet-cli", 11 | "crates/ratchet-macros", 12 | ] 13 | resolver = "2" 14 | edition = "2021" 15 | 16 | [profile.test] 17 | debug = 2 18 | debug-assertions = true 19 | 20 | [profile.release] 21 | panic = 'abort' 22 | lto = "fat" 23 | codegen-units = 1 24 | 25 | [profile.profiling] 26 | inherits = "release" 27 | debug = 2 28 | 29 | [workspace.dependencies] 30 | wgpu = { git = "https://github.com/FL33TW00D/wgpu", branch = "feature/multi-dim-compute-subgroups", features = ["fragile-send-sync-non-atomic-wasm"] } 31 | bytemuck = { version = "1.14.0", features=["wasm_simd", "aarch64_simd", "extern_crate_alloc"] } 32 | num-traits = "0.2.17" 33 | half = { version = "2.3.1", features = ["num-traits", "bytemuck"] } 34 | derive-new = "0.6.0" 35 | log = "0.4.20" 36 | thiserror = "1.0.56" 37 | byteorder = "1.5.0" 38 | npyz = { version = "0.8.3"} 39 | hf-hub = "0.3.2" 40 | serde = "1.0" 41 | anyhow = "1.0.79" 42 | tokenizers = "0.19.1" 43 | 44 | js-sys = "0.3.64" 45 | wasm-bindgen = "0.2.91" 46 | wasm-bindgen-test = "0.3.34" 47 | cfg-if = "1.0.0" 48 | chrono = "0.4.35" 49 | clap = "4.5.3" 50 | console_error_panic_hook = "0.1.7" 51 | console_log = "1.0.0" 52 | dot3 = "0.1.0" 53 | encase = { git = "https://github.com/cwfitzgerald/encase", branch = "add-member" } 54 | env_logger = "0.11.3" 55 | fern = "0.6.2" 56 | getrandom = "0.2" 57 | glam = "0.28.0" 58 | globwalk = "0.8.1" 59 | gloo-net = { version = "0.5.0", default-features = false } 60 | hound = "3.5.1" 61 | image = { version = "0.25.1", default-features = false, features = ["jpeg", "png"] } 62 | indexed_db_futures = "0.4.1" 63 | itertools = "0.12.1" 64 | lazy_static = "1.4.0" 65 | ndarray = "0.15.6" 66 | ndarray-stats = "0.5.1" 67 | num = "0.4.1" 68 | numpy = "0.20.0" 69 | parking_lot = "0.12.1" 70 | pathdiff = "0.2.1" 71 | pollster = "0.3.0" 72 | proptest = "1.4.0" 73 | pyo3 = "0.20.2" 74 | rand = "0.8.4" 75 | rand_distr = "0.4.3" 76 | realfft = "3.3.0" 77 | regex = "1.10.3" 78 | rustc-hash = "1.1.0" 79 | serde-wasm-bindgen = "0.6.5" 80 | serde_bytes = "0.11.14" 81 | serde_json = "1.0.114" 82 | slotmap = "1.0.7" 83 | smallvec = "1.11.2" 84 | strum = "0.26" 85 | strum_macros = "0.26" 86 | tabled = "0.15.0" 87 | tempfile = "3.3.0" 88 | tera = "1.19.0" 89 | test-strategy = "0.3.1" 90 | tokio = "1.36.0" 91 | uuid = "1.5.0" 92 | wasm-bindgen-futures = "0.4.42" 93 | web-sys = "0.3.69" 94 | web-time = "1.0.0" 95 | futures-intrusive = "0.5.0" 96 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Christopher Fleetwood 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

Demo Site | Discord | Roadmap

4 |

5 | A web-first, cross-platform ML developer toolkit 6 |

7 |
8 |
9 | 10 | We are on a mission to bring fast, cross platform GPU accelerated inference on native + browser. 11 | 12 | > [!NOTE] 13 | > Ratchet is currently in active development. We are working on the engine, adding more models and improving compatibility. Please, reach out if you'd like to help! 14 | 15 | ## Getting Started 16 | 17 | The easiest way to experience Ratchet is to check out our [Hugging Face spaces](https://huggingface.co/FL33TW00D-HF): 18 | - [Whisper](https://huggingface.co/spaces/FL33TW00D-HF/ratchet-whisper) 19 | - [Phi](https://huggingface.co/spaces/FL33TW00D-HF/ratchet-phi) 20 | 21 | To dig deeper, check out the [examples](https://github.com/FL33TW00D/ratchet/tree/master/examples) 22 | 23 | We welcome contributions from the community. If you have any ideas or suggestions, please feel free to open an issue or pull request. 24 | 25 | ### Javascript 26 | 27 | ```javascript 28 | // Asynchronous loading & caching with IndexedDB 29 | let model = await Model.load(AvailableModels.WHISPER_TINY, Quantization.Q8, (p: number) => setProgress(p)) 30 | let result = await model.run({ input }); 31 | ``` 32 | 33 | ### Rust 34 | 35 | Rust crate & CLI coming soon... 36 | 37 | ## Philosophy 38 | 39 | We want a toolkit for developers to make integrating performant AI functionality into existing production applications easy. 40 | The following principles will help us accomplish this: 41 | 1. **Inference only** 42 | 2. **WebGPU/CPU only** 43 | 3. First class quantization support 44 | 4. Lazy computation 45 | 5. Inplace by default 46 | 47 | ## Supported Models 48 | - Whisper 49 | - Phi 2 & 3 50 | - Moondream 51 | 52 | ## Upcoming Models 53 | - Gemini 2 2B 54 | -------------------------------------------------------------------------------- /config/webdriver-linux.json: -------------------------------------------------------------------------------- 1 | { 2 | "goog:chromeOptions": { 3 | "args": [ 4 | "--no-sandbox", 5 | "--headless=new", 6 | "--use-angle=vulkan", 7 | "--enable-features=Vulkan", 8 | "--enable-unsafe-webgpu" 9 | ] 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /config/webdriver-macos.json: -------------------------------------------------------------------------------- 1 | { 2 | "goog:chromeOptions": { 3 | "args": [ 4 | "--no-sandbox", 5 | "--headless=new", 6 | "--use-angle=metal", 7 | "--enable-features=Metal", 8 | "--enable-unsafe-webgpu" 9 | ] 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /config/webdriver-win.json: -------------------------------------------------------------------------------- 1 | { 2 | "goog:chromeOptions": { 3 | "args": [ 4 | "--no-sandbox", 5 | "--headless=new", 6 | "--use-angle=d3d12", 7 | "--enable-features=D3D12", 8 | "--enable-unsafe-webgpu" 9 | ] 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /crates/ratchet-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ratchet-cli" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [[bin]] 7 | name = "ratchet" 8 | path = "src/bin/cli.rs" 9 | 10 | [dependencies] 11 | ratchet = { path = "../ratchet-core" } 12 | ratchet-loader = { path = "../ratchet-loader" } 13 | ratchet-models = { path = "../ratchet-models" } 14 | ratchet-hub = { path = "../ratchet-hub" } 15 | ratchet-nn = { path = "../ratchet-nn" } 16 | log.workspace = true 17 | clap = { workspace = true, features = ["derive"] } 18 | hf-hub = { workspace = true } 19 | serde_json = { workspace = true } 20 | env_logger = { workspace = true } 21 | fern = { workspace = true } 22 | chrono = { workspace = true } 23 | tokenizers = { workspace = true } 24 | ndarray = { workspace = true } 25 | ndarray-stats = { workspace = true } 26 | anyhow.workspace = true 27 | -------------------------------------------------------------------------------- /crates/ratchet-cli/src/lib.rs: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /crates/ratchet-core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ratchet" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [features] 7 | default = ["rand", "testing"] 8 | gpu-profiling = ["dep:tabled", "dep:itertools"] 9 | rand = ["dep:rand", "dep:rand_distr"] 10 | plotting = ["dep:dot3", "dep:tempfile"] 11 | testing = ["dep:npyz", "dep:ndarray"] 12 | pyo3 = ["dep:pyo3", "dep:numpy", "dep:regex"] 13 | debug = [] #dump every node 14 | 15 | [dependencies] 16 | ratchet-macros = { path = "../ratchet-macros" } 17 | inline-wgsl = { git = "https://github.com/FL33TW00D/inline-wgsl.git", branch = "master" } 18 | wgpu = { workspace = true } 19 | bytemuck = { workspace = true } 20 | half = { workspace = true } 21 | derive-new = { workspace = true } 22 | num-traits = { workspace = true } 23 | log = { workspace = true } 24 | thiserror = { workspace = true } 25 | serde = { workspace = true, features = ["derive"] } 26 | anyhow.workspace = true 27 | 28 | rustc-hash = { workspace = true } 29 | slotmap = { workspace = true } 30 | parking_lot = { workspace = true } 31 | smallvec = { workspace = true } 32 | encase = { workspace = true, features = ["smallvec", "glam"] } 33 | pollster = { workspace = true } 34 | getrandom = { workspace = true, features = [ 35 | "js", 36 | ] } # Needed for wasm support in `num` trait 37 | num = { workspace = true } 38 | rand_distr = { workspace = true, optional = true } 39 | rand = { workspace = true, optional = true } 40 | glam = { workspace = true } 41 | npyz = { workspace = true, optional = true } 42 | ndarray = { workspace = true, optional = true } 43 | 44 | strum = { workspace = true } 45 | strum_macros = { workspace = true } 46 | 47 | #Plotting 48 | dot3 = { workspace = true, optional = true } 49 | tempfile = { workspace = true, optional = true } 50 | 51 | # Profiling 52 | tabled = { workspace = true, optional = true } 53 | itertools = { workspace = true, optional = true } 54 | 55 | pyo3 = { workspace = true, features = ["auto-initialize"], optional = true } 56 | regex = { workspace = true, optional = true } 57 | numpy = { workspace = true, optional = true, features = ["half"] } 58 | gemm = { version = "0.18.0", features = ["nightly", "wasm-simd128-enable"] } 59 | 60 | [target.'cfg(target_arch = "wasm32")'.dependencies] 61 | wasm-bindgen.workspace = true 62 | futures-intrusive.workspace = true 63 | wasm-bindgen-futures.workspace = true 64 | 65 | async-trait = "0.1.77" 66 | smallvec = { workspace = true, features = ["serde"] } 67 | 68 | [dev-dependencies] 69 | env_logger = { workspace = true } 70 | rand = { workspace = true } 71 | test-strategy = { workspace = true } 72 | ndarray = { workspace = true } 73 | proptest = { workspace = true } 74 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/compiled_op.rs: -------------------------------------------------------------------------------- 1 | use crate::gpu::{ 2 | BindGroupDescriptor, BindGroupLayoutHandle, ComputePipelineHandle, GpuBindGroup, WgpuDevice, 3 | WorkgroupCount, 4 | }; 5 | use crate::{drvec, rvec, KernelKey, OperationError, RVec, Tensor}; 6 | use derive_new::new; 7 | use wgpu::DynamicOffset; 8 | 9 | //Compiled op represents a single kernel invocation 10 | //TODO: We need to be more general here, enum with encoder.copy_buffer_to_buffer as a COPY, and 11 | //compiledOp as compute 12 | #[derive(Debug, new)] 13 | pub struct CompiledOp { 14 | pipeline_handle: ComputePipelineHandle, 15 | workgroup_count: WorkgroupCount, 16 | storage_groups: RVec, 17 | offset: DynamicOffset, //offset into the metadata uniform buffer 18 | pub kernel_key: KernelKey, 19 | #[cfg(feature = "debug")] 20 | pub debug_buffer: Option>, 21 | } 22 | 23 | impl CompiledOp { 24 | const MAX_BINDINGS_PER_GROUP: usize = 8; 25 | 26 | pub fn create_storage_bind_groups( 27 | srcs: &[&Tensor], 28 | dst: &Tensor, 29 | bind_group_layouts: RVec, 30 | device: &WgpuDevice, 31 | inplace: bool, 32 | ) -> Result, OperationError> { 33 | let mut bind_group_entries = drvec![]; 34 | 35 | for tensor in srcs.iter() { 36 | bind_group_entries.append(&mut tensor.bind_group_entries()); 37 | } 38 | 39 | if !inplace { 40 | bind_group_entries.append(&mut dst.bind_group_entries()); 41 | } 42 | 43 | let mut storage_groups = rvec![]; 44 | for (group_index, bind_group_layout) in bind_group_layouts.iter().enumerate() { 45 | let group_range = Self::group_range(group_index, bind_group_entries.len()); 46 | let entries = bind_group_entries[group_range].into(); 47 | let layout = *bind_group_layout; 48 | 49 | let bg = device.get_or_create_bind_group(&BindGroupDescriptor { entries, layout })?; 50 | storage_groups.push(bg); 51 | } 52 | Ok(storage_groups) 53 | } 54 | 55 | /// Determines which bindings belong to which bind group 56 | fn group_range(group_index: usize, binding_counter: usize) -> std::ops::Range { 57 | let group_end = usize::min( 58 | (group_index + 1) * Self::MAX_BINDINGS_PER_GROUP, 59 | binding_counter, 60 | ); 61 | group_index * Self::MAX_BINDINGS_PER_GROUP..group_end 62 | } 63 | 64 | pub fn workgroup_count(&self) -> &WorkgroupCount { 65 | &self.workgroup_count 66 | } 67 | 68 | pub fn offset(&self) -> DynamicOffset { 69 | self.offset 70 | } 71 | 72 | pub fn storage_groups(&self) -> &RVec { 73 | &self.storage_groups 74 | } 75 | 76 | pub fn pipeline_handle(&self) -> ComputePipelineHandle { 77 | self.pipeline_handle 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/cpu/binary.rs: -------------------------------------------------------------------------------- 1 | use crate::cpu::cpu_store_result; 2 | use crate::{Binary, BinaryOp, CPUOperation, DType, OperationError, Tensor, TensorDType}; 3 | use core::marker::PhantomData; 4 | use half::{bf16, f16}; 5 | use num_traits::NumOps; 6 | 7 | #[inline] 8 | pub(crate) fn binary_map( 9 | lhs: &[T], 10 | rhs: &[T], 11 | dst: &mut [U], 12 | f: fn(T, T) -> U, 13 | ) { 14 | assert_eq!(lhs.len(), dst.len()); 15 | assert_eq!(rhs.len(), dst.len()); 16 | for ((l, r), d) in lhs 17 | .iter() 18 | .copied() 19 | .zip(rhs.iter().copied()) 20 | .zip(dst.iter_mut()) 21 | { 22 | *d = f(l, r); 23 | } 24 | } 25 | 26 | #[inline] 27 | pub(crate) fn binary_map_inplace(lhs: &mut [T], rhs: &[T], f: fn(T, T) -> T) { 28 | assert_eq!(lhs.len(), rhs.len()); 29 | lhs.iter_mut().zip(rhs.iter()).for_each(|(l, r)| { 30 | *l = f(*l, *r); 31 | }); 32 | } 33 | 34 | #[inline] 35 | pub(crate) fn binary_apply( 36 | lhs: &Tensor, 37 | rhs: &Tensor, 38 | dst: &Tensor, 39 | f: fn(T, T) -> U, 40 | ) -> Result<(), OperationError> { 41 | let lhs = lhs.to_vec::()?; 42 | let rhs = rhs.to_vec::()?; 43 | let mut result = vec![U::zero(); dst.shape().numel()]; 44 | binary_map(&lhs, &rhs, &mut result, f); 45 | cpu_store_result(dst, &result); 46 | Ok(()) 47 | } 48 | 49 | #[inline] 50 | pub(crate) fn binary_apply_inplace( 51 | lhs: &Tensor, 52 | rhs: &Tensor, 53 | dst: &Tensor, 54 | f: fn(T, T) -> T, 55 | ) -> Result<(), OperationError> { 56 | let mut lhs = lhs.to_vec::()?; 57 | let rhs = rhs.to_vec::()?; 58 | binary_map_inplace(&mut lhs, &rhs, f); 59 | cpu_store_result(dst, &lhs); 60 | Ok(()) 61 | } 62 | 63 | pub struct BinaryOps { 64 | dtype: PhantomData, 65 | } 66 | 67 | macro_rules! impl_cpu_binary_op { 68 | ($method_name:ident, $dtype:ident, $op:expr) => { 69 | fn $method_name(lhs: &Tensor, rhs: &Tensor, dst: Tensor) -> Result { 70 | binary_apply_inplace::<$dtype>(lhs, rhs, &dst, $op)?; 71 | Ok(dst) 72 | } 73 | }; 74 | } 75 | 76 | macro_rules! cpu_binary_op_fn { 77 | ($method_name:ident, $op:expr) => { 78 | #[inline] 79 | pub(crate) fn $method_name(lhs: &mut [T], rhs: &[T]) { 80 | binary_map_inplace::(lhs, rhs, $op); 81 | } 82 | }; 83 | } 84 | 85 | cpu_binary_op_fn!(add, |lhs, rhs| lhs + rhs); 86 | cpu_binary_op_fn!(sub, |lhs, rhs| lhs - rhs); 87 | cpu_binary_op_fn!(mul, |lhs, rhs| lhs * rhs); 88 | cpu_binary_op_fn!(div, |lhs, rhs| lhs / rhs); 89 | 90 | macro_rules! impl_cpu_binary { 91 | ($dtype:ident) => { 92 | impl BinaryOps<$dtype> { 93 | impl_cpu_binary_op!(add, $dtype, |lhs, rhs| lhs + rhs); 94 | impl_cpu_binary_op!(sub, $dtype, |lhs, rhs| lhs - rhs); 95 | impl_cpu_binary_op!(mul, $dtype, |lhs, rhs| lhs * rhs); 96 | impl_cpu_binary_op!(div, $dtype, |lhs, rhs| lhs / rhs); 97 | 98 | pub fn apply(op: &Binary, dst: Tensor) -> Result { 99 | match op.op() { 100 | BinaryOp::Add => Self::add(op.lhs(), op.rhs(), dst), 101 | BinaryOp::Sub => Self::sub(op.lhs(), op.rhs(), dst), 102 | BinaryOp::Mul => Self::mul(op.lhs(), op.rhs(), dst), 103 | BinaryOp::Div => Self::div(op.lhs(), op.rhs(), dst), 104 | } 105 | } 106 | } 107 | }; 108 | } 109 | 110 | impl CPUOperation for Binary { 111 | fn apply_cpu(&self, dst: Tensor) -> Result { 112 | match dst.dt() { 113 | DType::F32 => BinaryOps::::apply(self, dst), 114 | DType::F16 => BinaryOps::::apply(self, dst), 115 | DType::BF16 => BinaryOps::::apply(self, dst), 116 | _ => todo!(), 117 | } 118 | } 119 | } 120 | 121 | impl_cpu_binary!(f32); 122 | impl_cpu_binary!(f16); 123 | impl_cpu_binary!(bf16); 124 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/cpu/rope.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | concat, 3 | cpu::{cpu_store_result, gemm::gemm, reindex::slice}, 4 | shape, DType, OperationError, RoPE, Shape, Strides, Tensor, 5 | }; 6 | 7 | pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { 8 | match op.input().dt() { 9 | DType::F32 => { 10 | let dim = op.dim(); 11 | let base = op.base(); 12 | let offset = op.offset(); 13 | let src = op.input().to_vec::()?; 14 | let result = rope(src, op.input().shape(), dim, base, offset)?; 15 | cpu_store_result(&dst, &result) 16 | } 17 | _ => todo!(), 18 | } 19 | 20 | Ok(dst) 21 | } 22 | 23 | fn compute_theta( 24 | dim: usize, 25 | seq_len: usize, 26 | base: f32, 27 | offset: usize, 28 | ) -> Result, OperationError> { 29 | let half_dim = dim / 2; 30 | 31 | let positions = (offset..seq_len + offset) 32 | .map(|x| x as f32) 33 | .collect::>(); 34 | 35 | let inv_freqs = (0..half_dim) 36 | .map(|i| -(i as f32)) 37 | .map(|i| i * base.ln() / half_dim as f32) 38 | .map(f32::exp) 39 | .collect::>(); 40 | 41 | let p_shape = shape!(seq_len, 1); 42 | let p_strides = Strides::from(&p_shape); 43 | let i_shape = shape!(1, half_dim); 44 | let i_strides = Strides::from(&i_shape); 45 | let dst_strides = Strides::from(&shape!(seq_len, half_dim)); 46 | let theta = gemm( 47 | &positions, 48 | &p_shape, 49 | &p_strides, 50 | &inv_freqs, 51 | &i_shape, 52 | &i_strides, 53 | &dst_strides, 54 | 1, 55 | seq_len, 56 | half_dim, 57 | 1, 58 | )?; 59 | 60 | Ok(theta) 61 | } 62 | 63 | fn rope( 64 | src: Vec, 65 | shape: &Shape, 66 | dim: usize, 67 | base: f32, 68 | offset: usize, 69 | ) -> Result, OperationError> { 70 | let [batches, num_heads, seq_len, head_dim] = shape.try_into().unwrap(); 71 | 72 | let half_dim = dim / 2; 73 | let theta = compute_theta(dim, seq_len, base, offset)?; 74 | let (sin, cos): (Vec, Vec) = theta.iter().map(|i| i.sin_cos()).unzip(); 75 | let src_strides = Strides::from(shape); 76 | let x1 = slice( 77 | &src, 78 | &src_strides, 79 | &[0, 0, 0, 0], 80 | &[batches, num_heads, seq_len, half_dim], 81 | ); 82 | let x2 = slice( 83 | &src, 84 | &src_strides, 85 | &[0, 0, 0, half_dim], 86 | &[batches, num_heads, seq_len, dim], 87 | ); 88 | 89 | //`multiply` as an operation that deals with broadcasting 90 | let x1_cos = x1 91 | .iter() 92 | .zip(cos.iter().cycle()) 93 | .map(|(x, c)| x * c) 94 | .collect::>(); 95 | let x2_sin = x2 96 | .iter() 97 | .zip(sin.iter().cycle()) 98 | .map(|(x, s)| x * s) 99 | .collect::>(); 100 | 101 | let mut r1 = x1_cos 102 | .iter() 103 | .zip(x2_sin.iter()) 104 | .map(|(x1, x2)| x1 - x2) 105 | .collect::>(); 106 | r1.extend(vec![0.0; shape.numel() - r1.len()]); 107 | 108 | let x1_sin = x1 109 | .iter() 110 | .zip(sin.iter().cycle()) 111 | .map(|(x, s)| x * s) 112 | .collect::>(); 113 | let x2_cos = x2 114 | .iter() 115 | .zip(cos.iter().cycle()) 116 | .map(|(x, c)| x * c) 117 | .collect::>(); 118 | let mut r2 = x1_sin 119 | .iter() 120 | .zip(x2_cos.iter()) 121 | .map(|(x1, x2)| x1 + x2) 122 | .collect::>(); 123 | r2.extend(vec![0.0; shape.numel() - r2.len()]); 124 | 125 | let mut to_cat = vec![ 126 | (shape![batches, num_heads, seq_len, half_dim], r1), 127 | (shape![batches, num_heads, seq_len, half_dim], r2), 128 | ]; 129 | if dim < shape[3] { 130 | let r3 = slice( 131 | &src, 132 | &src_strides, 133 | &[0, 0, 0, dim], 134 | &[batches, num_heads, seq_len, head_dim], 135 | ); 136 | to_cat.push((shape![batches, num_heads, seq_len, head_dim - dim], r3)); 137 | } 138 | 139 | let dst_shape = shape![batches, num_heads, seq_len, head_dim]; 140 | let mut dst = vec![0.0f32; dst_shape.numel()]; 141 | concat(to_cat.as_slice(), 3, &dst_shape, &mut dst)?; 142 | Ok(dst) 143 | } 144 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/cpu/softmax.rs: -------------------------------------------------------------------------------- 1 | use crate::cpu::utils::cpu_store_result; 2 | use crate::{CPUOperation, DType, OperationError, Softmax, Tensor, TensorDType}; 3 | use half::{bf16, f16}; 4 | use num::Float; 5 | use num_traits::NumAssignOps; 6 | 7 | impl CPUOperation for Softmax { 8 | fn apply_cpu(&self, dst: Tensor) -> Result { 9 | let Softmax { input, dim } = self; 10 | match input.dt() { 11 | DType::F32 => softmax::(input, *dim, &dst)?, 12 | DType::F16 => softmax::(input, *dim, &dst)?, 13 | DType::BF16 => softmax::(input, *dim, &dst)?, 14 | _ => todo!(), 15 | } 16 | 17 | Ok(dst) 18 | } 19 | } 20 | 21 | fn softmax(input: &Tensor, dim: usize, dst: &Tensor) -> Result<(), OperationError> 22 | where 23 | T: TensorDType + Float + NumAssignOps, 24 | { 25 | let src_shape = input.shape(); 26 | let mut input = input.to_vec::()?; 27 | let N = src_shape[dim]; 28 | input.chunks_mut(N).for_each(|chunk| { 29 | let mut sum = T::zero(); 30 | for j in 0..N { 31 | chunk[j] = chunk[j].exp(); 32 | sum += chunk[j]; 33 | } 34 | for j in 0..N { 35 | chunk[j] /= sum; 36 | } 37 | }); 38 | 39 | cpu_store_result(dst, &input); 40 | 41 | Ok(()) 42 | } 43 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/cpu/unary.rs: -------------------------------------------------------------------------------- 1 | use crate::cpu::cpu_store_result; 2 | use crate::{CPUOperation, DType, OperationError, Tensor, TensorDType, Unary, UnaryOp}; 3 | use core::marker::PhantomData; 4 | use half::{bf16, f16}; 5 | use num_traits::Float; 6 | 7 | #[inline] 8 | pub(crate) fn unary_apply_fn_helper( 9 | src: &[T], 10 | dst: &mut [U], 11 | f: fn(T) -> U, 12 | ) { 13 | assert_eq!(src.len(), dst.len()); 14 | for (s, d) in src.iter().copied().zip(dst.iter_mut()) { 15 | *d = f(s); 16 | } 17 | } 18 | 19 | #[inline] 20 | pub(crate) fn unary_map_inplace(src: &mut [T], f: fn(T) -> T) { 21 | for s in src.iter_mut() { 22 | *s = f(*s); 23 | } 24 | } 25 | 26 | #[inline] 27 | pub(crate) fn unary_apply_fn( 28 | input: &Tensor, 29 | dst: &Tensor, 30 | f: fn(T) -> U, 31 | ) -> Result<(), OperationError> { 32 | let input = input.to_vec::()?; 33 | let mut result = vec![U::zero(); dst.shape().numel()]; 34 | unary_apply_fn_helper(&input, &mut result, f); 35 | cpu_store_result(dst, &result); 36 | Ok(()) 37 | } 38 | 39 | struct UnaryOps { 40 | dtype: PhantomData, 41 | } 42 | 43 | macro_rules! impl_unary_ops { 44 | ($dtype:ident, $conv:expr) => { 45 | impl UnaryOps<$dtype> { 46 | impl_cpu_unary_op!(gelu, |x: $dtype| $conv(0.5) 47 | * x 48 | * ($conv(1.0) 49 | + $dtype::tanh( 50 | $conv(0.797_884_6) * x * ($conv(1.0) + $conv(0.044715) * x * x) 51 | ))); 52 | 53 | impl_cpu_unary_op!(tanh, |x: $dtype| x.tanh()); 54 | impl_cpu_unary_op!(exp, |x: $dtype| x.exp()); 55 | impl_cpu_unary_op!(log, |x: $dtype| x.ln()); 56 | impl_cpu_unary_op!(sin, |x: $dtype| x.sin()); 57 | impl_cpu_unary_op!(cos, |x: $dtype| x.cos()); 58 | impl_cpu_unary_op!(abs, |x: $dtype| x.abs()); 59 | impl_cpu_unary_op!(sqrt, |x: $dtype| x.sqrt()); 60 | impl_cpu_unary_op!(relu, |x: $dtype| x.max($conv(0.0))); 61 | impl_cpu_unary_op!(floor, |x: $dtype| x.floor()); 62 | impl_cpu_unary_op!(ceil, |x: $dtype| x.ceil()); 63 | impl_cpu_unary_op!(neg, |x: $dtype| -x); 64 | impl_cpu_unary_op!(silu, |x: $dtype| x / ($conv(1.0) + (-x).exp())); 65 | impl_cpu_unary_op!(sigmoid, |x: $dtype| $conv(1.0) / ($conv(1.0) + (-x).exp())); 66 | 67 | fn apply(op: &Unary, dst: Tensor) -> Result { 68 | match op.op() { 69 | UnaryOp::Gelu => Self::gelu(op.input(), dst), 70 | UnaryOp::Tanh => Self::tanh(op.input(), dst), 71 | UnaryOp::Exp => Self::exp(op.input(), dst), 72 | UnaryOp::Log => Self::log(op.input(), dst), 73 | UnaryOp::Sin => Self::sin(op.input(), dst), 74 | UnaryOp::Cos => Self::cos(op.input(), dst), 75 | UnaryOp::Abs => Self::abs(op.input(), dst), 76 | UnaryOp::Sqrt => Self::sqrt(op.input(), dst), 77 | UnaryOp::Relu => Self::relu(op.input(), dst), 78 | UnaryOp::Floor => Self::floor(op.input(), dst), 79 | UnaryOp::Ceil => Self::ceil(op.input(), dst), 80 | UnaryOp::Neg => Self::neg(op.input(), dst), 81 | UnaryOp::Silu => Self::silu(op.input(), dst), 82 | UnaryOp::Sigmoid => Self::sigmoid(op.input(), dst), 83 | } 84 | } 85 | } 86 | }; 87 | } 88 | 89 | macro_rules! impl_cpu_unary_op { 90 | ($method_name:ident, $op:expr) => { 91 | fn $method_name(input: &Tensor, dst: Tensor) -> Result { 92 | unary_apply_fn(input, &dst, $op)?; 93 | Ok(dst) 94 | } 95 | }; 96 | } 97 | 98 | impl CPUOperation for Unary { 99 | fn apply_cpu(&self, dst: Tensor) -> Result { 100 | match dst.dt() { 101 | DType::F32 => UnaryOps::::apply(self, dst), 102 | DType::F16 => UnaryOps::::apply(self, dst), 103 | DType::BF16 => UnaryOps::::apply(self, dst), 104 | _ => todo!(), 105 | } 106 | } 107 | } 108 | 109 | macro_rules! impl_cpu_unary { 110 | ($dtype:ident) => { 111 | impl_cpu_unary!($dtype, |x| x); 112 | }; 113 | ($dtype:ident, $conv:expr) => { 114 | impl_unary_ops!($dtype, $conv); 115 | }; 116 | } 117 | 118 | impl_cpu_unary!(f32); 119 | impl_cpu_unary!(f16, f16::from_f32); 120 | impl_cpu_unary!(bf16, bf16::from_f32); 121 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/device.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | gpu::{AllocatorError, PoolError, WgpuDevice}, 3 | DType, 4 | }; 5 | 6 | #[derive(Clone, Debug, thiserror::Error)] 7 | pub enum DeviceError { 8 | #[error("Failed to acquire device with error: {0:?}")] 9 | DeviceAcquisitionFailed(#[from] wgpu::RequestDeviceError), 10 | #[error("Failed to request adapter required for WebGPU. Please ensure that your browser supports WebGPU. 11 | (Chrome 121+, Firefox Nightly, Edge & all Chromium based browsers)")] 12 | AdapterRequestFailed, 13 | #[error("Failed to create storage with error: {0:?}")] 14 | StorageCreationFailed(#[from] PoolError), //TODO: shouldn't be PoolError 15 | #[error("Device mismatch, requested device: {0:?}, actual device: {1:?}")] 16 | DeviceMismatch(String, String), 17 | #[error("Failed to allocate buffer with error: {0:?}")] 18 | BufferAllocationFailed(#[from] AllocatorError), 19 | #[error("Invalid GPU Buffer Usage, current: {0:?}, required: {1:?}")] 20 | InvalidBufferUsage(wgpu::BufferUsages, wgpu::BufferUsages), 21 | #[error("Failed to transfer buffer with error: {0:?}")] 22 | BufferTransferFailed(#[from] wgpu::BufferAsyncError), 23 | } 24 | 25 | pub enum DeviceRequest { 26 | CPU, 27 | GPU, 28 | } 29 | 30 | #[derive(Clone, Default, PartialEq)] 31 | pub enum Device { 32 | #[default] 33 | CPU, 34 | GPU(WgpuDevice), 35 | } 36 | 37 | impl std::fmt::Debug for Device { 38 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 39 | match self { 40 | Device::CPU => write!(f, "CPU"), 41 | Device::GPU(gpu) => write!(f, "GPU:{}", gpu.ordinal()), 42 | } 43 | } 44 | } 45 | 46 | impl Device { 47 | pub fn is_cpu(&self) -> bool { 48 | matches!(self, Device::CPU) 49 | } 50 | 51 | pub fn is_gpu(&self) -> bool { 52 | matches!(self, Device::GPU(_)) 53 | } 54 | 55 | pub fn compute_precision(&self) -> DType { 56 | match self { 57 | Device::CPU => DType::F16, 58 | Device::GPU(gpu) => gpu.compute_features().compute_precision(), 59 | } 60 | } 61 | 62 | #[cfg(target_arch = "wasm32")] 63 | pub async fn request_device(request: DeviceRequest) -> Result { 64 | match request { 65 | DeviceRequest::CPU => Ok(Device::CPU), 66 | DeviceRequest::GPU => Ok(Device::GPU(WgpuDevice::new().await?)), 67 | } 68 | } 69 | 70 | #[cfg(not(target_arch = "wasm32"))] 71 | pub fn request_device(request: DeviceRequest) -> Result { 72 | match request { 73 | DeviceRequest::CPU => Ok(Device::CPU), 74 | DeviceRequest::GPU => Ok(Device::GPU(pollster::block_on(async { 75 | WgpuDevice::new().await 76 | })?)), 77 | } 78 | } 79 | 80 | pub fn label(&self) -> String { 81 | format!("{:?}", self) 82 | } 83 | 84 | pub fn try_gpu(&self) -> Result<&WgpuDevice, DeviceError> { 85 | match self { 86 | Device::GPU(gpu) => Ok(gpu), 87 | Device::CPU => Err(DeviceError::DeviceMismatch( 88 | "GPU".to_string(), 89 | "CPU".to_string(), 90 | )), 91 | } 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/enforcer.rs: -------------------------------------------------------------------------------- 1 | use std::ops::RangeInclusive; 2 | 3 | use crate::{DType, Shape}; 4 | 5 | #[derive(Debug, thiserror::Error)] 6 | pub enum InvariantError { 7 | #[error("Shape mismatch at {left},{right}, {a} != {b}.")] 8 | ShapeMismatch { 9 | left: usize, 10 | right: usize, 11 | a: usize, //TODO: RDim 12 | b: usize, 13 | }, 14 | #[error("Rank mismatch. {accepted:?} != {actual}.")] 15 | RankMismatch { 16 | accepted: RangeInclusive, 17 | actual: usize, 18 | }, 19 | #[error("Wrong input arity. Allowed range is {accepted:?}, node has {actual}.")] 20 | InputArity { 21 | accepted: RangeInclusive, 22 | actual: usize, 23 | }, 24 | #[error("Wrong output arity. Allowed is {accepted:?}, node has {actual}.")] 25 | OutputArity { 26 | accepted: RangeInclusive, 27 | actual: usize, 28 | }, 29 | #[error("DType mismatch, expected {expected:?}, got {actual:?}.")] 30 | DTypeMismatch { expected: DType, actual: DType }, 31 | #[error("Unsupported DType {0:?}.")] 32 | UnsupportedDType(DType), 33 | #[error("Duplicate dims in permutation.")] 34 | DuplicateDims, 35 | #[error("Broadcasting failed: {0:?}")] 36 | BroadcastingFailed(Vec), 37 | #[error("Dim out of range {dim} in shape {shape:?}.")] 38 | DimOutOfRange { dim: usize, shape: Shape }, 39 | } 40 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/align.rs: -------------------------------------------------------------------------------- 1 | ///WebGPU is very specific about buffer alignment. 2 | ///Since in Ratchet, any buffer may be copied back from GPU -> CPU, all buffers have a size 3 | ///that is a multiple of COPY_BUFFER_ALIGNMENT (4 bytes). 4 | /// 5 | ///However, WebGPU also has more stringent alignment for storage buffer offsets. 6 | ///This is controlled by `min_storage_buffer_offset_alignment` in wgpu::Limits. 7 | ///This defaults to 256 8 | /// 9 | ///For quantized data types in Ratchet, each "segment" of quantized block (mins, scales, qs, zero 10 | ///point etc.) is extracted and put into separate segments. Thus, these segments must be aligned to 11 | ///256. 12 | 13 | ///The `Align` trait provides methods to calculate the alignment of a usize, and to align a usize 14 | pub trait Align { 15 | const STORAGE_BUFFER_OFFSET_ALIGNMENT: usize = 256; 16 | const COPY_BUFFER_ALIGNMENT: usize = 4; 17 | 18 | fn calculate_alignment(&self, alignment: usize) -> usize; 19 | fn align_for_copy(&self) -> usize; 20 | fn align_for_offset(&self) -> usize; 21 | } 22 | 23 | impl Align for usize { 24 | fn calculate_alignment(&self, alignment: usize) -> usize { 25 | let remainder = self % alignment; 26 | if remainder == 0 { 27 | 0 28 | } else { 29 | alignment - remainder 30 | } 31 | } 32 | 33 | fn align_for_copy(&self) -> usize { 34 | self + self.calculate_alignment(Self::COPY_BUFFER_ALIGNMENT) 35 | } 36 | 37 | fn align_for_offset(&self) -> usize { 38 | self + self.calculate_alignment(Self::STORAGE_BUFFER_OFFSET_ALIGNMENT) 39 | } 40 | } 41 | 42 | pub trait Padding { 43 | //Pad the vector to the next multiple of 256. 44 | fn pad_to_offset(&mut self) -> usize; 45 | 46 | //Pad the vector to the next multiple of 4. 47 | fn pad_to_copy(&mut self) -> usize; 48 | } 49 | 50 | impl Padding for Vec { 51 | fn pad_to_copy(&mut self) -> usize { 52 | let length = &self.len(); 53 | let alignment = length.calculate_alignment(4); 54 | if alignment != 0 { 55 | let default_value: T = Default::default(); 56 | let mut padding = vec![default_value; alignment]; 57 | self.append(&mut padding); 58 | alignment 59 | } else { 60 | 0 61 | } 62 | } 63 | 64 | fn pad_to_offset(&mut self) -> usize { 65 | let length = &self.len(); 66 | let alignment = length.calculate_alignment(256); 67 | if alignment != 0 { 68 | let default_value: T = Default::default(); 69 | let mut padding = vec![default_value; alignment]; 70 | self.append(&mut padding); 71 | alignment 72 | } else { 73 | 0 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/buffer_allocator/mod.rs: -------------------------------------------------------------------------------- 1 | mod allocator; 2 | mod tensor_usage_record; 3 | 4 | pub use allocator::*; 5 | pub use tensor_usage_record::*; 6 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/buffer_allocator/tensor_usage_record.rs: -------------------------------------------------------------------------------- 1 | use crate::TensorId; 2 | use rustc_hash::FxHashMap; 3 | use std::cmp::Reverse; 4 | 5 | /// Records the interval for which a tensor is used 6 | /// produce & last_consumer as indices into the topologically sorted execution order 7 | #[derive(Debug, Clone, PartialEq, Eq)] 8 | pub struct TensorUsageRecord { 9 | pub id: Option, 10 | pub producer: Option, 11 | pub last_consumer: usize, 12 | #[cfg(debug_assertions)] 13 | pub last_consumer_id: TensorId, 14 | pub size: usize, 15 | } 16 | 17 | impl std::ops::Index for TensorUsageRecords { 18 | type Output = TensorUsageRecord; 19 | 20 | fn index(&self, index: usize) -> &Self::Output { 21 | &self.0[index] 22 | } 23 | } 24 | 25 | impl std::ops::IndexMut for TensorUsageRecords { 26 | fn index_mut(&mut self, index: usize) -> &mut Self::Output { 27 | &mut self.0[index] 28 | } 29 | } 30 | 31 | #[derive(Debug, Clone)] 32 | pub struct TensorUsageRecords(pub Vec); 33 | 34 | impl From> for TensorUsageRecords { 35 | fn from(mut map: FxHashMap) -> Self { 36 | let mut records = map.drain().map(|(_, v)| v).collect::>(); 37 | records.sort_unstable_by_key(|r| Reverse(r.size)); 38 | TensorUsageRecords(records) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/mod.rs: -------------------------------------------------------------------------------- 1 | mod align; 2 | mod buffer_allocator; 3 | mod device; 4 | mod pools; 5 | mod uniform; 6 | mod wgsl; 7 | mod workload; 8 | 9 | #[cfg(feature = "gpu-profiling")] 10 | mod profiler; 11 | 12 | pub use align::*; 13 | pub use buffer_allocator::*; 14 | pub use device::*; 15 | pub use pools::*; 16 | pub use uniform::*; 17 | pub use wgsl::*; 18 | pub use workload::*; 19 | 20 | #[cfg(feature = "gpu-profiling")] 21 | pub use profiler::*; 22 | 23 | pub const MIN_STORAGE_BUFFER_SIZE: usize = 16; 24 | pub const STORAGE_BUFFER_ALIGN: usize = 256; //TODO: should be a device limit 25 | 26 | /// Usages we use everywhere 27 | pub trait BufferUsagesExt { 28 | fn standard() -> Self; 29 | } 30 | 31 | impl BufferUsagesExt for wgpu::BufferUsages { 32 | fn standard() -> Self { 33 | Self::COPY_DST | Self::COPY_SRC | Self::STORAGE 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/pools/bind_group_layout_pool.rs: -------------------------------------------------------------------------------- 1 | use crate::{gpu::WgpuDevice, rvec, RVec}; 2 | 3 | use super::{static_resource_pool::StaticResourcePool, StaticResourcePoolReadLockAccessor}; 4 | 5 | pub trait BindGroupLayoutEntryExt { 6 | fn compute_storage_buffer(binding: u32, read_only: bool) -> Self; 7 | fn dynamic_uniform_buffer() -> Self; 8 | } 9 | 10 | impl BindGroupLayoutEntryExt for wgpu::BindGroupLayoutEntry { 11 | fn compute_storage_buffer(binding: u32, read_only: bool) -> Self { 12 | Self { 13 | binding, 14 | visibility: wgpu::ShaderStages::COMPUTE, 15 | ty: wgpu::BindingType::Buffer { 16 | ty: wgpu::BufferBindingType::Storage { read_only }, 17 | min_binding_size: None, 18 | has_dynamic_offset: false, 19 | }, 20 | count: None, 21 | } 22 | } 23 | 24 | fn dynamic_uniform_buffer() -> Self { 25 | Self { 26 | binding: 0, 27 | visibility: wgpu::ShaderStages::COMPUTE, 28 | ty: wgpu::BindingType::Buffer { 29 | ty: wgpu::BufferBindingType::Uniform, 30 | min_binding_size: None, 31 | has_dynamic_offset: true, 32 | }, 33 | count: None, 34 | } 35 | } 36 | } 37 | 38 | slotmap::new_key_type! { pub struct BindGroupLayoutHandle; } 39 | 40 | #[derive(Debug, Clone, Hash, PartialEq, Eq, Default)] 41 | pub struct BindGroupLayoutDescriptor { 42 | pub entries: RVec, 43 | } 44 | 45 | impl BindGroupLayoutDescriptor { 46 | //Used for unary, binary, ternary (NOT INPLACE) 47 | fn entries(ro_length: usize) -> RVec { 48 | let mut read_only: RVec = (0..ro_length) 49 | .map(|idx| wgpu::BindGroupLayoutEntry::compute_storage_buffer(idx as u32, true)) 50 | .collect(); 51 | read_only.push(wgpu::BindGroupLayoutEntry::compute_storage_buffer( 52 | ro_length as u32, 53 | false, 54 | )); 55 | read_only 56 | } 57 | 58 | pub fn unary() -> Self { 59 | Self { 60 | entries: Self::entries(1), 61 | } 62 | } 63 | 64 | pub fn unary_inplace() -> Self { 65 | Self { 66 | entries: rvec![wgpu::BindGroupLayoutEntry::compute_storage_buffer(0, false)], 67 | } 68 | } 69 | 70 | pub fn binary() -> Self { 71 | Self { 72 | entries: Self::entries(2), 73 | } 74 | } 75 | 76 | pub fn binary_inplace() -> Self { 77 | Self { 78 | entries: rvec![ 79 | wgpu::BindGroupLayoutEntry::compute_storage_buffer(0, false), 80 | wgpu::BindGroupLayoutEntry::compute_storage_buffer(1, true) 81 | ], 82 | } 83 | } 84 | 85 | pub fn ternary() -> Self { 86 | Self { 87 | entries: Self::entries(3), 88 | } 89 | } 90 | 91 | pub fn nthary(ro: usize) -> Self { 92 | Self { 93 | entries: Self::entries(ro), 94 | } 95 | } 96 | 97 | pub fn uniform() -> Self { 98 | Self { 99 | entries: rvec![wgpu::BindGroupLayoutEntry::dynamic_uniform_buffer()], 100 | } 101 | } 102 | } 103 | 104 | pub struct BindGroupLayoutPool { 105 | inner: 106 | StaticResourcePool, 107 | } 108 | 109 | impl Default for BindGroupLayoutPool { 110 | fn default() -> Self { 111 | Self::new() 112 | } 113 | } 114 | 115 | impl BindGroupLayoutPool { 116 | pub fn new() -> Self { 117 | Self { 118 | inner: StaticResourcePool::default(), 119 | } 120 | } 121 | } 122 | 123 | impl BindGroupLayoutPool { 124 | pub fn get_or_create( 125 | &self, 126 | descriptor: &BindGroupLayoutDescriptor, 127 | device: &WgpuDevice, 128 | ) -> BindGroupLayoutHandle { 129 | self.inner.get_or_create(descriptor, |desc| { 130 | device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { 131 | label: None, 132 | entries: &desc.entries, 133 | }) 134 | }) 135 | } 136 | 137 | /// Locks the resource pool for resolving handles. 138 | /// 139 | /// While it is locked, no new resources can be added. 140 | pub fn resources( 141 | &self, 142 | ) -> StaticResourcePoolReadLockAccessor<'_, BindGroupLayoutHandle, wgpu::BindGroupLayout> { 143 | self.inner.resources() 144 | } 145 | } 146 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/pools/buffer_pool.rs: -------------------------------------------------------------------------------- 1 | use std::sync::Arc; 2 | 3 | // Adapted from https://github.com/rerun-io/rerun MIT licensed 4 | use super::{DynamicResource, DynamicResourcePool, DynamicResourcesDesc, PoolError}; 5 | use crate::{ 6 | gpu::{WgpuDevice, MIN_STORAGE_BUFFER_SIZE}, 7 | RawGPUBuffer, 8 | }; 9 | 10 | #[derive(Clone, Hash, PartialEq, Eq, Debug, derive_new::new)] 11 | pub struct BufferDescriptor { 12 | pub size: wgpu::BufferAddress, 13 | pub usage: wgpu::BufferUsages, 14 | pub mapped_at_creation: bool, 15 | } 16 | 17 | impl BufferDescriptor { 18 | pub fn fields(&self) -> (wgpu::BufferAddress, wgpu::BufferUsages, bool) { 19 | (self.size, self.usage, self.mapped_at_creation) 20 | } 21 | } 22 | 23 | //All slotmap keys are COPY 24 | slotmap::new_key_type! { pub struct GpuBufferHandle; } 25 | 26 | /// A reference-counter baked buffer. 27 | /// Once all instances are dropped, the buffer will be marked for reclamation in the following pass. 28 | #[derive(Debug, Clone)] 29 | pub struct PooledGPUBuffer(Arc>); 30 | 31 | impl std::ops::Deref for PooledGPUBuffer { 32 | type Target = Arc>; 33 | 34 | fn deref(&self) -> &Self::Target { 35 | &self.0 36 | } 37 | } 38 | 39 | impl PartialEq for PooledGPUBuffer { 40 | fn eq(&self, other: &Self) -> bool { 41 | self.0.inner.global_id() == other.0.inner.global_id() 42 | } 43 | } 44 | 45 | impl DynamicResourcesDesc for BufferDescriptor { 46 | fn resource_size_in_bytes(&self) -> u64 { 47 | self.size 48 | } 49 | 50 | fn allow_reuse(&self) -> bool { 51 | if std::env::var("RATCHET_DEBUG").is_ok() { 52 | false 53 | } else { 54 | !self.mapped_at_creation 55 | } 56 | } 57 | } 58 | 59 | pub struct BufferPool { 60 | inner: DynamicResourcePool, 61 | } 62 | 63 | impl Default for BufferPool { 64 | fn default() -> Self { 65 | Self::new() 66 | } 67 | } 68 | 69 | impl BufferPool { 70 | pub fn new() -> Self { 71 | Self { 72 | inner: DynamicResourcePool::default(), 73 | } 74 | } 75 | 76 | pub fn get_or_create( 77 | &self, 78 | desc: &BufferDescriptor, 79 | device: &WgpuDevice, 80 | immediate: bool, 81 | ) -> PooledGPUBuffer { 82 | let size = if (desc.size as usize) < MIN_STORAGE_BUFFER_SIZE { 83 | //All buffers must be minimum 16 bytes 84 | MIN_STORAGE_BUFFER_SIZE as _ 85 | } else { 86 | //Round all buffers to 4 bytes, as any buffer may be read back to the CPU, which 87 | //requires a copy 88 | if desc.size % wgpu::COPY_BUFFER_ALIGNMENT == 0 { 89 | desc.size 90 | } else { 91 | desc.size + wgpu::COPY_BUFFER_ALIGNMENT - (desc.size % wgpu::COPY_BUFFER_ALIGNMENT) 92 | } 93 | }; 94 | 95 | let descriptor = BufferDescriptor { 96 | size, 97 | usage: desc.usage, 98 | mapped_at_creation: desc.mapped_at_creation, 99 | }; 100 | 101 | PooledGPUBuffer(self.inner.get_or_create(&descriptor, |descriptor| { 102 | let (size, usage, mapped_at_creation) = descriptor.fields(); 103 | let buf = device.create_buffer(&wgpu::BufferDescriptor { 104 | label: None, 105 | size, 106 | usage, 107 | mapped_at_creation, 108 | }); 109 | if immediate { 110 | device.queue().submit(None); 111 | device.poll(wgpu::Maintain::Wait); 112 | } 113 | buf 114 | })) 115 | } 116 | 117 | pub fn begin_pass(&mut self, pass_index: u64) { 118 | self.inner.begin_pass(pass_index, |res| res.destroy()); 119 | } 120 | 121 | /// Method to retrieve a resource from a weak handle (used by [`super::GpuBindGroupPool`]) 122 | pub fn get(&self, handle: GpuBufferHandle) -> Result { 123 | Ok(PooledGPUBuffer(self.inner.get_from_handle(handle)?)) 124 | } 125 | 126 | pub fn all_resources(&self) -> Vec { 127 | self.inner 128 | .all_resources() 129 | .into_iter() 130 | .map(PooledGPUBuffer) 131 | .collect::>() 132 | } 133 | 134 | pub fn num_resources(&self) -> usize { 135 | self.inner.num_resources() 136 | } 137 | 138 | pub fn total_gpu_size_in_bytes(&self) -> u64 { 139 | self.inner.total_resource_size_in_bytes() 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/pools/kernel_module_pool.rs: -------------------------------------------------------------------------------- 1 | use crate::{Kernel, KernelKey, KernelSource, OperationError, Tensor, WgpuDevice, WorkgroupSize}; 2 | 3 | use super::static_resource_pool::{StaticResourcePool, StaticResourcePoolReadLockAccessor}; 4 | use std::hash::Hash; 5 | 6 | slotmap::new_key_type! { pub struct KernelModuleHandle; } 7 | 8 | #[derive(Clone, PartialEq, Eq, Debug, Hash)] 9 | pub struct KernelModuleDesc { 10 | /// Unique identifier for the kernel module. 11 | /// e.g softmax_vec4_f32_128_1_1 12 | pub key: KernelKey, 13 | } 14 | 15 | impl KernelModuleDesc { 16 | #[track_caller] 17 | pub fn create_kernel_source( 18 | &self, 19 | op: &O, 20 | inplace: bool, 21 | dst: &Tensor, 22 | workgroup_size: &WorkgroupSize, 23 | ) -> Result { 24 | op.build_kernel(inplace, dst, workgroup_size) 25 | } 26 | } 27 | 28 | #[derive(Default)] 29 | pub struct KernelModulePool { 30 | pool: StaticResourcePool, 31 | } 32 | 33 | impl KernelModulePool { 34 | pub fn new() -> Self { 35 | Self { 36 | pool: StaticResourcePool::default(), 37 | } 38 | } 39 | 40 | pub fn get_or_create( 41 | &self, 42 | desc: &KernelModuleDesc, 43 | kernel: &K, 44 | inplace: bool, 45 | dst: &Tensor, 46 | workgroup_size: &WorkgroupSize, 47 | device: &WgpuDevice, 48 | ) -> KernelModuleHandle { 49 | self.pool.get_or_create(desc, |desc| { 50 | log::info!("Creating kernel module: {}", desc.key); 51 | let source = desc 52 | .create_kernel_source(kernel, inplace, dst, workgroup_size) 53 | .expect("Failed to create kernel source"); 54 | 55 | let shader_module_desc = wgpu::ShaderModuleDescriptor { 56 | label: Some(desc.key.as_str()), 57 | source: source.into(), 58 | }; 59 | 60 | if std::env::var("RATCHET_CHECKED").is_ok() { 61 | log::warn!("Using checked shader compilation"); 62 | device.create_shader_module(shader_module_desc) 63 | } else { 64 | unsafe { device.create_shader_module_unchecked(shader_module_desc) } 65 | } 66 | }) 67 | } 68 | 69 | /// Locks the resource pool for resolving handles. 70 | /// 71 | /// While it is locked, no new resources can be added. 72 | pub fn resources( 73 | &self, 74 | ) -> StaticResourcePoolReadLockAccessor<'_, KernelModuleHandle, wgpu::ShaderModule> { 75 | self.pool.resources() 76 | } 77 | 78 | pub fn num_resources(&self) -> usize { 79 | self.pool.num_resources() 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/pools/mod.rs: -------------------------------------------------------------------------------- 1 | mod bind_group_layout_pool; 2 | mod bind_group_pool; 3 | mod buffer_pool; 4 | mod dynamic_resource_pool; 5 | mod kernel_module_pool; 6 | mod pipeline_layout_pool; 7 | mod pipeline_pool; 8 | mod static_resource_pool; 9 | 10 | pub use bind_group_layout_pool::*; 11 | pub use bind_group_pool::*; 12 | pub use buffer_pool::*; 13 | pub use dynamic_resource_pool::*; 14 | pub use kernel_module_pool::*; 15 | pub use pipeline_layout_pool::*; 16 | pub use pipeline_pool::*; 17 | pub use static_resource_pool::*; 18 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/pools/pipeline_layout_pool.rs: -------------------------------------------------------------------------------- 1 | use crate::{gpu::WgpuDevice, RVec}; 2 | 3 | use super::{ 4 | static_resource_pool::{ 5 | StaticResourcePool, StaticResourcePoolAccessor as _, StaticResourcePoolReadLockAccessor, 6 | }, 7 | BindGroupLayoutHandle, 8 | }; 9 | 10 | slotmap::new_key_type! { pub struct PipelineLayoutHandle; } 11 | 12 | #[derive(Debug, Clone, Hash, PartialEq, Eq)] 13 | pub struct PipelineLayoutDescriptor { 14 | pub entries: RVec, 15 | } 16 | 17 | #[derive(Default)] 18 | pub(crate) struct PipelineLayoutPool { 19 | inner: StaticResourcePool, 20 | } 21 | 22 | impl PipelineLayoutPool { 23 | pub fn new() -> Self { 24 | Self { 25 | inner: StaticResourcePool::default(), 26 | } 27 | } 28 | 29 | pub fn get_or_create( 30 | &self, 31 | desc: &PipelineLayoutDescriptor, 32 | device: &WgpuDevice, 33 | ) -> PipelineLayoutHandle { 34 | self.inner.get_or_create(desc, |desc| { 35 | let bind_groups = device.bind_group_layout_resources(); 36 | 37 | device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { 38 | label: None, 39 | bind_group_layouts: &desc 40 | .entries 41 | .iter() 42 | .map(|handle| bind_groups.get(*handle).unwrap()) 43 | .collect::>(), 44 | push_constant_ranges: &[], 45 | }) 46 | }) 47 | } 48 | 49 | /// Locks the resource pool for resolving handles. 50 | /// 51 | /// While it is locked, no new resources can be added. 52 | pub fn resources( 53 | &self, 54 | ) -> StaticResourcePoolReadLockAccessor<'_, PipelineLayoutHandle, wgpu::PipelineLayout> { 55 | self.inner.resources() 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/pools/pipeline_pool.rs: -------------------------------------------------------------------------------- 1 | use crate::{gpu::WgpuDevice, KernelKey, KernelModuleHandle}; 2 | 3 | use super::{ 4 | PipelineLayoutHandle, StaticResourcePool, StaticResourcePoolAccessor, 5 | StaticResourcePoolReadLockAccessor, 6 | }; 7 | 8 | slotmap::new_key_type! { pub struct ComputePipelineHandle; } 9 | 10 | #[derive(Debug, Clone, Hash, PartialEq, Eq)] 11 | pub struct ComputePipelineDescriptor { 12 | pub pipeline_layout: PipelineLayoutHandle, 13 | pub kernel_key: KernelKey, 14 | pub kernel_module: KernelModuleHandle, 15 | } 16 | 17 | pub struct ComputePipelinePool { 18 | inner: 19 | StaticResourcePool, 20 | } 21 | 22 | impl Default for ComputePipelinePool { 23 | fn default() -> Self { 24 | Self::new() 25 | } 26 | } 27 | 28 | impl ComputePipelinePool { 29 | pub fn new() -> Self { 30 | Self { 31 | inner: StaticResourcePool::default(), 32 | } 33 | } 34 | 35 | pub fn get_or_create( 36 | &self, 37 | desc: &ComputePipelineDescriptor, 38 | device: &WgpuDevice, 39 | ) -> ComputePipelineHandle { 40 | self.inner.get_or_create(desc, |desc| { 41 | let label = Some(desc.kernel_key.as_str()); 42 | //println!("LABEL: {:?}", label); 43 | let kernel_resources = device.kernel_module_resources(); 44 | 45 | let module = kernel_resources.get(desc.kernel_module).unwrap(); 46 | 47 | let pipeline_layouts = device.pipeline_layout_resources(); 48 | let pipeline_layout = pipeline_layouts.get(desc.pipeline_layout).unwrap(); 49 | 50 | device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { 51 | label, 52 | layout: Some(pipeline_layout), 53 | module, 54 | entry_point: Some("main"), 55 | compilation_options: wgpu::PipelineCompilationOptions { 56 | zero_initialize_workgroup_memory: false, 57 | ..Default::default() 58 | }, 59 | cache: None, 60 | }) 61 | }) 62 | } 63 | 64 | /// Locks the resource pool for resolving handles. 65 | /// 66 | /// While it is locked, no new resources can be added. 67 | pub fn resources( 68 | &self, 69 | ) -> StaticResourcePoolReadLockAccessor<'_, ComputePipelineHandle, wgpu::ComputePipeline> { 70 | self.inner.resources() 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/pools/static_resource_pool.rs: -------------------------------------------------------------------------------- 1 | //Adapted from https://github.com/rerun-io/rerun MIT licensed. 2 | use std::hash::Hash; 3 | 4 | use parking_lot::{RwLock, RwLockReadGuard}; 5 | use rustc_hash::FxHashMap; 6 | use slotmap::{Key, SlotMap}; 7 | 8 | #[derive(thiserror::Error, Debug, PartialEq, Eq, Clone)] 9 | pub enum PoolError { 10 | #[error("Requested resource isn't available because the handle is no longer valid")] 11 | ResourceNotAvailable, 12 | 13 | #[error("The passed resource handle was null")] 14 | NullHandle, 15 | 16 | #[error("The passed descriptor doesn't refer to a known resource")] 17 | UnknownDescriptor, 18 | } 19 | 20 | /// Generic resource pool for all resources that are fully described upon creation, i.e. never have any variable content. 21 | /// 22 | /// This implies, a resource is uniquely defined by its description. 23 | /// We call these resources "static" because they never change their content over their lifetime. 24 | /// 25 | /// Lookup is queried to determine if a resource with the given descriptor already exists. 26 | pub(super) struct StaticResourcePool { 27 | resources: RwLock>, 28 | lookup: RwLock>, 29 | } 30 | 31 | /// We cannot #derive(Default) as that would require Handle/Desc/Res to implement Default too. 32 | impl Default for StaticResourcePool { 33 | fn default() -> Self { 34 | Self { 35 | resources: Default::default(), 36 | lookup: Default::default(), 37 | } 38 | } 39 | } 40 | 41 | impl StaticResourcePool 42 | where 43 | Handle: Key, 44 | Descriptor: std::fmt::Debug + Clone + Eq + Hash, 45 | { 46 | pub fn get_or_create Resource>( 47 | &self, 48 | descriptor: &Descriptor, 49 | constructor: C, 50 | ) -> Handle { 51 | // Ensure the lock isn't held in the creation case. 52 | if let Some(handle) = self.lookup.read().get(descriptor) { 53 | return *handle; 54 | } 55 | 56 | let resource = constructor(descriptor); 57 | let handle = self.resources.write().insert(resource); 58 | self.lookup.write().insert(descriptor.clone(), handle); 59 | 60 | handle 61 | } 62 | 63 | /// Locks the resource pool for resolving handles. 64 | /// 65 | /// While it is locked, no new resources can be added. 66 | pub fn resources(&self) -> StaticResourcePoolReadLockAccessor<'_, Handle, Resource> { 67 | StaticResourcePoolReadLockAccessor { 68 | resources: self.resources.read(), 69 | } 70 | } 71 | 72 | pub fn num_resources(&self) -> usize { 73 | self.resources.read().len() 74 | } 75 | } 76 | 77 | /// Accessor to the resource pool, either by taking a read lock or by moving out the resources. 78 | pub trait StaticResourcePoolAccessor { 79 | fn get(&self, handle: Handle) -> Result<&Res, PoolError>; 80 | } 81 | 82 | /// Accessor to the resource pool by taking a read lock. 83 | pub struct StaticResourcePoolReadLockAccessor<'a, Handle: Key, Res> { 84 | resources: RwLockReadGuard<'a, SlotMap>, 85 | } 86 | 87 | fn to_pool_error(get_result: Option, handle: impl Key) -> Result { 88 | get_result.ok_or_else(|| { 89 | if handle.is_null() { 90 | PoolError::NullHandle 91 | } else { 92 | PoolError::ResourceNotAvailable 93 | } 94 | }) 95 | } 96 | 97 | impl<'a, Handle: Key, Res> StaticResourcePoolAccessor 98 | for StaticResourcePoolReadLockAccessor<'a, Handle, Res> 99 | { 100 | fn get(&self, handle: Handle) -> Result<&Res, PoolError> { 101 | to_pool_error(self.resources.get(handle), handle) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/uniform.rs: -------------------------------------------------------------------------------- 1 | use std::num::NonZeroU64; 2 | 3 | use crate::{ 4 | gpu::{BindGroupEntry, BindGroupLayoutDescriptor}, 5 | rvec, OperationError, 6 | }; 7 | 8 | use super::{BindGroupDescriptor, GpuBindGroup, PooledGPUBuffer, WgpuDevice}; 9 | use encase::DynamicUniformBuffer; 10 | 11 | ///We use a single uniform buffer for all operations to hold their parameters. 12 | ///Every operation writes its metadata into this buffer, and an offset is returned. 13 | ///This offset is used when binding the buffer. 14 | pub struct CpuUniform(DynamicUniformBuffer>); 15 | 16 | ///Uniforms must be 256-byte aligned, encase handles this for us. 17 | pub const UNIFORM_ALIGN: usize = 256; 18 | pub const DEFAULT_UNIFORM_SIZE: usize = 16384; 19 | 20 | impl Default for CpuUniform { 21 | fn default() -> Self { 22 | Self::new() 23 | } 24 | } 25 | 26 | impl CpuUniform { 27 | pub fn new() -> Self { 28 | Self(DynamicUniformBuffer::new(Vec::with_capacity( 29 | DEFAULT_UNIFORM_SIZE, 30 | ))) 31 | } 32 | 33 | pub fn into_inner(self) -> Vec { 34 | self.0.into_inner() 35 | } 36 | 37 | /// Consumes the CPU repr of the uniform buffer and writes to the GPU. 38 | pub(crate) fn into_gpu(self, device: &WgpuDevice) -> Result { 39 | let buf = device.create_uniform_init(self); 40 | let layout = 41 | device.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::uniform())?; 42 | let bind_group = device.get_or_create_bind_group(&BindGroupDescriptor { 43 | entries: rvec![BindGroupEntry { 44 | handle: buf.handle, 45 | offset: 0, 46 | size: NonZeroU64::new(UNIFORM_ALIGN as u64), 47 | }], 48 | layout, 49 | })?; 50 | 51 | Ok(GpuUniform { buf, bind_group }) 52 | } 53 | } 54 | 55 | pub struct GpuUniform { 56 | buf: PooledGPUBuffer, 57 | bind_group: GpuBindGroup, 58 | } 59 | 60 | impl GpuUniform { 61 | pub fn bind_group(&self) -> &GpuBindGroup { 62 | &self.bind_group 63 | } 64 | } 65 | 66 | impl std::ops::Deref for CpuUniform { 67 | type Target = DynamicUniformBuffer>; 68 | 69 | fn deref(&self) -> &Self::Target { 70 | &self.0 71 | } 72 | } 73 | 74 | impl std::ops::DerefMut for CpuUniform { 75 | fn deref_mut(&mut self) -> &mut Self::Target { 76 | &mut self.0 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/wgsl/access_granularity.rs: -------------------------------------------------------------------------------- 1 | use super::dtype::WgslDType; 2 | 3 | /// WGSL types which are used to access buffers. 4 | pub trait WgslPrimitive: std::fmt::Display + Default + Clone + Copy { 5 | type T: WgslDType; 6 | const W: usize; 7 | 8 | fn render_type() -> String; 9 | } 10 | 11 | #[derive(Clone, Copy)] 12 | pub struct WgslVec { 13 | inner: [T; N], 14 | } 15 | 16 | impl Default for WgslVec { 17 | fn default() -> Self { 18 | WgslVec { 19 | inner: [T::default(); N], 20 | } 21 | } 22 | } 23 | 24 | impl std::fmt::Display for WgslVec { 25 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 26 | for i in 0..N - 2 { 27 | write!(f, "{}, ", self.inner[i].render())?; 28 | } 29 | write!(f, "{}", self.inner[N - 1].render())?; 30 | Ok(()) 31 | } 32 | } 33 | 34 | pub type Scalar = WgslVec; 35 | pub type Vec2 = WgslVec; 36 | pub type Vec3 = WgslVec; 37 | pub type Vec4 = WgslVec; 38 | 39 | impl WgslPrimitive for Vec4 { 40 | type T = T; 41 | const W: usize = 4; 42 | 43 | fn render_type() -> String { 44 | format!("vec4<{}>", T::DT) 45 | } 46 | } 47 | 48 | impl WgslPrimitive for Vec3 { 49 | type T = T; 50 | const W: usize = 3; 51 | fn render_type() -> String { 52 | format!("vec3<{}>", T::DT) 53 | } 54 | } 55 | 56 | impl WgslPrimitive for Vec2 { 57 | type T = T; 58 | const W: usize = 2; 59 | fn render_type() -> String { 60 | format!("vec2<{}>", T::DT) 61 | } 62 | } 63 | 64 | impl WgslPrimitive for Scalar { 65 | type T = T; 66 | const W: usize = 1; 67 | fn render_type() -> String { 68 | T::DT.to_string() 69 | } 70 | } 71 | 72 | #[derive(Default, Clone, Copy)] 73 | pub struct Array { 74 | _p1: std::marker::PhantomData

, 75 | } 76 | 77 | impl std::fmt::Display for Array

{ 78 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 79 | write!(f, "array<{}>", P::render_type()) 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/wgsl/dtype.rs: -------------------------------------------------------------------------------- 1 | use half::f16; 2 | use std::fmt::{Debug, Display}; 3 | 4 | /// Supported data types in WGSL. 5 | /// 6 | /// This can be mapped to and from the Ratchet DType. 7 | pub trait WgslDType: Debug + Display + Default + Copy + num_traits::Num + num_traits::Zero { 8 | const DT: &'static str; 9 | const MIN: Self; 10 | 11 | fn render(&self) -> String; 12 | } 13 | //RENDER IS CONFUSING HERE 14 | 15 | impl WgslDType for f32 { 16 | const DT: &'static str = "f32"; 17 | const MIN: Self = -3e10; //ranges for wgsl and rust are diff 18 | 19 | fn render(&self) -> String { 20 | format!("{}f", self) 21 | } 22 | } 23 | 24 | impl WgslDType for f16 { 25 | const DT: &'static str = "f16"; 26 | const MIN: Self = f16::MIN; 27 | 28 | fn render(&self) -> String { 29 | format!("{}h", self) 30 | } 31 | } 32 | 33 | impl WgslDType for i32 { 34 | const DT: &'static str = "i32"; 35 | const MIN: Self = i32::MIN; 36 | 37 | fn render(&self) -> String { 38 | format!("{}i", self) 39 | } 40 | } 41 | 42 | impl WgslDType for u32 { 43 | const DT: &'static str = "u32"; 44 | const MIN: Self = u32::MIN; 45 | 46 | fn render(&self) -> String { 47 | format!("{}u", self) 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/wgsl/kernel_binding.rs: -------------------------------------------------------------------------------- 1 | use crate::{Ident, RenderFragment}; 2 | use inline_wgsl::wgsl; 3 | 4 | #[derive(Debug, PartialEq, Eq)] 5 | pub(crate) enum BindingType { 6 | Storage, 7 | Uniform, 8 | } 9 | 10 | #[derive(Debug, Copy, Clone)] 11 | pub(crate) enum BindingMode { 12 | ReadOnly, 13 | ReadWrite, 14 | } 15 | 16 | impl BindingMode { 17 | pub fn as_str(&self) -> &'static str { 18 | match self { 19 | BindingMode::ReadOnly => "read", 20 | BindingMode::ReadWrite => "read_write", 21 | } 22 | } 23 | } 24 | 25 | #[derive(Debug, derive_new::new)] 26 | pub struct KernelBinding { 27 | name: Ident, 28 | group: usize, 29 | binding: usize, 30 | ty: BindingType, 31 | mode: BindingMode, 32 | accessor: String, 33 | } 34 | 35 | impl From for wgpu::BindGroupLayoutEntry { 36 | fn from(val: KernelBinding) -> Self { 37 | let (binding_type, has_dynamic_offset) = match val.ty { 38 | BindingType::Storage => ( 39 | wgpu::BufferBindingType::Storage { 40 | read_only: matches!(val.mode, BindingMode::ReadOnly), 41 | }, 42 | false, 43 | ), 44 | BindingType::Uniform => (wgpu::BufferBindingType::Uniform, true), 45 | }; 46 | 47 | wgpu::BindGroupLayoutEntry { 48 | binding: val.binding as u32, 49 | visibility: wgpu::ShaderStages::COMPUTE, 50 | ty: wgpu::BindingType::Buffer { 51 | ty: binding_type, 52 | min_binding_size: None, 53 | has_dynamic_offset, 54 | }, 55 | count: None, 56 | } 57 | } 58 | } 59 | 60 | impl RenderFragment for KernelBinding { 61 | fn render(&self) -> crate::WgslFragment { 62 | let KernelBinding { 63 | name, 64 | group, 65 | binding, 66 | accessor, 67 | .. 68 | } = self; 69 | let mode = self.mode.as_str(); 70 | 71 | let result = match self.ty { 72 | BindingType::Storage => wgsl! { 73 | @group('group) @binding('binding) var 'name: 'accessor; 74 | }, 75 | BindingType::Uniform => wgsl! { 76 | @group('group) @binding('binding) var 'name: 'accessor; 77 | }, 78 | }; 79 | result.into() 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/wgsl/mod.rs: -------------------------------------------------------------------------------- 1 | mod access_granularity; 2 | pub mod dtype; 3 | mod kernel; 4 | mod kernel_binding; 5 | mod kernel_builder; 6 | 7 | pub use access_granularity::*; 8 | pub use kernel::*; 9 | pub use kernel_binding::*; 10 | pub use kernel_builder::*; 11 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/gpu/workload.rs: -------------------------------------------------------------------------------- 1 | use derive_new::new; 2 | use inline_wgsl::wgsl; 3 | 4 | use crate::KernelElement; 5 | 6 | #[derive(Debug, Clone, new, PartialEq, Eq, Hash)] 7 | pub struct WorkgroupSize { 8 | pub x: u32, 9 | pub y: u32, 10 | pub z: u32, 11 | } 12 | 13 | impl WorkgroupSize { 14 | pub fn product(&self) -> u32 { 15 | self.x * self.y * self.z 16 | } 17 | 18 | pub fn as_key(&self) -> String { 19 | format!("{}_{}_{}", self.x, self.y, self.z) 20 | } 21 | } 22 | 23 | #[macro_export] 24 | macro_rules! wgs { 25 | ($x:expr, $y:expr, $z:expr) => { 26 | $crate::gpu::WorkgroupSize::new($x, $y, $z) 27 | }; 28 | } 29 | 30 | impl std::fmt::Display for WorkgroupSize { 31 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 32 | let WorkgroupSize { x, y, z } = self; 33 | write!(f, "{}", wgsl! { @compute @workgroup_size('x, 'y, 'z) }) 34 | } 35 | } 36 | 37 | #[macro_export] 38 | macro_rules! wgc { 39 | ($x:expr, $y:expr, $z:expr) => { 40 | $crate::gpu::WorkgroupCount::new($x, $y, $z) 41 | }; 42 | } 43 | 44 | #[derive(Debug, Clone, new, PartialEq, Eq, Hash)] 45 | pub struct WorkgroupCount { 46 | x: u32, 47 | y: u32, 48 | z: u32, 49 | } 50 | 51 | impl WorkgroupCount { 52 | pub const MAX_WORKGROUP_SIZE_X: usize = 256; 53 | pub const MAX_WORKGROUP_SIZE_Y: usize = 256; 54 | pub const MAX_WORKGROUP_SIZE_Z: usize = 64; 55 | pub const MAX_WGS_PER_DIM: usize = 65535; 56 | pub const MAX_THREADS_PER_WG: usize = 256; 57 | 58 | pub fn x(&self) -> u32 { 59 | self.x 60 | } 61 | 62 | pub fn y(&self) -> u32 { 63 | self.y 64 | } 65 | 66 | pub fn z(&self) -> u32 { 67 | self.z 68 | } 69 | 70 | pub fn as_slice(&self) -> [u32; 3] { 71 | [self.x, self.y, self.z] 72 | } 73 | 74 | pub fn product(&self) -> u32 { 75 | self.x * self.y * self.z 76 | } 77 | 78 | /// Divide a number by the indicated dividend, then round up to the next multiple of the dividend if there is a rest. 79 | pub fn div_ceil(num: usize, div: usize) -> usize { 80 | num / div + (num % div != 0) as usize 81 | } 82 | } 83 | 84 | impl std::fmt::Display for WorkgroupCount { 85 | fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { 86 | write!(f, "x{}_y{}_z{}", self.x, self.y, self.z) 87 | } 88 | } 89 | 90 | impl Default for WorkgroupCount { 91 | fn default() -> Self { 92 | Self::new(1, 1, 1) 93 | } 94 | } 95 | 96 | #[derive(Debug)] 97 | pub struct Workload { 98 | pub workgroup_size: WorkgroupSize, 99 | pub workgroup_count: WorkgroupCount, 100 | } 101 | 102 | impl Workload { 103 | pub fn std(numel: usize, ke: KernelElement) -> Workload { 104 | let workgroup_size = wgs![8, 8, 1]; 105 | 106 | let numel = numel / ke.as_size(); 107 | let x_groups = WorkgroupCount::div_ceil(numel as _, workgroup_size.product() as _); 108 | let (x_groups, y_groups) = if x_groups > WorkgroupCount::MAX_WGS_PER_DIM { 109 | let y_groups = WorkgroupCount::div_ceil(x_groups, WorkgroupCount::MAX_WGS_PER_DIM); 110 | (WorkgroupCount::MAX_WGS_PER_DIM, y_groups) 111 | } else { 112 | (x_groups, 1) 113 | }; 114 | 115 | Workload { 116 | workgroup_count: wgc![x_groups as _, y_groups as _, 1], 117 | workgroup_size, 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_snake_case)] 2 | mod compiled_op; 3 | mod cpu; 4 | mod device; 5 | mod dtype; 6 | mod enforcer; 7 | mod executable; 8 | mod gpu; 9 | mod ndarray_ext; 10 | mod op; 11 | mod ops; 12 | mod plot; 13 | mod quant; 14 | mod shape; 15 | mod storage; 16 | mod strides; 17 | mod tensor; 18 | mod tensor_id; 19 | 20 | pub use compiled_op::*; 21 | pub use cpu::*; 22 | pub use device::*; 23 | pub use dtype::*; 24 | pub use enforcer::*; 25 | pub use executable::*; 26 | pub use gpu::*; 27 | pub use ndarray_ext::*; 28 | pub use op::*; 29 | pub use ops::*; 30 | pub use quant::*; 31 | pub use shape::*; 32 | pub use storage::*; 33 | pub use strides::*; 34 | pub use tensor::*; 35 | pub use tensor_id::*; 36 | 37 | #[cfg(feature = "plotting")] 38 | pub use plot::render_to_file; 39 | 40 | use smallvec::SmallVec; 41 | pub type RVec = SmallVec<[T; 4]>; 42 | pub type DRVec = SmallVec<[T; 8]>; //Double RVec 43 | pub type RawGPUBuffer = wgpu::Buffer; 44 | 45 | //https://github.com/sonos/tract/blob/main/data/src/macros.rs#L2 46 | #[macro_export] 47 | macro_rules! rvec { 48 | (@one $x:expr) => (1usize); 49 | ($elem:expr; $n:expr) => ({ 50 | $crate::RVec::from_elem($elem, $n) 51 | }); 52 | ($($x:expr),*$(,)*) => ({ 53 | let count = 0usize $(+ rvec![@one $x])*; 54 | #[allow(unused_mut)] 55 | let mut vec = $crate::RVec::new(); 56 | if count <= vec.inline_size() { 57 | $(vec.push($x);)* 58 | vec 59 | } else { 60 | $crate::RVec::from_vec(vec![$($x,)*]) 61 | } 62 | }); 63 | } 64 | 65 | #[macro_export] 66 | macro_rules! drvec { 67 | (@one $x:expr) => (1usize); 68 | ($elem:expr; $n:expr) => ({ 69 | $crate::DRVec::from_elem($elem, $n) 70 | }); 71 | ($($x:expr),*$(,)*) => ({ 72 | let count = 0usize $(+ rvec![@one $x])*; 73 | #[allow(unused_mut)] 74 | let mut vec = $crate::DRVec::new(); 75 | if count <= vec.inline_size() { 76 | $(vec.push($x);)* 77 | vec 78 | } else { 79 | $crate::DRVec::from_vec(vec![$($x,)*]) 80 | } 81 | }); 82 | } 83 | 84 | #[macro_export] 85 | macro_rules! shape { 86 | ($($x:expr),*$(,)*) => ({ 87 | use $crate::rvec; 88 | $crate::Shape::new(rvec![$($x,)*]) 89 | }); 90 | } 91 | 92 | pub mod prelude { 93 | pub use crate::{rvec, shape, Device, DeviceRequest, Tensor}; 94 | } 95 | 96 | #[cfg(feature = "pyo3")] 97 | pub mod test_util { 98 | use crate::{DType, Tensor}; 99 | use half::f16; 100 | use regex::Regex; 101 | use { 102 | numpy::PyArrayDyn, 103 | pyo3::{prelude::*, types::PyTuple}, 104 | }; 105 | 106 | /// It's a bit of a hack, but it's useful for testing. 107 | pub fn run_py_prg( 108 | prg: String, 109 | tensors: &[&Tensor], 110 | args: &[&dyn ToPyObject], 111 | dst_dtype: DType, 112 | ) -> anyhow::Result { 113 | let re = Regex::new(r"def\s+(\w+)\s*\(").unwrap(); 114 | let func = match re.captures(&prg) { 115 | Some(caps) => caps.get(1).map(|m| m.as_str()).unwrap(), 116 | None => return Err(anyhow::anyhow!("No function name found")), 117 | }; 118 | 119 | Python::with_gil(|py| { 120 | let prg = PyModule::from_code(py, &prg, "x.py", "x")?; 121 | let py_tensors = tensors.iter().map(|t| match t.dt() { 122 | DType::F32 => t.to_py::(&py).to_object(py), 123 | DType::I32 => t.to_py::(&py).to_object(py), 124 | DType::F16 => t.to_py::(&py).to_object(py), 125 | _ => unimplemented!(), 126 | }); 127 | let py_args = py_tensors 128 | .chain(args.iter().map(|a| a.to_object(py))) 129 | .collect::>(); 130 | let py_args = PyTuple::new(py, py_args); 131 | let py_result = prg.getattr(func)?.call1(py_args)?; 132 | let result: Tensor = match dst_dtype { 133 | DType::F32 => py_result.extract::<&PyArrayDyn>()?.into(), 134 | DType::F16 => py_result.extract::<&PyArrayDyn>()?.into(), 135 | DType::I32 => py_result.extract::<&PyArrayDyn>()?.into(), 136 | DType::U32 => py_result.extract::<&PyArrayDyn>()?.into(), 137 | _ => unimplemented!(), 138 | }; 139 | Ok(result) 140 | }) 141 | } 142 | } 143 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/ndarray_ext.rs: -------------------------------------------------------------------------------- 1 | use ndarray::{Array, ArrayBase, Axis, Data, Dimension, RemoveAxis, Slice}; 2 | use num_traits::{Float, FromPrimitive}; 3 | 4 | pub trait NDArrayExt 5 | where 6 | S: Data, 7 | D: Dimension, 8 | { 9 | fn logsumexp(&self, axis: usize) -> A 10 | where 11 | A: Float + FromPrimitive, 12 | D: RemoveAxis; 13 | 14 | fn log_softmax(&self, axis: usize) -> Array 15 | where 16 | A: Float + FromPrimitive, 17 | D: RemoveAxis; 18 | 19 | fn softmax(&self, axis: usize) -> Array 20 | where 21 | A: Float + FromPrimitive, 22 | D: RemoveAxis; 23 | fn pad(&self, pad_width: Vec<[usize; 2]>, const_value: A) -> Array; 24 | } 25 | 26 | impl NDArrayExt for ArrayBase 27 | where 28 | A: Clone, 29 | S: Data, 30 | D: Dimension, 31 | { 32 | fn logsumexp(&self, axis: usize) -> A 33 | where 34 | A: Float + FromPrimitive, 35 | D: RemoveAxis, 36 | { 37 | self.lanes(Axis(axis)) 38 | .into_iter() 39 | .fold(A::neg_infinity(), |log_sum_exp, lane| { 40 | let max = lane.fold(A::neg_infinity(), |a, &b| a.max(b)); 41 | let log_sum_exp_lane = max 42 | + lane 43 | .mapv(|x| { 44 | if x.is_infinite() { 45 | A::zero() 46 | } else { 47 | (x - max).exp() 48 | } 49 | }) 50 | .sum() 51 | .ln(); 52 | A::max(log_sum_exp, log_sum_exp_lane) 53 | }) 54 | } 55 | 56 | fn log_softmax(&self, axis: usize) -> Array 57 | where 58 | A: Float + FromPrimitive, 59 | D: RemoveAxis, 60 | { 61 | let mut output = self.to_owned(); 62 | output 63 | .lanes_mut(Axis(axis)) 64 | .into_iter() 65 | .for_each(|mut lane| { 66 | let max = lane.fold(A::neg_infinity(), |a, &b| a.max(b)); 67 | lane.mapv_inplace(|x| x - max); 68 | let log_sum_exp = lane 69 | .mapv(|x| if x.is_infinite() { A::zero() } else { x.exp() }) 70 | .sum() 71 | .ln(); 72 | lane.mapv_inplace(move |x| x - log_sum_exp); 73 | }); 74 | output 75 | } 76 | 77 | fn softmax(&self, axis: usize) -> Array 78 | where 79 | A: Float + FromPrimitive, 80 | D: RemoveAxis, 81 | { 82 | let mut output = self.to_owned(); 83 | output 84 | .lanes_mut(Axis(axis)) 85 | .into_iter() 86 | .for_each(|mut lane| { 87 | let max = lane.fold(A::neg_infinity(), |a, &b| a.max(b)); 88 | lane.mapv_inplace(|x| x - max); 89 | let sum_exp = lane.mapv(|x| x.exp()).sum(); 90 | lane.mapv_inplace(move |x| x.exp() / sum_exp); 91 | }); 92 | output 93 | } 94 | 95 | fn pad(&self, pad_width: Vec<[usize; 2]>, const_value: A) -> Array { 96 | assert_eq!( 97 | self.ndim(), 98 | pad_width.len(), 99 | "Array ndim must match length of `pad_width`." 100 | ); 101 | 102 | // Compute shape of final padded array. 103 | let mut padded_shape = self.raw_dim(); 104 | for (ax, (&ax_len, &[pad_lo, pad_hi])) in self.shape().iter().zip(&pad_width).enumerate() { 105 | padded_shape[ax] = ax_len + pad_lo + pad_hi; 106 | } 107 | 108 | let mut padded = Array::from_elem(padded_shape, const_value); 109 | let padded_dim = padded.raw_dim(); 110 | { 111 | // Select portion of padded array that needs to be copied from the 112 | // original array. 113 | let mut orig_portion = padded.view_mut(); 114 | for (ax, &[pad_lo, pad_hi]) in pad_width.iter().enumerate() { 115 | orig_portion.slice_axis_inplace( 116 | Axis(ax), 117 | Slice::from(pad_lo as isize..padded_dim[ax] as isize - (pad_hi as isize)), 118 | ); 119 | } 120 | // Copy the data from the original array. 121 | orig_portion.assign(self); 122 | } 123 | padded 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/ops/mod.rs: -------------------------------------------------------------------------------- 1 | mod binary; 2 | mod cache; 3 | mod cast; 4 | mod concat; 5 | mod conv; 6 | mod index_write; 7 | mod matmul; 8 | mod norm; 9 | mod reindex; 10 | mod rope; 11 | mod select; 12 | mod softmax; 13 | mod unary; 14 | mod view; 15 | 16 | pub use binary::*; 17 | pub use cache::*; 18 | pub use cast::*; 19 | pub use concat::*; 20 | pub use conv::*; 21 | pub use index_write::*; 22 | pub use matmul::*; 23 | pub use norm::*; 24 | pub use reindex::*; 25 | pub use rope::*; 26 | pub use select::*; 27 | pub use softmax::*; 28 | pub use unary::*; 29 | pub use view::*; 30 | 31 | /// # KernelElement 32 | /// 33 | /// Used to select the largest possible data type for a kernel. 34 | /// If (dimension of interest % KE) == 0, it is safe to use. 35 | #[derive(Debug, Clone, Hash, PartialEq, Eq)] 36 | pub enum KernelElement { 37 | Vec4, 38 | Vec2, 39 | Scalar, 40 | } 41 | 42 | impl KernelElement { 43 | pub fn as_size(&self) -> usize { 44 | self.into() 45 | } 46 | 47 | pub fn as_str(&self) -> &'static str { 48 | match self { 49 | KernelElement::Vec4 => "vec4", 50 | KernelElement::Vec2 => "vec2", 51 | KernelElement::Scalar => "scalar", 52 | } 53 | } 54 | } 55 | 56 | impl From<&KernelElement> for usize { 57 | fn from(item: &KernelElement) -> Self { 58 | match item { 59 | KernelElement::Vec4 => 4, 60 | KernelElement::Vec2 => 2, 61 | KernelElement::Scalar => 1, 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/ops/norm/groupnorm.rs: -------------------------------------------------------------------------------- 1 | use derive_new::new; 2 | 3 | use super::*; 4 | 5 | #[derive(new, Debug, Clone)] 6 | pub struct GroupNorm { 7 | pub norm: Norm, 8 | pub num_groups: usize, 9 | } 10 | 11 | #[cfg(all(test, feature = "pyo3"))] 12 | mod tests { 13 | use test_strategy::{proptest, Arbitrary}; 14 | 15 | use crate::test_util::run_py_prg; 16 | use crate::{rvec, shape, Device, DeviceRequest, Tensor}; 17 | 18 | fn ground_truth( 19 | input: &Tensor, 20 | scale: &Tensor, 21 | bias: Option<&Tensor>, 22 | num_groups: usize, 23 | ) -> anyhow::Result { 24 | let prg = r#" 25 | import torch 26 | import torch.nn.functional as F 27 | 28 | def manual_group_norm(input, scale, bias, num_groups): 29 | (input, scale, bias) = (torch.from_numpy(input), torch.from_numpy(scale), torch.from_numpy(bias)) 30 | return F.group_norm(input, num_groups, weight=scale, bias=bias).numpy() 31 | "#; 32 | 33 | let inputs = match bias { 34 | Some(bias) => rvec![input, scale, bias], 35 | None => rvec![input, scale], 36 | }; 37 | run_py_prg(prg.to_string(), &inputs, &[&num_groups], input.dt()) 38 | } 39 | 40 | fn run_norm_trial(device: &Device, problem: GroupNormProblem) -> anyhow::Result<()> { 41 | let GroupNormProblem { 42 | num_groups, 43 | B, 44 | C, 45 | N, 46 | } = problem; 47 | 48 | let input = Tensor::randn::(shape![B, C, N], Device::CPU); 49 | let scale = Tensor::randn::(shape![C], Device::CPU); 50 | let bias = Some(Tensor::randn::(shape![C], Device::CPU)); 51 | 52 | let ground = ground_truth(&input, &scale, bias.as_ref(), num_groups)?; 53 | 54 | let input_gpu = input.to(device)?; 55 | let scale_gpu = scale.to(device)?; 56 | let bias_gpu = bias.map(|b| b.to(device)).transpose()?; 57 | 58 | let result = input_gpu 59 | .group_norm(num_groups, scale_gpu, bias_gpu, 1e-5)? 60 | .resolve()?; 61 | 62 | let ours = result.to(&Device::CPU)?; 63 | 64 | ground.all_close(&ours, 1e-4, 1e-4)?; 65 | Ok(()) 66 | } 67 | 68 | #[derive(Arbitrary, Debug)] 69 | struct GroupNormProblem { 70 | #[map(|num_groups: u32| #C/2 )] 71 | num_groups: usize, 72 | #[strategy(1..=1usize)] 73 | B: usize, 74 | #[strategy(2..=4usize)] 75 | #[filter(#C % 2 != 0)] 76 | C: usize, 77 | #[strategy(1..=1usize)] 78 | N: usize, 79 | } 80 | 81 | #[proptest(cases = 64)] 82 | fn test_groupnorm(prob: GroupNormProblem) { 83 | let device = Device::request_device(DeviceRequest::GPU).unwrap(); 84 | println!("prob = {:#?}", prob); 85 | run_norm_trial(&device, prob).unwrap(); 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/ops/reindex/permute.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashSet; 2 | 3 | use derive_new::new; 4 | use encase::ShaderType; 5 | use ratchet_macros::WgslMetadata; 6 | 7 | use crate::{ 8 | rvec, InvariantError, OpGuards, Operation, OperationError, RVec, StorageView, Strides, Tensor, 9 | }; 10 | 11 | #[derive(Debug, derive_new::new, WgslMetadata, ShaderType)] 12 | pub struct PermuteMeta { 13 | src_shape: glam::UVec4, 14 | dst_shape: glam::UVec4, 15 | src_stride: glam::UVec4, 16 | dst_stride: glam::UVec4, 17 | src_numel: u32, 18 | dst_numel: u32, 19 | perm: glam::UVec4, 20 | } 21 | 22 | #[derive(new, Debug, Clone)] 23 | pub struct Permute { 24 | pub src: Tensor, 25 | pub dims: Vec, 26 | } 27 | 28 | impl Permute { 29 | pub fn promote(&self) -> Vec { 30 | let pad_len = 4 - self.dims.len(); 31 | 32 | let mut perm = self.dims.clone(); 33 | for p in perm.iter_mut() { 34 | *p += pad_len; 35 | } 36 | (0..pad_len).for_each(|x| perm.insert(0, x)); 37 | perm 38 | } 39 | } 40 | 41 | impl Operation for Permute { 42 | fn name(&self) -> &'static str { 43 | "Permute" 44 | } 45 | 46 | fn compute_view(&self) -> Result { 47 | let input_shape = self.src.shape(); 48 | let dup_set: HashSet = HashSet::from_iter(self.dims.iter().cloned()); 49 | if dup_set.len() != self.dims.len() { 50 | return Err(InvariantError::DuplicateDims)?; 51 | } 52 | 53 | let mut output_shape = input_shape.clone(); 54 | for i in 0..input_shape.rank() { 55 | output_shape[i] = input_shape[self.dims[i]]; 56 | } 57 | let strides = Strides::from(&output_shape); 58 | Ok(StorageView::new(output_shape, self.src.dt(), strides)) 59 | } 60 | 61 | fn srcs(&self) -> RVec<&Tensor> { 62 | rvec![&self.src] 63 | } 64 | } 65 | 66 | impl OpGuards for Permute { 67 | fn check_shapes(&self) { 68 | assert!(self.src.shape().rank() == self.dims.len()); 69 | assert!(self.dims.iter().all(|&x| x < 4)); //Only support 4D for now 70 | } 71 | 72 | fn check_dtypes(&self) {} 73 | } 74 | 75 | #[cfg(all(test, feature = "pyo3"))] 76 | mod tests { 77 | use crate::{test_util::run_py_prg, Device, DeviceRequest, Permute, Shape, Tensor}; 78 | use proptest::prelude::*; 79 | use test_strategy::{proptest, Arbitrary}; 80 | 81 | impl Arbitrary for Permute { 82 | type Parameters = (); 83 | type Strategy = BoxedStrategy; 84 | 85 | fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { 86 | let ranges = vec![1..=2, 1..=4, 1..=256, 1..=256]; 87 | Shape::arbitrary_with(ranges) 88 | .prop_flat_map(|shape| (Just(shape.clone()), Just(vec![0, 1, 2, 3]).prop_shuffle())) 89 | .prop_map(|(shape, perm)| { 90 | Permute::new(Tensor::randn::(shape, Device::CPU), perm) 91 | }) 92 | .boxed() 93 | } 94 | } 95 | 96 | #[derive(Arbitrary, Debug)] 97 | struct PermuteProblem { 98 | op: Permute, 99 | } 100 | 101 | fn ground_truth(a: &Tensor, args: &str) -> anyhow::Result { 102 | let prg = format!( 103 | r#" 104 | import torch 105 | import numpy as np 106 | def permute(a): 107 | return np.ascontiguousarray(torch.permute(torch.from_numpy(a), {}).numpy()) 108 | "#, 109 | args 110 | ); 111 | run_py_prg(prg.to_string(), &[a], &[], a.dt()) 112 | } 113 | 114 | fn run_reindex_trial(prob: PermuteProblem, device: Device) -> anyhow::Result<()> { 115 | let PermuteProblem { op } = prob; 116 | let a = op.src.clone(); 117 | 118 | let a_gpu = a.to(&device)?; 119 | let ground = ground_truth(&a, format!("{:?}", op.dims).as_str())?; 120 | let ours = a_gpu.permute(&op.dims)?.resolve()?; 121 | let d_gpu = ours.to(&Device::CPU)?; 122 | ground.all_close(&d_gpu, 1e-5, 1e-5)?; 123 | Ok(()) 124 | } 125 | 126 | #[proptest(cases = 16)] 127 | fn test_permute_gpu(prob: PermuteProblem) { 128 | let device = Device::request_device(DeviceRequest::GPU).unwrap(); 129 | run_reindex_trial(prob, device).unwrap(); 130 | } 131 | 132 | #[proptest(cases = 16)] 133 | fn test_permute_cpu(prob: PermuteProblem) { 134 | let device = Device::request_device(DeviceRequest::CPU).unwrap(); 135 | run_reindex_trial(prob, device).unwrap(); 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/ops/view.rs: -------------------------------------------------------------------------------- 1 | use crate::{rvec, OpGuards, Operation, Shape, StorageView, Strides, Tensor}; 2 | 3 | #[derive(Debug, derive_new::new, Clone)] 4 | pub struct View { 5 | src: Tensor, 6 | shape: Shape, 7 | } 8 | 9 | impl View { 10 | pub fn input(&self) -> &Tensor { 11 | &self.src 12 | } 13 | } 14 | 15 | impl OpGuards for View { 16 | fn check_shapes(&self) { 17 | let (src_shape, dst_shape) = (self.src.shape(), &self.shape); 18 | assert_eq!(src_shape.rank(), dst_shape.rank()); 19 | assert_eq!(src_shape.numel(), dst_shape.numel()); 20 | } 21 | 22 | fn check_dtypes(&self) {} 23 | } 24 | 25 | impl Operation for View { 26 | fn name(&self) -> &'static str { 27 | "View" 28 | } 29 | 30 | fn compute_view(&self) -> Result { 31 | let strides = Strides::from(&self.shape); 32 | Ok(StorageView::new(self.shape.clone(), self.src.dt(), strides)) 33 | } 34 | 35 | fn srcs(&self) -> crate::RVec<&Tensor> { 36 | rvec![&self.src] 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/storage/mod.rs: -------------------------------------------------------------------------------- 1 | mod cpu_buffer; 2 | mod gpu_buffer; 3 | 4 | use std::io::{BufRead, Seek}; 5 | 6 | use bytemuck::NoUninit; 7 | pub use cpu_buffer::*; 8 | pub use gpu_buffer::*; 9 | 10 | use crate::{Device, DeviceError, Shape, TensorDType}; 11 | 12 | use crate::DType; 13 | 14 | #[derive(Debug)] 15 | pub enum Storage { 16 | CPU(CPUBuffer), 17 | GPU(GPUBuffer), 18 | } 19 | 20 | impl Storage { 21 | /// # Safety 22 | /// 23 | /// Inherited from the `from_quantized` method of the `CPUBuffer` and `GPUBuffer` structs. 24 | pub unsafe fn from_quantized(data: &[T], device: &Device) -> Self { 25 | match device { 26 | Device::CPU => Storage::CPU(unsafe { CPUBuffer::from_quantized(data) }), 27 | Device::GPU(g) => Storage::GPU(unsafe { GPUBuffer::from_quantized(data, g) }), 28 | } 29 | } 30 | 31 | pub fn from_disk( 32 | reader: &mut R, 33 | shape: &Shape, 34 | device: &Device, 35 | ) -> Result { 36 | match device { 37 | Device::CPU => Ok(Storage::CPU(CPUBuffer::from_disk::(reader, shape)?)), 38 | Device::GPU(_) => Ok(Storage::GPU(GPUBuffer::from_disk::( 39 | reader, shape, device, 40 | )?)), 41 | } 42 | } 43 | 44 | pub fn zeros(shape: &Shape, device: &Device) -> Self { 45 | match device { 46 | Device::CPU => Storage::CPU(CPUBuffer::zeros::(shape)), 47 | Device::GPU(g) => Storage::GPU(GPUBuffer::zeros::(shape, g)), 48 | } 49 | } 50 | 51 | pub fn from_slice(data: &[T], shape: &Shape, device: &Device) -> Self { 52 | match device { 53 | Device::CPU => Storage::CPU(CPUBuffer::from_slice(data, shape)), 54 | Device::GPU(g) => Storage::GPU(GPUBuffer::from_slice(data, shape, g)), 55 | } 56 | } 57 | 58 | pub fn from_bytes(data: &[u8], alignment: usize, device: &Device) -> Self { 59 | match device { 60 | Device::CPU => Storage::CPU(CPUBuffer::from_bytes(data, alignment)), 61 | Device::GPU(g) => Storage::GPU(GPUBuffer::from_bytes(data, alignment, g)), 62 | } 63 | } 64 | 65 | pub unsafe fn into_bytes(self) -> Vec { 66 | match self { 67 | Storage::CPU(c) => unsafe { c.into_bytes() }, 68 | _ => todo!(), 69 | } 70 | } 71 | 72 | pub fn dump(&self, dt: DType, full: bool) -> String { 73 | match self { 74 | Storage::CPU(c) => c.dump(dt, full), 75 | Storage::GPU(g) => g.dump(dt, full), 76 | } 77 | } 78 | 79 | pub fn try_cpu(&self) -> Result<&CPUBuffer, DeviceError> { 80 | match self { 81 | Storage::CPU(c) => Ok(c), 82 | Storage::GPU(_g) => Err(DeviceError::DeviceMismatch( 83 | "CPU".to_string(), 84 | "GPU".to_string(), 85 | )), 86 | } 87 | } 88 | 89 | pub fn try_gpu(&self) -> Result<&GPUBuffer, DeviceError> { 90 | match self { 91 | Storage::GPU(g) => Ok(g), 92 | Storage::CPU(_c) => Err(DeviceError::DeviceMismatch( 93 | "GPU".to_string(), 94 | "CPU".to_string(), 95 | )), 96 | } 97 | } 98 | 99 | pub fn deep_clone(&self, device: &Device) -> Result { 100 | match self { 101 | Storage::CPU(c) => { 102 | assert!(device.is_cpu()); 103 | Ok(Storage::CPU(c.deep_clone()?)) 104 | } 105 | Storage::GPU(g) => { 106 | let wgpu_device = device.try_gpu()?; 107 | Ok(Storage::GPU(g.deep_clone(wgpu_device))) 108 | } 109 | } 110 | } 111 | 112 | #[cfg(feature = "plotting")] 113 | pub fn plot_fmt(&self) -> String { 114 | match self { 115 | Storage::CPU(c) => c.plot_fmt(), 116 | Storage::GPU(g) => g.plot_fmt(), 117 | } 118 | } 119 | } 120 | 121 | #[cfg_attr(target_arch = "wasm32", async_trait::async_trait)] 122 | pub trait DeviceStorage: std::fmt::Debug + Clone + 'static { 123 | // To be expanded to other devices 124 | fn to_device(&self, device: &Device) -> Result; 125 | /// Creates a copy of the device buffer on the CPU 126 | #[cfg(target_arch = "wasm32")] 127 | async fn to_cpu(&self, device: &Device) -> Result; 128 | #[cfg(not(target_arch = "wasm32"))] 129 | fn to_cpu(&self, device: &Device) -> Result; 130 | fn n_bytes(&self) -> usize; 131 | fn dump(&self, dt: DType, full: bool) -> String; 132 | } 133 | -------------------------------------------------------------------------------- /crates/ratchet-core/src/tensor_id.rs: -------------------------------------------------------------------------------- 1 | /// Unique identifier for tensors. 2 | #[derive(Clone, Copy, PartialEq, Eq, Hash)] 3 | pub struct TensorId(usize); 4 | 5 | impl std::fmt::Debug for TensorId { 6 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 7 | write!(f, "T{}", self.0) 8 | } 9 | } 10 | 11 | impl TensorId { 12 | pub(crate) fn new() -> Self { 13 | // https://users.rust-lang.org/t/idiomatic-rust-way-to-generate-unique-id/33805 14 | use std::sync::atomic; 15 | static COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(1); 16 | Self(COUNTER.fetch_add(1, atomic::Ordering::Relaxed)) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /crates/ratchet-hub/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ratchet-hub" 3 | version = "0.1.0" 4 | edition = "2021" 5 | license = "MIT" 6 | description = "A web-first, cross-platform ML framework." 7 | keywords = ["llm","wasm","transformers","webgpu","ml","machine-learning","deep-learning"] 8 | repository = "https://github.com/FL33TW00D/ratchet" 9 | 10 | [lib] 11 | crate-type = ["cdylib", "rlib"] 12 | 13 | [package.metadata.wasm-pack.profile.dev.wasm-bindgen] 14 | debug-js-glue = true 15 | demangle-name-section = true 16 | dwarf-debug-info = true 17 | 18 | [package.metadata.wasm-pack.profile.release] 19 | wasm-opt = ['-O3', '--enable-simd'] 20 | 21 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 22 | [dependencies] 23 | ratchet = { path = "../ratchet-core" } 24 | ratchet-loader = { path = "../ratchet-loader" } 25 | js-sys.workspace = true 26 | thiserror.workspace = true 27 | anyhow.workspace = true 28 | log.workspace = true 29 | wasm-bindgen.workspace = true 30 | serde.workspace = true 31 | 32 | wasm-bindgen-futures = { workspace = true } 33 | indexed_db_futures = { workspace = true } 34 | serde-wasm-bindgen = { workspace = true } 35 | serde_bytes = { workspace = true } 36 | console_error_panic_hook = { workspace = true } 37 | console_log = { workspace = true } 38 | fern = { workspace = true } 39 | chrono = { workspace = true } 40 | gloo-net = { workspace = true, features = ["http"] } 41 | 42 | [dependencies.web-sys] 43 | features = [ 44 | 'console', 45 | 'Headers', 46 | 'Request', 47 | 'RequestInit', 48 | 'RequestMode', 49 | 'Response', 50 | 'ReadableStream', 51 | 'ReadableStreamGetReaderOptions', 52 | 'ReadableStreamReaderMode', 53 | 'ReadableStreamDefaultReader', 54 | 'Window', 55 | 'Navigator', 56 | 'StorageManager', 57 | 'Cache', 58 | 'CacheStorage', 59 | 'IdbKeyRange', 60 | ] 61 | workspace = true 62 | 63 | [target.'cfg(target_arch = "wasm32")'.dependencies] 64 | getrandom = { version = "0.2.6", features = ["js"] } 65 | 66 | [dev-dependencies] 67 | wasm-bindgen-test.workspace = true 68 | 69 | -------------------------------------------------------------------------------- /crates/ratchet-hub/src/util.rs: -------------------------------------------------------------------------------- 1 | #![cfg(target_arch = "wasm32")] 2 | use js_sys::JSON; 3 | use wasm_bindgen::prelude::*; 4 | use wasm_bindgen_futures::JsFuture; 5 | use web_sys::{Request, RequestInit, RequestMode, Response}; 6 | 7 | pub(crate) fn js_to_js_error(value: JsValue) -> JsError { 8 | JsError::new( 9 | JSON::stringify(&value) 10 | .map(|js_string| { 11 | js_string 12 | .as_string() 13 | .unwrap_or(String::from("An unknown error occurred.")) 14 | }) 15 | .unwrap_or(String::from("An unknown error occurred.")) 16 | .as_str(), 17 | ) 18 | } 19 | 20 | pub(crate) fn js_error(message: &str) -> JsError { 21 | JsError::new(message) 22 | } 23 | 24 | pub(crate) async fn to_future(promise: js_sys::Promise) -> Result 25 | where 26 | T: JsCast, 27 | { 28 | let result = JsFuture::from(promise).await?; 29 | result.dyn_into::() 30 | } 31 | 32 | pub(crate) async fn fetch(url: &str) -> Result { 33 | let mut opts = RequestInit::new(); 34 | opts.method("GET"); 35 | opts.mode(RequestMode::Cors); 36 | 37 | let request = Request::new_with_str_and_init(url, &opts)?; 38 | 39 | let window = web_sys::window().unwrap(); 40 | let promise = window.fetch_with_request(&request); 41 | to_future(promise).await 42 | } 43 | -------------------------------------------------------------------------------- /crates/ratchet-loader/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ratchet-loader" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | ratchet = { path = "../ratchet-core" } 10 | half.workspace = true 11 | byteorder.workspace = true 12 | anyhow.workspace = true 13 | bytemuck.workspace = true 14 | thiserror.workspace = true 15 | log.workspace = true 16 | itertools = { workspace = true } 17 | env_logger.workspace = true 18 | 19 | [target.'cfg(target_arch = "wasm32")'.dependencies] 20 | wasm-bindgen = "0.2.84" 21 | serde = { workspace = true, features = ["derive"] } 22 | 23 | [dev-dependencies] 24 | wasm-bindgen-test.workspace = true 25 | hf-hub.workspace = true 26 | tokio = { workspace = true, features = ["sync", "macros", "io-util", "rt", "time"] } 27 | -------------------------------------------------------------------------------- /crates/ratchet-loader/src/error.rs: -------------------------------------------------------------------------------- 1 | /// Main library error type. 2 | #[derive(thiserror::Error, Debug)] 3 | pub enum Error { 4 | /// I/O error. 5 | #[error(transparent)] 6 | Io(#[from] std::io::Error), 7 | 8 | /// Arbitrary errors wrapping. 9 | #[error(transparent)] 10 | Wrapped(Box), 11 | 12 | /// User generated error message, typically created via `bail!`. 13 | #[error("{0}")] 14 | Msg(String), 15 | } 16 | 17 | impl Error { 18 | pub fn wrap(err: impl std::error::Error + Send + Sync + 'static) -> Self { 19 | Self::Wrapped(Box::new(err)) 20 | } 21 | 22 | pub fn msg(err: impl std::error::Error + Send + Sync + 'static) -> Self { 23 | Self::Msg(err.to_string()) 24 | } 25 | } 26 | 27 | #[macro_export] 28 | macro_rules! bail { 29 | ($msg:literal $(,)?) => { 30 | return Err($crate::error::Error::Msg(format!($msg).into())) 31 | }; 32 | ($err:expr $(,)?) => { 33 | return Err($crate::error::Error::Msg(format!($err).into())) 34 | }; 35 | ($fmt:expr, $($arg:tt)*) => { 36 | return Err($crate::error::Error::Msg(format!($fmt, $($arg)*).into())) 37 | }; 38 | } 39 | 40 | pub type Result = std::result::Result; 41 | -------------------------------------------------------------------------------- /crates/ratchet-loader/src/gguf/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod dtype; 2 | pub mod gguf; 3 | pub mod utils; 4 | -------------------------------------------------------------------------------- /crates/ratchet-loader/src/gguf/utils.rs: -------------------------------------------------------------------------------- 1 | // Adapted from https://github.com/huggingface/candle/blob/fc67d878bb4a25cbeba361d0a31290f14beb9344/candle-core/src/quantized/utils.rs 2 | 3 | use half::f16; 4 | 5 | use crate::error::Result; 6 | 7 | pub trait ReadHalf { 8 | fn read_f16(&mut self) -> Result; 9 | } 10 | 11 | impl ReadHalf for R { 12 | fn read_f16(&mut self) -> Result { 13 | let mut d = [0u8; 2]; 14 | self.read_exact(&mut d)?; 15 | let f16_value = half::f16::from_le_bytes(d); 16 | Ok(f16_value) 17 | } 18 | } 19 | 20 | pub trait WriteHalf { 21 | fn write_f16(&mut self, input: f16) -> Result; 22 | } 23 | 24 | impl WriteHalf for W { 25 | fn write_f16(&mut self, input: f16) -> Result { 26 | let bytes = input.to_le_bytes(); 27 | let num_written = self.write(&bytes)?; 28 | Ok(num_written) 29 | } 30 | } 31 | 32 | pub trait ReadInto { 33 | fn read_u8s_into(&mut self, other: &mut Other, length: usize) -> Result<()>; 34 | } 35 | 36 | impl ReadInto for R { 37 | fn read_u8s_into(&mut self, other: &mut Other, length: usize) -> Result<()> { 38 | let mut temp = vec![0u8; length]; 39 | self.read_exact(&mut temp)?; 40 | other.write_all(&temp)?; 41 | Ok(()) 42 | } 43 | } 44 | 45 | pub trait ReadLen { 46 | fn read_len_bytes(&mut self, length: usize) -> Result>; 47 | } 48 | 49 | impl ReadLen for R { 50 | fn read_len_bytes(&mut self, length: usize) -> Result> { 51 | let mut temp = vec![0u8; length]; 52 | self.read_exact(&mut temp)?; 53 | Ok(temp) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /crates/ratchet-loader/test-data/nano-llama-q4k.gguf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/ratchet/136da4d5216910bfd015b27a17b837c21f17163a/crates/ratchet-loader/test-data/nano-llama-q4k.gguf -------------------------------------------------------------------------------- /crates/ratchet-macros/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ratchet-macros" 3 | authors = ["Chris Fleetwood "] 4 | version = "0.1.0" 5 | keywords = [] 6 | edition = "2021" 7 | readme.workspace = true 8 | repository = "https://github.com/huggingface/ratchet" 9 | 10 | [lib] 11 | proc-macro = true 12 | 13 | [features] 14 | default = [] 15 | std = [] 16 | 17 | [dependencies] 18 | proc-macro2 = { version = "1.0" } 19 | quote = { version = "1.0" } 20 | syn = { version = "2.0" } 21 | -------------------------------------------------------------------------------- /crates/ratchet-macros/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod wgsl_metadata; 2 | 3 | use proc_macro::TokenStream; 4 | use syn::parse_macro_input; 5 | 6 | /// Derives the `OpMetadata` trait implementation for a struct. 7 | /// 8 | /// Generates a `.render()` method that converts a Rust struct into a WGSL struct. 9 | #[proc_macro_derive(WgslMetadata, attributes(builder))] 10 | pub fn derive_wgsl_metadata(input: TokenStream) -> TokenStream { 11 | let input = parse_macro_input!(input); 12 | wgsl_metadata::derive(input).into() 13 | } 14 | -------------------------------------------------------------------------------- /crates/ratchet-macros/src/wgsl_metadata.rs: -------------------------------------------------------------------------------- 1 | use proc_macro2::TokenStream; 2 | use quote::quote; 3 | use syn::{parse2, DeriveInput}; 4 | 5 | pub fn derive(input: TokenStream) -> TokenStream { 6 | let _input = parse2::(input).unwrap(); 7 | let struct_name = _input.ident; 8 | 9 | let syn::Data::Struct(syn::DataStruct { fields, .. }) = _input.data else { 10 | unimplemented!("Only structs are supported"); 11 | }; 12 | 13 | let transformed_fields = fields.iter().map(|field| { 14 | let Some(ident) = &field.ident else { 15 | unimplemented!("tuple structs"); 16 | }; 17 | 18 | let ty = &field.ty; 19 | 20 | match ty { 21 | syn::Type::Path(p) => { 22 | let path = &p.path; 23 | let t = path.segments.last().unwrap().ident.to_string(); 24 | 25 | match t.as_str() { 26 | "UVec4" => { 27 | quote!(#ident: vec4) 28 | } 29 | "IVec4" => { 30 | quote!(#ident: vec4) 31 | } 32 | "UVec3" => { 33 | quote!(#ident: vec3) 34 | } 35 | "IVec3" => { 36 | quote!(#ident: vec3) 37 | } 38 | _ => quote!(#ident: #ty), 39 | } 40 | } 41 | _ => todo!(), 42 | } 43 | }); 44 | 45 | let expanded = quote! ( 46 | use crate::StaticKernelMetadata; 47 | impl StaticKernelMetadata for #struct_name {} 48 | 49 | impl crate::KernelMetadata for #struct_name { 50 | fn render_meta(&self) -> crate::WgslFragment { 51 | let mut fragment = crate::WgslFragment::new(512); 52 | fragment.write("struct Meta {\n"); 53 | #( 54 | fragment.write(" "); 55 | fragment.write(stringify!(#transformed_fields)); 56 | fragment.write(",\n"); 57 | )* 58 | fragment.write("}\n"); 59 | fragment 60 | } 61 | 62 | fn write(&self, uniform: &mut crate::CpuUniform) -> Result { 63 | self.write_static(uniform) 64 | } 65 | } 66 | ); 67 | 68 | expanded 69 | } 70 | -------------------------------------------------------------------------------- /crates/ratchet-models/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ratchet-models" 3 | version = "0.1.0" 4 | edition = "2021" 5 | resolver = "2" 6 | 7 | [features] 8 | ci = [] 9 | pyo3 = [] 10 | 11 | [lib] 12 | crate-type = ["cdylib", "lib"] 13 | 14 | [package.metadata.wasm-pack.profile.dev.wasm-bindgen] 15 | debug-js-glue = true 16 | demangle-name-section = true 17 | dwarf-debug-info = true 18 | 19 | [package.metadata.wasm-pack.profile.release] 20 | wasm-opt = ['-O3', '--enable-simd'] 21 | 22 | [dependencies] 23 | ratchet = { path = "../ratchet-core" } 24 | ratchet-nn = { path = "../ratchet-nn" } 25 | ratchet-loader = { path = "../ratchet-loader" } 26 | byteorder.workspace = true 27 | anyhow.workspace = true 28 | thiserror.workspace = true 29 | derive-new = { workspace = true } 30 | log.workspace = true 31 | ndarray-stats = { workspace = true } 32 | num = { workspace = true } 33 | realfft = { workspace = true } 34 | ndarray = { workspace = true } 35 | cfg-if = { workspace = true } 36 | serde = { workspace = true } 37 | tokenizers = { version = "0.19.1", default-features = false, features=["unstable_wasm"] } 38 | lazy_static = { workspace = true } 39 | web-time = { workspace = true } 40 | clap = { workspace = true, features = [ "derive" ] } 41 | serde_json.workspace = true 42 | half.workspace = true 43 | image = { workspace = true } 44 | pollster.workspace = true 45 | wasm-bindgen-futures = "0.4.42" 46 | 47 | [target.'cfg(target_arch = "wasm32")'.dependencies] 48 | wasm-bindgen = { workspace = true } 49 | serde-wasm-bindgen = "0.4.5" 50 | ratchet-hub = { path = "../ratchet-hub" } 51 | tsify = "0.4.5" 52 | js-sys = { workspace = true } 53 | 54 | [target.'cfg(not(target_arch = "wasm32"))'.dependencies] 55 | ratchet = { path = "../ratchet-core", features = ["pyo3"] } 56 | hf-hub.workspace = true 57 | 58 | [dev-dependencies] 59 | ratchet = { path = "../ratchet-core" } 60 | console_error_panic_hook = { workspace = true } 61 | console_log = { workspace = true } 62 | wasm-bindgen-test = { workspace = true } 63 | wasm-bindgen = { workspace = true } 64 | wasm-bindgen-futures = { workspace = true } 65 | npyz = { workspace = true } 66 | hound = { workspace = true } 67 | env_logger = { workspace = true } 68 | 69 | [target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies] 70 | ratchet = { path = "../ratchet-core", features = ["pyo3"] } 71 | pyo3 = "0.20.2" 72 | numpy = "0.20.0" 73 | 74 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::upper_case_acronyms)] 2 | pub mod moondream; 3 | pub mod phi2; 4 | pub mod phi3; 5 | pub mod registry; 6 | mod token_stream; 7 | pub mod whisper; 8 | pub use token_stream::TokenOutputStream; 9 | 10 | #[cfg(target_arch = "wasm32")] 11 | #[derive(Debug, derive_new::new)] 12 | pub struct WebTensor { 13 | ggml_dtype: ratchet_loader::GgmlDType, 14 | data: js_sys::Uint8Array, 15 | shape: ratchet::Shape, 16 | } 17 | 18 | #[cfg(target_arch = "wasm32")] 19 | pub type TensorMap = std::collections::HashMap; 20 | 21 | #[cfg(target_arch = "wasm32")] 22 | pub fn ratchet_from_gguf_web( 23 | wt: WebTensor, 24 | device: &ratchet::Device, 25 | ) -> anyhow::Result { 26 | use ratchet_loader::gguf::gguf::ratchet_from_gguf; 27 | let shape = wt.shape.clone(); 28 | let data = wt.data.to_vec(); 29 | ratchet_from_gguf(wt.ggml_dtype, &data, shape, device) 30 | } 31 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/moondream/mlp.rs: -------------------------------------------------------------------------------- 1 | use ratchet::Tensor; 2 | use ratchet_nn::{Linear, Module}; 3 | 4 | #[derive(Debug, derive_new::new)] 5 | pub struct MLP { 6 | pub fc1: Linear, 7 | pub fc2: Linear, 8 | } 9 | 10 | impl Module for MLP { 11 | type Input = Tensor; 12 | 13 | fn schedule(&self, input: Self::Input) -> anyhow::Result { 14 | let input_dt = input.dt(); 15 | self.fc2 16 | .schedule(self.fc1.schedule(input)?.full()?.gelu()?.cast(input_dt)?) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/moondream/mod.rs: -------------------------------------------------------------------------------- 1 | mod generate; 2 | mod mlp; 3 | pub mod model; 4 | mod text_model; 5 | mod vision_encoder; 6 | 7 | pub use generate::generate; 8 | pub use model::Moondream; 9 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/phi2/generate.rs: -------------------------------------------------------------------------------- 1 | #![cfg(target_arch = "wasm32")] 2 | use crate::phi2::Phi2; 3 | use crate::TokenOutputStream; 4 | use ndarray::Axis; 5 | use ndarray_stats::QuantileExt; 6 | use ratchet::{shape, Device, Tensor}; 7 | use ratchet_nn::Module; 8 | use tokenizers::Tokenizer; 9 | 10 | pub async fn generate( 11 | model: &mut Phi2, 12 | tokenizer: Tokenizer, 13 | prompt: String, 14 | callback: impl Fn(String), 15 | ) -> anyhow::Result<()> { 16 | use web_time::Instant; 17 | log::warn!("Prompt: {}", prompt); 18 | 19 | let mut tos = TokenOutputStream::new(tokenizer); 20 | let encoding = tos.tokenizer().encode(prompt, true).unwrap(); 21 | let mut tokens = encoding 22 | .get_ids() 23 | .iter() 24 | .map(|&x| x as i32) 25 | .collect::>(); 26 | let mut all_tokens = tokens.clone(); 27 | let mut loop_cnt = 0; 28 | let start = Instant::now(); 29 | while tokens[tokens.len() - 1] != 50256 && loop_cnt < 256 { 30 | let input = Tensor::from_data( 31 | tokens.clone(), 32 | shape![1, tokens.len()], 33 | model.device.clone(), 34 | ); 35 | let result = model.schedule(input)?.resolve()?; 36 | let logits = result.to(&Device::CPU).await?; 37 | model.cache_mut().update(tokens.len()); 38 | 39 | tokens = logits 40 | .to_ndarray_view::() 41 | .map_axis(Axis(2), |row| row.argmax_skipnan().unwrap()) 42 | .iter() 43 | .map(|&x| x as i32) 44 | .collect::>(); 45 | 46 | if let Some(t) = tos.next_token(tokens[0] as u32)? { 47 | callback(t); 48 | } 49 | all_tokens.extend(tokens.clone()); 50 | loop_cnt += 1; 51 | } 52 | let elapsed = start.elapsed(); 53 | log::warn!("Elapsed: {:?}", elapsed); 54 | log::warn!("Tok/s {}", all_tokens.len() as f64 / elapsed.as_secs_f64()); 55 | model.reset(); 56 | Ok(()) 57 | } 58 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/phi2/mlp.rs: -------------------------------------------------------------------------------- 1 | use ratchet::Tensor; 2 | use ratchet_nn::{Linear, Module}; 3 | 4 | #[derive(Debug, derive_new::new)] 5 | pub struct MLP { 6 | l1: Linear, 7 | l2: Linear, 8 | } 9 | 10 | impl Module for MLP { 11 | type Input = Tensor; 12 | 13 | fn schedule(&self, input: Self::Input) -> anyhow::Result { 14 | self.l2.schedule(self.l1.schedule(input)?.gelu()?) 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/phi2/mod.rs: -------------------------------------------------------------------------------- 1 | mod attn; 2 | mod generate; 3 | mod mlp; 4 | mod model; 5 | 6 | pub use model::Phi2; 7 | 8 | #[cfg(target_arch = "wasm32")] 9 | pub use generate::generate; 10 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/phi3/generate.rs: -------------------------------------------------------------------------------- 1 | use crate::phi3::Phi3; 2 | use crate::TokenOutputStream; 3 | use ndarray::Axis; 4 | use ndarray_stats::QuantileExt; 5 | use ratchet::{shape, Device, Tensor}; 6 | use ratchet_nn::Module; 7 | use tokenizers::Tokenizer; 8 | 9 | #[cfg(target_arch = "wasm32")] 10 | pub async fn generate( 11 | model: &mut Phi3, 12 | tokenizer: Tokenizer, 13 | prompt: String, 14 | callback: impl Fn(String), 15 | ) -> anyhow::Result<()> { 16 | use web_time::Instant; 17 | log::warn!("Prompt: {}", prompt); 18 | 19 | let prompt = format!( 20 | r#"<|user|> 21 | {}<|end|> 22 | <|assistant|>"#, 23 | prompt 24 | ); 25 | 26 | let mut tos = TokenOutputStream::new(tokenizer); 27 | 28 | let encoding = tos.tokenizer().encode(prompt, true).unwrap(); 29 | let mut tokens = encoding 30 | .get_ids() 31 | .iter() 32 | .map(|&x| x as i32) 33 | .collect::>(); 34 | tokens.insert(0, 1); 35 | let mut all_tokens = tokens.clone(); 36 | let start = Instant::now(); 37 | while tokens[tokens.len() - 1] != 32007 && all_tokens.len() < 2048 { 38 | let input = Tensor::from_data( 39 | tokens.clone(), 40 | shape![1, tokens.len()], 41 | model.device.clone(), 42 | ); 43 | let result = model.schedule(input)?.resolve()?; 44 | let logits = result.to(&Device::CPU).await?; 45 | model.cache_mut().update(tokens.len()); 46 | 47 | tokens = logits 48 | .to_ndarray_view::() 49 | .map_axis(Axis(2), |row| row.argmax_skipnan().unwrap()) 50 | .iter() 51 | .map(|&x| x as i32) 52 | .collect::>(); 53 | all_tokens.extend(tokens.clone()); 54 | if let Some(t) = tos.next_token(tokens[0] as u32)? { 55 | callback(t); 56 | } 57 | } 58 | let elapsed = start.elapsed(); 59 | log::warn!("Elapsed: {:?}", elapsed); 60 | log::warn!("Tok/s {}", all_tokens.len() as f64 / elapsed.as_secs_f64()); 61 | model.reset(); 62 | Ok(()) 63 | } 64 | 65 | #[cfg(not(target_arch = "wasm32"))] 66 | pub fn generate( 67 | model: &mut Phi3, 68 | tokenizer: Tokenizer, 69 | prompt: String, 70 | callback: impl Fn(String), 71 | ) -> anyhow::Result<()> { 72 | use web_time::Instant; 73 | log::warn!("Prompt: {}", prompt); 74 | 75 | let prompt = format!( 76 | r#"<|user|> 77 | {}<|end|> 78 | <|assistant|>"#, 79 | prompt 80 | ); 81 | 82 | let mut tos = TokenOutputStream::new(tokenizer); 83 | 84 | let encoding = tos.tokenizer().encode(prompt, true).unwrap(); 85 | let mut tokens = encoding 86 | .get_ids() 87 | .iter() 88 | .map(|&x| x as i32) 89 | .collect::>(); 90 | tokens.insert(0, 1); 91 | let mut all_tokens = tokens.clone(); 92 | let start = Instant::now(); 93 | while tokens[tokens.len() - 1] != 32007 && all_tokens.len() < 2048 { 94 | let input = Tensor::from_data( 95 | tokens.clone(), 96 | shape![1, tokens.len()], 97 | model.device.clone(), 98 | ); 99 | let result = model.schedule(input)?.resolve()?; 100 | let logits = result.to(&Device::CPU)?; 101 | model.cache_mut().update(tokens.len()); 102 | 103 | tokens = logits 104 | .to_ndarray_view::() 105 | .map_axis(Axis(2), |row| row.argmax_skipnan().unwrap()) 106 | .iter() 107 | .map(|&x| x as i32) 108 | .collect::>(); 109 | all_tokens.extend(tokens.clone()); 110 | if let Some(t) = tos.next_token(tokens[0] as u32)? { 111 | callback(t); 112 | } 113 | } 114 | let elapsed = start.elapsed(); 115 | log::warn!("Elapsed: {:?}", elapsed); 116 | log::warn!("Tok/s {}", all_tokens.len() as f64 / elapsed.as_secs_f64()); 117 | model.reset(); 118 | Ok(()) 119 | } 120 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/phi3/mlp.rs: -------------------------------------------------------------------------------- 1 | use ratchet::Tensor; 2 | use ratchet_nn::{Linear, Module}; 3 | 4 | #[derive(Debug, derive_new::new)] 5 | pub struct MLP { 6 | up_proj: Linear, 7 | down_proj: Linear, 8 | } 9 | 10 | //class Phi3MLP(nn.Module): 11 | // def __init__(self, config): 12 | // super().__init__() 13 | // 14 | // self.config = config 15 | // self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) 16 | // self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) 17 | // 18 | // self.activation_fn = ACT2FN[config.hidden_act] 19 | // 20 | // def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: 21 | // up_states = self.gate_up_proj(hidden_states) 22 | // 23 | // gate, up_states = up_states.chunk(2, dim=-1) 24 | // up_states = up_states * self.activation_fn(gate) 25 | // 26 | // return self.down_proj(up_states) 27 | 28 | impl Module for MLP { 29 | type Input = Tensor; 30 | 31 | fn schedule(&self, input: Self::Input) -> anyhow::Result { 32 | let input_dt = input.dt(); 33 | let up_states = self.up_proj.schedule(input)?; 34 | let [x, y, z]: [usize; 3] = up_states.shape().try_into()?; 35 | let gate = up_states.clone().slice(&[0..x, 0..y, 0..z / 2])?; 36 | let up_states = up_states.clone().slice(&[0..x, 0..y, z / 2..z])?; 37 | let up_states = up_states.mul(gate.full()?.silu()?.cast(input_dt)?)?; 38 | self.down_proj.schedule(up_states) 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/phi3/mod.rs: -------------------------------------------------------------------------------- 1 | mod attn; 2 | mod generate; 3 | mod mlp; 4 | mod model; 5 | 6 | pub use generate::generate; 7 | pub use model::Phi3; 8 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/registry.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_local_definitions)] 2 | //! # Registry 3 | //! 4 | //! The registry is responsible for surfacing available models to the user in both the CLI & WASM interfaces. 5 | 6 | #[cfg(target_arch = "wasm32")] 7 | use wasm_bindgen::prelude::wasm_bindgen; 8 | 9 | #[derive(Debug, Clone)] 10 | #[cfg_attr( 11 | target_arch = "wasm32", 12 | derive(tsify::Tsify, serde::Serialize, serde::Deserialize), 13 | tsify(from_wasm_abi), 14 | serde(rename_all = "snake_case") 15 | )] 16 | #[cfg_attr(not(target_arch = "wasm32"), derive(clap::ValueEnum))] 17 | pub enum WhisperVariants { 18 | Tiny, 19 | Base, 20 | Small, 21 | Medium, 22 | LargeV2, 23 | LargeV3, 24 | DistilLargeV3, 25 | } 26 | 27 | impl WhisperVariants { 28 | pub fn repo_id(&self) -> &str { 29 | match self { 30 | WhisperVariants::Tiny => "FL33TW00D-HF/whisper-tiny", 31 | WhisperVariants::Base => "FL33TW00D-HF/whisper-base", 32 | WhisperVariants::Small => "FL33TW00D-HF/whisper-small", 33 | WhisperVariants::Medium => "FL33TW00D-HF/whisper-medium", 34 | WhisperVariants::LargeV2 => "FL33TW00D-HF/whisper-large-v2", 35 | WhisperVariants::LargeV3 => "FL33TW00D-HF/whisper-large-v3", 36 | WhisperVariants::DistilLargeV3 => "FL33TW00D-HF/distil-whisper-large-v3", 37 | } 38 | } 39 | } 40 | 41 | #[derive(Debug, Clone)] 42 | #[cfg_attr( 43 | target_arch = "wasm32", 44 | derive(tsify::Tsify, serde::Serialize, serde::Deserialize), 45 | tsify(from_wasm_abi), 46 | serde(rename_all = "snake_case") 47 | )] 48 | #[cfg_attr(not(target_arch = "wasm32"), derive(clap::ValueEnum))] 49 | pub enum PhiVariants { 50 | Phi2, 51 | Phi3, 52 | } 53 | 54 | /// # Available Models 55 | /// 56 | /// This is a type safe way to surface models to users, 57 | /// providing autocomplete **within** model families. 58 | #[derive(Debug, Clone)] 59 | #[non_exhaustive] 60 | #[cfg_attr( 61 | target_arch = "wasm32", 62 | derive(tsify::Tsify, serde::Serialize, serde::Deserialize) 63 | )] 64 | #[cfg_attr(target_arch = "wasm32", tsify(from_wasm_abi))] 65 | pub enum AvailableModels { 66 | Whisper(WhisperVariants), 67 | Phi(PhiVariants), 68 | Moondream, 69 | } 70 | 71 | impl AvailableModels { 72 | pub fn repo_id(&self) -> String { 73 | let id = match self { 74 | AvailableModels::Whisper(w) => w.repo_id(), 75 | AvailableModels::Phi(p) => match p { 76 | PhiVariants::Phi2 => "FL33TW00D-HF/phi2", 77 | PhiVariants::Phi3 => "FL33TW00D-HF/phi3", 78 | }, 79 | AvailableModels::Moondream => "ratchet-community/ratchet-moondream-2", 80 | }; 81 | id.to_string() 82 | } 83 | 84 | pub fn model_id(&self, quantization: Quantization) -> String { 85 | let model_stem = match self { 86 | AvailableModels::Whisper(w) => match w { 87 | WhisperVariants::Tiny => "tiny", 88 | WhisperVariants::Base => "base", 89 | WhisperVariants::Small => "small", 90 | WhisperVariants::Medium => "medium", 91 | WhisperVariants::LargeV2 => "large-v2", 92 | WhisperVariants::LargeV3 => "large-v3", 93 | WhisperVariants::DistilLargeV3 => "distil-large-v3", 94 | }, 95 | AvailableModels::Phi(p) => match p { 96 | PhiVariants::Phi2 => "phi2", 97 | PhiVariants::Phi3 => "phi3-mini-4k", 98 | }, 99 | AvailableModels::Moondream => "moondream", 100 | }; 101 | match quantization { 102 | Quantization::Q8_0 => format!("{}_q8_0.gguf", model_stem), 103 | Quantization::F16 => format!("{}_f16.gguf", model_stem), 104 | Quantization::F32 => format!("{}_f32.gguf", model_stem), 105 | } 106 | } 107 | } 108 | 109 | #[derive(Debug, Clone)] 110 | #[cfg_attr(target_arch = "wasm32", wasm_bindgen)] 111 | #[cfg_attr(not(target_arch = "wasm32"), derive(clap::ValueEnum))] 112 | pub enum Quantization { 113 | Q8_0, 114 | F16, 115 | F32, 116 | } 117 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/token_stream.rs: -------------------------------------------------------------------------------- 1 | // Taken from Candle: https://github.com/huggingface/candle/blob/main/candle-examples/src/token_output_stream.rs 2 | use anyhow::Result; 3 | 4 | /// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a 5 | /// streaming way rather than having to wait for the full decoding. 6 | pub struct TokenOutputStream { 7 | tokenizer: tokenizers::Tokenizer, 8 | tokens: Vec, 9 | prev_index: usize, 10 | current_index: usize, 11 | } 12 | 13 | impl TokenOutputStream { 14 | pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { 15 | Self { 16 | tokenizer, 17 | tokens: Vec::new(), 18 | prev_index: 0, 19 | current_index: 0, 20 | } 21 | } 22 | 23 | pub fn into_inner(self) -> tokenizers::Tokenizer { 24 | self.tokenizer 25 | } 26 | 27 | fn decode(&self, tokens: &[u32]) -> Result { 28 | match self.tokenizer.decode(tokens, true) { 29 | Ok(str) => Ok(str), 30 | Err(err) => anyhow::bail!("cannot decode: {err}"), 31 | } 32 | } 33 | 34 | // https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68 35 | pub fn next_token(&mut self, token: u32) -> Result> { 36 | let prev_text = if self.tokens.is_empty() { 37 | String::new() 38 | } else { 39 | let tokens = &self.tokens[self.prev_index..self.current_index]; 40 | self.decode(tokens)? 41 | }; 42 | self.tokens.push(token); 43 | let text = self.decode(&self.tokens[self.prev_index..])?; 44 | if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { 45 | let text = text.split_at(prev_text.len()); 46 | self.prev_index = self.current_index; 47 | self.current_index = self.tokens.len(); 48 | Ok(Some(text.1.to_string())) 49 | } else { 50 | Ok(None) 51 | } 52 | } 53 | 54 | pub fn decode_rest(&self) -> Result> { 55 | let prev_text = if self.tokens.is_empty() { 56 | String::new() 57 | } else { 58 | let tokens = &self.tokens[self.prev_index..self.current_index]; 59 | self.decode(tokens)? 60 | }; 61 | let text = self.decode(&self.tokens[self.prev_index..])?; 62 | if text.len() > prev_text.len() { 63 | let text = text.split_at(prev_text.len()); 64 | Ok(Some(text.1.to_string())) 65 | } else { 66 | Ok(None) 67 | } 68 | } 69 | 70 | pub fn decode_all(&self) -> Result { 71 | self.decode(&self.tokens) 72 | } 73 | 74 | pub fn get_token(&self, token_s: &str) -> Option { 75 | self.tokenizer.get_vocab(true).get(token_s).copied() 76 | } 77 | 78 | pub fn tokenizer(&self) -> &tokenizers::Tokenizer { 79 | &self.tokenizer 80 | } 81 | 82 | pub fn clear(&mut self) { 83 | self.tokens.clear(); 84 | self.prev_index = 0; 85 | self.current_index = 0; 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/whisper/config.rs: -------------------------------------------------------------------------------- 1 | #[derive(Debug, Clone, PartialEq, serde::Deserialize)] 2 | pub struct Config { 3 | #[serde(alias = "num_mel_bins")] 4 | pub n_mels: usize, 5 | #[serde(alias = "max_source_positions")] 6 | pub n_audio_ctx: usize, 7 | #[serde(alias = "d_model")] 8 | pub n_audio_state: usize, 9 | #[serde(alias = "encoder_attention_heads")] 10 | pub n_audio_head: usize, 11 | #[serde(alias = "encoder_layers")] 12 | pub n_audio_layer: usize, 13 | #[serde(alias = "vocab_size")] 14 | pub n_vocab: usize, 15 | #[serde(alias = "max_target_positions")] 16 | pub n_text_ctx: usize, 17 | #[serde(alias = "decoder_attention_heads")] 18 | pub n_text_head: usize, 19 | #[serde(alias = "decoder_layers")] 20 | pub n_text_layer: usize, 21 | #[serde(alias = "torch_dtype")] 22 | pub dtype: String, 23 | #[serde(default)] 24 | pub suppress_tokens: Vec, 25 | } 26 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/whisper/logit_mutators/mod.rs: -------------------------------------------------------------------------------- 1 | mod timestamp_rules; 2 | pub use timestamp_rules::*; 3 | 4 | use crate::whisper::tokenizer::WhisperTokenizer; 5 | use ratchet::Tensor; 6 | 7 | pub trait LogitMutator { 8 | fn apply( 9 | &self, 10 | logits: Tensor, 11 | tokenizer: &WhisperTokenizer, 12 | tokens: Option<&Tensor>, 13 | ) -> anyhow::Result; 14 | } 15 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/whisper/logit_mutators/timestamp_rules.rs: -------------------------------------------------------------------------------- 1 | use ndarray::s; 2 | use ndarray_stats::QuantileExt; 3 | use ratchet::{NDArrayExt, Tensor}; 4 | 5 | use super::LogitMutator; 6 | use crate::whisper::tokenizer::WhisperTokenizer; 7 | 8 | #[derive(Debug, derive_new::new)] 9 | pub struct ApplyTimestampRules { 10 | pub sample_begin: usize, 11 | pub max_initial_timestamp_index: Option, 12 | } 13 | 14 | impl LogitMutator for ApplyTimestampRules { 15 | fn apply( 16 | &self, 17 | logits: Tensor, 18 | tokenizer: &WhisperTokenizer, 19 | tokens: Option<&Tensor>, 20 | ) -> anyhow::Result { 21 | let nd_tokens = tokens.unwrap().clone().into_ndarray::(); 22 | let mut nd_logits = logits.into_ndarray::(); 23 | 24 | nd_logits 25 | .slice_mut(s![.., tokenizer.notimestamps() as usize]) 26 | .map_inplace(move |el| *el = f32::NEG_INFINITY); 27 | 28 | for k in 0..nd_tokens.shape()[0] { 29 | let sampled_tokens = nd_tokens.slice(s![k, self.sample_begin..]); 30 | let sample_len = sampled_tokens.len(); 31 | 32 | let last_was_timestamp = !sampled_tokens.is_empty() 33 | && sampled_tokens[sample_len - 1] >= tokenizer.timestamp_begin(); 34 | let penultimate_was_timestamp = sampled_tokens.len() < 2 35 | || sampled_tokens[sample_len - 2] >= tokenizer.timestamp_begin(); 36 | 37 | if last_was_timestamp { 38 | if penultimate_was_timestamp { 39 | nd_logits 40 | .slice_mut(s![k, tokenizer.timestamp_begin()..]) 41 | .map_inplace(move |el| *el = f32::NEG_INFINITY); 42 | } else { 43 | nd_logits 44 | .slice_mut(s![k, ..WhisperTokenizer::EOT]) 45 | .map_inplace(move |el| *el = f32::NEG_INFINITY); 46 | } 47 | } 48 | 49 | let timestamps = sampled_tokens 50 | .iter() 51 | .filter(|x| **x >= tokenizer.timestamp_begin()) 52 | .collect::>(); 53 | 54 | if !timestamps.is_empty() { 55 | // timestamps shouldn't decrease; forbid timestamp tokens smaller than the last 56 | // also force each segment to have a nonzero length, to prevent infinite looping 57 | let timestamp_last = if last_was_timestamp && !penultimate_was_timestamp { 58 | *timestamps[timestamps.len() - 1] 59 | } else { 60 | timestamps[timestamps.len() - 1] + 1 61 | }; 62 | nd_logits 63 | .slice_mut(s![k, tokenizer.timestamp_begin()..timestamp_last]) 64 | .map_inplace(move |el| *el = f32::NEG_INFINITY); 65 | } 66 | } 67 | if nd_tokens.shape()[1] == self.sample_begin { 68 | // suppress generating non-timestamp tokens at the beginning 69 | nd_logits 70 | .slice_mut(s![.., ..tokenizer.timestamp_begin()]) 71 | .map_inplace(move |el| *el = f32::NEG_INFINITY); 72 | 73 | if self.max_initial_timestamp_index.is_some() { 74 | let last_allowed = (tokenizer.timestamp_begin() as usize) 75 | + self.max_initial_timestamp_index.unwrap(); 76 | nd_logits 77 | .slice_mut(s![.., last_allowed + 1..]) 78 | .map_inplace(move |el| *el = f32::NEG_INFINITY); 79 | } 80 | } 81 | 82 | let logprobs = nd_logits.log_softmax(1); 83 | for _k in 0..nd_tokens.shape()[0] { 84 | let timestamp_logprob = logprobs 85 | .slice(s![.., tokenizer.timestamp_begin()..]) 86 | .logsumexp(1); 87 | let text_logprobs = logprobs.slice(s![.., ..tokenizer.timestamp_begin()]); 88 | let max_text_token_logprob = text_logprobs.max()?; 89 | if timestamp_logprob > *max_text_token_logprob { 90 | nd_logits 91 | .slice_mut(s![.., ..tokenizer.timestamp_begin()]) 92 | .map_inplace(move |el| *el = f32::NEG_INFINITY); 93 | } 94 | } 95 | Ok(Tensor::from(nd_logits)) 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/whisper/mha.rs: -------------------------------------------------------------------------------- 1 | use half::f16; 2 | use num::traits::real::Real; 3 | use ratchet::{rvec, shape, Tensor}; 4 | use ratchet_nn::{KVEntry, Linear, Module}; 5 | 6 | #[derive(Debug)] 7 | pub struct MultiHeadAttention { 8 | q: Linear, 9 | k: Linear, 10 | v: Linear, 11 | o: Linear, 12 | n_heads: usize, 13 | dk: Tensor, 14 | } 15 | 16 | impl MultiHeadAttention { 17 | pub fn new(q: Linear, k: Linear, v: Linear, o: Linear, n_heads: usize) -> MultiHeadAttention { 18 | let n_state = q.w.shape()[1]; 19 | let dk = match q.w.dt().activation_dt() { 20 | ratchet::DType::F16 => { 21 | let dk = f16::from_f32((n_state / n_heads) as f32); 22 | Tensor::from_data( 23 | [dk.powf(f16::from_f32(-0.25))], 24 | shape![1], 25 | q.w.device().clone(), 26 | ) 27 | } 28 | ratchet::DType::F32 => { 29 | let dk = (n_state / n_heads) as f32; 30 | Tensor::from_data([dk.powf(-0.25)], shape![1], q.w.device().clone()) 31 | } 32 | _ => unimplemented!(), 33 | }; 34 | MultiHeadAttention { 35 | q, 36 | k, 37 | v, 38 | o, 39 | n_heads, 40 | dk, 41 | } 42 | } 43 | } 44 | 45 | #[derive(Debug, derive_new::new)] 46 | pub struct MHAInputs { 47 | x: Tensor, 48 | xa: Option, 49 | mask: Option, 50 | cache: Option, 51 | is_causal: bool, 52 | } 53 | 54 | impl Module for MultiHeadAttention { 55 | type Input = MHAInputs; 56 | 57 | fn schedule(&self, input: Self::Input) -> anyhow::Result { 58 | let MHAInputs { 59 | x, 60 | xa, 61 | mask, 62 | cache, 63 | is_causal, 64 | } = input; 65 | 66 | let q = self.q.schedule(x.clone())?; 67 | 68 | let to_project = xa.unwrap_or(x); 69 | let k = self.k.schedule(to_project.clone())?; 70 | let v = self.v.schedule(to_project)?; 71 | 72 | let (k, v) = if let Some(kv) = cache { 73 | let prev_entries = kv.entries; 74 | let k_cache = kv.k_cache.cache(k, 1, prev_entries)?; 75 | let v_cache = kv.v_cache.cache(v, 1, prev_entries)?; 76 | (k_cache, v_cache) 77 | } else { 78 | (k, v) 79 | }; 80 | 81 | self.qkv_attention(q, k, v, mask, is_causal) 82 | } 83 | } 84 | 85 | impl MultiHeadAttention { 86 | fn qkv_attention( 87 | &self, 88 | q: Tensor, 89 | k: Tensor, 90 | v: Tensor, 91 | mask: Option, 92 | is_causal: bool, 93 | ) -> anyhow::Result { 94 | let [bs, n_ctx, n_state]: [usize; 3] = q.shape().try_into()?; 95 | let [k0, k1, _]: [usize; 3] = k.shape().try_into()?; 96 | let [v0, v1, _]: [usize; 3] = v.shape().try_into()?; 97 | let q_dt = q.dt(); 98 | 99 | let hdim = n_state / self.n_heads; 100 | 101 | let qs = shape![bs, n_ctx, self.n_heads, hdim]; 102 | let ks = shape![k0, k1, self.n_heads, hdim]; 103 | let vs = shape![v0, v1, self.n_heads, hdim]; 104 | 105 | let q = q.view(qs)?.permute(&[0, 2, 1, 3])?.mul(self.dk.clone())?; 106 | let k = k.view(ks)?.permute(&[0, 2, 3, 1])?.mul(self.dk.clone())?; 107 | let v = v.view(vs)?.permute(&[0, 2, 1, 3])?; 108 | 109 | let mut qk = q.matmul(k, false, false)?; 110 | 111 | if let Some(m) = mask { 112 | let prepared_mask = if is_causal { 113 | m.slice(&[0..n_ctx, 0..n_ctx])? 114 | } else { 115 | m.clone() 116 | }; 117 | qk = qk.add(prepared_mask)?; 118 | } 119 | qk = qk.full()?; 120 | 121 | let w = qk.softmax(3)?.cast(q_dt)?; 122 | 123 | let s = shape![bs, n_ctx, n_state]; 124 | let wv = w.matmul(v, false, false)?.permute(&[0, 2, 1, 3])?.view(s)?; 125 | 126 | self.o.schedule(wv) 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/whisper/mlp.rs: -------------------------------------------------------------------------------- 1 | use ratchet::Tensor; 2 | use ratchet_nn::{Linear, Module}; 3 | 4 | #[derive(Debug, derive_new::new)] 5 | pub struct MLP { 6 | l1: Linear, 7 | l2: Linear, 8 | } 9 | 10 | impl MLP { 11 | pub fn activation_dt(&self) -> ratchet::DType { 12 | self.l1.w.dt().activation_dt() 13 | } 14 | } 15 | 16 | impl Module for MLP { 17 | type Input = Tensor; 18 | fn schedule(&self, input: Self::Input) -> anyhow::Result { 19 | let input_dt = input.dt(); 20 | self.l2 21 | .schedule(self.l1.schedule(input)?.full()?.gelu()?.cast(input_dt)?) 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/whisper/mod.rs: -------------------------------------------------------------------------------- 1 | mod config; 2 | mod decoder; 3 | mod encoder; 4 | mod logit_mutators; 5 | mod mha; 6 | mod mlp; 7 | mod model; 8 | mod residual_block; 9 | mod samplers; 10 | mod spectrogram; 11 | mod task; 12 | 13 | pub mod options; 14 | pub mod tokenizer; 15 | pub mod transcribe; 16 | pub mod transcript; 17 | 18 | pub use config::Config; 19 | pub use decoder::WhisperDecoder; 20 | pub use encoder::WhisperEncoder; 21 | pub use model::Whisper; 22 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/whisper/samplers/greedy.rs: -------------------------------------------------------------------------------- 1 | use crate::whisper::task::DecodeError; 2 | use crate::whisper::tokenizer::WhisperTokenizer; 3 | 4 | use ndarray::Axis; 5 | use ndarray_stats::QuantileExt; 6 | use ratchet::Tensor; 7 | 8 | pub struct GreedySampler; 9 | 10 | impl GreedySampler { 11 | pub fn sample( 12 | mut tokens: Vec, 13 | logits: Tensor, 14 | ) -> Result<(Tensor, Vec, bool), DecodeError> { 15 | let nd_logits = logits.to_ndarray_view::(); 16 | let next_tokens = nd_logits 17 | .map_axis(Axis(1), |row| row.argmax_skipnan()) 18 | .iter() 19 | .map(|r| { 20 | r.as_ref() 21 | .map_err(|_| DecodeError::InvalidLogits) 22 | .map(|v| *v as i32) 23 | }) 24 | .collect::, DecodeError>>()?; 25 | 26 | tokens.extend_from_slice(&next_tokens); 27 | let completed = tokens[tokens.len() - 1] == WhisperTokenizer::EOT; 28 | Ok((logits, tokens, completed)) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /crates/ratchet-models/src/whisper/samplers/mod.rs: -------------------------------------------------------------------------------- 1 | mod greedy; 2 | 3 | pub use greedy::*; 4 | -------------------------------------------------------------------------------- /crates/ratchet-models/webdriver.json: -------------------------------------------------------------------------------- 1 | { 2 | "goog:chromeOptions": { 3 | "args": [ 4 | "--no-sandbox", 5 | "--headless=new", 6 | "--use-angle=vulkan", 7 | "--enable-features=Vulkan", 8 | "--disable-vulkan-surface", 9 | "--enable-unsafe-webgpu" 10 | ] 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /crates/ratchet-nn/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ratchet-nn" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [features] 9 | pyo3 = ["ratchet/pyo3"] 10 | 11 | [dependencies] 12 | anyhow.workspace = true 13 | derive-new = { workspace = true } 14 | ratchet = { path = "../ratchet-core" } 15 | half = {workspace = true} 16 | 17 | [dev-dependencies] 18 | proptest = { workspace = true } 19 | test-strategy = { workspace = true } 20 | hf-hub = { workspace = true } 21 | ratchet-loader = { path = "../ratchet-loader" } 22 | tokenizers.workspace = true 23 | -------------------------------------------------------------------------------- /crates/ratchet-nn/src/embedding.rs: -------------------------------------------------------------------------------- 1 | use crate::Module; 2 | use ratchet::{shape, Tensor}; 3 | 4 | /// # Embedding 5 | /// 6 | /// Standard `torch.nn.Embedding` module. 7 | #[derive(Debug, derive_new::new)] 8 | pub struct Embedding { 9 | pub weight: Tensor, 10 | } 11 | 12 | impl Module for Embedding { 13 | type Input = Tensor; 14 | 15 | fn schedule(&self, input: Self::Input) -> anyhow::Result { 16 | let mut output_shape = input.shape().clone(); 17 | let weight_rank = self.weight.rank(); 18 | let weight_dim = weight_rank - 1; 19 | output_shape.push(self.weight.shape()[weight_dim]); 20 | 21 | let flat_shape = shape![input.shape().numel()]; 22 | let flat = input.view(flat_shape)?; 23 | let indexed = self.weight.clone().index_select(flat, 0)?; 24 | indexed.view(output_shape) 25 | } 26 | } 27 | 28 | #[cfg(all(test, feature = "pyo3"))] 29 | mod tests { 30 | use hf_hub::api::sync::Api; 31 | use proptest::arbitrary::Arbitrary; 32 | use proptest::strategy::{BoxedStrategy, Just, Strategy}; 33 | use ratchet_loader::gguf::gguf::Header; 34 | use test_strategy::proptest; 35 | use tokenizers::Tokenizer; 36 | 37 | use ratchet::test_util::run_py_prg; 38 | use ratchet::{rvec, shape, Device, DeviceRequest, Shape, Tensor}; 39 | 40 | use crate::{Embedding, Module}; 41 | 42 | impl Arbitrary for EmbeddingProblem { 43 | type Parameters = (); 44 | type Strategy = BoxedStrategy; 45 | 46 | fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { 47 | { 48 | let args = vec![1..512usize, 1..16usize]; 49 | args.prop_map(Into::::into).boxed() 50 | } 51 | .prop_flat_map(|vocab_shape| (Just(vocab_shape), 1..64usize)) 52 | .prop_map(|(vocab_shape, num_indices)| { 53 | let indices = 54 | Tensor::randint(0, vocab_shape[0] as i32, shape![num_indices], Device::CPU); 55 | EmbeddingProblem { 56 | vocab_shape, 57 | indices, 58 | } 59 | }) 60 | .boxed() 61 | } 62 | } 63 | 64 | fn ground_truth(weight: &Tensor, indices: &Tensor) -> anyhow::Result { 65 | let arg = "torch.from_numpy(weight)"; 66 | 67 | let prg = format!( 68 | r#" 69 | import torch 70 | def embedding(weight, indices): 71 | embedding = torch.nn.Embedding.from_pretrained({}) 72 | return embedding(torch.from_numpy(indices)).numpy() 73 | "#, 74 | arg 75 | ); 76 | run_py_prg(prg.to_string(), &[weight, indices], &[], weight.dt()) 77 | } 78 | 79 | fn run_embedding_trial(problem: EmbeddingProblem) { 80 | let device = Device::request_device(DeviceRequest::GPU).unwrap(); 81 | println!("Embedding problem: {:?}", problem); 82 | let EmbeddingProblem { 83 | vocab_shape, 84 | indices, 85 | } = problem; 86 | let weight = Tensor::randn::(vocab_shape, Device::CPU); 87 | 88 | let ground_truth = ground_truth(&weight, &indices).unwrap(); 89 | 90 | let weight = weight.to(&device).unwrap(); 91 | let indices = indices.to(&device).unwrap(); 92 | 93 | let embedding = Embedding::new(weight); 94 | let result = embedding.schedule(indices).unwrap().resolve().unwrap(); 95 | let x = result.to(&Device::CPU).unwrap(); 96 | ground_truth.all_close(&x, 1e-6, 1e-6).unwrap(); 97 | } 98 | 99 | #[derive(Debug, Clone)] 100 | struct EmbeddingProblem { 101 | vocab_shape: Shape, 102 | indices: Tensor, 103 | } 104 | 105 | #[test] 106 | fn debug_embedding() { 107 | let prob = EmbeddingProblem { 108 | vocab_shape: shape![10000, 384], 109 | indices: Tensor::from_data([400i32, 9001i32, 5555i32], shape![1, 3], Device::CPU), 110 | }; 111 | run_embedding_trial(prob); 112 | } 113 | 114 | #[proptest(cases = 16)] 115 | fn test_embedding(prob: EmbeddingProblem) { 116 | run_embedding_trial(prob); 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /crates/ratchet-nn/src/groupnorm.rs: -------------------------------------------------------------------------------- 1 | use ratchet::Tensor; 2 | 3 | #[derive(Debug, Clone, Copy, PartialEq)] 4 | pub struct GroupNormConfig { 5 | pub eps: f32, 6 | pub num_groups: usize, 7 | } 8 | 9 | impl Default for GroupNormConfig { 10 | fn default() -> Self { 11 | Self { 12 | eps: 1e-5, 13 | num_groups: 1, 14 | } 15 | } 16 | } 17 | 18 | #[derive(Clone, Debug)] 19 | pub struct GroupNorm { 20 | weight: Tensor, 21 | bias: Option, 22 | num_groups: usize, 23 | eps: f32, 24 | } 25 | 26 | impl GroupNorm { 27 | pub fn new(weight: Tensor, bias: Option, num_groups: usize, eps: f32) -> Self { 28 | Self { 29 | weight, 30 | bias, 31 | num_groups, 32 | eps, 33 | } 34 | } 35 | 36 | pub fn weight(&self) -> &Tensor { 37 | &self.weight 38 | } 39 | 40 | pub fn bias(&self) -> Option<&Tensor> { 41 | self.bias.as_ref() 42 | } 43 | } 44 | 45 | impl crate::Module for GroupNorm { 46 | type Input = Tensor; 47 | fn schedule(&self, input: Self::Input) -> anyhow::Result { 48 | input.group_norm( 49 | self.num_groups, 50 | self.weight.clone(), 51 | self.bias.clone(), 52 | self.eps, 53 | ) 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /crates/ratchet-nn/src/kv_cache.rs: -------------------------------------------------------------------------------- 1 | use ratchet::{Device, Shape, Tensor, TensorDType}; 2 | 3 | #[derive(Clone, Debug)] 4 | pub struct KVEntry { 5 | pub k_cache: Tensor, 6 | pub v_cache: Tensor, 7 | pub entries: usize, 8 | } 9 | 10 | impl KVEntry { 11 | pub fn allocate(shape: &Shape, device: &Device) -> Self { 12 | KVEntry { 13 | k_cache: Tensor::zeros::(shape, device), 14 | v_cache: Tensor::zeros::(shape, device), 15 | entries: 0, 16 | } 17 | } 18 | } 19 | 20 | #[derive(Clone, Debug)] 21 | pub struct KVCache(Vec); 22 | 23 | impl std::ops::Index for KVCache { 24 | type Output = KVEntry; 25 | 26 | fn index(&self, index: usize) -> &Self::Output { 27 | &self.0[index] 28 | } 29 | } 30 | 31 | impl KVCache { 32 | pub fn new(n_layers: i32, shape: Shape, device: &Device) -> Self { 33 | let mut entries = Vec::with_capacity(n_layers as _); 34 | for _ in 0..n_layers { 35 | entries.push(KVEntry::allocate::(&shape, device)); 36 | } 37 | KVCache(entries) 38 | } 39 | 40 | pub fn update(&mut self, offset: usize) { 41 | for entry in &mut self.0 { 42 | entry.entries += offset; 43 | } 44 | } 45 | 46 | pub fn entries(&self, layer: usize) -> usize { 47 | self.0[layer].entries 48 | } 49 | 50 | pub fn reset(&mut self) { 51 | for entry in &mut self.0 { 52 | entry.entries = 0; 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /crates/ratchet-nn/src/lib.rs: -------------------------------------------------------------------------------- 1 | mod embedding; 2 | mod groupnorm; 3 | mod kv_cache; 4 | mod linear; 5 | mod norm; 6 | mod rope; 7 | 8 | pub use embedding::*; 9 | pub use groupnorm::*; 10 | pub use kv_cache::*; 11 | pub use linear::*; 12 | pub use norm::*; 13 | pub use rope::*; 14 | 15 | use ratchet::Tensor; 16 | 17 | /// # Module 18 | /// 19 | /// Analagous to `torch.nn.Module` in PyTorch, a `Module` is a trait that represents a neural network 20 | /// module. However, it has 1 key difference. 21 | /// 22 | /// In PyTorch, `forward` performs the computation when called. In Ratchet, `schedule` is used to 23 | /// schedule the computation for future execution. The Tensor returned is lazy, in that it 24 | /// represents the result of the computation, but the computation itself has not been performed. 25 | /// 26 | /// If you want to immediately access the result of the computation (say for debugging), call 27 | /// `.resolve()` on the Tensor to execute the work. 28 | pub trait Module { 29 | type Input; 30 | fn schedule(&self, input: Self::Input) -> anyhow::Result; 31 | } 32 | 33 | /// # MutableModule 34 | /// 35 | /// Ditto above, but can mutate self. 36 | pub trait MutableModule { 37 | type Input; 38 | fn schedule(&mut self, input: Self::Input) -> anyhow::Result; 39 | } 40 | -------------------------------------------------------------------------------- /crates/ratchet-nn/src/linear.rs: -------------------------------------------------------------------------------- 1 | use ratchet::Tensor; 2 | 3 | use crate::Module; 4 | 5 | /// # Linear 6 | /// 7 | /// PyTorch case: y = xW^T + b 8 | /// If your weights are already in the correct layout, you can set `transpose` to `false` to avoid the transpose operation. 9 | #[derive(derive_new::new, Debug)] 10 | pub struct Linear { 11 | pub w: Tensor, 12 | b: Option, 13 | } 14 | 15 | impl Module for Linear { 16 | type Input = Tensor; 17 | fn schedule(&self, input: Self::Input) -> anyhow::Result { 18 | let b = if let Some(b) = &self.b { 19 | Some(b.clone().cast(input.dt())?) 20 | } else { 21 | None 22 | }; 23 | self.w.clone().gemm(input, b, false, true, true) 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /crates/ratchet-nn/src/norm.rs: -------------------------------------------------------------------------------- 1 | use ratchet::Tensor; 2 | 3 | #[derive(Debug, Clone, Copy, PartialEq)] 4 | pub struct LayerNormConfig { 5 | pub eps: f32, 6 | pub remove_mean: bool, 7 | } 8 | 9 | impl Default for LayerNormConfig { 10 | fn default() -> Self { 11 | Self { 12 | eps: 1e-5, 13 | remove_mean: true, 14 | } 15 | } 16 | } 17 | 18 | #[derive(Clone, Debug, derive_new::new)] 19 | pub struct LayerNorm { 20 | weight: Tensor, 21 | bias: Option, 22 | eps: f32, 23 | } 24 | 25 | impl LayerNorm { 26 | pub fn weight(&self) -> &Tensor { 27 | &self.weight 28 | } 29 | 30 | pub fn bias(&self) -> Option<&Tensor> { 31 | self.bias.as_ref() 32 | } 33 | } 34 | 35 | impl crate::Module for LayerNorm { 36 | type Input = Tensor; 37 | fn schedule(&self, input: Self::Input) -> anyhow::Result { 38 | let src_dt = input.dt(); 39 | input 40 | .full()? 41 | .layer_norm(self.weight.clone(), self.bias.clone(), self.eps)? 42 | .cast(src_dt) 43 | } 44 | } 45 | 46 | /// RMSNorm 47 | /// 48 | /// https://github.com/NVIDIA/apex/pull/1274/files 49 | #[derive(Clone, Debug, derive_new::new)] 50 | pub struct RMSNorm { 51 | weight: Tensor, 52 | eps: f32, 53 | } 54 | 55 | impl RMSNorm { 56 | pub fn weight(&self) -> &Tensor { 57 | &self.weight 58 | } 59 | } 60 | 61 | impl crate::Module for RMSNorm { 62 | type Input = Tensor; 63 | fn schedule(&self, input: Self::Input) -> anyhow::Result { 64 | let src_dt = input.dt(); 65 | input 66 | .full()? 67 | .rms_norm(self.weight.clone(), self.eps)? 68 | .cast(src_dt) 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /crates/ratchet-nn/src/rope.rs: -------------------------------------------------------------------------------- 1 | use ratchet::Tensor; 2 | 3 | use crate::Module; 4 | 5 | /// """Implements the rotary positional encoding. 6 | /// 7 | /// The traditional implementation rotates consecutive pairs of elements in the 8 | /// feature dimension while the default implementation rotates pairs with 9 | /// stride half the feature dimensions for efficiency. 10 | /// 11 | /// For more details see `RoFormer: Enhanced Transformer with Rotary Position 12 | /// Embedding `_. 13 | /// 14 | /// Args: 15 | /// dims (int): The feature dimensions to be rotated. If the input feature 16 | /// is larger than dims then the rest is left unchanged. 17 | /// traditional (bool, optional): If set to ``True`` choose the traditional 18 | /// implementation which is slightly less efficient. Default: ``False``. 19 | /// base (float, optional): The base used to compute angular frequency for 20 | /// each dimension in the positional encodings. Default: ``10000``. 21 | /// scale (float, optional): The scale used to scale the positions. Default: ``1.0``. 22 | /// """ 23 | #[derive(Clone, Debug, derive_new::new)] 24 | pub struct RotaryEmbedding { 25 | dim: usize, 26 | traditional: bool, 27 | base: f32, 28 | scale: f32, 29 | } 30 | 31 | pub struct RotaryInput { 32 | pub input: Tensor, 33 | pub offset: usize, 34 | } 35 | 36 | impl Module for RotaryEmbedding { 37 | type Input = RotaryInput; 38 | 39 | fn schedule(&self, input: Self::Input) -> anyhow::Result { 40 | let RotaryInput { input, offset } = input; 41 | input.rope(self.dim, self.base, offset) 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /crates/ratchet-web/.gitignore: -------------------------------------------------------------------------------- 1 | /test-data 2 | -------------------------------------------------------------------------------- /crates/ratchet-web/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ratchet-web" 3 | version = "0.3.0" 4 | edition = "2021" 5 | license = "MIT" 6 | description = "A web-first, cross-platform ML framework." 7 | keywords = ["llm","wasm","transformers","webgpu","ml","machine-learning","deep-learning"] 8 | repository = "https://github.com/FL33TW00D/ratchet" 9 | 10 | [lib] 11 | crate-type = ["cdylib", "rlib"] 12 | 13 | [package.metadata.docs.rs] 14 | default-target = "wasm32-unknown-unknown" 15 | 16 | [package.metadata.wasm-pack.profile.dev.wasm-bindgen] 17 | debug-js-glue = true 18 | demangle-name-section = true 19 | dwarf-debug-info = true 20 | 21 | [package.metadata.wasm-pack.profile.release] 22 | wasm-opt = ['-O3', '--enable-simd'] 23 | 24 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 25 | [dependencies] 26 | ratchet-models = { path = "../ratchet-models" } 27 | ratchet-hub = { path = "../ratchet-hub" } 28 | ratchet-loader = { path = "../ratchet-loader" } 29 | wasm-bindgen = { workspace = true } 30 | wasm-bindgen-futures = { workspace = true } 31 | js-sys = { workspace = true } 32 | indexed_db_futures = { workspace = true } 33 | thiserror.workspace = true 34 | anyhow.workspace = true 35 | serde = { workspace = true } 36 | serde-wasm-bindgen = { workspace = true } 37 | console_error_panic_hook = { workspace = true } 38 | console_log = { workspace = true, features = ["color"] } 39 | log.workspace = true 40 | hound = { workspace = true } 41 | fern = { workspace = true } 42 | chrono = { workspace = true } 43 | uuid = { workspace = true, features = ["v4", "serde"] } 44 | tokenizers = { version = "0.19.1", default-features = false, features=["unstable_wasm"] } 45 | futures = "0.3.30" 46 | [dependencies.web-sys] 47 | features = [ 48 | 'console', 49 | 'Headers', 50 | 'Request', 51 | 'RequestInit', 52 | 'RequestMode', 53 | 'Response', 54 | 'ReadableStream', 55 | 'ReadableStreamGetReaderOptions', 56 | 'ReadableStreamReaderMode', 57 | 'Window', 58 | 'Navigator', 59 | 'StorageManager', 60 | 'Cache', 61 | 'CacheStorage', 62 | 'IdbKeyRange', 63 | ] 64 | workspace = true 65 | 66 | 67 | [target.'cfg(target_arch = "wasm32")'.dependencies] 68 | getrandom = { version = "0.2.6", features = ["js"] } 69 | 70 | [dev-dependencies] 71 | wasm-bindgen-test.workspace = true 72 | ratchet-hub = { path = "../ratchet-hub" } 73 | 74 | -------------------------------------------------------------------------------- /crates/ratchet-web/README.md: -------------------------------------------------------------------------------- 1 | # ratchet-web 2 | 3 | -------------------------------------------------------------------------------- /crates/ratchet-web/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![cfg(target_arch = "wasm32")] 2 | mod db; 3 | mod model; 4 | -------------------------------------------------------------------------------- /examples/ratchet-moondream/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | -------------------------------------------------------------------------------- /examples/ratchet-moondream/README.md: -------------------------------------------------------------------------------- 1 | # Getting Started with Create React App 2 | 3 | This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app). 4 | 5 | ## Available Scripts 6 | 7 | In the project directory, you can run: 8 | 9 | ### `npm start` 10 | 11 | Runs the app in the development mode.\ 12 | Open [http://localhost:3000](http://localhost:3000) to view it in your browser. 13 | 14 | The page will reload when you make changes.\ 15 | You may also see any lint errors in the console. 16 | 17 | ### `npm test` 18 | 19 | Launches the test runner in the interactive watch mode.\ 20 | See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information. 21 | 22 | ### `npm run build` 23 | 24 | Builds the app for production to the `build` folder.\ 25 | It correctly bundles React in production mode and optimizes the build for the best performance. 26 | 27 | The build is minified and the filenames include the hashes.\ 28 | Your app is ready to be deployed! 29 | 30 | See the section about [deployment](https://facebook.github.io/create-react-app/docs/deployment) for more information. 31 | -------------------------------------------------------------------------------- /examples/ratchet-moondream/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ratchet-moondream", 3 | "version": "0.1.0", 4 | "private": true, 5 | "dependencies": { 6 | "@emotion/react": "^11.11.4", 7 | "@emotion/styled": "^11.11.5", 8 | "@mui/icons-material": "^5.15.19", 9 | "@mui/material": "^5.15.19", 10 | "@ratchet-ml/ratchet-web": "file:../../target/pkg/ratchet-web", 11 | "react": "^18.3.1", 12 | "react-dom": "^18.3.1", 13 | "react-scripts": "5.0.1", 14 | "web-vitals": "^2.1.4" 15 | }, 16 | "scripts": { 17 | "start": "react-scripts start", 18 | "build": "react-scripts build", 19 | "test": "react-scripts test", 20 | "eject": "react-scripts eject" 21 | }, 22 | "eslintConfig": { 23 | "extends": [ 24 | "react-app", 25 | "react-app/jest" 26 | ] 27 | }, 28 | "browserslist": { 29 | "production": [ 30 | ">0.2%", 31 | "not dead", 32 | "not op_mini all" 33 | ], 34 | "development": [ 35 | "last 1 chrome version", 36 | "last 1 firefox version", 37 | "last 1 safari version" 38 | ] 39 | }, 40 | "devDependencies": { 41 | "prettier": "3.3.1" 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /examples/ratchet-moondream/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 11 | Ratchet Moondream 12 | 13 | 14 | 15 |

16 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /examples/ratchet-moondream/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "React App", 3 | "name": "Create React App Sample", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "logo192.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "logo512.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } 26 | -------------------------------------------------------------------------------- /examples/ratchet-moondream/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /examples/ratchet-moondream/src/App.css: -------------------------------------------------------------------------------- 1 | .App { 2 | text-align: center; 3 | } 4 | 5 | .App-logo { 6 | height: 40vmin; 7 | pointer-events: none; 8 | } 9 | 10 | @media (prefers-reduced-motion: no-preference) { 11 | .App-logo { 12 | animation: App-logo-spin infinite 20s linear; 13 | } 14 | } 15 | 16 | .App-header { 17 | background-color: #282c34; 18 | min-height: 100vh; 19 | display: flex; 20 | flex-direction: column; 21 | align-items: center; 22 | justify-content: center; 23 | font-size: calc(10px + 2vmin); 24 | color: white; 25 | } 26 | 27 | .App-link { 28 | color: #61dafb; 29 | } 30 | 31 | @keyframes App-logo-spin { 32 | from { 33 | transform: rotate(0deg); 34 | } 35 | to { 36 | transform: rotate(360deg); 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /examples/ratchet-moondream/src/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | margin: 0; 3 | font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "Roboto", "Oxygen", 4 | "Ubuntu", "Cantarell", "Fira Sans", "Droid Sans", "Helvetica Neue", 5 | sans-serif; 6 | -webkit-font-smoothing: antialiased; 7 | -moz-osx-font-smoothing: grayscale; 8 | } 9 | 10 | code { 11 | font-family: source-code-pro, Menlo, Monaco, Consolas, "Courier New", 12 | monospace; 13 | } 14 | -------------------------------------------------------------------------------- /examples/ratchet-moondream/src/index.js: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import ReactDOM from "react-dom/client"; 3 | import "./index.css"; 4 | import App from "./App"; 5 | 6 | const root = ReactDOM.createRoot(document.getElementById("root")); 7 | root.render( 8 | 9 | 10 | , 11 | ); 12 | -------------------------------------------------------------------------------- /examples/ratchet-phi/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | .yarn/install-state.gz 8 | 9 | # testing 10 | /coverage 11 | 12 | # next.js 13 | /.next/ 14 | /out/ 15 | 16 | # production 17 | /build 18 | 19 | # misc 20 | .DS_Store 21 | *.pem 22 | 23 | # debug 24 | npm-debug.log* 25 | yarn-debug.log* 26 | yarn-error.log* 27 | 28 | # local env files 29 | .env*.local 30 | 31 | # vercel 32 | .vercel 33 | 34 | # typescript 35 | *.tsbuildinfo 36 | next-env.d.ts 37 | -------------------------------------------------------------------------------- /examples/ratchet-phi/README.md: -------------------------------------------------------------------------------- 1 | This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app). 2 | 3 | ## Getting Started 4 | 5 | First, run the development server: 6 | 7 | ```bash 8 | npm run dev 9 | # or 10 | yarn dev 11 | # or 12 | pnpm dev 13 | # or 14 | bun dev 15 | ``` 16 | 17 | Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. 18 | 19 | You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file. 20 | 21 | This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font. 22 | 23 | ## Learn More 24 | 25 | To learn more about Next.js, take a look at the following resources: 26 | 27 | - [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API. 28 | - [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial. 29 | 30 | You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js/) - your feedback and contributions are welcome! 31 | 32 | ## Deploy on Vercel 33 | 34 | The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js. 35 | 36 | Check out our [Next.js deployment documentation](https://nextjs.org/docs/deployment) for more details. 37 | -------------------------------------------------------------------------------- /examples/ratchet-phi/next.config.mjs: -------------------------------------------------------------------------------- 1 | /** @type {import('next').NextConfig} */ 2 | const nextConfig = { 3 | output: 'export' 4 | }; 5 | 6 | export default nextConfig; 7 | -------------------------------------------------------------------------------- /examples/ratchet-phi/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ratchet-phi", 3 | "version": "0.1.0", 4 | "private": true, 5 | "scripts": { 6 | "dev": "next dev", 7 | "build": "next build", 8 | "start": "next start", 9 | "lint": "next lint" 10 | }, 11 | "dependencies": { 12 | "@ffmpeg/ffmpeg": "0.12.6", 13 | "@ffmpeg/util": "^0.12.1", 14 | "@ratchet-ml/ratchet-web": "link:../../target/pkg/ratchet-web", 15 | "fix-webm-duration": "^1.0.5", 16 | "next": "14.1.0", 17 | "react": "^18.2.0", 18 | "react-dom": "^18.2.0", 19 | "react-hot-toast": "^2.4.1", 20 | "react-responsive-modal": "^6.4.2" 21 | }, 22 | "devDependencies": { 23 | "@types/node": "^20.11.24", 24 | "@types/react": "^18.2.61", 25 | "@types/react-dom": "^18.2.19", 26 | "autoprefixer": "^10.4.18", 27 | "postcss": "^8.4.35", 28 | "tailwindcss": "^3.4.1", 29 | "typescript": "^5.3.3" 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /examples/ratchet-phi/postcss.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | }, 6 | } 7 | -------------------------------------------------------------------------------- /examples/ratchet-phi/src/app/components/WebGPUModal.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect } from "react"; 2 | import Modal from "react-responsive-modal"; 3 | 4 | const WebGPUModal = () => { 5 | const [hasWebGPU, setHasWebGPU] = useState(false); 6 | const [isModalOpen, setIsModalOpen] = useState(false); 7 | 8 | useEffect(() => { 9 | //@ts-ignore 10 | if (!navigator.gpu) { 11 | setIsModalOpen(true); 12 | return; 13 | } 14 | setHasWebGPU(true); 15 | }, []); 16 | 17 | const handleModalClose = () => { 18 | setIsModalOpen(false); 19 | }; 20 | 21 | const closeIcon = ( 22 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | ); 47 | 48 | return ( 49 | <> 50 | {!hasWebGPU ? ( 51 | 60 |
63 |
64 |

65 | Uh oh! It looks like your browser doesn't 66 | support WebGPU. Please try again in supported browser (Chrome 121+). 67 |

68 |
69 |
70 |
71 | ) : ( 72 | <> 73 | )} 74 | 75 | ); 76 | }; 77 | 78 | export default WebGPUModal; 79 | 80 | -------------------------------------------------------------------------------- /examples/ratchet-phi/src/app/components/progressBar.tsx: -------------------------------------------------------------------------------- 1 | const ProgressBar = ({ progress }: any) => { 2 | return ( 3 | <> 4 | {progress > 0 && progress < 100 && ( 5 |
6 |
7 |
11 |
12 |
13 | )} 14 | 15 | ); 16 | }; 17 | 18 | export default ProgressBar; 19 | 20 | -------------------------------------------------------------------------------- /examples/ratchet-phi/src/app/components/warningModal.tsx: -------------------------------------------------------------------------------- 1 | import React from "react"; 2 | import Modal from "react-responsive-modal"; 3 | 4 | interface WarningModalProps { 5 | isModalOpen: boolean; 6 | setIsModalOpen: (value: boolean) => void; 7 | loadModel: () => void; 8 | } 9 | 10 | const WarningModal = ({ isModalOpen, setIsModalOpen, loadModel }: WarningModalProps) => { 11 | const handleModalClose = () => { 12 | setIsModalOpen(false); 13 | }; 14 | 15 | const closeIcon = ( 16 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | ); 41 | 42 | return ( 43 | <> 44 | {isModalOpen ? ( 45 | 54 |
57 |
58 |

59 | ⚠️ You are about to download a 2.9GB file. Click to confirm. 60 |

61 | 70 |
71 |
72 |
73 | ) : ( 74 | <> 75 | )} 76 | 77 | ); 78 | }; 79 | 80 | export default WarningModal; 81 | 82 | -------------------------------------------------------------------------------- /examples/ratchet-phi/src/app/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/ratchet/136da4d5216910bfd015b27a17b837c21f17163a/examples/ratchet-phi/src/app/favicon.ico -------------------------------------------------------------------------------- /examples/ratchet-phi/src/app/layout.tsx: -------------------------------------------------------------------------------- 1 | import type { Metadata } from "next"; 2 | import { Inter } from "next/font/google"; 3 | import "./globals.css"; 4 | import { Toaster } from "react-hot-toast"; 5 | import "react-responsive-modal/styles.css"; 6 | 7 | const inter = Inter({ subsets: ["latin"] }); 8 | 9 | export const metadata: Metadata = { 10 | title: "Ratchet + Phi", 11 | description: "Simple demo of Phi.", 12 | }; 13 | 14 | export default function RootLayout({ 15 | children, 16 | }: Readonly<{ 17 | children: React.ReactNode; 18 | }>) { 19 | return ( 20 | 21 | 22 |
23 | 24 |
{children}
25 |
26 | 27 | 28 | ); 29 | } 30 | -------------------------------------------------------------------------------- /examples/ratchet-phi/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | module.exports = { 3 | content: [ 4 | "./src/**/*.{js,ts,jsx,tsx,mdx}", 5 | ], 6 | theme: { 7 | extend: {}, 8 | }, 9 | plugins: [], 10 | } 11 | -------------------------------------------------------------------------------- /examples/ratchet-phi/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "lib": ["dom", "dom.iterable", "esnext"], 4 | "allowJs": true, 5 | "skipLibCheck": true, 6 | "strict": true, 7 | "noEmit": true, 8 | "esModuleInterop": true, 9 | "module": "esnext", 10 | "moduleResolution": "bundler", 11 | "resolveJsonModule": true, 12 | "isolatedModules": true, 13 | "jsx": "preserve", 14 | "incremental": true, 15 | "plugins": [ 16 | { 17 | "name": "next" 18 | } 19 | ], 20 | "paths": { 21 | "@/*": ["./src/*"] 22 | } 23 | }, 24 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], 25 | "exclude": ["node_modules"] 26 | } 27 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | .yarn/install-state.gz 8 | 9 | # testing 10 | /coverage 11 | 12 | # next.js 13 | /.next/ 14 | /out/ 15 | 16 | # production 17 | /build 18 | 19 | # misc 20 | .DS_Store 21 | *.pem 22 | 23 | # debug 24 | npm-debug.log* 25 | yarn-debug.log* 26 | yarn-error.log* 27 | 28 | # local env files 29 | .env*.local 30 | 31 | # vercel 32 | .vercel 33 | 34 | # typescript 35 | *.tsbuildinfo 36 | next-env.d.ts 37 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/README.md: -------------------------------------------------------------------------------- 1 | This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app). 2 | 3 | ## Getting Started 4 | 5 | First, run the development server: 6 | 7 | ```bash 8 | npm run dev 9 | # or 10 | yarn dev 11 | # or 12 | pnpm dev 13 | # or 14 | bun dev 15 | ``` 16 | 17 | Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. 18 | 19 | You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file. 20 | 21 | This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font. 22 | 23 | ## Learn More 24 | 25 | To learn more about Next.js, take a look at the following resources: 26 | 27 | - [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API. 28 | - [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial. 29 | 30 | You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js/) - your feedback and contributions are welcome! 31 | 32 | ## Deploy on Vercel 33 | 34 | The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js. 35 | 36 | Check out our [Next.js deployment documentation](https://nextjs.org/docs/deployment) for more details. 37 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/next.config.mjs: -------------------------------------------------------------------------------- 1 | /** @type {import('next').NextConfig} */ 2 | const nextConfig = { 3 | output: 'export' 4 | }; 5 | 6 | export default nextConfig; 7 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ratchet-whisper", 3 | "version": "0.1.0", 4 | "private": true, 5 | "scripts": { 6 | "dev": "next dev", 7 | "build": "next build", 8 | "start": "next start", 9 | "lint": "next lint" 10 | }, 11 | "dependencies": { 12 | "@ffmpeg/ffmpeg": "0.12.6", 13 | "@ffmpeg/util": "^0.12.1", 14 | "@ratchet-ml/ratchet-web": "link:../../target/pkg/ratchet-web", 15 | "fix-webm-duration": "^1.0.5", 16 | "next": "14.1.0", 17 | "react": "^18.2.0", 18 | "react-dom": "^18.2.0", 19 | "react-hot-toast": "^2.4.1", 20 | "react-responsive-modal": "^6.4.2" 21 | }, 22 | "devDependencies": { 23 | "@types/node": "^20.11.24", 24 | "@types/react": "^18.2.61", 25 | "@types/react-dom": "^18.2.19", 26 | "autoprefixer": "^10.4.18", 27 | "postcss": "^8.4.35", 28 | "tailwindcss": "^3.4.1", 29 | "typescript": "^5.3.3" 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/postcss.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | plugins: { 3 | tailwindcss: {}, 4 | autoprefixer: {}, 5 | }, 6 | } 7 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/src/app/audio.ts: -------------------------------------------------------------------------------- 1 | import fixWebmDuration from "fix-webm-duration"; 2 | 3 | export interface Recording { 4 | blob: Blob; 5 | buffer: ArrayBuffer; 6 | } 7 | 8 | export class MicRecorder { 9 | private currentStart: number | null = null; 10 | private currentStream: MediaStream | null = null; 11 | private inner: MediaRecorder | null = null; 12 | private audioChunks: Blob[] = []; 13 | private static readonly supportedMimes = [ 14 | "audio/webm", // Chrome 15 | "audio/ogg", // Firefox 16 | ]; 17 | 18 | private constructor(recorder: MediaRecorder) { 19 | this.inner = recorder; 20 | } 21 | 22 | public static async start(): Promise { 23 | if (!navigator.mediaDevices) { 24 | throw new Error("Media device not available"); 25 | } 26 | 27 | const stream = await navigator.mediaDevices.getUserMedia({ 28 | audio: true, 29 | }); 30 | let mimeType = MicRecorder.supportedMimes.find((mime: string) => 31 | MediaRecorder.isTypeSupported(mime) 32 | ); 33 | const inner = new MediaRecorder(stream, { 34 | mimeType 35 | }); 36 | const recorder = new MicRecorder(inner); 37 | recorder.currentStream = stream; 38 | 39 | inner.addEventListener("dataavailable", (event) => { 40 | recorder.audioChunks.push(event.data); 41 | }); 42 | inner.start(); 43 | recorder.currentStart = Date.now(); 44 | return recorder; 45 | } 46 | 47 | public isRecording(): boolean { 48 | return this.inner !== null && this.inner.state === "recording"; 49 | } 50 | 51 | public async stop(): Promise { 52 | if (!this.inner) { 53 | throw new Error("Please start the recorder first"); 54 | } 55 | 56 | const promise: Promise = new Promise( 57 | (resolve) => { 58 | this.inner!.addEventListener("stop", async () => { 59 | const duration = Date.now() - this.currentStart!; 60 | let blob = new Blob(this.audioChunks, { 61 | type: this.inner!.mimeType, 62 | }); 63 | 64 | if (this.inner!.mimeType.includes("webm")) { 65 | blob = await fixWebmDuration(blob, duration, { 66 | logger: false, 67 | }); 68 | } 69 | 70 | const buffer = await blob.arrayBuffer(); 71 | 72 | resolve({ 73 | blob, 74 | buffer, 75 | }); 76 | }); 77 | this.inner!.stop(); 78 | this.currentStream!.getTracks().forEach((track) => 79 | track.stop() 80 | ); 81 | } 82 | ); 83 | return promise; 84 | } 85 | } 86 | 87 | export default MicRecorder; 88 | 89 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/src/app/components/WebGPUModal.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect } from "react"; 2 | import Modal from "react-responsive-modal"; 3 | 4 | const WebGPUModal = () => { 5 | const [hasWebGPU, setHasWebGPU] = useState(false); 6 | const [isModalOpen, setIsModalOpen] = useState(false); 7 | 8 | useEffect(() => { 9 | //@ts-ignore 10 | if (!navigator.gpu) { 11 | setIsModalOpen(true); 12 | return; 13 | } 14 | setHasWebGPU(true); 15 | }, []); 16 | 17 | const handleModalClose = () => { 18 | setIsModalOpen(false); 19 | }; 20 | 21 | const closeIcon = ( 22 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | ); 47 | 48 | return ( 49 | <> 50 | {!hasWebGPU ? ( 51 | 60 |
63 |
64 |

65 | Uh oh! It looks like your browser doesn't 66 | support WebGPU. Please try again in supported browser (Chrome 121+). 67 |

68 |
69 |
70 |
71 | ) : ( 72 | <> 73 | )} 74 | 75 | ); 76 | }; 77 | 78 | export default WebGPUModal; 79 | 80 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/src/app/components/configModal.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState, useEffect } from "react"; 2 | import Modal from "react-responsive-modal"; 3 | import LanguageDropdown from "./languageDropdown"; 4 | import SuppressComponent from "./suppressSelector"; 5 | import TaskComponent from "./taskSelector"; 6 | import { Task } from "@ratchet-ml/ratchet-web"; 7 | 8 | interface ConfigModalProps { 9 | isModalOpen: boolean; 10 | setIsModalOpen: React.Dispatch>; 11 | configOptions: ConfigOptions; 12 | setConfigOptions: React.Dispatch>; 13 | } 14 | 15 | export interface ConfigOptions { 16 | language: string | null; 17 | task: Task; 18 | suppress_non_speech: boolean; 19 | } 20 | 21 | const ConfigModal = (props: ConfigModalProps) => { 22 | const handleModalClose = () => { 23 | props.setIsModalOpen(false); 24 | }; 25 | 26 | return ( 27 | <> 28 | 37 |
40 |
41 | 42 | 43 | 44 |
45 |
46 |
47 | 48 | ); 49 | }; 50 | 51 | export default ConfigModal; 52 | 53 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/src/app/components/micButton.tsx: -------------------------------------------------------------------------------- 1 | import { useState } from "react"; 2 | import MicRecorder from "../audio"; 3 | 4 | const SAMPLE_RATE = 16000; 5 | 6 | interface MicButtonProps { 7 | setBlobUrl: (blobUrl: string) => void; 8 | setAudioData: (audioData: Float32Array) => void; 9 | setAudioMetadata: (audioMetadata: AudioMetadata) => void; 10 | } 11 | 12 | export interface AudioMetadata { 13 | file: File; 14 | fromMic: boolean; 15 | } 16 | 17 | const MicButton = (props: MicButtonProps) => { 18 | const [mic, setMic] = useState(null); 19 | const [isRecording, setIsRecording] = useState(false); 20 | 21 | const handleRecord = async () => { 22 | setMic(await MicRecorder.start()); 23 | }; 24 | 25 | const handleStop = async () => { 26 | if (!mic) { 27 | return; 28 | } 29 | let recording = await mic.stop(); 30 | let ctx = new AudioContext({ sampleRate: SAMPLE_RATE }); 31 | let resampled = await ctx.decodeAudioData(recording.buffer); 32 | let ch0 = resampled.getChannelData(0); 33 | props.setAudioData(new Float32Array(ch0.buffer)); 34 | 35 | let blob = recording.blob; 36 | props.setAudioMetadata({ 37 | file: new File([blob], "recording.wav"), 38 | fromMic: true, 39 | }); 40 | props.setBlobUrl(URL.createObjectURL(blob)); 41 | setMic(null); 42 | }; 43 | 44 | const handleClick = async () => { 45 | if (isRecording) { 46 | await handleStop(); 47 | } else { 48 | await handleRecord(); 49 | } 50 | setIsRecording(!isRecording); 51 | }; 52 | 53 | return ( 54 |
55 | 89 |
90 | ); 91 | }; 92 | 93 | export default MicButton; 94 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/src/app/components/modelSelector.tsx: -------------------------------------------------------------------------------- 1 | import { useState } from "react"; 2 | import { AvailableModels, Whisper } from "@ratchet-ml/ratchet-web"; 3 | 4 | interface ModelSelectorProps { 5 | selectedModel: AvailableModels | null; 6 | setSelectedModel: (model: AvailableModels) => void; 7 | loaded: boolean; 8 | progress: number; 9 | } 10 | 11 | const UNITS = [ 12 | "byte", 13 | "kilobyte", 14 | "megabyte", 15 | "gigabyte", 16 | ]; 17 | const BYTES_PER_KB = 1000; 18 | 19 | export function humanFileSize(sizeBytes: number | bigint): string { 20 | let size = Math.abs(Number(sizeBytes)); 21 | 22 | let u = 0; 23 | while (size >= BYTES_PER_KB && u < UNITS.length - 1) { 24 | size /= BYTES_PER_KB; 25 | ++u; 26 | } 27 | 28 | return new Intl.NumberFormat([], { 29 | style: "unit", 30 | unit: UNITS[u], 31 | unitDisplay: "short", 32 | maximumFractionDigits: 1, 33 | }).format(size); 34 | } 35 | 36 | export function availableModelToString(model: AvailableModels): string { 37 | if ("Whisper" in model) { 38 | return model.Whisper; 39 | } else if ("Llama" in model) { 40 | return model.Llama; 41 | } 42 | return ""; 43 | } 44 | 45 | const ModelSelector = (props: ModelSelectorProps) => { 46 | const { selectedModel, setSelectedModel, loaded, progress } = props; 47 | const [dropdownOpen, setDropdownOpen] = useState(false); 48 | 49 | const whisper = ["tiny", "base", "small", "medium", "large_v2", "large_v3", "distil_large_v3"] as const; 50 | type WhisperIter = typeof whisper[number]; 51 | 52 | const modelNames = [ 53 | ...whisper, 54 | ]; 55 | 56 | const displayModels = () => { 57 | return modelNames.map((model, idx) => ( 58 |
  • 59 | { 63 | const isWhisper = whisper.includes(model as WhisperIter); 64 | if (isWhisper) { 65 | setSelectedModel({ Whisper: model as Whisper }); 66 | } 67 | setDropdownOpen(false); 68 | }} 69 | > 70 | {model} 71 | 72 |
  • 73 | )); 74 | }; 75 | 76 | return ( 77 | <> 78 | {progress > 0 && !loaded && ( 79 |
    80 | 83 |
    84 | )} 85 |
    86 | 103 |
      109 | {displayModels()} 110 |
    111 |
    112 | 113 | ); 114 | }; 115 | 116 | export default ModelSelector; 117 | 118 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/src/app/components/progressBar.tsx: -------------------------------------------------------------------------------- 1 | const ProgressBar = ({ progress }: any) => { 2 | return ( 3 | <> 4 | {progress > 0 && progress < 100 && ( 5 |
    6 |
    7 |
    11 |
    12 |
    13 | )} 14 | 15 | ); 16 | }; 17 | 18 | export default ProgressBar; 19 | 20 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/src/app/components/suppressSelector.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState } from "react"; 2 | import { ConfigOptions } from "./configModal"; 3 | 4 | interface SuppressComponentProps { 5 | configOptions: ConfigOptions; 6 | setConfigOptions: React.Dispatch>; 7 | } 8 | 9 | const SuppressComponent = (props: SuppressComponentProps) => { 10 | const [checkedState, setCheckedState] = useState({ 11 | suppress_non_speech: props.configOptions.suppress_non_speech 12 | }); 13 | 14 | const handleOnChange = (event: React.ChangeEvent) => { 15 | setCheckedState({ 16 | ...checkedState, 17 | [event.target.name]: event.target.checked 18 | }); 19 | 20 | props.setConfigOptions({ 21 | ...props.configOptions, 22 | suppress_non_speech: event.target.checked 23 | }); 24 | }; 25 | 26 | return ( 27 |
    28 | 29 |
    30 |
    31 | 34 | 42 |
    43 |
    44 |
    45 | ); 46 | }; 47 | 48 | export default SuppressComponent; 49 | 50 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/src/app/components/taskSelector.tsx: -------------------------------------------------------------------------------- 1 | import React, { useState } from "react"; 2 | import { ConfigOptions } from "./configModal"; 3 | import { Task } from "@ratchet-ml/ratchet-web"; 4 | 5 | interface TaskComponentProps { 6 | configOptions: ConfigOptions; 7 | setConfigOptions: React.Dispatch>; 8 | } 9 | 10 | const TaskComponent = (props: TaskComponentProps) => { 11 | let state = { 12 | translate: props.configOptions.task === Task.Translate, 13 | transcribe: props.configOptions.task === Task.Transcribe, 14 | }; 15 | 16 | const [checkedState, setCheckedState] = useState(state); 17 | 18 | const handleOnChange = (event: React.ChangeEvent) => { 19 | setCheckedState({ 20 | ...checkedState, 21 | [event.target.name]: event.target.checked, 22 | }); 23 | if (event.target.name === "translate") 24 | setCheckedState({ 25 | translate: event.target.checked, 26 | transcribe: !event.target.checked, 27 | }); 28 | if (event.target.name === "transcribe") 29 | setCheckedState({ 30 | translate: !event.target.checked, 31 | transcribe: event.target.checked, 32 | }); 33 | props.setConfigOptions((prev: ConfigOptions) => ({ 34 | ...prev, 35 | task: 36 | event.target.name === "translate" 37 | ? Task.Translate 38 | : Task.Transcribe, 39 | })); 40 | }; 41 | 42 | return ( 43 |
    44 | 45 |
    46 |
    47 | 50 | 58 |
    59 | 60 |
    61 | 64 | 72 |
    73 |
    74 |
    75 | ); 76 | }; 77 | 78 | export default TaskComponent; 79 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/src/app/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/ratchet/136da4d5216910bfd015b27a17b837c21f17163a/examples/ratchet-whisper/src/app/favicon.ico -------------------------------------------------------------------------------- /examples/ratchet-whisper/src/app/layout.tsx: -------------------------------------------------------------------------------- 1 | import type { Metadata } from "next"; 2 | import { Inter } from "next/font/google"; 3 | import "./globals.css"; 4 | import { Toaster } from "react-hot-toast"; 5 | import "react-responsive-modal/styles.css"; 6 | 7 | const inter = Inter({ subsets: ["latin"] }); 8 | 9 | export const metadata: Metadata = { 10 | title: "Whisper by Ratchet", 11 | description: "Simple demo of Whisper.", 12 | }; 13 | 14 | export default function RootLayout({ 15 | children, 16 | }: Readonly<{ 17 | children: React.ReactNode; 18 | }>) { 19 | return ( 20 | 21 | 22 |
    23 | 24 |
    {children}
    25 |
    26 | 27 | 28 | ); 29 | } 30 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/tailwind.config.js: -------------------------------------------------------------------------------- 1 | /** @type {import('tailwindcss').Config} */ 2 | module.exports = { 3 | content: [ 4 | "./src/**/*.{js,ts,jsx,tsx,mdx}", 5 | ], 6 | theme: { 7 | extend: {}, 8 | }, 9 | plugins: [], 10 | } 11 | -------------------------------------------------------------------------------- /examples/ratchet-whisper/tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "lib": ["dom", "dom.iterable", "esnext"], 4 | "allowJs": true, 5 | "skipLibCheck": true, 6 | "strict": true, 7 | "noEmit": true, 8 | "esModuleInterop": true, 9 | "module": "esnext", 10 | "moduleResolution": "bundler", 11 | "resolveJsonModule": true, 12 | "isolatedModules": true, 13 | "jsx": "preserve", 14 | "incremental": true, 15 | "plugins": [ 16 | { 17 | "name": "next" 18 | } 19 | ], 20 | "paths": { 21 | "@/*": ["./src/*"] 22 | } 23 | }, 24 | "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], 25 | "exclude": ["node_modules"] 26 | } 27 | -------------------------------------------------------------------------------- /justfile: -------------------------------------------------------------------------------- 1 | line-count: 2 | cd ./crates/ratchet-core && scc -irs 3 | install-pyo3: 4 | env PYTHON_CONFIGURE_OPTS="--enable-shared" pyenv install --verbose 3.10.6 5 | pyenv local 3.10.6 6 | echo $(python --version) 7 | wasm CRATE: 8 | node_modules/.bin/wasm-pack build -s ratchet --target web -d `pwd`/target/pkg/{{CRATE}} --out-name {{CRATE}} ./crates/{{CRATE}} --release 9 | wasm-dbg CRATE: 10 | node_modules/.bin/wasm-pack build -s ratchet --target web -d `pwd`/target/pkg/{{CRATE}} --out-name {{CRATE}} ./crates/{{CRATE}} --dev 11 | wasm-test CRATE BROWSER: 12 | cp ./config/webdriver-macos.json ./crates/{{CRATE}}/webdriver.json 13 | node_modules/.bin/wasm-pack test --{{BROWSER}} --headless `pwd`/crates/{{CRATE}} 14 | # Publish a new version of a crate using pkg.pr.new 15 | wasm-publish-pr CRATE: 16 | node_modules/.bin/pkg-pr-new publish --pnpm ./target/pkg/{{CRATE}} 17 | push-example EXAMPLE: 18 | git push {{ EXAMPLE }} `git subtree split --prefix=examples/{{EXAMPLE}}/out master`:main --force 19 | export-libtorch: # Install libtorch 20 | export LIBTORCH=$(python3 -c 'import torch; from pathlib import Path; print(Path(torch.__file__).parent)') 21 | export DYLD_LIBRARY_PATH=${LIBTORCH}/lib 22 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ratchet-repo", 3 | "version": "0.0.0", 4 | "packageManager": "pnpm@8.15.8", 5 | "private": true, 6 | "devDependencies": { 7 | "pkg-pr-new": "0.0.15", 8 | "wasm-pack": "0.12.1" 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /pnpm-workspace.yaml: -------------------------------------------------------------------------------- 1 | packages: 2 | - "examples/ratchet-whisper" 3 | - "examples/ratchet-phi" 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | numpy==1.24.3 3 | torch==2.3.0 4 | requests==2.26.0 5 | mlx==0.9.1; sys_platform == 'darwin' 6 | git+https://github.com/FL33TW00D/whisper.git@feature/reference#egg=openai-whisper 7 | gguf==0.6.0 8 | -------------------------------------------------------------------------------- /rust-toolchain.toml: -------------------------------------------------------------------------------- 1 | [toolchain] 2 | channel = "nightly" 3 | -------------------------------------------------------------------------------- /scripts/phi3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import Phi3ForCausalLM, AutoTokenizer 3 | 4 | def ground(): 5 | model = Phi3ForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True) 6 | print("Model: ", model) 7 | model.eval() 8 | tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True) 9 | inputs = tokenizer("""<|user|> 10 | How to explain Internet for a medieval knight?<|end|> 11 | <|assistant|>""", return_tensors="pt", return_attention_mask=False) 12 | outputs = model.generate(**inputs, max_length=500, return_dict_in_generate=True, output_logits=True) 13 | tokens = outputs[0] 14 | print("Tokens: ", tokens) 15 | print("Text: ", tokenizer.decode(tokens[0], skip_special_tokens=True)) 16 | 17 | def hooked(): 18 | model = Phi3ForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct", torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True) 19 | model.eval() 20 | 21 | first_layer_output = None 22 | def hook(module, input, output): 23 | nonlocal first_layer_output 24 | first_layer_output = output 25 | 26 | # Register the forward hook on the first decoder layer 27 | model.model.layers[0].register_forward_hook(hook) 28 | 29 | tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True) 30 | inputs = tokenizer('{}', return_tensors="pt", return_attention_mask=False) 31 | print("PROMPT TOKENS:", inputs["input_ids"]) 32 | logits = model(**inputs).logits 33 | print("FIRST LAYER OUTPUT: ", first_layer_output) 34 | return [first_layer_output[0].detach().numpy()] 35 | 36 | 37 | ground() 38 | -------------------------------------------------------------------------------- /scripts/understanding_matmul.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # Understanding argument reordering for a matmul 4 | # nn.Linear in PyTorch is defined as: 5 | # y = x @ W.t() + b 6 | # Weights in a GGUF are stored as (out_features, in_features) 7 | # 8 | # Argument reordering 9 | # In order to have fast memory access patterns, it can sometimes be prudent to reorder the arguments of a matmul 10 | # Particularly in the case of a vector-matrix multiplication. 11 | # e.g [1, 2560] @ [10240, 2560].t() -> [1, 10240] 12 | # If everything is stored in row-major order, the above matmul will have poor memory access patterns. 13 | # However, we can swap the arguments. 14 | # [10240, 2560] @ [1, 2560].t() -> [10240, 1] 15 | # This will have good access patterns on BOTH A & B. 16 | W = np.random.rand(10240, 2560) # 17 | X = np.random.rand(2, 2560) # 18 | 19 | WT = np.ascontiguousarray(np.transpose(W, (1, 0))) 20 | 21 | Y = X @ WT 22 | print("Standard case: y = xWT + b") 23 | print(f"{X.shape} @ {WT.shape} = {Y.shape}\n") 24 | 25 | XT = np.ascontiguousarray(np.transpose(X, (1, 0))) 26 | 27 | ZT = W @ XT 28 | print("Reordered case: zT = WxT + b") 29 | print(f"{W.shape} @ {XT.shape} = {ZT.shape}\n") 30 | 31 | Z = np.ascontiguousarray(np.transpose(ZT, (1, 0))) 32 | 33 | #check if Y and Z are the same 34 | print("Are results the same: ", np.allclose(Y, Z)) 35 | 36 | 37 | print("By performing the reordered case, we can avoid transposing W, which is not feasible for quantized W.") 38 | --------------------------------------------------------------------------------