├── .gitattributes ├── .github └── workflows │ ├── integration_tests.yml │ ├── release.yml │ └── rust.yml ├── .gitignore ├── .gitmodules ├── .rusty-hook.toml ├── .vscode ├── launch.json └── settings.json ├── CHANGELOG.md ├── Cargo.lock ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── binaries ├── generate-ggml-bindings │ ├── Cargo.toml │ └── src │ │ └── main.rs ├── llm-cli │ ├── Cargo.toml │ └── src │ │ ├── cli_args.rs │ │ ├── interactive.rs │ │ ├── main.rs │ │ ├── snapshot.rs │ │ └── util.rs ├── llm-test │ ├── Cargo.toml │ ├── configs │ │ ├── bloom.json │ │ ├── gptj.json │ │ ├── gptneox.json │ │ ├── llama.json │ │ └── mpt.json │ └── src │ │ ├── common.rs │ │ ├── delete.rs │ │ ├── inference.rs │ │ ├── main.rs │ │ └── tokens.rs └── precommit-check │ ├── Cargo.toml │ ├── README.md │ └── src │ └── main.rs ├── crates ├── ggml │ ├── Cargo.toml │ ├── README.md │ ├── src │ │ ├── accelerator │ │ │ ├── metal.rs │ │ │ └── mod.rs │ │ ├── context.rs │ │ ├── format │ │ │ ├── loader.rs │ │ │ ├── mod.rs │ │ │ └── saver.rs │ │ ├── lib.rs │ │ ├── tensor.rs │ │ ├── tests.rs │ │ └── util.rs │ └── sys │ │ ├── Cargo.toml │ │ ├── UPDATING.md │ │ ├── build.rs │ │ └── src │ │ ├── cuda.rs │ │ ├── lib.rs │ │ ├── llama.rs │ │ ├── metal.rs │ │ └── opencl.rs ├── llm-base │ ├── Cargo.toml │ └── src │ │ ├── inference_session.rs │ │ ├── lib.rs │ │ ├── loader.rs │ │ ├── lora.rs │ │ ├── model │ │ ├── common.rs │ │ └── mod.rs │ │ ├── quantize.rs │ │ ├── samplers.rs │ │ ├── tokenizer │ │ ├── embedded.rs │ │ ├── huggingface.rs │ │ └── mod.rs │ │ └── util.rs ├── llm │ ├── Cargo.toml │ ├── examples │ │ ├── embeddings.rs │ │ ├── inference.rs │ │ └── vicuna-chat.rs │ └── src │ │ └── lib.rs └── models │ ├── bloom │ ├── Cargo.toml │ └── src │ │ └── lib.rs │ ├── falcon │ ├── Cargo.toml │ └── src │ │ └── lib.rs │ ├── gpt2 │ ├── Cargo.toml │ └── src │ │ └── lib.rs │ ├── gptj │ ├── Cargo.toml │ └── src │ │ └── lib.rs │ ├── gptneox │ ├── .gitignore │ ├── Cargo.toml │ └── src │ │ └── lib.rs │ ├── llama │ ├── Cargo.toml │ └── src │ │ └── lib.rs │ └── mpt │ ├── Cargo.toml │ └── src │ └── lib.rs ├── doc ├── CONTRIBUTING.md ├── acceleration-support.md ├── img │ └── llm-crab-llama.png └── known-good-models.md ├── flake.lock ├── flake.nix └── utils ├── Dockerfile └── prompts ├── alpaca.txt ├── pygmalion-message.txt ├── pygmalion-prelude.txt ├── vicuna-message.txt └── vicuna-prelude.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | utils/prompts/*.txt text eol=lf -------------------------------------------------------------------------------- /.github/workflows/integration_tests.yml: -------------------------------------------------------------------------------- 1 | name: Integration Tests 2 | 3 | permissions: 4 | contents: write 5 | 6 | on: 7 | push: 8 | branches: ["main"] 9 | pull_request: 10 | branches: ["main"] 11 | workflow_dispatch: 12 | 13 | env: 14 | CARGO_TERM_COLOR: always 15 | 16 | jobs: 17 | test: 18 | strategy: 19 | # Don't stop testing if an architecture fails 20 | fail-fast: false 21 | matrix: 22 | model: [llama, gptneox, gptj, mpt, bloom] 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v3 26 | with: 27 | submodules: recursive 28 | - uses: actions-rs/toolchain@v1 29 | with: 30 | toolchain: 1.65.0 31 | override: true 32 | - name: Install dependencies 33 | run: | 34 | sudo apt-get update 35 | sudo apt-get install -y \ 36 | libssl-dev \ 37 | pkg-config \ 38 | zlib1g-dev 39 | - name: Run Integration Tests for ${{ matrix.model }} 40 | run: cargo run --release -p llm-test -- ${{ matrix.model }} 41 | # Upload test results 42 | - uses: actions/upload-artifact@v3 43 | if: always() 44 | with: 45 | name: test-reports 46 | path: ./.tests/results/*.json 47 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # CI that: 2 | # 3 | # * checks for a Git Tag that looks like a release 4 | # * creates a Github Release™ and fills in its text 5 | # * builds artifacts with cargo-dist (executable-zips, installers) 6 | # * uploads those artifacts to the Github Release™ 7 | # 8 | # Note that the Github Release™ will be created before the artifacts, 9 | # so there will be a few minutes where the release has no artifacts 10 | # and then they will slowly trickle in, possibly failing. To make 11 | # this more pleasant we mark the release as a "draft" until all 12 | # artifacts have been successfully uploaded. This allows you to 13 | # choose what to do with partial successes and avoids spamming 14 | # anyone with notifications before the release is actually ready. 15 | name: Release 16 | 17 | permissions: 18 | contents: write 19 | 20 | # This task will run whenever you push a git tag that looks like a version 21 | # like "v1", "v1.2.0", "v0.1.0-prerelease01", "my-app-v1.0.0", etc. 22 | # The version will be roughly parsed as ({PACKAGE_NAME}-)?v{VERSION}, where 23 | # PACKAGE_NAME must be the name of a Cargo package in your workspace, and VERSION 24 | # must be a Cargo-style SemVer Version. 25 | # 26 | # If PACKAGE_NAME is specified, then we will create a Github Release™ for that 27 | # package (erroring out if it doesn't have the given version or isn't cargo-dist-able). 28 | # 29 | # If PACKAGE_NAME isn't specified, then we will create a Github Release™ for all 30 | # (cargo-dist-able) packages in the workspace with that version (this is mode is 31 | # intended for workspaces with only one dist-able package, or with all dist-able 32 | # packages versioned/released in lockstep). 33 | # 34 | # If you push multiple tags at once, separate instances of this workflow will 35 | # spin up, creating an independent Github Release™ for each one. 36 | # 37 | # If there's a prerelease-style suffix to the version then the Github Release™ 38 | # will be marked as a prerelease. 39 | on: 40 | push: 41 | tags: 42 | - '*-?v[0-9]+*' 43 | 44 | jobs: 45 | # Create the Github Release™ so the packages have something to be uploaded to 46 | create-release: 47 | runs-on: ubuntu-latest 48 | outputs: 49 | has-releases: ${{ steps.create-release.outputs.has-releases }} 50 | env: 51 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 52 | steps: 53 | - uses: actions/checkout@v3 54 | with: 55 | submodules: recursive 56 | - name: Install Rust 57 | run: rustup update 1.67.1 --no-self-update && rustup default 1.67.1 58 | - name: Install cargo-dist 59 | run: curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.0.6-prerelease.8/cargo-dist-installer.sh | sh 60 | - id: create-release 61 | run: | 62 | cargo dist plan --tag=${{ github.ref_name }} --output-format=json > dist-manifest.json 63 | echo "dist plan ran successfully" 64 | cat dist-manifest.json 65 | 66 | # Create the Github Release™ based on what cargo-dist thinks it should be 67 | ANNOUNCEMENT_TITLE=$(jq --raw-output ".announcement_title" dist-manifest.json) 68 | IS_PRERELEASE=$(jq --raw-output ".announcement_is_prerelease" dist-manifest.json) 69 | jq --raw-output ".announcement_github_body" dist-manifest.json > new_dist_announcement.md 70 | gh release create ${{ github.ref_name }} --draft --prerelease="$IS_PRERELEASE" --title="$ANNOUNCEMENT_TITLE" --notes-file=new_dist_announcement.md 71 | echo "created announcement!" 72 | 73 | # Upload the manifest to the Github Release™ 74 | gh release upload ${{ github.ref_name }} dist-manifest.json 75 | echo "uploaded manifest!" 76 | 77 | # Disable all the upload-artifacts tasks if we have no actual releases 78 | HAS_RELEASES=$(jq --raw-output ".releases != null" dist-manifest.json) 79 | echo "has-releases=$HAS_RELEASES" >> "$GITHUB_OUTPUT" 80 | 81 | # Build and packages all the things 82 | upload-artifacts: 83 | # Let the initial task tell us to not run (currently very blunt) 84 | needs: create-release 85 | if: ${{ needs.create-release.outputs.has-releases == 'true' }} 86 | strategy: 87 | matrix: 88 | # For these target platforms 89 | include: 90 | - os: ubuntu-20.04 91 | dist-args: --artifacts=global 92 | install-dist: curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.0.6-prerelease.8/cargo-dist-installer.sh | sh 93 | - os: macos-11 94 | dist-args: --artifacts=local --target=aarch64-apple-darwin --target=x86_64-apple-darwin 95 | install-dist: curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.0.6-prerelease.8/cargo-dist-installer.sh | sh 96 | - os: ubuntu-20.04 97 | dist-args: --artifacts=local --target=x86_64-unknown-linux-gnu 98 | install-dist: curl --proto '=https' --tlsv1.2 -LsSf https://github.com/axodotdev/cargo-dist/releases/download/v0.0.6-prerelease.8/cargo-dist-installer.sh | sh 99 | - os: windows-2019 100 | dist-args: --artifacts=local --target=x86_64-pc-windows-msvc 101 | install-dist: irm https://github.com/axodotdev/cargo-dist/releases/download/v0.0.6-prerelease.8/cargo-dist-installer.ps1 | iex 102 | 103 | runs-on: ${{ matrix.os }} 104 | env: 105 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 106 | steps: 107 | - uses: actions/checkout@v3 108 | with: 109 | submodules: recursive 110 | - name: Install Rust 111 | run: rustup update 1.67.1 --no-self-update && rustup default 1.67.1 112 | - name: Install cargo-dist 113 | run: ${{ matrix.install-dist }} 114 | - name: Run cargo-dist 115 | # This logic is a bit janky because it's trying to be a polyglot between 116 | # powershell and bash since this will run on windows, macos, and linux! 117 | # The two platforms don't agree on how to talk about env vars but they 118 | # do agree on 'cat' and '$()' so we use that to marshal values between commands. 119 | run: | 120 | # Actually do builds and make zips and whatnot 121 | cargo dist build --tag=${{ github.ref_name }} --output-format=json ${{ matrix.dist-args }} > dist-manifest.json 122 | echo "dist ran successfully" 123 | cat dist-manifest.json 124 | 125 | # Parse out what we just built and upload it to the Github Release™ 126 | jq --raw-output ".artifacts[]?.path | select( . != null )" dist-manifest.json > uploads.txt 127 | echo "uploading..." 128 | cat uploads.txt 129 | gh release upload ${{ github.ref_name }} $(cat uploads.txt) 130 | echo "uploaded!" 131 | 132 | # Mark the Github Release™ as a non-draft now that everything has succeeded! 133 | publish-release: 134 | # Only run after all the other tasks, but it's ok if upload-artifacts was skipped 135 | needs: [create-release, upload-artifacts] 136 | if: ${{ always() && needs.create-release.result == 'success' && (needs.upload-artifacts.result == 'skipped' || needs.upload-artifacts.result == 'success') }} 137 | runs-on: ubuntu-latest 138 | env: 139 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 140 | steps: 141 | - uses: actions/checkout@v3 142 | with: 143 | submodules: recursive 144 | - name: mark release as non-draft 145 | run: | 146 | gh release edit ${{ github.ref_name }} --draft=false 147 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | build: 14 | strategy: 15 | # Don't stop building if it fails on an OS 16 | fail-fast: false 17 | matrix: 18 | os: [windows-latest, ubuntu-latest, macos-latest] 19 | runs-on: ${{ matrix.os }} 20 | steps: 21 | - uses: actions/checkout@v3 22 | with: 23 | submodules: recursive 24 | - uses: actions-rs/toolchain@v1 25 | with: 26 | toolchain: 1.65.0 27 | override: true 28 | - name: Check 29 | run: cargo check --verbose 30 | - name: Build 31 | run: cargo build --verbose 32 | - name: Run tests 33 | run: cargo test --all --verbose 34 | fmt: 35 | name: Clippy, formatting and docs 36 | runs-on: ubuntu-latest 37 | steps: 38 | - uses: actions/checkout@v3 39 | with: 40 | submodules: recursive 41 | - uses: actions-rs/toolchain@v1 42 | with: 43 | toolchain: stable 44 | components: rustfmt, clippy 45 | - name: Formatting 46 | run: cargo fmt --all -- --check 47 | - name: Clippy 48 | run: cargo clippy --workspace -- -Dclippy::all 49 | - name: Documentation 50 | env: 51 | RUSTDOCFLAGS: -Dwarnings 52 | run: cargo doc --workspace --exclude llm-cli 53 | 54 | metal: 55 | name: Build with Metal support 56 | runs-on: macos-latest 57 | steps: 58 | - uses: actions/checkout@v3 59 | with: 60 | submodules: recursive 61 | - uses: actions-rs/toolchain@v1 62 | with: 63 | toolchain: 1.65.0 64 | override: true 65 | - name: Check 66 | run: cargo check --verbose 67 | - name: Build 68 | run: cargo build --verbose --features metal 69 | 70 | cuda: 71 | name: Build with CUDA support 72 | strategy: 73 | # Don't stop building if it fails on an OS 74 | fail-fast: false 75 | matrix: 76 | os: [windows-latest, ubuntu-latest] 77 | runs-on: ${{ matrix.os }} 78 | steps: 79 | - uses: actions/checkout@v3 80 | with: 81 | submodules: recursive 82 | - uses: Jimver/cuda-toolkit@v0.2.11 83 | name: Install CUDA toolkit on Linux 84 | if: matrix.os == 'ubuntu-latest' 85 | id: cuda-toolkit-linux 86 | with: 87 | cuda: "12.2.0" 88 | method: "network" 89 | #See e.g. https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ 90 | non-cuda-sub-packages: '["libcublas","libcublas-dev"]' 91 | sub-packages: '["nvcc","compiler","libraries","libraries-dev","cudart","cudart-dev"]' 92 | 93 | - uses: Jimver/cuda-toolkit@v0.2.11 94 | name: Install CUDA toolkit on Windows 95 | if: matrix.os == 'windows-latest' 96 | id: cuda-toolkit-windows 97 | with: 98 | cuda: "12.2.0" 99 | #See https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html#install-the-cuda-software 100 | method: "local" 101 | - uses: actions-rs/toolchain@v1 102 | with: 103 | toolchain: 1.65.0 104 | override: true 105 | - name: Check 106 | run: cargo check --verbose 107 | - name: Build 108 | run: cargo build --verbose --features cublas 109 | 110 | opencl: 111 | name: Build with OpenCL support 112 | strategy: 113 | # Don't stop building if it fails on an OS 114 | fail-fast: false 115 | matrix: 116 | os: [windows-latest, ubuntu-latest] 117 | runs-on: ${{ matrix.os }} 118 | steps: 119 | - uses: actions/checkout@v3 120 | with: 121 | submodules: recursive 122 | 123 | - name: Install CLBlast on linux 124 | if: matrix.os == 'ubuntu-latest' 125 | run: sudo apt install libclblast-dev 126 | 127 | - name: Install vcpkg on windows 128 | if: matrix.os == 'windows-latest' 129 | run: | 130 | git clone https://github.com/microsoft/vcpkg.git 131 | cd vcpkg 132 | ./bootstrap-vcpkg.sh 133 | ls -la 134 | shell: bash 135 | 136 | - name: Install OpenCL on windows 137 | if: matrix.os == 'windows-latest' 138 | run: | 139 | ${{ github.workspace }}\vcpkg\vcpkg.exe install opencl:x64-windows 140 | shell: pwsh 141 | 142 | - name: Install CLBlast on windows 143 | if: matrix.os == 'windows-latest' 144 | run: | 145 | ${{ github.workspace }}\vcpkg\vcpkg.exe install clblast:x64-windows 146 | shell: pwsh 147 | 148 | - name: Set Windows Environment Variables 149 | if: matrix.os == 'windows-latest' 150 | run: | 151 | echo "CLBLAST_PATH=${{ github.workspace }}/vcpkg/packages/clblast_x64-windows" >> $GITHUB_ENV 152 | echo "OPENCL_PATH=${{ github.workspace }}/vcpkg/packages/opencl_x64-windows" >> $GITHUB_ENV 153 | echo "${{ github.workspace }}/vcpkg/packages/clblast_x64-windows/bin" >> $GITHUB_PATH 154 | echo "${{ github.workspace }}/vcpkg/packages/opencl_x64-windows/bin" >> $GITHUB_PATH 155 | shell: bash 156 | 157 | - uses: actions-rs/toolchain@v1 158 | with: 159 | toolchain: 1.65.0 160 | override: true 161 | - name: Check 162 | run: cargo check --verbose 163 | 164 | - name: Build with OpenCL on Windows 165 | if: matrix.os == 'windows-latest' 166 | run: cargo build --verbose --features clblast 167 | env: 168 | RUSTFLAGS: "-Ctarget-feature=+crt-static" 169 | 170 | - name: Build with OpenCL on Linux 171 | if: matrix.os == 'ubuntu-latest' 172 | run: cargo build --verbose --features clblast 173 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /models 3 | .DS_Store 4 | /.tests -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "crates/ggml/sys/llama-cpp"] 2 | path = crates/ggml/sys/llama-cpp 3 | url = https://github.com/ggerganov/llama.cpp 4 | -------------------------------------------------------------------------------- /.rusty-hook.toml: -------------------------------------------------------------------------------- 1 | [hooks] 2 | pre-commit = "cargo run -p precommit-check" 3 | 4 | [logging] 5 | verbose = true 6 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "type": "lldb", 9 | "request": "launch", 10 | "name": "Debug BLOOM Inference", 11 | "cargo": { 12 | "args": ["build", "--example=inference", "--package=llm"], 13 | "filter": { 14 | "name": "inference", 15 | "kind": "example" 16 | } 17 | }, 18 | "args": ["bloom", "${env:HOME}/.ggml-models/bloom-7b.bin"], 19 | "cwd": "${workspaceFolder}" 20 | }, 21 | { 22 | "type": "lldb", 23 | "request": "launch", 24 | "name": "Debug GPT-2 Inference", 25 | "cargo": { 26 | "args": ["build", "--example=inference", "--package=llm"], 27 | "filter": { 28 | "name": "inference", 29 | "kind": "example" 30 | } 31 | }, 32 | "args": ["gpt2", "${env:HOME}/.ggml-models/cerebras-gpt-13b.bin"], 33 | "cwd": "${workspaceFolder}" 34 | }, 35 | { 36 | "type": "lldb", 37 | "request": "launch", 38 | "name": "Debug GPT-J Inference", 39 | "cargo": { 40 | "args": [ 41 | "build", 42 | "--example=inference", 43 | "--package=llm" 44 | ], 45 | "filter": { 46 | "name": "inference", 47 | "kind": "example" 48 | } 49 | }, 50 | "args": ["gptj", "${env:HOME}/.ggml-models/gpt-j-6b.bin"], 51 | "cwd": "${workspaceFolder}" 52 | }, 53 | { 54 | "type": "lldb", 55 | "request": "launch", 56 | "name": "Debug LLaMA Inference", 57 | "cargo": { 58 | "args": ["build", "--example=inference", "--package=llm"], 59 | "filter": { 60 | "name": "inference", 61 | "kind": "example" 62 | } 63 | }, 64 | "args": ["llama", "${env:HOME}/.ggml-models/gpt4all-7b.bin"], 65 | "cwd": "${workspaceFolder}" 66 | }, 67 | { 68 | "type": "lldb", 69 | "request": "launch", 70 | "name": "Debug MPT Inference", 71 | "cargo": { 72 | "args": ["build", "--example=inference", "--package=llm"], 73 | "filter": { 74 | "name": "inference", 75 | "kind": "example" 76 | } 77 | }, 78 | "args": ["mpt", "${env:HOME}/.ggml-models/mpt-7b.bin"], 79 | "cwd": "${workspaceFolder}" 80 | }, 81 | { 82 | "type": "lldb", 83 | "request": "launch", 84 | "name": "Debug GPT-NeoX Inference", 85 | "cargo": { 86 | "args": ["build", "--example=inference", "--package=llm"], 87 | "filter": { 88 | "name": "inference", 89 | "kind": "example" 90 | } 91 | }, 92 | "args": ["gptneox", "${env:HOME}/.ggml-models/stablelm-base-alpha-3b.bin"], 93 | "cwd": "${workspaceFolder}" 94 | }, 95 | { 96 | "type": "lldb", 97 | "request": "launch", 98 | "name": "Debug RedPajama Inference", 99 | "cargo": { 100 | "args": ["build", "--example=inference", "--package=llm"], 101 | "filter": { 102 | "name": "inference", 103 | "kind": "example" 104 | } 105 | }, 106 | "args": ["redpajama", "${env:HOME}/.ggml-models/redpajama-incite-7b.bin"], 107 | "cwd": "${workspaceFolder}" 108 | }, 109 | { 110 | "type": "lldb", 111 | "request": "launch", 112 | "name": "Debug Vicuna Chat", 113 | "cargo": { 114 | "args": ["build", "--example=vicuna-chat", "--package=llm"], 115 | "filter": { 116 | "name": "vicuna-chat", 117 | "kind": "example" 118 | } 119 | }, 120 | "args": ["llama", "${env:HOME}/.ggml-models/wizardlm-7b.bin"], 121 | "cwd": "${workspaceFolder}" 122 | } 123 | ] 124 | } 125 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "rust-analyzer.cargo.features": [] 3 | } 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | `llm` is actively being iterated upon, so there will be breaking changes to the interface and to compatibility. Where possible, we will try to find ways to mitigate the breaking changes, but we do not expect to have a stable interface for some time. 4 | 5 | # 0.2.0-dev (unreleased) 6 | 7 | - `llm` now uses the latest GGML version. This limits use to older unquantized models or to models quantized with the latest version (quantization version 2, file format GGJTv3). We are investigating ways to [mitigate this breakage in the future](https://github.com/rustformers/llm/discussions/261). 8 | - `llm::InferenceRequest` no longer implements `Default::default`. 9 | - The `infer` callback now provides an `InferenceResponse` instead of a string to disambiguate the source of the token. Additionally, it now returns an `InferenceFeedback` to control whether or not the generation should continue. 10 | - Several fields have been renamed: 11 | - `n_context_tokens` -> `context_size` 12 | 13 | # 0.1.1 (2023-05-08) 14 | 15 | - Fix an issue with the binary build of `llm-cli`. 16 | 17 | # 0.1.0 (2023-05-08) 18 | 19 | Initial release. 20 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | members = [ 3 | # Crates 4 | "crates/ggml", 5 | "crates/ggml/sys", 6 | "crates/llm", 7 | "crates/llm-base", 8 | "crates/models/*", 9 | "binaries/*", 10 | ] 11 | resolver = "2" 12 | default-members = ["binaries/llm-cli", "crates/llm"] 13 | 14 | [workspace.package] 15 | repository = "https://github.com/rustformers/llm" 16 | license = "MIT OR Apache-2.0" 17 | 18 | [workspace.dependencies] 19 | bytemuck = "1.13.1" 20 | bytesize = "1.1" 21 | env_logger = "0.10.0" 22 | log = "0.4" 23 | rand = "0.8.5" 24 | thiserror = "1.0" 25 | anyhow = "1.0" 26 | 27 | rustyline = { version = "11.0.0", features = ["derive"] } 28 | serde = { version = "1.0", features = ["derive"] } 29 | serde_json = { version = "1.0" } 30 | spinoff = { version = "0.8.0", default-features = false, features = ["dots2"] } 31 | clap = { version = "4.1.8", features = ["derive"] } 32 | memmap2 = "0.5.10" 33 | tracing-subscriber = { version = "0.3", features = ["env-filter"] } 34 | tracing = { version = "0.1", features = ["log"] } 35 | llm-samplers = "=0.0.7" 36 | 37 | # Config for 'cargo dist' 38 | [workspace.metadata.dist] 39 | # The preferred cargo-dist version to use in CI (Cargo.toml SemVer syntax) 40 | cargo-dist-version = "0.0.6-prerelease.8" 41 | # The preferred Rust toolchain to use in CI (rustup toolchain syntax) 42 | rust-toolchain-version = "1.67.1" 43 | # CI backends to support (see 'cargo dist generate-ci') 44 | ci = ["github"] 45 | # The installers to generate for each app 46 | installers = ["shell", "powershell"] 47 | # Target platforms to build apps for (Rust target-triple syntax) 48 | targets = [ 49 | "x86_64-unknown-linux-gnu", 50 | "x86_64-apple-darwin", 51 | "x86_64-pc-windows-msvc", 52 | "aarch64-apple-darwin", 53 | ] 54 | 55 | # The profile that 'cargo dist' will build with 56 | [profile.dist] 57 | inherits = "release" 58 | lto = "thin" 59 | 60 | [workspace.metadata.release] 61 | tag-prefix = "" 62 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 The llm Authors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /binaries/generate-ggml-bindings/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "generate-ggml-bindings" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [package.metadata.release] 8 | release = false 9 | 10 | [dependencies] 11 | bindgen = "0.65.1" 12 | -------------------------------------------------------------------------------- /binaries/generate-ggml-bindings/src/main.rs: -------------------------------------------------------------------------------- 1 | //! Helper tool to generate the bindings for the ggml crate. 2 | //! 3 | //! Assumed to be run from the root of the workspace. 4 | 5 | use std::{ 6 | fs, 7 | path::{Path, PathBuf}, 8 | }; 9 | 10 | fn main() { 11 | let sys_path = PathBuf::from("crates").join("ggml").join("sys"); 12 | let ggml_path = sys_path.join("llama-cpp"); 13 | let src_path = sys_path.join("src"); 14 | 15 | generate_main(&ggml_path, &src_path); 16 | generate_cuda(&ggml_path, &src_path); 17 | generate_opencl(&ggml_path, &src_path); 18 | generate_metal(&ggml_path, &src_path); 19 | generate_llama(&ggml_path, &src_path); 20 | 21 | println!("Successfully updated bindings"); 22 | } 23 | 24 | fn generate_main(ggml_path: &Path, src_path: &Path) { 25 | let bindings = bindgen::Builder::default() 26 | .header(ggml_path.join("ggml.h").to_str().unwrap().to_string()) 27 | .allowlist_file(r".*ggml.h") 28 | .header(ggml_path.join("k_quants.h").to_string_lossy()) 29 | .allowlist_file(r".*k_quants.h") 30 | // Suppress some warnings 31 | .raw_line("#![allow(non_upper_case_globals)]") 32 | .raw_line("#![allow(non_camel_case_types)]") 33 | .raw_line("#![allow(non_snake_case)]") 34 | .raw_line("#![allow(unused)]") 35 | .raw_line("pub mod llama;") 36 | .raw_line("") 37 | .raw_line(r#"#[cfg(feature = "cublas")]"#) 38 | .raw_line("pub mod cuda;") 39 | .raw_line(r#"#[cfg(feature = "metal")]"#) 40 | .raw_line("pub mod metal;") 41 | .raw_line(r#"#[cfg(feature = "clblast")]"#) 42 | .raw_line("pub mod opencl;") 43 | // Only generate code if it's from GGML 44 | .allowlist_file("crates/ggml/.*") 45 | .generate() 46 | .expect("Unable to generate bindings"); 47 | 48 | let mut generated_bindings = bindings.to_string(); 49 | if cfg!(windows) { 50 | // windows generates all ::std::os::raw::c_* enum types as i32. 51 | // We need to replace some of them with c_uint as the rust bindings expect them to be unsigned. 52 | // Temporary hack until bindgen supports defining the enum types manually. See https://github.com/rust-lang/rust-bindgen/issues/1907 53 | for name in &[ 54 | "type", 55 | "backend", 56 | "op", 57 | "linesearch", 58 | "opt_type", 59 | "task_type", 60 | ] { 61 | generated_bindings = generated_bindings.replace( 62 | &format!("ggml_{name} = ::std::os::raw::c_int;"), 63 | &format!("ggml_{name} = ::std::os::raw::c_uint;"), 64 | ); 65 | } 66 | } 67 | fs::write(src_path.join("lib.rs"), generated_bindings).expect("Couldn't write bindings"); 68 | } 69 | 70 | fn generate_cuda(ggml_path: &Path, src_path: &Path) { 71 | generate_extra("cuda", ggml_path, src_path, |b| { 72 | b.header(ggml_path.join("ggml-cuda.h").to_string_lossy()) 73 | .allowlist_file(r".*ggml-cuda\.h") 74 | .raw_line("use super::ggml_compute_params;") 75 | .raw_line("use super::ggml_tensor;") 76 | }) 77 | } 78 | 79 | fn generate_opencl(ggml_path: &Path, src_path: &Path) { 80 | generate_extra("opencl", ggml_path, src_path, |b| { 81 | b.header(ggml_path.join("ggml-opencl.h").to_string_lossy()) 82 | .allowlist_file(r".*ggml-opencl\.h") 83 | .raw_line("use super::ggml_tensor;") 84 | }) 85 | } 86 | 87 | fn generate_metal(ggml_path: &Path, src_path: &Path) { 88 | generate_extra("metal", ggml_path, src_path, |b| { 89 | b.header(ggml_path.join("ggml-metal.h").to_string_lossy()) 90 | .allowlist_file(r".*ggml-metal\.h") 91 | }); 92 | } 93 | 94 | fn generate_llama(ggml_path: &Path, src_path: &Path) { 95 | // We do not use `llama.cpp` for its implementation at all; 96 | // we only use it for its header file and its associated constants. 97 | generate_extra("llama", ggml_path, src_path, |b| { 98 | b.header(ggml_path.join("llama.h").to_string_lossy()) 99 | .allowlist_type("llama_ftype") 100 | .allowlist_var("LLAMA_.*") 101 | .prepend_enum_name(false) 102 | .ignore_functions() 103 | }); 104 | } 105 | 106 | fn generate_extra( 107 | name: &str, 108 | ggml_path: &Path, 109 | src_path: &Path, 110 | mut callback: impl FnMut(bindgen::Builder) -> bindgen::Builder, 111 | ) { 112 | let builder = callback( 113 | bindgen::Builder::default() 114 | .allowlist_recursively(false) 115 | .clang_arg("-I") 116 | .clang_arg(ggml_path.to_string_lossy()), 117 | ); 118 | 119 | builder 120 | .generate() 121 | .unwrap_or_else(|_| panic!("Unable to generate {name} bindings")) 122 | .write_to_file(src_path.join(format!("{name}.rs"))) 123 | .unwrap_or_else(|_| panic!("Couldn't write {name} bindings")); 124 | } 125 | -------------------------------------------------------------------------------- /binaries/llm-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2021" 3 | name = "llm-cli" 4 | version = "0.2.0-dev" 5 | repository = { workspace = true } 6 | license = { workspace = true } 7 | description = "A CLI for running inference on supported Large Language Models. Powered by the `llm` library." 8 | readme = "../../README.md" 9 | 10 | [[bin]] 11 | name = "llm" 12 | path = "src/main.rs" 13 | 14 | [dependencies] 15 | llm = { path = "../../crates/llm", version = "0.2.0-dev", default-features = false, features = ["models"] } 16 | 17 | bytesize = { workspace = true } 18 | env_logger = { workspace = true } 19 | log = { workspace = true } 20 | rand = { workspace = true } 21 | rustyline = { workspace = true } 22 | spinoff = { workspace = true } 23 | clap = { workspace = true } 24 | 25 | bincode = "1.3.3" 26 | num_cpus = "1.15.0" 27 | 28 | color-eyre = { version = "0.6.2", default-features = false } 29 | zstd = { version = "0.12", default-features = false } 30 | tracing-subscriber = {workspace = true } 31 | tracing = { workspace = true} 32 | tracing-appender = "0.2.2" 33 | 34 | # TEMPORARY: This was introduced in Rust 1.70, but our MSRV is below this. 35 | # Remove this once we bump our MSRV to 1.70. 36 | is-terminal = "0.4" 37 | 38 | llm-samplers = { workspace = true } 39 | 40 | [dev-dependencies] 41 | rusty-hook = "^0.11.2" 42 | 43 | [features] 44 | default = ["tokenizers-remote"] 45 | 46 | tokenizers-remote = ["llm/tokenizers-remote"] 47 | cublas = ["llm/cublas"] 48 | clblast = ["llm/clblast"] 49 | metal = ["llm/metal"] 50 | 51 | # Falcon is off by default. See `llm_falcon`'s module documentation for more information. 52 | falcon = ["llm/falcon"] 53 | -------------------------------------------------------------------------------- /binaries/llm-cli/src/interactive.rs: -------------------------------------------------------------------------------- 1 | use std::convert::Infallible; 2 | 3 | use color_eyre::eyre; 4 | use rustyline::{ 5 | error::ReadlineError, 6 | history::DefaultHistory, 7 | validate::{ValidationContext, ValidationResult, Validator}, 8 | Cmd, Completer, Helper, Highlighter, Hinter, KeyCode, KeyEvent, Modifiers, 9 | }; 10 | 11 | use crate::{ 12 | cli_args::{Chat, Repl}, 13 | snapshot, util, 14 | }; 15 | 16 | pub fn repl( 17 | Repl { 18 | generate, 19 | model_load, 20 | prompt_file, 21 | }: &Repl, 22 | ) -> eyre::Result<()> { 23 | let (inference_session_config, parameters, model, mut rng) = 24 | initialize_common_state(generate, model_load)?; 25 | 26 | let template = prompt_file.contents()?; 27 | 28 | let model = model.as_ref(); 29 | let mut session = create_session(model, inference_session_config); 30 | readline_loop(|raw_line| { 31 | let line = raw_line.replace("\\\n", "\n"); 32 | 33 | let prompt = template 34 | .as_deref() 35 | .map(|template| util::process_prompt(template, &line)) 36 | .unwrap_or(line); 37 | feed_prompt_with_spinner(model, &mut session, prompt)?; 38 | 39 | session.infer::( 40 | model, 41 | &mut rng, 42 | &llm::InferenceRequest { 43 | prompt: "".into(), 44 | parameters: ¶meters, 45 | play_back_previous_tokens: false, 46 | maximum_token_count: generate.num_predict, 47 | }, 48 | &mut Default::default(), 49 | |r| { 50 | if let llm::InferenceResponse::InferredToken(t) = r { 51 | util::print_token(t); 52 | } 53 | Ok(llm::InferenceFeedback::Continue) 54 | }, 55 | )?; 56 | 57 | if !session_ends_with_newline(&session) { 58 | println!(); 59 | } 60 | session = create_session(model, inference_session_config); 61 | 62 | Ok(()) 63 | }) 64 | } 65 | 66 | pub fn chat(args: &Chat) -> eyre::Result<()> { 67 | let Chat { 68 | model_load, 69 | prelude_prompt_file, 70 | generate, 71 | .. 72 | } = args; 73 | 74 | let (inference_session_config, parameters, model, mut rng) = 75 | initialize_common_state(generate, model_load)?; 76 | 77 | let prelude_prompt = std::fs::read_to_string(prelude_prompt_file)?; 78 | let message_prompt_prefix = args.message_prompt_prefix()?; 79 | 80 | let model = model.as_ref(); 81 | let mut session = create_session(model, inference_session_config); 82 | feed_prompt_with_spinner(model, &mut session, prelude_prompt)?; 83 | 84 | readline_loop(|raw_line| { 85 | let prompt = { 86 | let line = raw_line.replace("\\\n", "\n"); 87 | let mut prompt = format!("{message_prompt_prefix}{line}"); 88 | // Add a newline to the end of the prompt if it doesn't end with one 89 | if !prompt.ends_with('\n') { 90 | prompt.push('\n'); 91 | } 92 | prompt 93 | }; 94 | 95 | session.infer::( 96 | model, 97 | &mut rng, 98 | &llm::InferenceRequest { 99 | prompt: (&prompt).into(), 100 | parameters: ¶meters, 101 | play_back_previous_tokens: false, 102 | maximum_token_count: generate.num_predict, 103 | }, 104 | &mut Default::default(), 105 | llm::conversation_inference_callback(&message_prompt_prefix, util::print_token), 106 | )?; 107 | 108 | if !session_ends_with_newline(&session) { 109 | println!(); 110 | } 111 | 112 | Ok(()) 113 | }) 114 | } 115 | 116 | fn initialize_common_state( 117 | generate: &crate::cli_args::Generate, 118 | model_load: &crate::cli_args::ModelLoad, 119 | ) -> eyre::Result<( 120 | llm::InferenceSessionConfig, 121 | llm::InferenceParameters, 122 | Box, 123 | rand::rngs::StdRng, 124 | )> { 125 | let model = model_load.load(generate.use_gpu)?; 126 | Ok(( 127 | generate.inference_session_config(), 128 | generate.inference_parameters(model.eot_token_id(), model.tokenizer().len())?, 129 | model, 130 | generate.rng(), 131 | )) 132 | } 133 | 134 | fn feed_prompt_with_spinner( 135 | model: &dyn llm::Model, 136 | session: &mut llm::InferenceSession, 137 | mut prompt: String, 138 | ) -> eyre::Result<()> { 139 | // Add a newline to the beginning of the prompt if the last character in the session is not a newline 140 | if !session_ends_with_newline(session) { 141 | prompt.insert(0, '\n'); 142 | } 143 | 144 | let mut sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None); 145 | let result = session.feed_prompt( 146 | model, 147 | &prompt, 148 | // OutputRequest 149 | &mut Default::default(), 150 | |_| Ok::<_, Infallible>(llm::InferenceFeedback::Continue), 151 | ); 152 | sp.clear(); 153 | 154 | Ok(result?) 155 | } 156 | 157 | fn create_session( 158 | model: &dyn llm::Model, 159 | inference_session_config: llm::InferenceSessionConfig, 160 | ) -> llm::InferenceSession { 161 | snapshot::read_or_create_session(model, None, None, inference_session_config).0 162 | } 163 | 164 | fn session_ends_with_newline(session: &llm::InferenceSession) -> bool { 165 | session 166 | .decoded_tokens() 167 | .last() 168 | .map_or(true, |t| *t == b'\n') 169 | } 170 | 171 | fn readline_loop(mut body: impl FnMut(String) -> eyre::Result<()>) -> eyre::Result<()> { 172 | let mut rl = rustyline::Editor::::new()?; 173 | rl.set_helper(Some(LineContinuationValidator)); 174 | rl.bind_sequence(force_newline_event_seq(), Cmd::Newline); 175 | 176 | loop { 177 | match rl.readline(">> ") { 178 | Ok(raw_line) => { 179 | if let Err(err) = body(raw_line) { 180 | log::error!("{err}"); 181 | break; 182 | } 183 | } 184 | Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { 185 | break; 186 | } 187 | Err(err) => { 188 | log::error!("{err}"); 189 | break; 190 | } 191 | } 192 | } 193 | 194 | Ok(()) 195 | } 196 | 197 | #[cfg(not(windows))] 198 | fn force_newline_event_seq() -> KeyEvent { 199 | KeyEvent(KeyCode::Enter, Modifiers::ALT) 200 | } 201 | 202 | // On Windows, `SHIFT+ENTER` is the key sequence for forcing a newline. This is 203 | // because `ALT+ENTER` typically maximizes the window. 204 | #[cfg(windows)] 205 | fn force_newline_event_seq() -> KeyEvent { 206 | KeyEvent(KeyCode::Enter, Modifiers::SHIFT) 207 | } 208 | 209 | #[derive(Completer, Helper, Highlighter, Hinter, Debug, Clone, Copy)] 210 | struct LineContinuationValidator; 211 | 212 | impl Validator for LineContinuationValidator { 213 | fn validate(&self, ctx: &mut ValidationContext) -> rustyline::Result { 214 | if ctx.input().ends_with('\\') { 215 | Ok(ValidationResult::Incomplete) 216 | } else { 217 | Ok(ValidationResult::Valid(None)) 218 | } 219 | } 220 | } 221 | -------------------------------------------------------------------------------- /binaries/llm-cli/src/main.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | convert::Infallible, 3 | fs::File, 4 | io::{BufReader, BufWriter}, 5 | }; 6 | 7 | use clap::Parser; 8 | use cli_args::Args; 9 | use color_eyre::eyre::{self, Context, ContextCompat}; 10 | use is_terminal::IsTerminal; 11 | 12 | mod cli_args; 13 | mod interactive; 14 | mod snapshot; 15 | mod util; 16 | 17 | fn main() -> eyre::Result<()> { 18 | tracing_subscriber::fmt() 19 | .with_writer(std::io::stderr) 20 | .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) 21 | .with_ansi(std::io::stderr().is_terminal()) 22 | .init(); 23 | 24 | color_eyre::install()?; 25 | 26 | let args = Args::parse(); 27 | match args { 28 | Args::Infer(args) => infer(&args), 29 | Args::Perplexity(args) => perplexity(&args), 30 | Args::Info(args) => info(&args), 31 | Args::PromptTokens(args) => prompt_tokens(&args), 32 | Args::Repl(args) => interactive::repl(&args), 33 | Args::Chat(args) => interactive::chat(&args), 34 | Args::Quantize(args) => quantize(&args), 35 | } 36 | } 37 | 38 | #[tracing::instrument(skip_all)] 39 | fn infer(args: &cli_args::Infer) -> eyre::Result<()> { 40 | let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?; 41 | let inference_session_config = args.generate.inference_session_config(); 42 | let model = args.model_load.load(args.generate.use_gpu)?; 43 | 44 | let (mut session, session_loaded) = snapshot::read_or_create_session( 45 | model.as_ref(), 46 | args.persist_session.as_deref(), 47 | args.load_session.as_deref(), 48 | inference_session_config, 49 | ); 50 | let parameters = args 51 | .generate 52 | .inference_parameters(model.eot_token_id(), model.tokenizer().len())?; 53 | 54 | let mut rng = args.generate.rng(); 55 | 56 | let span = tracing::trace_span!("infer"); 57 | 58 | span.in_scope(|| { 59 | // do work inside the span... 60 | let res = session.infer::( 61 | model.as_ref(), 62 | &mut rng, 63 | &llm::InferenceRequest { 64 | prompt: prompt.as_str().into(), 65 | parameters: ¶meters, 66 | play_back_previous_tokens: session_loaded, 67 | maximum_token_count: args.generate.num_predict, 68 | }, 69 | // OutputRequest 70 | &mut Default::default(), 71 | |r| { 72 | match r { 73 | llm::InferenceResponse::PromptToken(t) if !args.hide_prompt => { 74 | util::print_token(t) 75 | } 76 | llm::InferenceResponse::InferredToken(t) => util::print_token(t), 77 | _ => {} 78 | } 79 | Ok(llm::InferenceFeedback::Continue) 80 | }, 81 | ); 82 | 83 | println!(); 84 | 85 | match res { 86 | Ok(stats) => { 87 | if args.stats { 88 | println!(); 89 | println!("{}", stats); 90 | println!(); 91 | } 92 | } 93 | Err(llm::InferenceError::ContextFull) => { 94 | log::warn!("Context window full, stopping inference.") 95 | } 96 | Err(llm::InferenceError::TokenizationFailed(err)) => { 97 | log::error!("A tokenization-related failure occurred: {}", err); 98 | } 99 | Err(llm::InferenceError::SamplerFailure(err)) => { 100 | log::error!("A sampling-related failure occurred: {}", err); 101 | } 102 | Err(llm::InferenceError::UserCallback(_)) | Err(llm::InferenceError::EndOfText) => { 103 | unreachable!("cannot fail") 104 | } 105 | } 106 | }); 107 | 108 | if let Some(session_path) = args.save_session.as_ref().or(args.persist_session.as_ref()) { 109 | // Write the memory to the cache file 110 | snapshot::write_session(session, session_path); 111 | } 112 | 113 | Ok(()) 114 | } 115 | 116 | fn perplexity(args: &cli_args::Perplexity) -> eyre::Result<()> { 117 | let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?; 118 | let inference_session_config = args.generate.inference_session_config(); 119 | let model = args.model_load.load(args.generate.use_gpu)?; 120 | let (mut session, _) = 121 | snapshot::read_or_create_session(model.as_ref(), None, None, inference_session_config); 122 | 123 | session.perplexity(model.as_ref(), prompt.as_str(), |chunk, perplexity| { 124 | println!("Perplexity[{chunk}]: {perplexity}"); 125 | })?; 126 | 127 | Ok(()) 128 | } 129 | 130 | fn info(args: &cli_args::Info) -> eyre::Result<()> { 131 | struct InfoVisitor<'a>(&'a cli_args::Info); 132 | impl llm::ModelArchitectureVisitor> for InfoVisitor<'_> { 133 | fn visit(&mut self) -> eyre::Result<()> { 134 | let args = self.0; 135 | 136 | let model_path = &args.model_and_tokenizer.model_path; 137 | let tokenizer = args.model_and_tokenizer.to_source()?.retrieve(model_path)?; 138 | 139 | let file = File::open(model_path)?; 140 | let mut reader = BufReader::new(&file); 141 | let mut loader: llm::Loader = 142 | llm::Loader::new(tokenizer, |_| { 143 | // We purposely do not print progress here, as we are only interested in the metadata 144 | }); 145 | 146 | llm::ggml_format::load(&mut reader, &mut loader)?; 147 | 148 | log::info!("Container type: {:?}", loader.container_type); 149 | log::info!("Hyperparameters: {:?}", loader.hyperparameters); 150 | log::info!("Tokenizer vocabulary size: {}", loader.tokenizer.len()); 151 | 152 | if args.tokenizer { 153 | log::info!("Tokens:"); 154 | for i in 0..loader.tokenizer.len() { 155 | log::info!("- {}: {}", i, utf8_or_array(&loader.tokenizer.token(i))); 156 | } 157 | } 158 | 159 | if args.tensors { 160 | log::info!("Tensors:"); 161 | for (name, tensor) in &loader.tensors { 162 | log::info!("- {} ({:?} {:?})", name, tensor.element_type, tensor.dims()); 163 | } 164 | } 165 | 166 | fn utf8_or_array(token: &[u8]) -> String { 167 | std::str::from_utf8(token).map_or(format!("{:?}", token), |s| s.to_owned()) 168 | } 169 | 170 | Ok(()) 171 | } 172 | } 173 | 174 | args.model_and_tokenizer 175 | .architecture 176 | .model_architecture 177 | .wrap_err("a model architecture is required at present")? 178 | .visit(&mut InfoVisitor(args)) 179 | } 180 | 181 | fn prompt_tokens(args: &cli_args::PromptTokens) -> eyre::Result<()> { 182 | let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?; 183 | let model = args.model_load.load(false)?; 184 | let toks = match model.tokenizer().tokenize(&prompt, false) { 185 | Ok(toks) => toks, 186 | Err(e) => { 187 | log::error!("Could not tokenize prompt: {e}"); 188 | std::process::exit(1); 189 | } 190 | }; 191 | log::info!("=== Dumping prompt tokens:"); 192 | log::info!( 193 | "{}", 194 | toks.iter() 195 | .map(|(_, tid)| tid.to_string()) 196 | .collect::>() 197 | .join(", ") 198 | ); 199 | log::info!( 200 | "{}", 201 | toks.iter() 202 | .map(|(s, tid)| format!("{s:?}:{tid}")) 203 | .collect::>() 204 | .join(", ") 205 | ); 206 | 207 | Ok(()) 208 | } 209 | 210 | fn quantize(args: &cli_args::Quantize) -> eyre::Result<()> { 211 | use llm::QuantizeProgress; 212 | 213 | struct QuantizeVisitor<'a>(&'a cli_args::Quantize); 214 | impl llm::ModelArchitectureVisitor> for QuantizeVisitor<'_> { 215 | fn visit(&mut self) -> eyre::Result<()> { 216 | let args = self.0; 217 | 218 | let mut source: BufReader = BufReader::new(std::fs::File::open(&args.source)?); 219 | let mut destination: BufWriter = 220 | BufWriter::new(std::fs::File::create(&args.destination)?); 221 | let tokenizer: llm::Tokenizer = args.tokenizer.to_source()?.retrieve(&args.source)?; 222 | 223 | llm::quantize::( 224 | &mut source, 225 | &mut destination, 226 | tokenizer, 227 | args.container_type.into(), 228 | args.target.into(), 229 | |progress| match progress { 230 | QuantizeProgress::HyperparametersLoaded => log::info!("Loaded hyperparameters"), 231 | QuantizeProgress::TensorLoading { 232 | name, 233 | dims, 234 | element_type, 235 | n_elements, 236 | } => log::info!( 237 | "Loading tensor `{name}` ({n_elements} ({dims:?}) {element_type} elements)" 238 | ), 239 | QuantizeProgress::TensorQuantizing { name } => log::info!("Quantizing tensor `{name}`"), 240 | QuantizeProgress::TensorQuantized { 241 | name, 242 | original_size, 243 | reduced_size, 244 | history, 245 | } => log::info!( 246 | "Quantized tensor `{name}` from {original_size} to {reduced_size} bytes ({history:?})" 247 | ), 248 | QuantizeProgress::TensorSkipped { name, size } => { 249 | log::info!("Skipped tensor `{name}` ({size} bytes)") 250 | } 251 | QuantizeProgress::Finished { 252 | original_size, 253 | reduced_size, 254 | history, 255 | } => log::info!( 256 | "Finished quantization from {original_size} to {reduced_size} bytes ({history:?})" 257 | ), 258 | }, 259 | ) 260 | .wrap_err("failed to quantize model") 261 | } 262 | } 263 | 264 | args.architecture 265 | .model_architecture 266 | .wrap_err("the architecture must be known for quantization")? 267 | .visit(&mut QuantizeVisitor(args)) 268 | } 269 | 270 | fn load_prompt_file_with_prompt( 271 | prompt_file: &cli_args::PromptFile, 272 | prompt: Option<&str>, 273 | ) -> eyre::Result { 274 | Ok(match (prompt_file.contents()?, prompt) { 275 | (Some(prompt_file), None) => prompt_file, 276 | (None, Some(prompt)) => prompt.to_owned(), 277 | (Some(prompt_file), Some(prompt)) => util::process_prompt(&prompt_file, prompt), 278 | (None, None) => eyre::bail!("No prompt or prompt file was provided. See --help"), 279 | }) 280 | } 281 | -------------------------------------------------------------------------------- /binaries/llm-cli/src/snapshot.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | error::Error, 3 | fs::File, 4 | io::{BufReader, BufWriter}, 5 | path::Path, 6 | }; 7 | 8 | use llm::{InferenceSession, InferenceSessionConfig, Model}; 9 | 10 | use zstd::{ 11 | stream::{read::Decoder, write::Encoder}, 12 | zstd_safe::CompressionLevel, 13 | }; 14 | 15 | const SNAPSHOT_COMPRESSION_LEVEL: CompressionLevel = 1; 16 | 17 | /// Read or create a session 18 | pub fn read_or_create_session( 19 | model: &dyn Model, 20 | persist_session: Option<&Path>, 21 | load_session: Option<&Path>, 22 | inference_session_config: InferenceSessionConfig, 23 | ) -> (InferenceSession, bool) { 24 | fn load(model: &dyn Model, path: &Path) -> InferenceSession { 25 | let file = unwrap_or_exit(File::open(path), || format!("Could not open file {path:?}")); 26 | let decoder = unwrap_or_exit(Decoder::new(BufReader::new(file)), || { 27 | format!("Could not create decoder for {path:?}") 28 | }); 29 | let snapshot = unwrap_or_exit(bincode::deserialize_from(decoder), || { 30 | format!("Could not deserialize inference session from {path:?}") 31 | }); 32 | let session = unwrap_or_exit(InferenceSession::from_snapshot(snapshot, model), || { 33 | format!("Could not convert snapshot from {path:?} to session") 34 | }); 35 | log::info!("Loaded inference session from {path:?}"); 36 | session 37 | } 38 | 39 | match (persist_session, load_session) { 40 | (Some(path), _) if path.exists() => (load(model, path), true), 41 | (_, Some(path)) => (load(model, path), true), 42 | _ => (model.start_session(inference_session_config), false), 43 | } 44 | } 45 | 46 | /// Write the session 47 | pub fn write_session(mut session: InferenceSession, path: &Path) { 48 | // SAFETY: the session is consumed here, so nothing else can access it. 49 | let snapshot = unsafe { session.get_snapshot() }; 50 | let file = unwrap_or_exit(File::create(path), || { 51 | format!("Could not create file {path:?}") 52 | }); 53 | let encoder = unwrap_or_exit( 54 | Encoder::new(BufWriter::new(file), SNAPSHOT_COMPRESSION_LEVEL), 55 | || format!("Could not create encoder for {path:?}"), 56 | ); 57 | unwrap_or_exit( 58 | bincode::serialize_into(encoder.auto_finish(), &snapshot), 59 | || format!("Could not serialize inference session to {path:?}"), 60 | ); 61 | log::info!("Successfully wrote session to {path:?}"); 62 | } 63 | 64 | fn unwrap_or_exit(result: Result, error_message: impl Fn() -> String) -> T { 65 | match result { 66 | Ok(t) => t, 67 | Err(err) => { 68 | log::error!("{}. Error: {err}", error_message()); 69 | std::process::exit(1); 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /binaries/llm-cli/src/util.rs: -------------------------------------------------------------------------------- 1 | use std::io::Write; 2 | 3 | pub fn process_prompt(raw_prompt: &str, prompt: &str) -> String { 4 | raw_prompt.replace("{{PROMPT}}", prompt) 5 | } 6 | 7 | pub fn print_token(t: String) { 8 | print!("{t}"); 9 | std::io::stdout().flush().unwrap(); 10 | } 11 | -------------------------------------------------------------------------------- /binaries/llm-test/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | edition = "2021" 3 | name = "llm-test" 4 | version = "0.2.0-dev" 5 | repository = { workspace = true } 6 | license = { workspace = true } 7 | publish = false 8 | 9 | [package.metadata.release] 10 | release = false 11 | 12 | [dependencies] 13 | llm = { path = "../../crates/llm", version = "0.2.0-dev" } 14 | 15 | anyhow = { workspace = true } 16 | clap = { workspace = true } 17 | env_logger = { workspace = true } 18 | log = { workspace = true } 19 | rand = { workspace = true } 20 | llm-samplers = { workspace = true } 21 | 22 | reqwest = "0.11.9" 23 | indicatif = "0.16.2" 24 | 25 | tokio = { version = "1.14.0", features = ["full"] } 26 | 27 | serde = "1.0.130" 28 | serde_json = "1.0.67" 29 | 30 | [dev-dependencies] 31 | rusty-hook = "^0.11.2" 32 | 33 | [features] 34 | cublas = ["llm/cublas"] 35 | clblast = ["llm/clblast"] 36 | metal = ["llm/metal"] 37 | 38 | # Falcon is off by default. See `llm_falcon`'s module documentation for more information. 39 | falcon = ["llm/falcon"] 40 | -------------------------------------------------------------------------------- /binaries/llm-test/configs/bloom.json: -------------------------------------------------------------------------------- 1 | { 2 | "url": "https://huggingface.co/rustformers/bloom-ggml/resolve/main/bloom-560m-q4_0.bin", 3 | "filename": "bloom.bin", 4 | "architecture": "bloom", 5 | "test_cases": [ 6 | { 7 | "Inference": { 8 | "input": "When a llama rides a crab, ", 9 | "output_disabled": "When a llama rides a crab, ,.-\n\n/? ', , ; A;A = (b),d e orm\n“t” + “p。n unus et les el duetant alle that are by no ... ”\n( ? ) – ‘?\n!!\n«…..’,\nS.\n\n‘l」之 attergoir à dit-on pas .. 。。 ..\n– La leçon se confond quelquefois con ce qui es vée par occident .\n( 2 ) .\nLa protestation del paysan mécontent regardait pendre eussent mœurs faillite forteresse rivières lieues forteressemelés inquiétudes crackdown brawl slaughter massacresokea .\n» » … « …\n. . . \" \" ….", 10 | "maximum_token_count": 128 11 | } 12 | }, 13 | { 14 | "Tokens": { 15 | "input": "Rustformers is", 16 | "output": 15 17 | } 18 | }, 19 | { 20 | "Delete": {} 21 | } 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /binaries/llm-test/configs/gptj.json: -------------------------------------------------------------------------------- 1 | { 2 | "url": "https://huggingface.co/rustformers/gpt-j-ggml/resolve/main/gpt-j-6b-q4_0-ggjt.bin", 3 | "filename": "gptj.bin", 4 | "architecture": "gptj", 5 | "test_cases": [ 6 | { 7 | "Inference": { 8 | "input": "When a llama rides a crab, ", 9 | "output_disabled": "\"When a llama rides a crab, \nit's not the same as when an elephant does it.\" - John Steinbeck, East of Eden.\n\n \"The best way to predict your future is by looking at history.\"- Robert Kiyosaki (author). Rich Dad Poor dad : what 10 rules for success really mean and how you can apply them in life! The rich dads guidebook on personal finance: How To Become A Millionaire In Less Than 5 years! http://www..richdadpoordaddyguidebooksalexanderkimballblogcom/the_bestwaytopredictyourfutureislookingathistory/. You will learn about money management", 10 | "maximum_token_count": 128 11 | } 12 | }, 13 | { 14 | "Tokens": { 15 | "input": "Rustformers is", 16 | "output": 257 17 | } 18 | }, 19 | { 20 | "Delete": {} 21 | } 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /binaries/llm-test/configs/gptneox.json: -------------------------------------------------------------------------------- 1 | { 2 | "url": "https://huggingface.co/rustformers/redpajama-3b-ggml/resolve/main/RedPajama-INCITE-Base-3B-v1-q4_0-ggjt.bin", 3 | "filename": "gptneox.bin", 4 | "architecture": "gptneox", 5 | "test_cases": [ 6 | { 7 | "Inference": { 8 | "input": "When a llama rides a crab, ", 9 | "output_disabled": "<|padding|>When a llama rides a crab, \n“The Greatest Show on Earth” is the title of an 1875 book by Phineas Taylor Barnum, who founded and operated The circus. He was born in Bethel Connecticut to Meshack (Meshake) Bowman Jr., from New York City; his mother’s name has not been recorded but she may have had some Native American ancestry as well.[2] His father died when he[3][4], at age three,[5]: 9–10 (p1), 11-12​—was left with relatives until they could find him work or send for them back home where there", 10 | "maximum_token_count": 128 11 | } 12 | }, 13 | { 14 | "Tokens": { 15 | "input": "Rustformers is", 16 | "output": 247 17 | } 18 | }, 19 | { 20 | "Delete": {} 21 | } 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /binaries/llm-test/configs/llama.json: -------------------------------------------------------------------------------- 1 | { 2 | "url": "https://huggingface.co/rustformers/open-llama-ggml/resolve/main/open_llama_3b-q4_0-ggjt.bin", 3 | "filename": "llama.bin", 4 | "architecture": "llama", 5 | "test_cases": [ 6 | { 7 | "Inference": { 8 | "input": "When a llama rides a crab, ", 9 | "output": "When a llama rides a crab, 10-year olds are the ones who get to eat.\nTheir parents have been told that they will be eating for another year or two before their children can enjoy it again – and then only if there is enough food left over from Christmas dinner!", 10 | "maximum_token_count": 128 11 | } 12 | }, 13 | { 14 | "Tokens": { 15 | "input": "Rustformers is", 16 | "output": 260 17 | } 18 | }, 19 | { 20 | "Delete": {} 21 | } 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /binaries/llm-test/configs/mpt.json: -------------------------------------------------------------------------------- 1 | { 2 | "url": "https://huggingface.co/rustformers/mpt-7b-ggml/resolve/main/mpt-7b-q4_0-ggjt.bin", 3 | "filename": "mpt.bin", 4 | "architecture": "mpt", 5 | "test_cases": [ 6 | { 7 | "Inference": { 8 | "input": "When a llama rides a crab, ", 9 | "output": "When a llama rides a crab,  the llama is called the \"crab rider\".\nThe crabs are very popular in South America, especially Brazil. They have been used as transportation for many years and they can carry up to five people at once!", 10 | "maximum_token_count": 128 11 | } 12 | }, 13 | { 14 | "Tokens": { 15 | "input": "Rustformers is", 16 | "output": 247 17 | } 18 | }, 19 | { 20 | "Delete": {} 21 | } 22 | ] 23 | } 24 | -------------------------------------------------------------------------------- /binaries/llm-test/src/common.rs: -------------------------------------------------------------------------------- 1 | //! Tests that are run on every model, regardless of config. 2 | 3 | pub(super) fn can_send(model: M) -> anyhow::Result { 4 | let model = std::thread::spawn(move || model) 5 | .join() 6 | .map_err(|e| anyhow::anyhow!("Failed to join thread: {e:?}")); 7 | 8 | log::info!("`can_send` test passed!"); 9 | 10 | model 11 | } 12 | 13 | pub(super) fn can_roundtrip_hyperparameters( 14 | model: &M, 15 | ) -> anyhow::Result<()> { 16 | fn test_hyperparameters(hyperparameters: &M) -> anyhow::Result<()> { 17 | let mut data = vec![]; 18 | hyperparameters.write_ggml(&mut data)?; 19 | let new_hyperparameters = 20 | ::read_ggml(&mut std::io::Cursor::new(data))?; 21 | 22 | assert_eq!(hyperparameters, &new_hyperparameters); 23 | 24 | log::info!("`can_roundtrip_hyperparameters` test passed!"); 25 | 26 | Ok(()) 27 | } 28 | 29 | test_hyperparameters(model.hyperparameters()) 30 | } 31 | -------------------------------------------------------------------------------- /binaries/llm-test/src/delete.rs: -------------------------------------------------------------------------------- 1 | //! Tests the model's token manipulation APIs: 2 | //! 3 | //! * [llm::InferenceSession::feed_prompt()] 4 | //! 5 | //! See [crate::TestCase::Tokens]. 6 | 7 | use std::convert::Infallible; 8 | 9 | use llm::{InferenceFeedback, InferenceSession, Model, OutputRequest}; 10 | use serde::Serialize; 11 | 12 | use crate::{TestCaseReport, TestCaseReportMeta}; 13 | 14 | /// Tests that models can delete tokens without changing the model's behavior. 15 | pub(crate) fn can_delete(model: &impl Model) -> TestCaseReport { 16 | let report = DeleteReport::default(); 17 | let mut session = model.start_session(Default::default()); 18 | let mut output = OutputRequest { 19 | all_logits: Some(vec![]), 20 | ..Default::default() 21 | }; 22 | 23 | // Feed some tokens 24 | if let Err(err) = feed_prompt("The llama lived on the", &mut session, model, &mut output) { 25 | return report.failure(&err.to_string()); 26 | } 27 | 28 | // Add token and get the logits 29 | if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) { 30 | return report.failure(&err.to_string()); 31 | } 32 | let Some(original_logits) = output.all_logits.clone() else { 33 | return report.failure("Model did not return logits."); 34 | }; 35 | 36 | // Rewind, then re-add. Verify logits are the same. 37 | if let Err(err) = session.rewind(model, 1) { 38 | return report.failure(&err.to_string()); 39 | } 40 | if let Err(err) = feed_prompt(" ", &mut session, model, &mut output) { 41 | return report.failure(&err.to_string()); 42 | } 43 | let Some(redone_logits) = output.all_logits.clone() else { 44 | return report.failure("Second run of model did not return logits."); 45 | }; 46 | 47 | // Compare the logits 48 | for (idx, (&original, redone)) in original_logits.iter().zip(redone_logits).enumerate() { 49 | if original > redone + f32::EPSILON || original < redone - f32::EPSILON { 50 | return report.failure(&format!( 51 | "Expected logits to be the same after delete, but differed at {idx}, \ 52 | expected {original}, but was {redone}." 53 | )); 54 | } 55 | } 56 | 57 | log::info!("`can_delete` test passed!"); 58 | report.success() 59 | } 60 | 61 | fn feed_prompt( 62 | prompt: &str, 63 | session: &mut InferenceSession, 64 | model: &impl Model, 65 | output: &mut OutputRequest, 66 | ) -> Result<(), llm::InferenceError> { 67 | session.feed_prompt(model, prompt, output, always_continue) 68 | } 69 | 70 | fn always_continue(_: &[u8]) -> Result { 71 | Ok(InferenceFeedback::Continue) 72 | } 73 | 74 | #[derive(Serialize, Default)] 75 | pub struct DeleteReport { 76 | output: usize, 77 | } 78 | 79 | impl DeleteReport { 80 | fn failure(self, msg: &str) -> TestCaseReport { 81 | TestCaseReport { 82 | meta: TestCaseReportMeta::Error { 83 | error: msg.to_owned(), 84 | }, 85 | report: crate::TestCaseReportInner::Delete(self), 86 | } 87 | } 88 | 89 | fn success(self) -> TestCaseReport { 90 | TestCaseReport { 91 | meta: TestCaseReportMeta::Success, 92 | report: crate::TestCaseReportInner::Delete(self), 93 | } 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /binaries/llm-test/src/inference.rs: -------------------------------------------------------------------------------- 1 | //! Tests the model's inference APIs. 2 | //! 3 | //! See [crate::TestCase::Inference]. 4 | 5 | use std::{ 6 | convert::Infallible, 7 | sync::{Arc, Mutex}, 8 | }; 9 | 10 | use llm::{InferenceSessionConfig, InferenceStats, TokenId}; 11 | 12 | use llm_samplers::prelude::{HasSamplerResources, Logits, SampleFlatBias, SampleGreedy, Sampler}; 13 | 14 | use crate::{ModelConfig, TestCaseReport, TestCaseReportInner, TestCaseReportMeta}; 15 | 16 | pub(crate) fn can_infer( 17 | model: &dyn llm::Model, 18 | model_config: &ModelConfig, 19 | input: &str, 20 | expected_output: Option<&str>, 21 | maximum_token_count: usize, 22 | ) -> anyhow::Result { 23 | let mut session = model.start_session(InferenceSessionConfig { 24 | n_threads: model_config.threads, 25 | ..Default::default() 26 | }); 27 | let (actual_output, res) = run_inference(model, &mut session, input, maximum_token_count); 28 | 29 | // Process the results 30 | Ok(TestCaseReport { 31 | meta: match &res { 32 | Ok(_) => match expected_output { 33 | Some(expected_output) => { 34 | if expected_output == actual_output { 35 | log::info!("`can_infer` test passed!"); 36 | TestCaseReportMeta::Success 37 | } else { 38 | TestCaseReportMeta::Error { 39 | error: "The output did not match the expected output.".to_string(), 40 | } 41 | } 42 | } 43 | None => { 44 | log::info!("`can_infer` test passed (no expected output)!"); 45 | TestCaseReportMeta::Success 46 | } 47 | }, 48 | Err(err) => TestCaseReportMeta::Error { 49 | error: err.to_string(), 50 | }, 51 | }, 52 | report: TestCaseReportInner::Inference { 53 | input: input.into(), 54 | expect_output: expected_output.map(|s| s.to_string()), 55 | actual_output, 56 | inference_stats: res.ok(), 57 | }, 58 | }) 59 | } 60 | 61 | fn run_inference( 62 | model: &dyn llm::Model, 63 | session: &mut llm::InferenceSession, 64 | input: &str, 65 | maximum_token_count: usize, 66 | ) -> (String, Result) { 67 | let mut actual_output: String = String::new(); 68 | let res = session.infer::( 69 | model, 70 | &mut rand::rngs::mock::StepRng::new(0, 1), 71 | &llm::InferenceRequest { 72 | prompt: input.into(), 73 | parameters: &llm::InferenceParameters { 74 | sampler: Arc::new(Mutex::new(DeterministicSampler::default())), 75 | }, 76 | play_back_previous_tokens: false, 77 | maximum_token_count: Some(maximum_token_count), 78 | }, 79 | &mut Default::default(), 80 | |r| match r { 81 | llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => { 82 | actual_output += &t; 83 | Ok(llm::InferenceFeedback::Continue) 84 | } 85 | _ => Ok(llm::InferenceFeedback::Continue), 86 | }, 87 | ); 88 | 89 | (actual_output, res) 90 | } 91 | 92 | // Takes the most likely element from the logits, except if they've appeared in `previous_tokens` 93 | // at all 94 | #[derive(Debug, Default)] 95 | struct DeterministicSampler(SampleGreedy); 96 | 97 | impl Sampler for DeterministicSampler { 98 | fn sample<'a>( 99 | &mut self, 100 | res: &mut dyn HasSamplerResources, 101 | logits: &'a mut Logits, 102 | ) -> anyhow::Result<&'a mut Logits> { 103 | let mut flat_bias = Default::default(); 104 | 105 | // This might look a little weird, but it's necessary because the resource 106 | // `with_` functions can't return a value. 107 | res.with_last_tokens(&mut |lt| { 108 | flat_bias = SampleFlatBias::new(lt.iter().map(|tid| (*tid, f32::NEG_INFINITY))); 109 | })?; 110 | 111 | logits.sample(res, &mut flat_bias)?.sample(res, &mut self.0) 112 | } 113 | 114 | fn sampled_token_id(&self) -> Option { 115 | *self.0 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /binaries/llm-test/src/tokens.rs: -------------------------------------------------------------------------------- 1 | //! Tests the model's token manipulation APIs: 2 | //! 3 | //! * [llm::InferenceSession::feed_prompt()] 4 | //! 5 | //! See [crate::TestCase::Tokens]. 6 | 7 | use std::convert::Infallible; 8 | 9 | use llm::{InferenceFeedback, InferenceSession, Model, OutputRequest}; 10 | use serde::Serialize; 11 | 12 | use crate::{TestCaseReport, TestCaseReportMeta}; 13 | 14 | /// Tests that the model performs as expected when feeding tokens 15 | pub(crate) fn can_feed(model: &impl Model, input: &str, expected_output: usize) -> TestCaseReport { 16 | let mut report = TokensReport::default(); 17 | let mut session = model.start_session(Default::default()); 18 | let mut output = OutputRequest { 19 | all_logits: Some(vec![]), 20 | ..Default::default() 21 | }; 22 | 23 | if let Err(err) = feed_prompt(input, &mut session, model, &mut output) { 24 | return report.failure(&err.to_string()); 25 | }; 26 | 27 | let top_token; 28 | match output.all_logits { 29 | Some(logits) => { 30 | let start = logits.len() - model.tokenizer().len(); 31 | let mut iter = logits[start..].iter().enumerate(); 32 | let Some((mut max_idx, mut max)) = iter.next() else { 33 | return report.failure("Could not find any logits for last token."); 34 | }; 35 | for (idx, score) in iter { 36 | if score > max { 37 | max = score; 38 | max_idx = idx; 39 | } 40 | } 41 | top_token = max_idx; 42 | } 43 | None => return report.failure("Model did not output any logits."), 44 | } 45 | 46 | report.output = top_token; 47 | 48 | if top_token != expected_output { 49 | let tokenizer = model.tokenizer(); 50 | let top_token_str = String::from_utf8_lossy(&tokenizer.token(top_token)).to_string(); 51 | let expected_str = String::from_utf8_lossy(&tokenizer.token(expected_output)).to_string(); 52 | return report.failure(&format!( 53 | "Expected top token to be {expected_output} ({expected_str}), \ 54 | but was {top_token} ({top_token_str})" 55 | )); 56 | } 57 | 58 | log::info!("`can_feed` test passed!"); 59 | report.success() 60 | } 61 | 62 | fn feed_prompt( 63 | prompt: &str, 64 | session: &mut InferenceSession, 65 | model: &impl Model, 66 | output: &mut OutputRequest, 67 | ) -> Result<(), llm::InferenceError> { 68 | session.feed_prompt(model, prompt, output, always_continue) 69 | } 70 | 71 | fn always_continue(_: &[u8]) -> Result { 72 | Ok(InferenceFeedback::Continue) 73 | } 74 | 75 | #[derive(Serialize, Default)] 76 | pub struct TokensReport { 77 | output: usize, 78 | } 79 | 80 | impl TokensReport { 81 | fn failure(self, msg: &str) -> TestCaseReport { 82 | TestCaseReport { 83 | meta: TestCaseReportMeta::Error { 84 | error: msg.to_owned(), 85 | }, 86 | report: crate::TestCaseReportInner::Tokens(self), 87 | } 88 | } 89 | 90 | fn success(self) -> TestCaseReport { 91 | TestCaseReport { 92 | meta: TestCaseReportMeta::Success, 93 | report: crate::TestCaseReportInner::Tokens(self), 94 | } 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /binaries/precommit-check/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "precommit-check" 3 | version = "0.1.0" 4 | edition = "2021" 5 | publish = false 6 | 7 | [package.metadata.release] 8 | release = false -------------------------------------------------------------------------------- /binaries/precommit-check/README.md: -------------------------------------------------------------------------------- 1 | # precommit-check 2 | 3 | Helper script to run pre-commit checks on a repository. Used with `rusty-hook` to execute all of the checks and early exit if any of them fail. 4 | -------------------------------------------------------------------------------- /binaries/precommit-check/src/main.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | // Ensure that these match `.github/workflows/rust.yml`. 3 | cmd("cargo", &["check"], &[]); 4 | cmd("cargo", &["test", "--all"], &[]); 5 | cmd("cargo", &["fmt", "--check", "--all"], &[]); 6 | cmd( 7 | "cargo", 8 | &["doc", "--workspace", "--exclude", "llm-cli"], 9 | &[("RUSTDOCFLAGS", "-Dwarnings")], 10 | ); 11 | cmd( 12 | "cargo", 13 | &["clippy", "--workspace", "--", "-Dclippy::all"], 14 | &[], 15 | ); 16 | } 17 | 18 | fn cmd(cmd: &str, args: &[&str], env: &[(&str, &str)]) { 19 | println!("=== Running command: {cmd} {args:?}"); 20 | let mut builder = std::process::Command::new(cmd); 21 | builder.args(args); 22 | builder.envs(env.iter().copied()); 23 | let mut child = builder.spawn().unwrap(); 24 | if !child.wait().unwrap().success() { 25 | panic!("Failed to run command: {} {:?}", cmd, builder); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /crates/ggml/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ggml" 3 | version = "0.2.0-dev" 4 | repository = { workspace = true } 5 | edition = "2021" 6 | description = "Semi-idiomatic Rust bindings for the ggml library (from `ggml-sys`)." 7 | license = "MIT" 8 | 9 | [dependencies] 10 | ggml-sys = { path = "sys", version = "0.2.0-dev" } 11 | 12 | thiserror = { workspace = true } 13 | memmap2 = { workspace = true } 14 | 15 | [dev-dependencies] 16 | rand = { workspace = true } 17 | anyhow = { workspace = true } 18 | 19 | [features] 20 | cublas = ["ggml-sys/cublas"] 21 | clblast = ["ggml-sys/clblast"] 22 | metal = ["ggml-sys/metal"] 23 | -------------------------------------------------------------------------------- /crates/ggml/README.md: -------------------------------------------------------------------------------- 1 | # GGML - Large Language Models for Everyone 2 | 3 | [GGML](https://github.com/ggerganov/ggml) is a C library for machine learning 4 | (ML) - the "GG" refers to the initials of its originator 5 | ([Georgi Gerganov](https://ggerganov.com/)). In addition to defining low-level 6 | machine learning primitives (like a [tensor](#weights) type), GGML defines a 7 | binary format for distributing large language models (LLMs). This crate provides 8 | Rust [bindings](sys) into the reference implementation of GGML, as well as a 9 | collection of [native](src) Rust helpers to provide safe, idiomatic access to 10 | those bindings. GGML makes use of a technique called 11 | "[quantization]()" 12 | that allows for large language models to run on consumer hardware. This 13 | documents describes the basics of the GGML format, including how 14 | [quantization](#quantization) is used to democratize access to LLMs. 15 | 16 | ## Format 17 | 18 | GGML files consists of binary-encoded data that is laid out according to a 19 | specified format. The format specifies what kind of data is present in the file, 20 | how it is represented, and the order in which it appears. The first piece of 21 | information present in a valid GGML file is a GGML version number, followed by 22 | three components that define a large language model: the model's 23 | [hyperparameters](#hyperparameters), its [vocabulary](#vocabulary), and its 24 | [weights](#weights). Continue reading to learn more about GGML versions and the 25 | components of a GGML model. 26 | 27 | ### GGML Versions 28 | 29 | GGML is "bleeding-edge" technology and undergoes frequent changes. In an effort 30 | to support rapid development without sacrificing backwards-compatibility, GGML 31 | uses versioning to introduce improvements that may change the format of the 32 | encoding. For example, newer versions of GGML make use of 33 | [vocabulary](#vocabulary)-scoring, which introduces extra information into the 34 | encoding, as well as [mmap](https://en.wikipedia.org/wiki/Mmap), which enhances 35 | performance through memory-mapping. The first value that is present in a valid 36 | GGML file is a "magic number" that indicates the GGML version that was used to 37 | encode the model. 38 | 39 | ### Hyperparameters 40 | 41 | The term 42 | "[hyperparameter]()" 43 | describes a value that is used to configure the behavior of a large language 44 | model; this is in contrast to the model's **parameters**, which are the 45 | [weights](#weights) that were derived in the training process that was used to 46 | create the model. Each model defines its own hyperparameter structure that 47 | defines the hyperparameter values accepted by that model. Valid GGML files must 48 | list these values in the correct order, and each value must be represented using 49 | the correct data type. Although hyperparameters are different across models, 50 | some attributes appear in the hyperparameters for most models: 51 | 52 | - `n_vocab`: the size of the model's [vocabulary](#vocabulary) 53 | - `n_embd`: the size of the model's 54 | "[embedding](https://en.wikipedia.org/wiki/Word_embedding) layer", which is 55 | used during prompt ingestion 56 | - `n_layer`: the number of layers in the model; each layer represents a set of 57 | [weights](#weights). 58 | 59 | ### Vocabulary 60 | 61 | As the name implies, a model's vocabulary comprises components that are used by 62 | the model to generate language (text). However, unlike the vocabulary of a 63 | human, which consists of _words_, the vocabulary of a large language model 64 | consists of "tokens". A token _can_ be an entire word, but oftentimes they are 65 | word _fragments_. Just like humans can compose millions of words from just a 66 | dozen or two letters, large language models use _tokens_ to express a large 67 | number of words from a relatively smaller number of components. Consider a 68 | vocabulary with the following tokens: `whi`, `ch` `le`, `who`, and `a`; this 69 | vocabulary can be used to create the English words "which", "while", "who", "a", 70 | and "leach". How would the behavior change if the model contained the following 71 | tokens: `wh`, `ich`, `ile`, `o`, and `leach`? Choices such as these allow 72 | model-creators to tune the behavior and performance of their models. 73 | 74 | As described above, the model's [hyperparameters](#hyperparameters) typically 75 | contains a value that specifies the number of tokens in the vocabulary. The 76 | vocabulary is encoded as a list of tokens, each of which includes a 32-bit 77 | integer that specifies the length of the token. Depending on the GGML version, 78 | the token may also include a 32-bit floating point score, which represents the 79 | frequency of that token in the model's training data. 80 | 81 | ### Weights 82 | 83 | The final, and largest, component of a GGML file is the weights of the LLM that 84 | the file represents. Abstractly, a large language model is software that is used 85 | to generate language - just like software that is used to generate _images_ can 86 | be improved by increasing the number of colors with which images can be 87 | rendered, large language models can be improved by increasing the number of 88 | _weights_ in the model. The total number of weights in a model are referred to 89 | as the "size" of that model. For example, the 90 | [StableLM](https://github.com/Stability-AI/StableLM) implementation of the 91 | [GPT-NeoX](https://github.com/EleutherAI/gpt-neox) language model architecture 92 | is available in a number of sizes, like 3B and 7B, which stands for 3-billion 93 | and 7-billion, respectively. These numbers refer to the total number of weights 94 | in that model. As described in the [hyperparameters](#hyperparameters) section, 95 | weights are grouped together in sets called "layers", which, like 96 | hyperparameters, have structures that are uniquely defined by the model 97 | architecture; within a layer, weights are grouped together in structures called 98 | "tensors". So, for instance, both StableLM 3B and StableLM 7B use layers that 99 | comprise the same tensors, but StableLM 3B has relatively _fewer_ layers when 100 | compared to StableLM 7B. 101 | 102 | In GGML, a tensor consists of a number of components, including: a name, a 103 | 4-element list that represents the number of dimensions in the tensor and their 104 | lengths, and a list of the weights in that tensor. For example, consider the 105 | following 2 ⨯ 2 tensor named `tensor_a0`: 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 |
tensor_a0
1.00.0
0.11.1
120 | 121 | A simplification of the GGML representation of `tensor_a0` is 122 | `{"tensor_a0", [2, 2, 1, 1], [1.0, 0.0, 0.1, 1.1]}`. Note that the 4-element 123 | list of dimensions uses `1` as a placeholder for unused dimensions - this is 124 | because the product of the dimensions should not equal zero. 125 | 126 | The weights in a GGML file are encoded as a list of layers, the length of which 127 | is typically specified in the model's hyperparameters; each layer is encoded as 128 | an ordered set of tensors. 129 | 130 | #### Quantization 131 | 132 | LLM weights are floating point (decimal) numbers. Just like it requires more 133 | space to represent a large integer (e.g. 1000) compared to a small integer (e.g. 134 | 1), it requires more space to represent a high-precision floating point number 135 | (e.g. 0.0001) compared to a low-precision floating number (e.g. 0.1). The 136 | process of "quantizing" a large language model involves reducing the precision 137 | with which weights are represented in order to reduce the resources required to 138 | use the model. GGML supports a number of different quantization strategies (e.g. 139 | 4-bit, 5-bit, and 8-bit quantization), each of which offers different trade-offs 140 | between efficiency and performance. 141 | [More information](https://github.com/ggerganov/llama.cpp#quantization) about 142 | these trade-offs can be found in the documentation for llama.cpp, which is 143 | another project by the maintainer of GGML. Technical details about quantization 144 | are described in [this video](https://www.youtube.com/watch?v=mii-xFaPCrA) by 145 | @Aemon-Algiz. 146 | -------------------------------------------------------------------------------- /crates/ggml/src/accelerator/metal.rs: -------------------------------------------------------------------------------- 1 | //! Metal support. 2 | use crate::{sys::metal, Buffer, ComputationGraph, Context, Tensor}; 3 | use std::{ptr::NonNull, sync::Arc}; 4 | 5 | /// Acts as a RAII-guard over a `sys::metal::ggml_metal_context`, allocating via 6 | /// `ggml_metal_init` and dropping via `ggml_metal_free`. 7 | pub struct MetalContext { 8 | ptr: Arc>, 9 | 10 | /// References to the context that hold buffers that are used in this Metal context. As Metal does not need to copy 11 | /// buffers to VRAM, we do need to keep the original buffers alive through this reference. 12 | contexts: Vec>, 13 | } 14 | 15 | impl MetalContext { 16 | /// Create a new Metal context 17 | pub fn new(n_threads: usize) -> Self { 18 | let raw = unsafe { metal::ggml_metal_init(n_threads.try_into().unwrap()) }; 19 | 20 | MetalContext { 21 | contexts: vec![], 22 | ptr: Arc::new(NonNull::new(raw).expect("Should not be null")), 23 | } 24 | } 25 | 26 | /// Register a buffer mapping 27 | pub fn add_scratch_buffer(&mut self, buf: &Buffer) { 28 | unsafe { 29 | let raw_metal_context = self.ptr.as_ptr(); 30 | 31 | //Last we need to add the scratch buffers to the buffers 32 | assert!( 33 | metal::ggml_metal_add_buffer( 34 | raw_metal_context, 35 | "scratch\0".as_ptr().cast(), // FIXME: allocate string and insert number in name 36 | buf.data, 37 | buf.size(), 38 | buf.size() 39 | ), 40 | "{}", 41 | format!("Could not add scratch buffer to metal context") 42 | ); 43 | } 44 | } 45 | 46 | /// Add a context's memory as buffer to this Metal context 47 | pub fn add_context(&mut self, from_context: Arc) { 48 | if !self.ref_context(from_context.clone()) { 49 | return; 50 | } 51 | 52 | unsafe { 53 | let raw_context = from_context.as_ptr(); 54 | let (data_ptr, data_size) = from_context.storage().as_ptr_and_size(&from_context); 55 | let max_size = ggml_sys::ggml_get_max_tensor_size(raw_context); 56 | assert!( 57 | metal::ggml_metal_add_buffer( 58 | self.ptr.as_ptr(), 59 | "wt\0".as_ptr().cast(), // FIXME provide an actual name 60 | data_ptr, 61 | data_size, 62 | max_size 63 | ), 64 | "Could not add weight buffer to metal context" 65 | ); 66 | } 67 | } 68 | } 69 | 70 | impl MetalContext { 71 | /// Registers a context as a context that provides Metal buffers. Returns true if the context was not registered before. 72 | fn ref_context(&mut self, context: Arc) -> bool { 73 | if self.contexts.iter().any(|c| *c == context) { 74 | false 75 | } else { 76 | self.contexts.push(context); 77 | true 78 | } 79 | } 80 | 81 | /// Computes the specified graph using Metal. 82 | pub fn graph_compute(&self, graph: &mut ComputationGraph) { 83 | unsafe { 84 | metal::ggml_metal_graph_compute( 85 | self.ptr.as_ptr(), 86 | graph.inner as *mut ggml_sys::ggml_cgraph as *mut metal::ggml_cgraph, 87 | ); 88 | } 89 | } 90 | 91 | /// Reads a tensor from Metal 92 | pub fn get_tensor(&self, tensor: &Tensor) { 93 | unsafe { 94 | metal::ggml_metal_get_tensor( 95 | self.ptr.as_ptr(), 96 | tensor.ptr.as_ptr() as *mut metal::ggml_tensor, 97 | ) 98 | } 99 | } 100 | } 101 | 102 | impl Drop for MetalContext { 103 | fn drop(&mut self) { 104 | // SAFETY: The only non-weak copy of ptr is no longer accessible after 105 | // this drop call. 106 | unsafe { metal::ggml_metal_free(self.ptr.as_ptr()) } 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /crates/ggml/src/accelerator/mod.rs: -------------------------------------------------------------------------------- 1 | //! Functionality related to hardware acceleration of GGML (GPU, etc.) 2 | use crate::sys; 3 | 4 | #[cfg(feature = "metal")] 5 | pub mod metal; 6 | 7 | #[derive(Debug, Copy, Clone, PartialEq, Eq)] 8 | /// Accelerators supported by `ggml`. 9 | pub enum Accelerator { 10 | /// CuBLAS accelerated 11 | CuBLAS, 12 | /// CLBlast accelerated 13 | CLBlast, 14 | /// Metal accelerated 15 | Metal, 16 | /// Cpu accelerated 17 | None, 18 | } 19 | 20 | /// Returns the accelerator `ggml` was compiled with. 21 | pub fn get_accelerator() -> Accelerator { 22 | #[cfg(feature = "clblast")] 23 | return Accelerator::CLBlast; 24 | #[cfg(feature = "cublas")] 25 | return Accelerator::CuBLAS; 26 | #[cfg(feature = "metal")] 27 | return Accelerator::Metal; 28 | #[cfg(not(any(feature = "cublas", feature = "clblast", feature = "metal")))] 29 | return Accelerator::None; 30 | } 31 | 32 | #[derive(Default, Debug, Copy, Clone, PartialEq, Eq)] 33 | /// Backend to use for a tensor. 34 | pub enum Backend { 35 | /// CPU backend 36 | #[default] 37 | Cpu, 38 | /// GPU backend 39 | Gpu, 40 | /// Multi-GPU backend 41 | GpuSplit, 42 | } 43 | 44 | impl From for sys::ggml_backend { 45 | fn from(b: Backend) -> Self { 46 | match b { 47 | Backend::Cpu => sys::ggml_backend_GGML_BACKEND_CPU, 48 | Backend::Gpu => sys::ggml_backend_GGML_BACKEND_GPU, 49 | Backend::GpuSplit => sys::ggml_backend_GGML_BACKEND_GPU_SPLIT, 50 | } 51 | } 52 | } 53 | 54 | impl TryFrom for Backend { 55 | type Error = (); 56 | fn try_from(b: sys::ggml_backend) -> Result { 57 | match b { 58 | sys::ggml_backend_GGML_BACKEND_CPU => Ok(Backend::Cpu), 59 | sys::ggml_backend_GGML_BACKEND_GPU => Ok(Backend::Gpu), 60 | sys::ggml_backend_GGML_BACKEND_GPU_SPLIT => Ok(Backend::GpuSplit), 61 | _ => Err(()), 62 | } 63 | } 64 | } 65 | 66 | /// Initialize the accelerator. If ggml-sys is compiled with CUDA or CLBlast support, this function will initialize the accelerator. If not this is a no-op. 67 | #[allow(unused_variables)] 68 | pub fn initialize(device: i32) { 69 | #[cfg(feature = "cublas")] 70 | unsafe { 71 | //TODO: Make this configurable 72 | sys::cuda::ggml_init_cublas(); 73 | sys::cuda::ggml_cuda_set_main_device(device); 74 | let split = 1.0f32; 75 | sys::cuda::ggml_cuda_set_tensor_split(&split as *const f32); 76 | } 77 | } 78 | 79 | /// Sets the scratch size for the GPU. If ggml-sys is compiled with CUDA support, this function will set the scratch size. If not this is a no-op. 80 | #[allow(unused_variables)] 81 | pub fn set_scratch_size(size: usize) { 82 | #[cfg(feature = "cublas")] 83 | unsafe { 84 | sys::cuda::ggml_cuda_set_scratch_size(size); 85 | } 86 | } 87 | 88 | /// Frees the scratch memory. If ggml-sys is compiled with CUDA support, this function will free the scratch memory. If not this is a no-op. 89 | pub fn free_scratch() { 90 | #[cfg(feature = "cublas")] 91 | unsafe { 92 | sys::cuda::ggml_cuda_free_scratch(); 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /crates/ggml/src/format/loader.rs: -------------------------------------------------------------------------------- 1 | //! The loader module contains the code for loading a model from disk. 2 | //! 3 | //! To handle a specific model, implement [LoadHandler] for your model 4 | //! and call [load] with an instance of your handler. It is up to you 5 | //! to process the data from the handler and construct your model. 6 | 7 | use std::{ 8 | error::Error, 9 | fmt, 10 | io::{BufRead, Seek, SeekFrom}, 11 | }; 12 | 13 | use crate::{ 14 | util::{has_data_left, read_bytes_with_len, read_f32, read_i32, read_u32}, 15 | ContainerType, ElementType, 16 | }; 17 | 18 | /// Helper struct that wraps the magic number of a file format, 19 | /// so that it can be printed in a human-readable format. 20 | pub struct FormatMagic(pub u32); 21 | impl fmt::Display for FormatMagic { 22 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { 23 | write!( 24 | f, 25 | "{:x} ({})", 26 | self.0, 27 | String::from_utf8_lossy(&self.0.to_le_bytes()) 28 | ) 29 | } 30 | } 31 | impl fmt::Debug for FormatMagic { 32 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { 33 | fmt::Display::fmt(self, f) 34 | } 35 | } 36 | 37 | #[derive(Debug, thiserror::Error)] 38 | /// Errors that can occur while loading a model. 39 | pub enum LoadError { 40 | #[error("invalid file magic number: {0}")] 41 | /// The file magic number is invalid. 42 | InvalidMagic(FormatMagic), 43 | #[error("invalid ggml format: format={0:?}")] 44 | /// An unsupported format version was found. 45 | InvalidFormatVersion(ContainerType), 46 | #[error("non-specific I/O error")] 47 | /// A non-specific IO error. 48 | Io(#[from] std::io::Error), 49 | #[error("could not convert bytes to a UTF-8 string")] 50 | /// One of the strings encountered was not valid UTF-8. 51 | InvalidUtf8(#[from] std::string::FromUtf8Error), 52 | #[error("invalid integer conversion")] 53 | /// One of the integers encountered could not be converted to a more appropriate type. 54 | InvalidIntegerConversion(#[from] std::num::TryFromIntError), 55 | #[error("implementation error")] 56 | /// An error `E` was returned by the implementation of the loader. 57 | ImplementationError(#[source] E), 58 | #[error("unsupported tensor type {ftype} for tensor {tensor_name}")] 59 | /// One of the tensors encountered had an unsupported data type. 60 | UnsupportedElementType { 61 | /// The name of the tensor. 62 | tensor_name: String, 63 | /// The format type that was encountered. 64 | ftype: u32, 65 | }, 66 | #[error("invariant broken: {0}")] 67 | /// An invariant was broken. 68 | InvariantBroken(String), 69 | } 70 | 71 | #[derive(Debug, Clone)] 72 | /// Information about a [tensor](https://en.wikipedia.org/wiki/Tensor_(machine_learning)) that is being read. 73 | pub struct TensorLoadInfo { 74 | /// The name of the tensor. 75 | pub name: String, 76 | /// The number of dimensions in the tensor. 77 | pub n_dims: usize, 78 | /// The dimensions of the tensor. 79 | pub dims: [usize; 2], 80 | /// The number of elements in the tensor. 81 | pub n_elements: usize, 82 | /// The type of the elements in the tensor. 83 | pub element_type: ElementType, 84 | /// start of tensor - start of file 85 | pub start_offset: u64, 86 | } 87 | impl TensorLoadInfo { 88 | /// Get the dimensions of the tensor. 89 | pub fn dims(&self) -> &[usize] { 90 | &self.dims[0..self.n_dims] 91 | } 92 | 93 | /// Calculate the size of the tensor's values in bytes. 94 | pub fn calc_size(&self) -> usize { 95 | data_size(self.element_type, self.dims().iter().product()) 96 | } 97 | 98 | /// Calculates the absolute size in bytes of the tensor's data, given the mmap flag. 99 | pub fn calc_absolute_size(&self, mmap: bool) -> usize { 100 | if mmap { 101 | header_size() 102 | } else { 103 | header_size() + self.calc_size() 104 | } 105 | } 106 | 107 | /// Reads the tensor's data from the given reader in an owned fashion. 108 | /// 109 | /// The behaviour is undefined if the reader does not correspond to this info. 110 | /// 111 | /// Do not use this if loading with `mmap`. 112 | pub fn read_data(&self, reader: &mut R) -> std::io::Result> { 113 | let n_bytes = self.n_elements * crate::type_size(self.element_type); 114 | let mut data = vec![0; n_bytes]; 115 | reader.seek(SeekFrom::Start(self.start_offset))?; 116 | reader.read_exact(&mut data)?; 117 | Ok(data) 118 | } 119 | } 120 | 121 | /// Returns the size occupied by a tensor's data in bytes given the element type and number of elements. 122 | pub(crate) fn data_size(element_type: ElementType, n_elements: usize) -> usize { 123 | (crate::type_size(element_type) * n_elements) / crate::blck_size(element_type) 124 | } 125 | 126 | /// Returns the size of the ggml tensor header in bytes. 127 | pub(crate) fn header_size() -> usize { 128 | crate::Tensor::C_TYPE_SIZE + crate::OBJECT_SIZE 129 | } 130 | 131 | /// Returns the size of a tensor in bytes given the element type and number of elements. This includes the tensor's header. 132 | pub fn tensor_size(element_type: ElementType, n_elements: usize) -> usize { 133 | header_size() + data_size(element_type, n_elements) 134 | } 135 | 136 | #[derive(Debug, Clone)] 137 | /// Information present within GGML [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)) 138 | /// that is required to continue loading the model. 139 | pub struct PartialHyperparameters { 140 | /// The number of tokens in the model's embedded vocabulary. 141 | pub n_vocab: usize, 142 | } 143 | 144 | /// A handler for loading a GGML model. 145 | pub trait LoadHandler { 146 | /// Called when the [ContainerType] is read. 147 | fn container_type(&mut self, container_type: ContainerType) -> Result<(), E>; 148 | /// Called when a token is read so it can be added to the model's embedded vocabulary. 149 | fn vocabulary_token(&mut self, i: usize, token: Vec, score: f32) -> Result<(), E>; 150 | /// Called when the model's hyperparameters need to be read. 151 | fn read_hyperparameters( 152 | &mut self, 153 | reader: &mut dyn BufRead, 154 | ) -> Result; 155 | /// Called when a new [crate::Tensor] is read for the model. 156 | fn tensor_buffer(&mut self, info: TensorLoadInfo) -> Result<(), E>; 157 | } 158 | 159 | /// Load a GGML model from a `reader` with the [LoadHandler], which will be called when certain events occur. 160 | pub fn load( 161 | reader: &mut R, 162 | handler: &mut impl LoadHandler, 163 | ) -> Result<(), LoadError> { 164 | // Verify magic 165 | let container_type = ContainerType::read(reader)?; 166 | 167 | match container_type { 168 | ContainerType::Ggml 169 | | ContainerType::Ggmf(1) 170 | | ContainerType::Ggjt(1..=3) 171 | | ContainerType::Ggla(1) => {} 172 | _ => return Err(LoadError::InvalidFormatVersion(container_type)), 173 | } 174 | 175 | handler 176 | .container_type(container_type) 177 | .map_err(LoadError::ImplementationError)?; 178 | 179 | // Load hyper params 180 | let hparams = handler 181 | .read_hyperparameters(reader) 182 | .map_err(LoadError::ImplementationError)?; 183 | let n_vocab = hparams.n_vocab; 184 | 185 | // Load vocabulary 186 | for i in 0..n_vocab { 187 | let len = read_u32(reader)?.try_into()?; 188 | let token = read_bytes_with_len(reader, len)?; 189 | let token_score = match container_type { 190 | ContainerType::Ggmf(_version) | ContainerType::Ggjt(_version) => read_f32(reader)?, 191 | ContainerType::Ggml | ContainerType::Ggla(_) => { 192 | // Legacy model, set empty score 193 | 0. 194 | } 195 | }; 196 | handler 197 | .vocabulary_token(i, token, token_score) 198 | .map_err(LoadError::ImplementationError)?; 199 | } 200 | 201 | // Load tensor data 202 | match container_type { 203 | ContainerType::Ggmf(_) | ContainerType::Ggml => load_weights(reader, handler, false), 204 | ContainerType::Ggjt(_version) | ContainerType::Ggla(_version) => { 205 | load_weights(reader, handler, true) 206 | } 207 | } 208 | } 209 | 210 | /// # Params 211 | /// 212 | /// `align` 213 | /// align to 4 bytes before reading tensor weights 214 | fn load_weights( 215 | reader: &mut R, 216 | handler: &mut impl LoadHandler, 217 | align: bool, 218 | ) -> Result<(), LoadError> { 219 | while has_data_left(reader)? { 220 | // load tensor header 221 | let n_dims: usize = read_i32(reader)?.try_into()?; 222 | let name_len = read_i32(reader)?; 223 | let ftype = read_u32(reader)?; 224 | 225 | let mut n_elements: usize = 1; 226 | let mut dims = [1usize, 1]; 227 | let ne_len = dims.len(); 228 | if n_dims > ne_len { 229 | return Err(LoadError::InvariantBroken(format!("{n_dims} <= {ne_len}"))); 230 | } 231 | 232 | #[allow(clippy::needless_range_loop)] 233 | for i in 0..n_dims { 234 | let dim: usize = read_i32(reader)?.try_into()?; 235 | dims[i] = dim; 236 | n_elements *= dim; 237 | } 238 | 239 | // load tensor name 240 | let name = String::from_utf8(read_bytes_with_len(reader, name_len.try_into()?)?)?; 241 | let ftype = 242 | crate::Type::try_from(ftype).map_err(|_| LoadError::UnsupportedElementType { 243 | tensor_name: name.clone(), 244 | ftype, 245 | })?; 246 | 247 | // sanity check 248 | match ftype { 249 | ElementType::Q4_0 | ElementType::Q4_1 => { 250 | if dims[0] % 64 != 0 { 251 | return Err(LoadError::InvariantBroken(format!("{dims:?}[0] % 64 == 0"))); 252 | } 253 | } 254 | _ => {} 255 | } 256 | 257 | // load tensor weights 258 | let offset_curr = reader.stream_position()?; 259 | let offset_aligned: u64 = if align { 260 | (offset_curr + 31) & !31 261 | } else { 262 | offset_curr 263 | }; 264 | 265 | let tensor_info = TensorLoadInfo { 266 | name, 267 | dims, 268 | n_dims, 269 | n_elements, 270 | element_type: ftype, 271 | start_offset: offset_aligned, 272 | }; 273 | let n_bytes = tensor_info.calc_size(); 274 | handler 275 | .tensor_buffer(tensor_info) 276 | .map_err(LoadError::ImplementationError)?; 277 | reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?; 278 | } 279 | 280 | Ok(()) 281 | } 282 | -------------------------------------------------------------------------------- /crates/ggml/src/format/mod.rs: -------------------------------------------------------------------------------- 1 | //! Loading and saving of [GGML](https://github.com/ggerganov/ggml) files. 2 | 3 | mod loader; 4 | mod saver; 5 | 6 | pub use loader::*; 7 | pub use saver::*; 8 | -------------------------------------------------------------------------------- /crates/ggml/src/format/saver.rs: -------------------------------------------------------------------------------- 1 | //! The saver module implements a way to save a model to disk in the GGJT format. 2 | //! 3 | //! To implement a saver for your model, implement [SaveHandler] for your model 4 | //! and provide data as appropriate, then call [save] with an instance of 5 | //! your handler. 6 | 7 | use std::{ 8 | error::Error, 9 | io::{Seek, Write}, 10 | }; 11 | 12 | use crate::{util, ContainerType, ElementType}; 13 | 14 | #[derive(Debug, thiserror::Error)] 15 | /// Errors that can occur while writing a model. 16 | pub enum SaveError { 17 | #[error("non-specific I/O error")] 18 | /// A non-specific IO error. 19 | Io(#[from] std::io::Error), 20 | #[error("invalid integer conversion")] 21 | /// One of the integers encountered could not be converted to a more appropriate type. 22 | InvalidIntegerConversion(#[from] std::num::TryFromIntError), 23 | #[error("implementation error")] 24 | /// An error `E` was returned by the implementation of the loader. 25 | ImplementationError(#[source] E), 26 | #[error("invariant broken: {0}")] 27 | /// An invariant was broken. 28 | InvariantBroken(String), 29 | /// An attempt was made to save a model with a container type that does not 30 | /// support vocabulary scoring, despite the model having a scored vocabulary. 31 | #[error("container type does not support vocabulary scoring")] 32 | VocabularyScoringNotSupported, 33 | } 34 | 35 | /// A handler for saving a GGML model. 36 | pub trait SaveHandler { 37 | /// Called when the hyperparameters must be written. 38 | fn write_hyperparameters(&mut self, writer: &mut dyn Write) -> Result<(), E>; 39 | 40 | /// Called when information for a tensor is to be written. 41 | fn tensor_data(&mut self, tensor_name: &str) -> Result; 42 | } 43 | 44 | /// Information about a [tensor](https://en.wikipedia.org/wiki/Tensor_(machine_learning)) that is to be saved. 45 | #[derive(Clone, PartialEq, Debug)] 46 | pub struct TensorSaveInfo { 47 | /// The number of dimensions in the tensor. 48 | pub n_dims: usize, 49 | /// The dimensions of the tensor. 50 | pub dims: [usize; 2], 51 | /// The type of the elements in the tensor. 52 | pub element_type: ElementType, 53 | /// The data to save to disk. 54 | // TODO: This can be done more efficiently by borrowing the data, but 55 | // I wanted to avoid the lifetime parameter for now, especially as 56 | // the naive solution would borrow `TensorData` for the lifetime of the 57 | // handler, which is obviously not ideal if you're trying to transcode 58 | // an existing file tensor-by-tensor. 59 | pub data: Vec, 60 | } 61 | 62 | /// The container of the model to save. 63 | /// 64 | /// This is separate from [ContainerType] to ensure that the user 65 | /// does not accidentally use an unsupported container type. 66 | #[derive(Clone, Copy, PartialEq, Debug)] 67 | pub enum SaveContainerType { 68 | /// The GGML container. 69 | Ggml, 70 | /// The GGJT container. 71 | GgjtV3, 72 | } 73 | impl From for ContainerType { 74 | fn from(value: SaveContainerType) -> Self { 75 | match value { 76 | SaveContainerType::Ggml => ContainerType::Ggml, 77 | SaveContainerType::GgjtV3 => ContainerType::Ggjt(3), 78 | } 79 | } 80 | } 81 | 82 | /// Saves a model to the given writer. 83 | /// 84 | /// Only GGML and GGJT version 2 are supported. If using GGML, 85 | /// the vocabulary *must* have scores of 0.0. 86 | pub fn save( 87 | writer: &mut W, 88 | handler: &mut dyn SaveHandler, 89 | container_type: SaveContainerType, 90 | vocabulary: &[(Vec, f32)], 91 | tensor_names: &[String], 92 | ) -> Result<(), SaveError> { 93 | // Write header and hyperparameters 94 | ContainerType::from(container_type).write(writer)?; 95 | 96 | if container_type == SaveContainerType::Ggml 97 | && vocabulary.iter().any(|(_, score)| *score != 0.0) 98 | { 99 | return Err(SaveError::VocabularyScoringNotSupported); 100 | } 101 | 102 | handler 103 | .write_hyperparameters(writer) 104 | .map_err(SaveError::ImplementationError)?; 105 | 106 | // Write vocabulary 107 | for (token, score) in vocabulary { 108 | util::write_u32(writer, token.len().try_into()?)?; 109 | writer.write_all(token)?; 110 | 111 | if container_type != SaveContainerType::Ggml { 112 | util::write_f32(writer, *score)?; 113 | } 114 | } 115 | 116 | // Write tensors 117 | for name in tensor_names { 118 | let TensorSaveInfo { 119 | n_dims, 120 | dims, 121 | element_type, 122 | data, 123 | } = handler 124 | .tensor_data(name) 125 | .map_err(SaveError::ImplementationError)?; 126 | 127 | match element_type { 128 | ElementType::Q4_0 | ElementType::Q4_1 => { 129 | if dims[0] % 64 != 0 { 130 | return Err(SaveError::InvariantBroken(format!("{dims:?}[0] % 64 == 0"))); 131 | } 132 | } 133 | _ => {} 134 | } 135 | 136 | // Write tensor header 137 | util::write_i32(writer, n_dims.try_into()?)?; 138 | util::write_i32(writer, name.len().try_into()?)?; 139 | util::write_u32(writer, element_type.into())?; 140 | for &dim in &dims[0..n_dims] { 141 | util::write_i32(writer, dim.try_into()?)?; 142 | } 143 | 144 | // Write tensor name 145 | writer.write_all(name.as_bytes())?; 146 | 147 | // Align to nearest 32 bytes 148 | if container_type != SaveContainerType::Ggml { 149 | let offset_curr = writer.stream_position()?; 150 | let offset_aligned = (offset_curr + 31) & !31; 151 | let padding = usize::try_from(offset_aligned - offset_curr)?; 152 | writer.write_all(&vec![0; padding])?; 153 | } 154 | 155 | // Write tensor data 156 | writer.write_all(&data)?; 157 | } 158 | 159 | Ok(()) 160 | } 161 | -------------------------------------------------------------------------------- /crates/ggml/src/tensor.rs: -------------------------------------------------------------------------------- 1 | use std::{os::raw::c_void, ptr::NonNull, sync::Weak}; 2 | 3 | use crate::{ 4 | accelerator::Backend, context::ContextInner, i64_to_usize, sys, Type, MAX_NAME_LENGTH, 5 | }; 6 | 7 | /// Tensors are owned by the context. A tensor is alive as long as the 8 | /// underlying context it was created with is alive. 9 | pub struct Tensor { 10 | pub(crate) ptr: NonNull, 11 | pub(crate) inner: Weak, 12 | } 13 | 14 | impl Tensor { 15 | /// Size of the `ggml_tensor` struct in bytes. 16 | /// 17 | /// Exposed for purposes of determining context size. 18 | pub const C_TYPE_SIZE: usize = std::mem::size_of::(); 19 | 20 | /// Sets the name of the tensor. 21 | /// 22 | /// # Safety 23 | /// 24 | /// The name must be a valid UTF-8 string and must not be longer than [`MAX_NAME_LENGTH`] bytes. 25 | pub fn set_name(mut self, name: &str) -> Tensor { 26 | assert!( 27 | name.len() <= MAX_NAME_LENGTH, 28 | "Tensor name must be less than {} bytes", 29 | MAX_NAME_LENGTH 30 | ); 31 | 32 | let c_name = std::ffi::CString::new(name).unwrap(); 33 | self.with_alive_ctx_mut(|t| unsafe { sys::ggml_set_name(t.ptr.as_ptr(), c_name.as_ptr()) }); 34 | self 35 | } 36 | 37 | /// Gets the name of the tensor 38 | pub fn name(&self) -> String { 39 | self.with_alive_ctx(|| { 40 | let name_ptr = unsafe { sys::ggml_get_name(self.ptr.as_ptr()) }; 41 | let name = unsafe { std::ffi::CStr::from_ptr(name_ptr) }; 42 | name.to_string_lossy().into_owned() 43 | }) 44 | } 45 | 46 | /// Gets the acceleration backend of the tensor 47 | pub fn backend(&self) -> Backend { 48 | self.with_alive_ctx(|| unsafe { 49 | (self.ptr.as_ref().backend as sys::ggml_backend) 50 | .try_into() 51 | .unwrap() 52 | }) 53 | } 54 | 55 | /// Sets the tensor's acceleration backend and moves the tensor's data to the new backend. 56 | pub fn transfer_to(mut self, backend: Backend) -> Tensor { 57 | self.with_alive_ctx_mut(|t| { 58 | let current_backend = t.backend(); 59 | 60 | if current_backend != Backend::Cpu && backend == Backend::Cpu { 61 | unimplemented!("Tensors cannot be moved from an accelerator to the CPU at present"); 62 | } 63 | if backend == Backend::Cpu { 64 | return; 65 | } 66 | t.set_backend(backend); 67 | 68 | #[cfg(feature = "cublas")] 69 | unsafe { 70 | sys::cuda::ggml_cuda_transform_tensor(t.data(), t.ptr.as_ptr()); 71 | } 72 | #[cfg(feature = "clblast")] 73 | unsafe { 74 | sys::opencl::ggml_cl_transform_tensor(t.data(), t.ptr.as_ptr()); 75 | } 76 | 77 | t.mark_as_offloaded(); 78 | }); 79 | self 80 | } 81 | 82 | /// If ggml-sys is compiled with CUDA support, this function will offload the tensor to the GPU. 83 | /// If not, this is a no-op. 84 | /// 85 | /// It will not transfer the data. Use `transfer_to` for that. 86 | #[allow(unused_variables)] 87 | pub fn offload(&self) { 88 | self.with_alive_ctx(|| { 89 | #[cfg(feature = "cublas")] 90 | unsafe { 91 | sys::cuda::ggml_cuda_assign_buffers(self.ptr.as_ptr()); 92 | } 93 | }) 94 | } 95 | 96 | /// If ggml-sys is compiled with CUDA support, this function will offload the tensor to the GPU without using the scratch buffer. 97 | /// If not, this is a no-op. 98 | /// 99 | /// It will not transfer the data. Use `transfer_to` for that. 100 | /// 101 | /// Unlike `offload`, this function will add the tensor to the offloaded tensors map. This is because the non-use of a scratch buffer 102 | /// allows us to safely assume that this tensor will actually point to data. 103 | #[allow(unused_variables)] 104 | pub fn offload_no_scratch(&self) { 105 | self.with_alive_ctx(|| { 106 | #[cfg(feature = "cublas")] 107 | unsafe { 108 | sys::cuda::ggml_cuda_assign_buffers_no_scratch(self.ptr.as_ptr()); 109 | } 110 | self.mark_as_offloaded(); 111 | }) 112 | } 113 | 114 | /// Creates a shared copy of this tensor pointer. 115 | pub fn share(&self) -> Self { 116 | Tensor { 117 | ptr: self.ptr, 118 | inner: Weak::clone(&self.inner), 119 | } 120 | } 121 | 122 | /// Number of bytes used by this tensor. 123 | pub fn nbytes(&self) -> usize { 124 | self.with_alive_ctx(|| { 125 | // SAFETY: The with_alive_call guarantees the context is alive 126 | unsafe { sys::ggml_nbytes(self.ptr.as_ptr()) } 127 | }) 128 | } 129 | 130 | /// Provides raw mutable access to the data contained within the tensor. 131 | /// 132 | /// # Safety 133 | /// 134 | /// Only `std::slice::from_raw_parts_mut(tensor.data(), tensor.nbytes())` is safe to mutate. 135 | pub unsafe fn data(&mut self) -> *mut c_void { 136 | self.with_alive_ctx(|| { 137 | // SAFETY: The with_alive_call guarantees the context is alive 138 | unsafe { *self.ptr.as_ptr() }.data 139 | }) 140 | } 141 | 142 | /// Set the tensor's data pointer (useful for mmap-ed data) 143 | /// 144 | /// # Safety 145 | /// 146 | /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. 147 | pub unsafe fn set_data(&mut self, data_ptr: *mut c_void) { 148 | self.with_alive_ctx_mut(|t| { 149 | let tensor = t.ptr.as_mut(); 150 | // SAFETY: The with_alive_call guarantees the context is alive 151 | tensor.data = data_ptr; 152 | }) 153 | } 154 | 155 | /// Number of elements in this tensor. 156 | pub fn nelements(&self) -> usize { 157 | self.with_alive_ctx(|| { 158 | // SAFETY: The with_alive_call guarantees the context is alive 159 | i64_to_usize(unsafe { sys::ggml_nelements(self.ptr.as_ptr()) }) 160 | }) 161 | } 162 | 163 | /// Number of elements in each dimension. 164 | pub fn get_ne(&self) -> [i64; 4] { 165 | self.with_alive_ctx(|| unsafe { *self.ptr.as_ptr() }.ne) 166 | } 167 | 168 | /// Stride of each dimension. 169 | pub fn get_nb(&self) -> [usize; 4] { 170 | self.with_alive_ctx(|| unsafe { *self.ptr.as_ptr() }.nb) 171 | } 172 | 173 | /// The data type. 174 | pub fn get_type(&self) -> Type { 175 | self.with_alive_ctx(|| unsafe { *self.ptr.as_ptr() }.type_.try_into().unwrap()) 176 | } 177 | 178 | /// The size of the element type in bytes. 179 | pub fn element_size(&self) -> usize { 180 | self.with_alive_ctx(|| unsafe { sys::ggml_element_size(self.ptr.as_ptr()) }) 181 | } 182 | 183 | /// Writes `src` to this tensor. 184 | /// 185 | /// # Safety 186 | /// 187 | /// This tensor must not be written to or read by from any other code. 188 | pub unsafe fn write_data(&mut self, src: &[u8]) { 189 | std::ptr::copy_nonoverlapping(src.as_ptr(), self.data() as *mut u8, src.len()) 190 | } 191 | 192 | /// Zeroes out this tensor. 193 | pub fn zero_data(&mut self) { 194 | unsafe { std::ptr::write_bytes(self.data() as *mut u8, 0, self.nbytes()) } 195 | } 196 | 197 | /// Reads this tensor into `dst`, starting from `offset`. The size of `dst` 198 | /// will be used to determine how many bytes to read. 199 | /// 200 | /// # Safety 201 | /// 202 | /// This tensor must not be written to or read by from any other code. 203 | pub unsafe fn read_data(&self, offset: usize, dst: &mut [u8]) { 204 | let data = unsafe { sys::ggml_get_data(self.ptr.as_ptr()).add(offset) }; 205 | std::ptr::copy_nonoverlapping(data, dst as *mut _ as _, dst.len()) 206 | } 207 | 208 | /// Frees the memory of a tensor on an accelerator if ggml-sys is compiled with CUDA or CLBlast support. 209 | /// If not, this is a no-op. 210 | /// 211 | /// This is temporary while GGML improves their context memory management. This should only be called by 212 | /// `Context` when it is dropped. 213 | pub(crate) fn free_accelerator(self) { 214 | #[cfg(feature = "cublas")] 215 | unsafe { 216 | sys::cuda::ggml_cuda_free_data(self.ptr.as_ptr()); 217 | } 218 | #[cfg(feature = "clblast")] 219 | unsafe { 220 | sys::opencl::ggml_cl_free_data(self.ptr.as_ptr()); 221 | } 222 | } 223 | 224 | /// Returns true if this tensor is stored contiguously in memory 225 | pub fn is_contiguous(&self) -> bool { 226 | unsafe { sys::ggml_is_contiguous(self.ptr.as_ptr()) } 227 | } 228 | } 229 | impl Tensor { 230 | fn with_alive_ctx(&self, mut f: impl FnMut() -> U) -> U { 231 | let _ctx = self 232 | .inner 233 | .upgrade() 234 | .expect("Using a tensor after the context was dropped"); 235 | f() 236 | } 237 | 238 | fn with_alive_ctx_mut(&mut self, mut f: impl FnMut(&mut Tensor) -> U) -> U { 239 | let _ctx = self 240 | .inner 241 | .upgrade() 242 | .expect("Using a tensor after the context was dropped"); 243 | f(self) 244 | } 245 | 246 | /// Sets the acceleration backend of the tensor. 247 | /// 248 | /// # Caution 249 | /// 250 | /// This will not move the data to the new backend! See [Tensor::transfer_to] if you want to move the data to the new backend. 251 | fn set_backend(&mut self, backend: Backend) { 252 | unsafe { 253 | self.ptr.as_mut().backend = backend.try_into().unwrap(); 254 | } 255 | } 256 | 257 | /// Adds this tensor to the context's list of offloaded tensors, so that it will be automatically freed. 258 | fn mark_as_offloaded(&self) { 259 | self.inner 260 | .upgrade() 261 | .expect("Attempted to update a dropped context's offloaded tensors") 262 | .offloaded_tensors 263 | .lock() 264 | .unwrap() 265 | .insert(self.name(), self.share()); 266 | } 267 | } 268 | -------------------------------------------------------------------------------- /crates/ggml/src/tests.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::BTreeMap, 3 | error::Error, 4 | io::{BufRead, Write}, 5 | }; 6 | 7 | use crate::*; 8 | use rand::{distributions::Uniform, prelude::*}; 9 | 10 | #[derive(Debug)] 11 | struct DummyError; 12 | impl std::fmt::Display for DummyError { 13 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 14 | std::fmt::Debug::fmt(&self, f) 15 | } 16 | } 17 | impl Error for DummyError {} 18 | 19 | #[test] 20 | fn can_roundtrip_loader_and_saver_ggml() { 21 | let tokenizer = vec![ 22 | ("blazingly".as_bytes().to_vec(), 0.0), 23 | ("fast".as_bytes().to_vec(), 0.0), 24 | ("memory".as_bytes().to_vec(), 0.0), 25 | ("efficient".as_bytes().to_vec(), 0.0), 26 | ]; 27 | 28 | roundtrip_test(format::SaveContainerType::Ggml, tokenizer).unwrap(); 29 | } 30 | 31 | #[test] 32 | fn will_fail_on_scored_ggml_save() { 33 | let tokenizer = vec![ 34 | ("blazingly".as_bytes().to_vec(), 0.1), 35 | ("fast".as_bytes().to_vec(), 0.2), 36 | ("memory".as_bytes().to_vec(), 0.3), 37 | ("efficient".as_bytes().to_vec(), 0.4), 38 | ]; 39 | 40 | assert_eq!( 41 | roundtrip_test(format::SaveContainerType::Ggml, tokenizer) 42 | .unwrap_err() 43 | .to_string(), 44 | format::SaveError::::VocabularyScoringNotSupported.to_string() 45 | ); 46 | } 47 | 48 | #[test] 49 | fn can_roundtrip_loader_and_saver_ggjt_v3() { 50 | let tokenizer = vec![ 51 | ("blazingly".as_bytes().to_vec(), 0.1), 52 | ("fast".as_bytes().to_vec(), 0.2), 53 | ("memory".as_bytes().to_vec(), 0.3), 54 | ("efficient".as_bytes().to_vec(), 0.4), 55 | ]; 56 | 57 | roundtrip_test(format::SaveContainerType::GgjtV3, tokenizer).unwrap(); 58 | } 59 | 60 | fn roundtrip_test( 61 | save_container_type: format::SaveContainerType, 62 | tokenizer: Vec<(Vec, f32)>, 63 | ) -> anyhow::Result<()> { 64 | let mut rng = rand::thread_rng(); 65 | let element_type = crate::Type::F16; 66 | let model = Model { 67 | hyperparameters: Hyperparameters { 68 | some_hyperparameter: random(), 69 | some_other_hyperparameter: random(), 70 | tokenizer_size: tokenizer.len().try_into()?, 71 | }, 72 | tokenizer, 73 | tensors: (0..10) 74 | .map(|i| { 75 | let n_dims = Uniform::from(1..3).sample(&mut rng); 76 | let dims = (0..n_dims) 77 | .map(|_| Uniform::from(1..10).sample(&mut rng)) 78 | .chain(std::iter::repeat(1).take(2 - n_dims)) 79 | .collect::>(); 80 | 81 | let n_elements = dims.iter().product::(); 82 | let data = (0..format::data_size(element_type, n_elements)) 83 | .map(|_| random()) 84 | .collect::>(); 85 | 86 | ( 87 | format!("tensor_{}", i), 88 | format::TensorSaveInfo { 89 | n_dims, 90 | dims: dims.try_into().unwrap(), 91 | element_type, 92 | data, 93 | }, 94 | ) 95 | }) 96 | .collect(), 97 | }; 98 | 99 | // Save the model. 100 | let mut buffer = Vec::new(); 101 | let mut cursor = std::io::Cursor::new(&mut buffer); 102 | let mut save_handler = MockSaveHandler { model: &model }; 103 | format::save( 104 | &mut cursor, 105 | &mut save_handler, 106 | save_container_type, 107 | &model.tokenizer, 108 | &model.tensors.keys().cloned().collect::>(), 109 | )?; 110 | 111 | // Load the model and confirm that it is the same as the original. 112 | let mut cursor = std::io::Cursor::new(&buffer); 113 | let mut load_handler = MockLoadHandler { 114 | data: &buffer, 115 | loaded_model: Model::default(), 116 | expected_container_type: save_container_type.into(), 117 | }; 118 | format::load(&mut cursor, &mut load_handler)?; 119 | assert_eq!(load_handler.loaded_model, model); 120 | 121 | Ok(()) 122 | } 123 | 124 | #[derive(Default, PartialEq, Debug)] 125 | struct Hyperparameters { 126 | some_hyperparameter: u32, 127 | some_other_hyperparameter: u32, 128 | tokenizer_size: u32, 129 | } 130 | impl Hyperparameters { 131 | fn read(reader: &mut dyn BufRead) -> Result { 132 | Ok(Self { 133 | some_hyperparameter: util::read_u32(reader)?, 134 | some_other_hyperparameter: util::read_u32(reader)?, 135 | tokenizer_size: util::read_u32(reader)?, 136 | }) 137 | } 138 | 139 | fn write(&self, writer: &mut dyn Write) -> Result<(), std::io::Error> { 140 | util::write_u32(writer, self.some_hyperparameter)?; 141 | util::write_u32(writer, self.some_other_hyperparameter)?; 142 | util::write_u32(writer, self.tokenizer_size)?; 143 | Ok(()) 144 | } 145 | } 146 | 147 | #[derive(Default, PartialEq, Debug)] 148 | struct Model { 149 | hyperparameters: Hyperparameters, 150 | tokenizer: Vec<(Vec, f32)>, 151 | tensors: BTreeMap, 152 | } 153 | 154 | struct MockSaveHandler<'a> { 155 | model: &'a Model, 156 | } 157 | impl format::SaveHandler for MockSaveHandler<'_> { 158 | fn write_hyperparameters(&mut self, writer: &mut dyn Write) -> Result<(), DummyError> { 159 | self.model.hyperparameters.write(writer).unwrap(); 160 | Ok(()) 161 | } 162 | 163 | fn tensor_data(&mut self, tensor_name: &str) -> Result { 164 | self.model 165 | .tensors 166 | .get(tensor_name) 167 | .cloned() 168 | .ok_or(DummyError) 169 | } 170 | } 171 | 172 | struct MockLoadHandler<'a> { 173 | data: &'a [u8], 174 | loaded_model: Model, 175 | expected_container_type: ContainerType, 176 | } 177 | impl format::LoadHandler for MockLoadHandler<'_> { 178 | fn container_type(&mut self, container_type: ContainerType) -> Result<(), DummyError> { 179 | assert_eq!(container_type, self.expected_container_type); 180 | Ok(()) 181 | } 182 | 183 | fn vocabulary_token(&mut self, i: usize, token: Vec, score: f32) -> Result<(), DummyError> { 184 | assert_eq!(i, self.loaded_model.tokenizer.len()); 185 | self.loaded_model.tokenizer.push((token, score)); 186 | Ok(()) 187 | } 188 | 189 | fn read_hyperparameters( 190 | &mut self, 191 | reader: &mut dyn BufRead, 192 | ) -> Result { 193 | self.loaded_model.hyperparameters = Hyperparameters::read(reader).unwrap(); 194 | Ok(format::PartialHyperparameters { 195 | n_vocab: self 196 | .loaded_model 197 | .hyperparameters 198 | .tokenizer_size 199 | .try_into() 200 | .unwrap(), 201 | }) 202 | } 203 | 204 | fn tensor_buffer(&mut self, info: format::TensorLoadInfo) -> Result<(), DummyError> { 205 | let data = format::TensorSaveInfo { 206 | n_dims: info.n_dims, 207 | dims: info.dims, 208 | element_type: info.element_type, 209 | data: info 210 | .read_data(&mut std::io::Cursor::new(self.data)) 211 | .unwrap(), 212 | }; 213 | self.loaded_model.tensors.insert(info.name, data); 214 | Ok(()) 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /crates/ggml/src/util.rs: -------------------------------------------------------------------------------- 1 | //! Utilities for reading and writing. 2 | 3 | use std::io::{BufRead, Write}; 4 | 5 | /// Read a fixed-size array of bytes from a reader. 6 | pub fn read_bytes(reader: &mut dyn BufRead) -> Result<[u8; N], std::io::Error> { 7 | let mut bytes = [0u8; N]; 8 | reader.read_exact(&mut bytes)?; 9 | Ok(bytes) 10 | } 11 | 12 | /// Read a `i32` from a reader. 13 | pub fn read_i32(reader: &mut dyn BufRead) -> Result { 14 | Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) 15 | } 16 | 17 | /// Read a `u32` from a reader. 18 | pub fn read_u32(reader: &mut dyn BufRead) -> Result { 19 | Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) 20 | } 21 | 22 | /// Read a `f32` from a reader. 23 | pub fn read_f32(reader: &mut dyn BufRead) -> Result { 24 | Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) 25 | } 26 | 27 | /// Read a `bool` represented as an `i32` from a reader. 28 | pub fn read_bool(reader: &mut dyn BufRead) -> Result { 29 | let val = i32::from_le_bytes(read_bytes::<4>(reader)?); 30 | match val { 31 | 0 => Ok(false), 32 | 1 => Ok(true), 33 | _ => Err(std::io::Error::new( 34 | std::io::ErrorKind::InvalidData, 35 | format!("Invalid i32 value for bool: '{}'", val), 36 | )), 37 | } 38 | } 39 | 40 | /// Read a variable-length array of bytes from a reader. 41 | pub fn read_bytes_with_len( 42 | reader: &mut dyn BufRead, 43 | len: usize, 44 | ) -> Result, std::io::Error> { 45 | let mut bytes = vec![0u8; len]; 46 | reader.read_exact(&mut bytes)?; 47 | Ok(bytes) 48 | } 49 | 50 | /// Write a `i32` from a writer. 51 | pub fn write_i32(writer: &mut dyn Write, value: i32) -> Result<(), std::io::Error> { 52 | writer.write_all(&value.to_le_bytes()) 53 | } 54 | 55 | /// Write a `u32` from a writer. 56 | pub fn write_u32(writer: &mut dyn Write, value: u32) -> Result<(), std::io::Error> { 57 | writer.write_all(&value.to_le_bytes()) 58 | } 59 | 60 | /// Write a `f32` from a writer. 61 | pub fn write_f32(writer: &mut dyn Write, value: f32) -> Result<(), std::io::Error> { 62 | writer.write_all(&value.to_le_bytes()) 63 | } 64 | 65 | /// Write a `bool` represented as an `i32` to a writer. 66 | pub fn write_bool(writer: &mut dyn Write, value: bool) -> Result<(), std::io::Error> { 67 | let int_value: i32 = if value { 1 } else { 0 }; 68 | writer.write_all(&int_value.to_le_bytes()) 69 | } 70 | 71 | // NOTE: Implementation from #![feature(buf_read_has_data_left)] 72 | /// Check if there is any data left in the reader. 73 | pub fn has_data_left(reader: &mut impl BufRead) -> Result { 74 | reader.fill_buf().map(|b| !b.is_empty()) 75 | } 76 | -------------------------------------------------------------------------------- /crates/ggml/sys/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "ggml-sys" 3 | version = "0.2.0-dev" 4 | repository = { workspace = true } 5 | edition = "2021" 6 | description = "Raw bindings (i.e. bindgen output) for the ggml library." 7 | license = "MIT" 8 | 9 | [build-dependencies] 10 | cc = "^1.0" 11 | 12 | [features] 13 | cublas = [] 14 | clblast = [] 15 | metal = [] 16 | -------------------------------------------------------------------------------- /crates/ggml/sys/UPDATING.md: -------------------------------------------------------------------------------- 1 | # Updating 2 | 3 | The bindings are automatically generated. See [CONTRIBUTING](../../../CONTRIBUTING.md#regenerating-the-ggml-bindings). 4 | -------------------------------------------------------------------------------- /crates/ggml/sys/src/cuda.rs: -------------------------------------------------------------------------------- 1 | /* automatically generated by rust-bindgen 0.65.1 */ 2 | 3 | use super::ggml_compute_params; 4 | use super::ggml_tensor; 5 | 6 | pub const GGML_CUDA_MAX_DEVICES: u32 = 16; 7 | extern "C" { 8 | pub fn ggml_init_cublas(); 9 | } 10 | extern "C" { 11 | pub fn ggml_cuda_set_tensor_split(tensor_split: *const f32); 12 | } 13 | extern "C" { 14 | pub fn ggml_cuda_mul(src0: *const ggml_tensor, src1: *const ggml_tensor, dst: *mut ggml_tensor); 15 | } 16 | extern "C" { 17 | pub fn ggml_cuda_can_mul_mat( 18 | src0: *const ggml_tensor, 19 | src1: *const ggml_tensor, 20 | dst: *mut ggml_tensor, 21 | ) -> bool; 22 | } 23 | extern "C" { 24 | pub fn ggml_cuda_mul_mat_get_wsize( 25 | src0: *const ggml_tensor, 26 | src1: *const ggml_tensor, 27 | dst: *mut ggml_tensor, 28 | ) -> usize; 29 | } 30 | extern "C" { 31 | pub fn ggml_cuda_mul_mat( 32 | src0: *const ggml_tensor, 33 | src1: *const ggml_tensor, 34 | dst: *mut ggml_tensor, 35 | wdata: *mut ::std::os::raw::c_void, 36 | wsize: usize, 37 | ); 38 | } 39 | extern "C" { 40 | pub fn ggml_cuda_host_malloc(size: usize) -> *mut ::std::os::raw::c_void; 41 | } 42 | extern "C" { 43 | pub fn ggml_cuda_host_free(ptr: *mut ::std::os::raw::c_void); 44 | } 45 | extern "C" { 46 | pub fn ggml_cuda_transform_tensor(data: *mut ::std::os::raw::c_void, tensor: *mut ggml_tensor); 47 | } 48 | extern "C" { 49 | pub fn ggml_cuda_free_data(tensor: *mut ggml_tensor); 50 | } 51 | extern "C" { 52 | pub fn ggml_cuda_assign_buffers(tensor: *mut ggml_tensor); 53 | } 54 | extern "C" { 55 | pub fn ggml_cuda_assign_buffers_no_scratch(tensor: *mut ggml_tensor); 56 | } 57 | extern "C" { 58 | pub fn ggml_cuda_assign_buffers_force_inplace(tensor: *mut ggml_tensor); 59 | } 60 | extern "C" { 61 | pub fn ggml_cuda_set_main_device(main_device: ::std::os::raw::c_int); 62 | } 63 | extern "C" { 64 | pub fn ggml_cuda_set_mul_mat_q(mul_mat_q: bool); 65 | } 66 | extern "C" { 67 | pub fn ggml_cuda_set_scratch_size(scratch_size: usize); 68 | } 69 | extern "C" { 70 | pub fn ggml_cuda_free_scratch(); 71 | } 72 | extern "C" { 73 | pub fn ggml_cuda_compute_forward( 74 | params: *mut ggml_compute_params, 75 | tensor: *mut ggml_tensor, 76 | ) -> bool; 77 | } 78 | -------------------------------------------------------------------------------- /crates/ggml/sys/src/llama.rs: -------------------------------------------------------------------------------- 1 | /* automatically generated by rust-bindgen 0.65.1 */ 2 | 3 | pub const LLAMA_MAX_DEVICES: u32 = 1; 4 | pub const LLAMA_FILE_MAGIC_GGJT: u32 = 1734830708; 5 | pub const LLAMA_FILE_MAGIC_GGLA: u32 = 1734831201; 6 | pub const LLAMA_FILE_MAGIC_GGMF: u32 = 1734831462; 7 | pub const LLAMA_FILE_MAGIC_GGML: u32 = 1734831468; 8 | pub const LLAMA_FILE_MAGIC_GGSN: u32 = 1734833006; 9 | pub const LLAMA_FILE_VERSION: u32 = 3; 10 | pub const LLAMA_FILE_MAGIC: u32 = 1734830708; 11 | pub const LLAMA_FILE_MAGIC_UNVERSIONED: u32 = 1734831468; 12 | pub const LLAMA_SESSION_MAGIC: u32 = 1734833006; 13 | pub const LLAMA_SESSION_VERSION: u32 = 1; 14 | pub const LLAMA_DEFAULT_SEED: u32 = 4294967295; 15 | pub const LLAMA_DEFAULT_RMS_EPS: f64 = 0.000005; 16 | pub const LLAMA_FTYPE_ALL_F32: llama_ftype = 0; 17 | pub const LLAMA_FTYPE_MOSTLY_F16: llama_ftype = 1; 18 | pub const LLAMA_FTYPE_MOSTLY_Q4_0: llama_ftype = 2; 19 | pub const LLAMA_FTYPE_MOSTLY_Q4_1: llama_ftype = 3; 20 | pub const LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16: llama_ftype = 4; 21 | pub const LLAMA_FTYPE_MOSTLY_Q8_0: llama_ftype = 7; 22 | pub const LLAMA_FTYPE_MOSTLY_Q5_0: llama_ftype = 8; 23 | pub const LLAMA_FTYPE_MOSTLY_Q5_1: llama_ftype = 9; 24 | pub const LLAMA_FTYPE_MOSTLY_Q2_K: llama_ftype = 10; 25 | pub const LLAMA_FTYPE_MOSTLY_Q3_K_S: llama_ftype = 11; 26 | pub const LLAMA_FTYPE_MOSTLY_Q3_K_M: llama_ftype = 12; 27 | pub const LLAMA_FTYPE_MOSTLY_Q3_K_L: llama_ftype = 13; 28 | pub const LLAMA_FTYPE_MOSTLY_Q4_K_S: llama_ftype = 14; 29 | pub const LLAMA_FTYPE_MOSTLY_Q4_K_M: llama_ftype = 15; 30 | pub const LLAMA_FTYPE_MOSTLY_Q5_K_S: llama_ftype = 16; 31 | pub const LLAMA_FTYPE_MOSTLY_Q5_K_M: llama_ftype = 17; 32 | pub const LLAMA_FTYPE_MOSTLY_Q6_K: llama_ftype = 18; 33 | pub type llama_ftype = ::std::os::raw::c_int; 34 | -------------------------------------------------------------------------------- /crates/ggml/sys/src/metal.rs: -------------------------------------------------------------------------------- 1 | /* automatically generated by rust-bindgen 0.65.1 */ 2 | 3 | pub const GGML_METAL_MAX_BUFFERS: u32 = 16; 4 | #[repr(C)] 5 | #[derive(Debug, Copy, Clone)] 6 | pub struct ggml_tensor { 7 | _unused: [u8; 0], 8 | } 9 | #[repr(C)] 10 | #[derive(Debug, Copy, Clone)] 11 | pub struct ggml_cgraph { 12 | _unused: [u8; 0], 13 | } 14 | #[repr(C)] 15 | #[derive(Debug, Copy, Clone)] 16 | pub struct ggml_metal_context { 17 | _unused: [u8; 0], 18 | } 19 | extern "C" { 20 | pub fn ggml_metal_init(n_cb: ::std::os::raw::c_int) -> *mut ggml_metal_context; 21 | } 22 | extern "C" { 23 | pub fn ggml_metal_free(ctx: *mut ggml_metal_context); 24 | } 25 | extern "C" { 26 | pub fn ggml_metal_set_n_cb(ctx: *mut ggml_metal_context, n_cb: ::std::os::raw::c_int); 27 | } 28 | extern "C" { 29 | pub fn ggml_metal_add_buffer( 30 | ctx: *mut ggml_metal_context, 31 | name: *const ::std::os::raw::c_char, 32 | data: *mut ::std::os::raw::c_void, 33 | size: usize, 34 | max_size: usize, 35 | ) -> bool; 36 | } 37 | extern "C" { 38 | pub fn ggml_metal_set_tensor(ctx: *mut ggml_metal_context, t: *mut ggml_tensor); 39 | } 40 | extern "C" { 41 | pub fn ggml_metal_get_tensor(ctx: *mut ggml_metal_context, t: *mut ggml_tensor); 42 | } 43 | extern "C" { 44 | pub fn ggml_metal_graph_find_concurrency(ctx: *mut ggml_metal_context, gf: *mut ggml_cgraph); 45 | } 46 | extern "C" { 47 | pub fn ggml_metal_if_optimized(ctx: *mut ggml_metal_context) -> bool; 48 | } 49 | extern "C" { 50 | pub fn ggml_metal_graph_compute(ctx: *mut ggml_metal_context, gf: *mut ggml_cgraph); 51 | } 52 | -------------------------------------------------------------------------------- /crates/ggml/sys/src/opencl.rs: -------------------------------------------------------------------------------- 1 | /* automatically generated by rust-bindgen 0.65.1 */ 2 | 3 | use super::ggml_tensor; 4 | 5 | extern "C" { 6 | pub fn ggml_cl_init(); 7 | } 8 | extern "C" { 9 | pub fn ggml_cl_mul(src0: *const ggml_tensor, src1: *const ggml_tensor, dst: *mut ggml_tensor); 10 | } 11 | extern "C" { 12 | pub fn ggml_cl_can_mul_mat( 13 | src0: *const ggml_tensor, 14 | src1: *const ggml_tensor, 15 | dst: *mut ggml_tensor, 16 | ) -> bool; 17 | } 18 | extern "C" { 19 | pub fn ggml_cl_mul_mat_get_wsize( 20 | src0: *const ggml_tensor, 21 | src1: *const ggml_tensor, 22 | dst: *mut ggml_tensor, 23 | ) -> usize; 24 | } 25 | extern "C" { 26 | pub fn ggml_cl_mul_mat( 27 | src0: *const ggml_tensor, 28 | src1: *const ggml_tensor, 29 | dst: *mut ggml_tensor, 30 | wdata: *mut ::std::os::raw::c_void, 31 | wsize: usize, 32 | ); 33 | } 34 | extern "C" { 35 | pub fn ggml_cl_host_malloc(size: usize) -> *mut ::std::os::raw::c_void; 36 | } 37 | extern "C" { 38 | pub fn ggml_cl_host_free(ptr: *mut ::std::os::raw::c_void); 39 | } 40 | extern "C" { 41 | pub fn ggml_cl_free_data(tensor: *const ggml_tensor); 42 | } 43 | extern "C" { 44 | pub fn ggml_cl_transform_tensor(data: *mut ::std::os::raw::c_void, tensor: *mut ggml_tensor); 45 | } 46 | -------------------------------------------------------------------------------- /crates/llm-base/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llm-base" 3 | version = "0.2.0-dev" 4 | license = { workspace = true } 5 | repository = { workspace = true } 6 | description = "The base for `llm`; provides common structure for model implementations. Not intended for use by end-users." 7 | edition = "2021" 8 | rust-version = "1.65" 9 | readme = "../../README.md" 10 | 11 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 12 | 13 | [dependencies] 14 | ggml = { path = "../ggml", version = "0.2.0-dev" } 15 | 16 | bytemuck = { workspace = true } 17 | rand = { workspace = true } 18 | serde = { workspace = true } 19 | thiserror = { workspace = true } 20 | 21 | partial_sort = "0.2.0" 22 | serde_bytes = "0.11" 23 | memmap2 = { workspace = true } 24 | half = "2" 25 | tokenizers = {version="0.13.4", default-features=false, features=["onig"]} 26 | regex = "1.8" 27 | tracing = { workspace = true } 28 | 29 | llm-samplers = { workspace = true } 30 | 31 | [features] 32 | tokenizers-remote = ["tokenizers/http"] 33 | cublas = ["ggml/cublas"] 34 | clblast = ["ggml/clblast"] 35 | metal = ["ggml/metal"] 36 | -------------------------------------------------------------------------------- /crates/llm-base/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This crate provides a unified interface for loading and using 2 | //! Large Language Models (LLMs). 3 | //! 4 | //! This is the base crate that implementors can use to implement their own 5 | //! LLMs. 6 | //! 7 | //! As a user, you probably want to use the [llm](https://crates.io/crates/llm) crate instead. 8 | #![deny(missing_docs)] 9 | 10 | mod inference_session; 11 | mod loader; 12 | mod lora; 13 | mod quantize; 14 | mod tokenizer; 15 | 16 | pub mod model; 17 | pub mod samplers; 18 | pub mod util; 19 | 20 | use std::sync::{Arc, Mutex}; 21 | 22 | pub use ggml; 23 | pub use ggml::Type as ElementType; 24 | 25 | pub use inference_session::{ 26 | conversation_inference_callback, feed_prompt_callback, GraphOutputs, InferenceError, 27 | InferenceFeedback, InferenceRequest, InferenceResponse, InferenceSession, 28 | InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats, 29 | ModelKVMemoryType, RewindError, SnapshotError, 30 | }; 31 | pub use llm_samplers::prelude::{Sampler, SamplerChain}; 32 | pub use loader::{ 33 | load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic, 34 | LoadError, LoadProgress, Loader, TensorLoader, 35 | }; 36 | pub use lora::{LoraAdapter, LoraParameters}; 37 | pub use memmap2::Mmap; 38 | pub use model::{Hyperparameters, KnownModel, Model, ModelContext, ModelParameters, OutputRequest}; 39 | pub use quantize::{quantize, QuantizeError, QuantizeProgress}; 40 | pub use regex::Regex; 41 | pub use tokenizer::{ 42 | InvalidTokenBias, Prompt, TokenBias, TokenId, TokenizationError, Tokenizer, TokenizerLoadError, 43 | TokenizerSource, 44 | }; 45 | pub use util::TokenUtf8Buffer; 46 | 47 | #[derive(Clone, Debug)] 48 | /// The parameters for text generation. 49 | /// 50 | /// This needs to be provided during all inference calls, 51 | /// but can be changed between calls. 52 | pub struct InferenceParameters { 53 | /// The sampler to use for sampling tokens from the model's probabilities. 54 | /// 55 | /// Each time the model runs, it generates a distribution of probabilities; each token 56 | /// has a probability of being the next token. The sampler is responsible for sampling 57 | /// from this distribution to generate the next token. Using a different sampler may 58 | /// change the output of the model, or control how deterministic the generated text is. 59 | /// 60 | /// This can be anything that implements [Sampler]. Refer to 61 | /// the `llm-samplers` documentation for possible samplers and suggested 62 | /// combinations: 63 | pub sampler: Arc>, 64 | } 65 | 66 | //Since Sampler implements Send and Sync, InferenceParameters should too. 67 | unsafe impl Send for InferenceParameters {} 68 | unsafe impl Sync for InferenceParameters {} 69 | 70 | impl Default for InferenceParameters { 71 | fn default() -> Self { 72 | Self { 73 | sampler: samplers::default_samplers(), 74 | } 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /crates/llm-base/src/lora.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | loader::FileContext, model::HyperparametersWriteError, util, FileType, Hyperparameters, 3 | LoadError, 4 | }; 5 | 6 | use ggml::{format::TensorLoadInfo, GraphExecutionPlan}; 7 | use std::{ 8 | collections::{HashMap, HashSet}, 9 | fs::File, 10 | path::PathBuf, 11 | }; 12 | 13 | #[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] 14 | /// Parameters for a [LoRA](https://arxiv.org/abs/2106.09685) adapter. 15 | pub struct LoraParameters { 16 | /// r 17 | pub r: i32, 18 | /// alpha 19 | pub alpha: i32, 20 | } 21 | impl LoraParameters { 22 | /// Returns the scaling factor for the LoRA adapter. 23 | pub fn calculate_scaling(&self) -> f32 { 24 | (self.alpha as f32) / (self.r as f32) 25 | } 26 | } 27 | impl Hyperparameters for LoraParameters { 28 | fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result { 29 | Ok(LoraParameters { 30 | r: util::read_i32(reader)?, 31 | alpha: util::read_i32(reader)?, 32 | }) 33 | } 34 | 35 | fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> { 36 | util::write_i32(writer, self.r)?; 37 | util::write_i32(writer, self.alpha)?; 38 | Ok(()) 39 | } 40 | 41 | fn n_vocabulary(&self) -> usize { 42 | // LoRA adapters do not have a vocabulary. 43 | 0 44 | } 45 | 46 | fn file_type(&self) -> Option { 47 | None 48 | } 49 | 50 | fn file_type_mut(&mut self) -> Option<&mut FileType> { 51 | None 52 | } 53 | } 54 | 55 | /// [LoRA](https://arxiv.org/abs/2106.09685) adapter for a model. 56 | pub struct LoraAdapter { 57 | /// Scaling to apply to the LoRA weights. 58 | pub scaling: f32, 59 | /// The tensors of the LoRA. 60 | pub tensors: HashMap, 61 | /// Names of the tensors that should be patched. 62 | pub tensors_to_patch: HashSet, 63 | /// File containing the LoRA weights. 64 | pub file: File, 65 | /// Path to the LoRA file. 66 | pub path: PathBuf, 67 | } 68 | 69 | impl LoraAdapter { 70 | /// Patch a tensor via LoRA 71 | pub fn patch( 72 | &mut self, 73 | info: &TensorLoadInfo, 74 | tensor: &mut ggml::Tensor, 75 | ) -> Result<(), LoadError> { 76 | // Check if we need to patch this tensor 77 | let name = &info.name; 78 | if !self.tensors_to_patch.contains(name) { 79 | return Ok(()); 80 | } 81 | 82 | let a_info = self.get_info(&format!("{}.loraA", name))?; 83 | let b_info = self.get_info(&format!("{}.loraB", name))?; 84 | 85 | let must_scale = self.scaling != 1.0; 86 | // Calculate the size of the patch context via the following steps: 87 | // 1. Calculate the size of the two `a` and `b` tensors 88 | // 2. Calculate the size of the original tensor 89 | // 3. Calculate the size of the `ba` and tensors. It has the same dimensions as the original tensor, but is of the element type of the `a` or `b` tensor e.g. fp16 90 | let ba_size = ggml::format::tensor_size(a_info.element_type, info.dims().iter().product()); 91 | let mut patch_context_size = a_info.calc_absolute_size(false) 92 | + b_info.calc_absolute_size(false) 93 | + info.calc_absolute_size(false) 94 | + ba_size; 95 | 96 | // 3b. (Optional) If we need to scale the `ba` tensor, we need to allocate for a second `ba` and the `scaled` tensors which will be crated as an `f32` tensor. 97 | if must_scale { 98 | let scaled_size = 99 | ggml::format::tensor_size(ggml::ElementType::F32, info.dims().iter().product()); 100 | patch_context_size += scaled_size + ba_size; 101 | } 102 | 103 | // 4. Add 5% as ggml overhead (I dont know why this is needed but the calculation is always a few 100-1000 bytes off) 104 | patch_context_size = patch_context_size + (patch_context_size / 20); 105 | 106 | // Create a temporary context for the patching operations 107 | // TODO: test if GPU can be enabled (make it configurable) 108 | let patch_context = ggml::Context::new_with_allocate(patch_context_size); 109 | let mut patch_file = FileContext::new(&patch_context, &mut self.file, &self.path); 110 | 111 | // Load the A and B tensors 112 | let a = patch_file.get_tensor(&a_info)?; 113 | let b = patch_file.get_tensor(&b_info)?; 114 | 115 | //Build a ggml context and apply the patch 116 | 117 | let mut gf = patch_context.create_compute_graph(); 118 | 119 | // LoRA formula: w = w + ba*s 120 | let mut ba = patch_context.op_mul_mat(&a, &b); 121 | if must_scale { 122 | let scaling_tensor = patch_context.new_f32(self.scaling); 123 | ba = patch_context.op_scale(&ba, &scaling_tensor); 124 | } 125 | let mut output = patch_context.op_add(tensor, &ba); 126 | 127 | // Compute the graph 128 | gf.build_forward_expand(&output); 129 | 130 | //TODO: maybe pass the model's thread count to this context 131 | let mut plan = GraphExecutionPlan::new(&mut gf, 8); 132 | plan.execute(&patch_context); 133 | 134 | // Overwrite the original tensor. 135 | // The `output` and the `target_tensor` are not from the same context, 136 | // so this should be fine. 137 | unsafe { 138 | std::ptr::copy_nonoverlapping(output.data(), tensor.data(), tensor.nbytes()); 139 | } 140 | 141 | Ok(()) 142 | } 143 | 144 | fn get_info(&self, name: &str) -> Result { 145 | self.tensors 146 | .get(name) 147 | .cloned() 148 | .ok_or(LoadError::UnknownTensor { 149 | path: self.path.to_owned(), 150 | tensor_name: name.to_owned(), 151 | }) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /crates/llm-base/src/model/common.rs: -------------------------------------------------------------------------------- 1 | use ggml::Tensor; 2 | 3 | use crate::{InferenceSession, OutputRequest}; 4 | 5 | /// Return result for just the last token 6 | pub fn read_last_token( 7 | session: &mut InferenceSession, 8 | input_layer: &Tensor, 9 | n_vocab: usize, 10 | n: usize, 11 | ) { 12 | assert_eq!(session.last_logits.len(), n_vocab); 13 | unsafe { 14 | input_layer.read_data( 15 | n_vocab * (n - 1) * std::mem::size_of::(), 16 | bytemuck::cast_slice_mut(&mut session.last_logits), 17 | ) 18 | }; 19 | } 20 | 21 | /// Extract logits from [OutputRequest] evaluation 22 | pub fn extract_logits( 23 | output_request: &mut OutputRequest, 24 | input_layer: &Tensor, 25 | n_vocab: usize, 26 | n: usize, 27 | ) { 28 | if let Some(all_logits) = &mut output_request.all_logits { 29 | all_logits.resize(n_vocab * n, 0.0); 30 | // SAFETY: Tensor data can be read (properly aligned, initialized, 31 | // data will not be mutated or otherwise aliased during the copy), 32 | // and we're not reading past the end of the tensor data. 33 | assert_eq!(input_layer.nelements(), n_vocab * n); 34 | unsafe { 35 | input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits)); 36 | } 37 | } 38 | } 39 | 40 | /// Extract embeddings from [OutputRequest] evaluation 41 | pub fn extract_embeddings( 42 | output_request: &mut OutputRequest, 43 | embeddings_tensor: &Tensor, 44 | n_embd: usize, 45 | n: usize, 46 | ) { 47 | // Extract embeddings 48 | if let Some(embeddings) = &mut output_request.embeddings { 49 | embeddings.resize(n_embd, 0.0); 50 | // Create a new vector to hold all embeddings 51 | let mut all_embeddings = vec![0.0; n_embd * n]; 52 | // SAFETY: Same rationale as for the "Extract logits" section applies. 53 | assert_eq!(embeddings_tensor.nelements(), n_embd * n); 54 | unsafe { 55 | embeddings_tensor.read_data(0, bytemuck::cast_slice_mut(&mut all_embeddings)); 56 | } 57 | embeddings.copy_from_slice(&all_embeddings[n_embd * (n - 1)..]); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /crates/llm-base/src/model/mod.rs: -------------------------------------------------------------------------------- 1 | //! Large language model traits and types 2 | 3 | use std::{ 4 | error::Error, 5 | fmt::Debug, 6 | io::{BufRead, Write}, 7 | path::{Path, PathBuf}, 8 | sync::Arc, 9 | }; 10 | 11 | use ggml::accelerator::Backend; 12 | use regex::Regex; 13 | use thiserror::Error; 14 | 15 | use crate::{ 16 | loader::TensorLoader, tokenizer::TokenId, FileType, InferenceSession, InferenceSessionConfig, 17 | LoadError, LoadProgress, Tokenizer, TokenizerSource, 18 | }; 19 | 20 | /// Common functions for model evaluation 21 | pub mod common; 22 | 23 | /// Interfaces for creating and interacting with a large language model with a known type 24 | /// of [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)). 25 | pub trait KnownModel: Send + Sync { 26 | /// Hyperparameters for the model. 27 | type Hyperparameters: Hyperparameters; 28 | 29 | /// Load this model from the `path` and configure it per the `params`. The status 30 | /// of the loading process will be reported through `load_progress_callback`. This 31 | /// is a helper function on top of [llm_base::load](crate::load). 32 | fn load( 33 | path: &Path, 34 | tokenizer_source: TokenizerSource, 35 | params: ModelParameters, 36 | load_progress_callback: impl FnMut(LoadProgress), 37 | ) -> Result 38 | where 39 | Self: Sized, 40 | { 41 | crate::load(path, tokenizer_source, params, load_progress_callback) 42 | } 43 | 44 | /// Creates a new model from the provided [ModelParameters] hyperparameters. 45 | /// This function is called by the [load](crate::loader::load) function. 46 | fn new( 47 | hyperparameters: Self::Hyperparameters, 48 | params: ModelParameters, 49 | tokenizer: Tokenizer, 50 | tensor_loader: impl TensorLoader, 51 | ) -> Result 52 | where 53 | Self: Sized; 54 | 55 | /// Starts a new `InferenceSession` for this model. 56 | fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession; 57 | 58 | /// This function is called by the provided [InferenceSession]; it will use this model 59 | /// to generate output by evaluating the `input_tokens`. 60 | /// The [OutputRequest] is used to specify additional data to fetch from the 61 | /// model. 62 | fn evaluate( 63 | &self, 64 | session: &mut InferenceSession, 65 | input_tokens: &[TokenId], 66 | output_request: &mut OutputRequest, 67 | ); 68 | 69 | /// Get the hyperparameters for this model. 70 | fn hyperparameters(&self) -> &Self::Hyperparameters; 71 | 72 | /// Get the tokenizer for this model. 73 | fn tokenizer(&self) -> &Tokenizer; 74 | 75 | /// Get the context size (configured with [ModelParameters::context_size]) used by 76 | /// this model. 77 | fn context_size(&self) -> usize; 78 | 79 | /// Get the beginning of text/beginning of string token ID, if available. This value is defined by model implementers. 80 | fn bot_token_id(&self) -> Option; 81 | 82 | /// Get the end of text/end of string token ID. This value is defined by model implementers. 83 | fn eot_token_id(&self) -> TokenId; 84 | 85 | /// Get the list of regexes to use to determine if a tensor in this model should be quantized. 86 | fn quantize_tensors() -> Vec; 87 | 88 | /// Get the list of regexes to use to determine if a tensor in this model should not be quantized. 89 | fn skip_quantize_tensors() -> Vec; 90 | 91 | /// Returns whether the model supports deleting tokens. 92 | fn supports_rewind(&self) -> bool { 93 | // Assume we can't delete unless otherwise specified 94 | false 95 | } 96 | } 97 | 98 | /// A type-erased model to allow for interacting with a model without knowing 99 | /// its hyperparameters. 100 | pub trait Model: Send + Sync { 101 | /// Starts a new `InferenceSession` for this model. 102 | fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession; 103 | 104 | /// This function is called by the provided [InferenceSession]; it will use this model 105 | /// to generate output by evaluating the `input_tokens`. 106 | /// The [OutputRequest] is used to specify additional data to fetch from the 107 | /// model. 108 | fn evaluate( 109 | &self, 110 | session: &mut InferenceSession, 111 | input_tokens: &[TokenId], 112 | output_request: &mut OutputRequest, 113 | ); 114 | 115 | /// Get the tokenizer for this model. 116 | fn tokenizer(&self) -> &Tokenizer; 117 | 118 | /// Get the context size (configured with [ModelParameters::context_size]) used by 119 | /// this model. 120 | fn context_size(&self) -> usize; 121 | 122 | /// Get the beginning of text/beginning of string token ID, if available. This value is defined by model implementers. 123 | fn bot_token_id(&self) -> Option; 124 | 125 | /// Get the end of text/end of string token ID. This value is defined by model implementers. 126 | fn eot_token_id(&self) -> TokenId; 127 | 128 | /// Returns whether the model supports deleting tokens. 129 | fn supports_rewind(&self) -> bool; 130 | } 131 | impl> Model for M { 132 | fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession { 133 | KnownModel::start_session(self, config) 134 | } 135 | 136 | fn evaluate( 137 | &self, 138 | session: &mut InferenceSession, 139 | input_tokens: &[TokenId], 140 | output_request: &mut OutputRequest, 141 | ) { 142 | KnownModel::evaluate(self, session, input_tokens, output_request) 143 | } 144 | 145 | fn tokenizer(&self) -> &Tokenizer { 146 | KnownModel::tokenizer(self) 147 | } 148 | 149 | fn context_size(&self) -> usize { 150 | KnownModel::context_size(self) 151 | } 152 | 153 | fn bot_token_id(&self) -> Option { 154 | KnownModel::bot_token_id(self) 155 | } 156 | 157 | fn eot_token_id(&self) -> TokenId { 158 | KnownModel::eot_token_id(self) 159 | } 160 | 161 | fn supports_rewind(&self) -> bool { 162 | KnownModel::supports_rewind(self) 163 | } 164 | } 165 | 166 | /// Implemented by model hyperparameters for interacting with hyperparameters 167 | /// without knowing what they are, as well as writing/reading them as required. 168 | pub trait Hyperparameters: Sized + Default + Debug + PartialEq + Eq { 169 | /// Read the parameters in GGML format from a reader. 170 | fn read_ggml(reader: &mut dyn BufRead) -> Result; 171 | 172 | /// Write the parameters in GGML format to a writer. 173 | fn write_ggml(&self, writer: &mut dyn Write) -> Result<(), HyperparametersWriteError>; 174 | 175 | /// Get the number of tokens in the embedded vocabulary, if any. 176 | fn n_vocabulary(&self) -> usize; 177 | 178 | /// Get the filetype of the model. 179 | fn file_type(&self) -> Option; 180 | 181 | /// Get mutable access to filetype of the model. 182 | fn file_type_mut(&mut self) -> Option<&mut FileType>; 183 | } 184 | #[derive(Error, Debug)] 185 | /// Reported from functions that write 186 | pub enum HyperparametersWriteError { 187 | #[error("non-specific I/O error")] 188 | /// A non-specific IO error. 189 | Io(#[from] std::io::Error), 190 | #[error("invalid integer conversion")] 191 | /// One of the integers encountered could not be converted to a more appropriate type. 192 | InvalidIntegerConversion(#[from] std::num::TryFromIntError), 193 | } 194 | 195 | /// Parameters for model-wide behaviour. 196 | #[derive(Debug, Clone)] 197 | pub struct ModelParameters { 198 | /// For [GGML formats](ggml::ContainerType) that support it, [mmap](https://en.wikipedia.org/wiki/Mmap) 199 | /// is the default. Although mmap typically improves performance, setting this value to `false` may 200 | /// be preferred in resource-constrained environments. 201 | pub prefer_mmap: bool, 202 | /// The context size ("memory") the model should use when evaluating a prompt. A larger context 203 | /// consumes more resources, but produces more consistent and coherent responses. 204 | pub context_size: usize, 205 | /// The [LoRA](https://arxiv.org/abs/2106.09685) adapters to use when loading the model. If `None`, no adapters will be used. 206 | pub lora_adapters: Option>, 207 | /// Whether to use GPU acceleration when available 208 | pub use_gpu: bool, 209 | /// If `use_gpu` is active this defines the number of layers to offload to the gpu. If `None`, all layers will be offloaded. 210 | pub gpu_layers: Option, 211 | /// The arguments/overrides to pass to the [custom RoPE](https://arxiv.org/pdf/2306.15595.pdf) function, if it is used by the model. 212 | pub rope_overrides: Option, 213 | /// Enables gouped-query attention for Llama-2 70B model 214 | pub n_gqa: Option, 215 | } 216 | 217 | impl Default for ModelParameters { 218 | fn default() -> Self { 219 | Self { 220 | prefer_mmap: true, 221 | context_size: 2048, 222 | lora_adapters: None, 223 | use_gpu: false, 224 | gpu_layers: None, 225 | rope_overrides: None, 226 | n_gqa: None, 227 | } 228 | } 229 | } 230 | 231 | impl ModelParameters { 232 | /// Returns true if the model should offload the given layer to the accelerator. 233 | pub fn should_offload(&self, layer: usize) -> bool { 234 | if !self.use_gpu { 235 | return false; 236 | } 237 | 238 | self.gpu_layers 239 | .map(|gpu_layers| layer < gpu_layers) 240 | .unwrap_or(true) 241 | } 242 | 243 | /// Returns the backend to use for the given layer. 244 | pub fn backend(&self, layer: usize) -> Backend { 245 | if self.should_offload(layer) { 246 | Backend::Gpu 247 | } else { 248 | Backend::Cpu 249 | } 250 | } 251 | } 252 | 253 | /// Used in a call to [Model::evaluate] or [InferenceSession::infer] to request 254 | /// information from the model. If a value is set to `Some`, the `Vec` will be 255 | /// cleared, resized, and filled with the related data. 256 | #[derive(Default, Debug, PartialEq, Clone)] 257 | pub struct OutputRequest { 258 | /// Returns all the logits for evaluation. A logit represents the likelihood 259 | /// that a given token will be generated based on the tokens that have been 260 | /// evaluated or generated so far. Output shape is `n_batch * n_vocab`. 261 | pub all_logits: Option>, 262 | /// Returns all the embeddings for an evaluation. An embedding is a vector 263 | /// that measures the relatedness of text strings. Output shape is 264 | /// `n_batch * n_embd`. 265 | pub embeddings: Option>, 266 | } 267 | 268 | /// Contains the GGML context for a [`Model`]. Implements `Send` and `Sync` 269 | /// to allow for the free transfer of models; this is made possible by this 270 | /// context being effectively inert after creation, so that it cannot be 271 | /// modified across threads. 272 | #[derive(Clone)] 273 | #[allow(clippy::arc_with_non_send_sync)] 274 | pub struct ModelContext(pub(crate) Arc); 275 | unsafe impl Send for ModelContext {} 276 | unsafe impl Sync for ModelContext {} 277 | -------------------------------------------------------------------------------- /crates/llm-base/src/tokenizer/embedded.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use thiserror::Error; 4 | 5 | use super::{Token, TokenId, TokenScore, TokenizationError}; 6 | 7 | #[derive(Debug, Error)] 8 | /// Errors that can occur when using a model tokenizer. 9 | pub enum EmbeddedTokenizerError { 10 | /// Arbitrary error that occurred during use of the model tokenizer. 11 | #[error("Arbitrary error: {0:?}")] 12 | Arbitrary(String), 13 | } 14 | 15 | /// The built-in GGML tokenizer. 16 | #[derive(Debug, Clone, Default)] 17 | pub struct EmbeddedTokenizer { 18 | /// Maps every integer (index) token ID to its corresponding token. 19 | id_to_token: Vec, 20 | 21 | /// Maps every integer (index) token ID to corresponding score. 22 | id_to_token_score: Vec, 23 | 24 | // todo: use a radix tree 25 | /// Maps a token to a token ID. 26 | token_to_id: HashMap, 27 | 28 | /// The longest token in this tokenizer. 29 | max_token_length: usize, 30 | } 31 | 32 | impl EmbeddedTokenizer { 33 | /// Add a token to the internal vocabulary. 34 | /// 35 | /// The token added must have `id` directly after the last token in the vocabulary. 36 | /// 37 | /// # Panics 38 | /// - This function can panic if `id` does not correspond to the next token in the vocabulary. 39 | /// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`. 40 | pub(crate) fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { 41 | // These are loader invariants. If this is broken, then the loader is broken and this is a bug, 42 | // not an issue with the model itself. 43 | assert_eq!(self.id_to_token.len(), self.id_to_token_score.len()); 44 | if self.id_to_token.len() != id as usize || self.id_to_token_score.len() != id as usize { 45 | let expected_id = self.id_to_token.len() as TokenId; 46 | panic!("the id of token added should be {expected_id}; is {id}"); 47 | } 48 | 49 | self.max_token_length = self.max_token_length.max(content.len()); 50 | self.id_to_token.push(content.clone()); 51 | self.id_to_token_score.push(score); 52 | self.token_to_id.insert(content, id); 53 | } 54 | 55 | pub(crate) fn id(&self, token: &[u8]) -> Option { 56 | self.token_to_id.get(token).copied() 57 | } 58 | 59 | /// Converts a token index to the token it represents in this tokenizer. 60 | pub(crate) fn token(&self, idx: usize) -> Vec { 61 | self.id_to_token[idx].clone() 62 | } 63 | 64 | /// Returns the number of tokens in the tokenizer. 65 | pub(crate) fn len(&self) -> usize { 66 | self.id_to_token.len() 67 | } 68 | 69 | /// Returns whether the tokenizer is empty. 70 | pub(crate) fn is_empty(&self) -> bool { 71 | self.id_to_token.is_empty() 72 | } 73 | 74 | // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece 75 | /// Tokenize a `text` with this tokenizer. 76 | /// 77 | /// `bos` controls whether a beginning-of-string token should be inserted. 78 | pub(crate) fn tokenize( 79 | &self, 80 | text: &str, 81 | bos: bool, 82 | ) -> Result, TokenId)>, TokenizationError> { 83 | let len = text.len(); 84 | 85 | let mut score = vec![0usize; len + 1]; 86 | let mut prev = vec![TokenId::default(); len + 1]; 87 | 88 | for i in 0..len { 89 | let max_len = (len - i).min(self.max_token_length); 90 | for sub_len in 1..=max_len { 91 | let sub = &text.as_bytes()[i..i + sub_len]; 92 | let token = self.token_to_id.get(sub); 93 | 94 | if let Some(token) = token { 95 | let token_score = sub.len() * sub.len(); 96 | let local_score = score[i] + token_score; 97 | let next = i + sub_len; 98 | 99 | if score[next] < local_score { 100 | score[next] = local_score; 101 | prev[next] = *token; 102 | } 103 | } 104 | } 105 | } 106 | 107 | // Backward pass 108 | let mut res = vec![]; 109 | let mut i = len; 110 | while i > 0 { 111 | let token_id = prev[i]; 112 | if token_id == 0 { 113 | return Err(TokenizationError::TokenizationFailed { 114 | error: Box::new(EmbeddedTokenizerError::Arbitrary( 115 | "the backward pass for the tokenizer encountered a non-set token" 116 | .to_string(), 117 | )), 118 | }); 119 | } 120 | let token = self.id_to_token[token_id as usize].as_slice(); 121 | res.push((token.to_vec(), token_id)); 122 | i -= token.len(); 123 | } 124 | 125 | if bos { 126 | // TODO: replace with vocab.bos 127 | res.push((vec![], 1)); 128 | } 129 | 130 | // Pieces are in reverse order so correct that 131 | res.reverse(); 132 | 133 | Ok(res) 134 | } 135 | 136 | /// Decode a list `tokens` with this tokenizer. 137 | pub(crate) fn decode(&self, tokens: Vec, skip_special_tokens: bool) -> Vec { 138 | let mut vec = vec![]; 139 | 140 | for token in tokens { 141 | if skip_special_tokens && token == 1 { 142 | continue; 143 | } 144 | 145 | vec.append(&mut self.id_to_token[token as usize].to_vec()); 146 | } 147 | 148 | vec 149 | } 150 | 151 | pub(crate) fn iter(&self) -> impl Iterator + '_ { 152 | self.id_to_token 153 | .iter() 154 | .zip(self.id_to_token_score.iter()) 155 | .map(|(token, score)| (token.clone(), *score)) 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /crates/llm-base/src/tokenizer/huggingface.rs: -------------------------------------------------------------------------------- 1 | use super::{TokenId, TokenizationError}; 2 | 3 | /// A Hugging Face tokenizer. 4 | #[derive(Debug, Clone)] 5 | pub struct HuggingFaceTokenizer { 6 | pub(crate) tokenizer: tokenizers::Tokenizer, 7 | } 8 | 9 | impl HuggingFaceTokenizer { 10 | /// Create a new `HuggingFaceTokenizer`. 11 | pub fn new(tokenizer: tokenizers::Tokenizer) -> Self { 12 | Self { tokenizer } 13 | } 14 | } 15 | 16 | impl HuggingFaceTokenizer { 17 | pub(crate) fn id(&self, token: &[u8]) -> Option { 18 | self.tokenizer 19 | .token_to_id(std::str::from_utf8(token).unwrap()) 20 | } 21 | 22 | /// Converts a token index to the token it represents in this tokenizer. 23 | pub(crate) fn token(&self, idx: usize) -> Vec { 24 | self.tokenizer 25 | .decode(&[idx as u32], true) 26 | .expect("Cannot decode token from tokenizer tokenizer.") 27 | .as_bytes() 28 | .to_vec() 29 | } 30 | 31 | /// Returns the number of tokens in the tokenizer. 32 | pub(crate) fn len(&self) -> usize { 33 | self.tokenizer.get_vocab_size(false) 34 | } 35 | 36 | /// Returns whether the tokenizer is empty. 37 | pub(crate) fn is_empty(&self) -> bool { 38 | self.tokenizer.get_vocab_size(false) == 0 39 | } 40 | 41 | /// Tokenize a `text` with this tokenizer. 42 | /// 43 | /// `bos` controls whether a beginning-of-string token should be inserted. 44 | pub(crate) fn tokenize( 45 | &self, 46 | text: &str, 47 | bos: bool, 48 | ) -> Result, TokenId)>, TokenizationError> { 49 | let encoding = self 50 | .tokenizer 51 | .encode(text, false) 52 | .map_err(|e| TokenizationError::TokenizationFailed { error: e })?; 53 | 54 | let encoding = self 55 | .tokenizer 56 | .post_process(encoding, None, bos) 57 | .map_err(|e| TokenizationError::TokenizationFailed { error: e })?; 58 | 59 | Ok(encoding 60 | .get_tokens() 61 | .iter() 62 | .map(|t| t.as_bytes().to_vec()) 63 | .zip(encoding.get_ids().iter().copied()) 64 | .collect()) 65 | } 66 | 67 | /// Decode a list `tokens` with this tokenizer. 68 | pub(crate) fn decode(&self, tokens: Vec, skip_special_tokens: bool) -> Vec { 69 | self.tokenizer 70 | .decode(&tokens, skip_special_tokens) 71 | .expect("Cannot decode token from tokenizer.") 72 | .as_bytes() 73 | .to_vec() 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /crates/llm-base/src/util.rs: -------------------------------------------------------------------------------- 1 | //! Utilities for interacting with LLMs and loading them. 2 | pub use ggml::util::*; 3 | 4 | use std::{ 5 | io::BufRead, 6 | path::{Path, PathBuf}, 7 | }; 8 | 9 | /// NOTE: The original code relies in promotion rules and automatic cast between 10 | /// int to float. What we do instead is use this macro to convert every term of 11 | /// the multiplication to f64, which should have enough precision bits to hold 12 | /// the final value, then cast to usize. I have observed a discrepancy between 13 | /// the ctx_size found using this code, and the one in llama.cpp. The number for 14 | /// rust ends up being slightly lower, but no "out of memory" errors are 15 | /// reported by ggml. 16 | #[macro_export] 17 | #[doc(hidden)] 18 | macro_rules! mulf { 19 | ($term:expr, $($terms:expr),*) => { 20 | usize::try_from((($term as f64) $(* ($terms as f64))*) as u64).unwrap() 21 | }; 22 | } 23 | 24 | use memmap2::{Mmap, MmapAsRawDesc, MmapOptions}; 25 | use thiserror::Error; 26 | 27 | use crate::{FileType, LoadError}; 28 | 29 | /// Read the filetype from a reader. 30 | pub fn read_filetype(reader: &mut dyn BufRead) -> Result { 31 | let ftype = read_i32(reader)?; 32 | FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype)) 33 | } 34 | 35 | /// Used to buffer incoming tokens until they produce a valid string of UTF-8 text. 36 | /// 37 | /// Tokens are *not* valid UTF-8 by themselves. However, the LLM will produce valid UTF-8 38 | /// from multiple tokens. This helps alleviate that issue. 39 | #[derive(Clone, PartialEq, Eq, Default)] 40 | pub struct TokenUtf8Buffer(Vec); 41 | impl TokenUtf8Buffer { 42 | /// Create a new buffer. 43 | pub const fn new() -> Self { 44 | Self(vec![]) 45 | } 46 | 47 | /// Add a token to the buffer. If the buffer contains a valid string of UTF-8 text, 48 | /// it is returned and the buffer is cleared for next use. 49 | pub fn push(&mut self, token: &[u8]) -> Option { 50 | self.0.extend_from_slice(token); 51 | match std::str::from_utf8(&self.0) { 52 | Ok(s) => { 53 | let out = s.to_owned(); 54 | self.0 = vec![]; 55 | Some(out) 56 | } 57 | Err(..) => { 58 | for i in 1..self.0.len() { 59 | let slice = &self.0[i..]; 60 | if slice.is_empty() { 61 | break; 62 | } 63 | 64 | if let Ok(s) = std::str::from_utf8(slice) { 65 | let out = s.to_owned(); 66 | self.0 = vec![]; 67 | return Some(out); 68 | } 69 | } 70 | None 71 | } 72 | } 73 | } 74 | } 75 | 76 | #[derive(Error, Debug)] 77 | /// Errors encountered during the loading process. 78 | pub enum FindAllModelFilesError { 79 | #[error("no parent path for {path:?}")] 80 | /// There is no parent path for a given path. 81 | NoParentPath { 82 | /// The path without a parent. 83 | path: PathBuf, 84 | }, 85 | #[error("non-specific I/O error")] 86 | /// A non-specific IO error. 87 | IO(#[from] std::io::Error), 88 | } 89 | 90 | /// Find all the files related to a model. 91 | pub fn find_all_model_files(main_path: &Path) -> Result, FindAllModelFilesError> { 92 | let mut main_path_parent = 93 | main_path 94 | .parent() 95 | .ok_or_else(|| FindAllModelFilesError::NoParentPath { 96 | path: main_path.to_owned(), 97 | })?; 98 | if main_path_parent.to_str() == Some("") { 99 | main_path_parent = Path::new("."); 100 | } 101 | Ok(collect_related_paths( 102 | main_path, 103 | std::fs::read_dir(main_path_parent)? 104 | .filter_map(Result::ok) 105 | .map(|de| de.path()), 106 | )) 107 | } 108 | 109 | fn collect_related_paths( 110 | main_path: &Path, 111 | directory_paths: impl Iterator, 112 | ) -> Vec { 113 | let main_filename = main_path.file_name().and_then(|p| p.to_str()); 114 | 115 | let mut paths: Vec = directory_paths 116 | .filter(|p| { 117 | p.file_name() 118 | .and_then(|p| p.to_str()) 119 | .zip(main_filename) 120 | .map_or(false, |(part_filename, main_filename)| match part_filename 121 | .strip_prefix(main_filename) 122 | { 123 | Some(suffix) => { 124 | suffix.is_empty() 125 | || (suffix 126 | .strip_prefix('.') 127 | .map_or(false, |s| s.parse::().is_ok())) 128 | } 129 | None => false, 130 | }) 131 | }) 132 | .collect(); 133 | paths.sort(); 134 | paths 135 | } 136 | 137 | /// mmap with MAP_POPULATE 138 | pub fn mmap_populate(file: T) -> Result { 139 | unsafe { MmapOptions::new().populate().map(file) } 140 | } 141 | 142 | /// Calculate softmax for a slice 143 | pub fn softmax(logits: &[f32]) -> Vec { 144 | let mut probs = logits.to_vec(); 145 | let max_logit = probs.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)); 146 | let sum: f32 = logits.iter().map(|v| (v - max_logit).exp()).sum(); 147 | for v in probs.iter_mut() { 148 | *v = (*v - max_logit).exp() / sum; 149 | } 150 | probs 151 | } 152 | 153 | #[cfg(test)] 154 | mod tests { 155 | use super::*; 156 | 157 | #[test] 158 | fn test_collect_related_paths() { 159 | let main_path = PathBuf::from("/models/llama.bin"); 160 | let directory_paths = [ 161 | "/models/llama.bin", 162 | "/models/llama.bin.1", 163 | "/models/llama.bin.2", 164 | "/models/llama.bin.tmp", 165 | ] 166 | .map(PathBuf::from); 167 | let expected_paths = [ 168 | "/models/llama.bin", 169 | "/models/llama.bin.1", 170 | "/models/llama.bin.2", 171 | ] 172 | .map(PathBuf::from); 173 | 174 | let output_paths = collect_related_paths(&main_path, directory_paths.into_iter()); 175 | assert_eq!(expected_paths.as_slice(), output_paths); 176 | } 177 | 178 | #[test] 179 | fn test_valid_utf8() { 180 | let mut buffer = TokenUtf8Buffer::new(); 181 | assert_eq!(buffer.push(b"hello").as_deref(), Some("hello")); 182 | assert_eq!(buffer.push(&[0xE2, 0x82, 0xAC]).as_deref(), Some("€")); 183 | } 184 | 185 | #[test] 186 | fn test_partial_utf8() { 187 | let mut buffer = TokenUtf8Buffer::new(); 188 | assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None); 189 | assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€")); 190 | } 191 | 192 | #[test] 193 | fn test_invalid_prelude_for_valid_utf8() { 194 | let mut buffer = TokenUtf8Buffer::new(); 195 | assert_eq!(buffer.push(&[0xD8]).as_deref(), None); 196 | assert_eq!(buffer.push(&[0xE2, 0x82]).as_deref(), None); 197 | assert_eq!(buffer.push(&[0xAC]).as_deref(), Some("€")); 198 | } 199 | } 200 | -------------------------------------------------------------------------------- /crates/llm/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llm" 3 | version = "0.2.0-dev" 4 | license = { workspace = true } 5 | repository = { workspace = true } 6 | description = "A Rust ecosystem of libraries for running inference on large language models, inspired by llama.cpp." 7 | edition = "2021" 8 | readme = "../../README.md" 9 | 10 | [dependencies] 11 | llm-base = { path = "../llm-base", version = "0.2.0-dev" } 12 | llm-llama = { path = "../models/llama", optional = true, version = "0.2.0-dev" } 13 | llm-gpt2 = { path = "../models/gpt2", optional = true, version = "0.2.0-dev" } 14 | llm-gptj = { path = "../models/gptj", optional = true, version = "0.2.0-dev" } 15 | llm-bloom = { path = "../models/bloom", optional = true, version = "0.2.0-dev" } 16 | llm-gptneox = { path = "../models/gptneox", optional = true, version = "0.2.0-dev" } 17 | llm-mpt = { path = "../models/mpt", optional = true, version = "0.2.0-dev" } 18 | llm-falcon = { path = "../models/falcon", optional = true, version = "0.2.0-dev" } 19 | 20 | serde = { workspace = true } 21 | tracing = { workspace = true } 22 | 23 | [dev-dependencies] 24 | bytesize = { workspace = true } 25 | log = { workspace = true } 26 | rand = { workspace = true } 27 | rustyline = { workspace = true } 28 | spinoff = { workspace = true } 29 | serde_json = { workspace = true } 30 | clap = { workspace = true } 31 | 32 | [features] 33 | default = ["models", "tokenizers-remote"] 34 | 35 | tokenizers-remote = ["llm-base/tokenizers-remote"] 36 | 37 | models = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"] 38 | llama = ["dep:llm-llama"] 39 | gpt2 = ["dep:llm-gpt2"] 40 | gptj = ["dep:llm-gptj"] 41 | bloom = ["dep:llm-bloom"] 42 | gptneox = ["dep:llm-gptneox"] 43 | mpt = ["dep:llm-mpt"] 44 | # Falcon is off by default. See `llm_falcon`'s module documentation for more information. 45 | falcon = ["dep:llm-falcon"] 46 | 47 | cublas = ["llm-base/cublas"] 48 | clblast = ["llm-base/clblast"] 49 | metal = ["llm-base/metal"] 50 | -------------------------------------------------------------------------------- /crates/llm/examples/embeddings.rs: -------------------------------------------------------------------------------- 1 | use std::path::PathBuf; 2 | 3 | use clap::Parser; 4 | 5 | #[derive(Parser)] 6 | struct Args { 7 | model_architecture: llm::ModelArchitecture, 8 | model_path: PathBuf, 9 | #[arg(long, short = 'v')] 10 | pub tokenizer_path: Option, 11 | #[arg(long, short = 'r')] 12 | pub tokenizer_repository: Option, 13 | #[arg(long, short = 'q')] 14 | pub query: Option, 15 | #[arg(long, short = 'c')] 16 | pub comparands: Vec, 17 | } 18 | impl Args { 19 | pub fn to_tokenizer_source(&self) -> llm::TokenizerSource { 20 | match (&self.tokenizer_path, &self.tokenizer_repository) { 21 | (Some(_), Some(_)) => { 22 | panic!("Cannot specify both --tokenizer-path and --tokenizer-repository"); 23 | } 24 | (Some(path), None) => llm::TokenizerSource::HuggingFaceTokenizerFile(path.to_owned()), 25 | (None, Some(repo)) => llm::TokenizerSource::HuggingFaceRemote(repo.to_owned()), 26 | (None, None) => llm::TokenizerSource::Embedded, 27 | } 28 | } 29 | } 30 | 31 | fn main() { 32 | let args = Args::parse(); 33 | 34 | let tokenizer_source = args.to_tokenizer_source(); 35 | let model_architecture = args.model_architecture; 36 | let model_path = args.model_path; 37 | let query = args 38 | .query 39 | .as_deref() 40 | .unwrap_or("My favourite animal is the dog"); 41 | let comparands = if !args.comparands.is_empty() { 42 | args.comparands 43 | } else { 44 | vec![ 45 | "My favourite animal is the dog".to_string(), 46 | "I have just adopted a cute dog".to_string(), 47 | "My favourite animal is the cat".to_string(), 48 | ] 49 | }; 50 | 51 | // Load model 52 | let model_params = llm::ModelParameters::default(); 53 | let model = llm::load_dynamic( 54 | Some(model_architecture), 55 | &model_path, 56 | tokenizer_source, 57 | model_params, 58 | llm::load_progress_callback_stdout, 59 | ) 60 | .unwrap_or_else(|err| { 61 | panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") 62 | }); 63 | let inference_parameters = llm::InferenceParameters::default(); 64 | 65 | // Generate embeddings for query and comparands 66 | let query_embeddings = get_embeddings(model.as_ref(), &inference_parameters, query); 67 | let comparand_embeddings: Vec<(String, Vec)> = comparands 68 | .iter() 69 | .map(|text| { 70 | ( 71 | text.clone(), 72 | get_embeddings(model.as_ref(), &inference_parameters, text), 73 | ) 74 | }) 75 | .collect(); 76 | 77 | // Print embeddings 78 | fn print_embeddings(text: &str, embeddings: &[f32]) { 79 | println!("{text}"); 80 | println!(" Embeddings length: {}", embeddings.len()); 81 | println!(" Embeddings first 10: {:.02?}", embeddings.get(0..10)); 82 | } 83 | 84 | print_embeddings(query, &query_embeddings); 85 | println!("---"); 86 | for (text, embeddings) in &comparand_embeddings { 87 | print_embeddings(text, embeddings); 88 | } 89 | 90 | // Calculate the cosine similarity between the query and each comparand, and sort by similarity 91 | let mut similarities: Vec<(&str, f32)> = comparand_embeddings 92 | .iter() 93 | .map(|(text, embeddings)| { 94 | ( 95 | text.as_str(), 96 | cosine_similarity(&query_embeddings, embeddings), 97 | ) 98 | }) 99 | .collect(); 100 | similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap()); 101 | 102 | // Print similarities 103 | println!("---"); 104 | println!("Similarities:"); 105 | for (text, score) in similarities { 106 | println!(" {text}: {score}"); 107 | } 108 | } 109 | 110 | fn get_embeddings( 111 | model: &dyn llm::Model, 112 | inference_parameters: &llm::InferenceParameters, 113 | query: &str, 114 | ) -> Vec { 115 | let mut session = model.start_session(Default::default()); 116 | let mut output_request = llm::OutputRequest { 117 | all_logits: None, 118 | embeddings: Some(Vec::new()), 119 | }; 120 | let vocab = model.tokenizer(); 121 | let beginning_of_sentence = true; 122 | let query_token_ids = vocab 123 | .tokenize(query, beginning_of_sentence) 124 | .unwrap() 125 | .iter() 126 | .map(|(_, tok)| *tok) 127 | .collect::>(); 128 | model.evaluate(&mut session, &query_token_ids, &mut output_request); 129 | output_request.embeddings.unwrap() 130 | } 131 | 132 | fn cosine_similarity(v1: &[f32], v2: &[f32]) -> f32 { 133 | let dot_product = dot(v1, v2); 134 | let magnitude1 = magnitude(v1); 135 | let magnitude2 = magnitude(v2); 136 | 137 | dot_product / (magnitude1 * magnitude2) 138 | } 139 | 140 | fn dot(v1: &[f32], v2: &[f32]) -> f32 { 141 | v1.iter().zip(v2.iter()).map(|(&x, &y)| x * y).sum() 142 | } 143 | 144 | fn magnitude(v: &[f32]) -> f32 { 145 | v.iter().map(|&x| x * x).sum::().sqrt() 146 | } 147 | -------------------------------------------------------------------------------- /crates/llm/examples/inference.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | use std::{convert::Infallible, io::Write, path::PathBuf}; 3 | 4 | #[derive(Parser)] 5 | struct Args { 6 | model_architecture: llm::ModelArchitecture, 7 | model_path: PathBuf, 8 | #[arg(long, short = 'p')] 9 | prompt: Option, 10 | #[arg(long, short = 'v')] 11 | pub tokenizer_path: Option, 12 | #[arg(long, short = 'r')] 13 | pub tokenizer_repository: Option, 14 | } 15 | impl Args { 16 | pub fn to_tokenizer_source(&self) -> llm::TokenizerSource { 17 | match (&self.tokenizer_path, &self.tokenizer_repository) { 18 | (Some(_), Some(_)) => { 19 | panic!("Cannot specify both --tokenizer-path and --tokenizer-repository"); 20 | } 21 | (Some(path), None) => llm::TokenizerSource::HuggingFaceTokenizerFile(path.to_owned()), 22 | (None, Some(repo)) => llm::TokenizerSource::HuggingFaceRemote(repo.to_owned()), 23 | (None, None) => llm::TokenizerSource::Embedded, 24 | } 25 | } 26 | } 27 | 28 | fn main() { 29 | let args = Args::parse(); 30 | 31 | let tokenizer_source = args.to_tokenizer_source(); 32 | let model_architecture = args.model_architecture; 33 | let model_path = args.model_path; 34 | let prompt = args 35 | .prompt 36 | .as_deref() 37 | .unwrap_or("Rust is a cool programming language because"); 38 | 39 | let now = std::time::Instant::now(); 40 | 41 | let model = llm::load_dynamic( 42 | Some(model_architecture), 43 | &model_path, 44 | tokenizer_source, 45 | Default::default(), 46 | llm::load_progress_callback_stdout, 47 | ) 48 | .unwrap_or_else(|err| { 49 | panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") 50 | }); 51 | 52 | println!( 53 | "Model fully loaded! Elapsed: {}ms", 54 | now.elapsed().as_millis() 55 | ); 56 | 57 | let mut session = model.start_session(Default::default()); 58 | 59 | let res = session.infer::( 60 | model.as_ref(), 61 | &mut rand::thread_rng(), 62 | &llm::InferenceRequest { 63 | prompt: prompt.into(), 64 | parameters: &llm::InferenceParameters::default(), 65 | play_back_previous_tokens: false, 66 | maximum_token_count: None, 67 | }, 68 | // OutputRequest 69 | &mut Default::default(), 70 | |r| match r { 71 | llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => { 72 | print!("{t}"); 73 | std::io::stdout().flush().unwrap(); 74 | 75 | Ok(llm::InferenceFeedback::Continue) 76 | } 77 | _ => Ok(llm::InferenceFeedback::Continue), 78 | }, 79 | ); 80 | 81 | match res { 82 | Ok(result) => println!("\n\nInference stats:\n{result}"), 83 | Err(err) => println!("\n{err}"), 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /crates/llm/examples/vicuna-chat.rs: -------------------------------------------------------------------------------- 1 | use clap::Parser; 2 | use llm_base::conversation_inference_callback; 3 | use rustyline::error::ReadlineError; 4 | use std::{convert::Infallible, io::Write, path::PathBuf}; 5 | 6 | #[derive(Parser)] 7 | struct Args { 8 | model_architecture: llm::ModelArchitecture, 9 | model_path: PathBuf, 10 | #[arg(long, short = 'v')] 11 | pub tokenizer_path: Option, 12 | #[arg(long, short = 'r')] 13 | pub tokenizer_repository: Option, 14 | } 15 | impl Args { 16 | pub fn to_tokenizer_source(&self) -> llm::TokenizerSource { 17 | match (&self.tokenizer_path, &self.tokenizer_repository) { 18 | (Some(_), Some(_)) => { 19 | panic!("Cannot specify both --tokenizer-path and --tokenizer-repository"); 20 | } 21 | (Some(path), None) => llm::TokenizerSource::HuggingFaceTokenizerFile(path.to_owned()), 22 | (None, Some(repo)) => llm::TokenizerSource::HuggingFaceRemote(repo.to_owned()), 23 | (None, None) => llm::TokenizerSource::Embedded, 24 | } 25 | } 26 | } 27 | 28 | fn main() { 29 | let args = Args::parse(); 30 | 31 | let tokenizer_source = args.to_tokenizer_source(); 32 | let model_architecture = args.model_architecture; 33 | let model_path = args.model_path; 34 | let model = llm::load_dynamic( 35 | Some(model_architecture), 36 | &model_path, 37 | tokenizer_source, 38 | Default::default(), 39 | llm::load_progress_callback_stdout, 40 | ) 41 | .unwrap_or_else(|err| { 42 | panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") 43 | }); 44 | 45 | let mut session = model.start_session(Default::default()); 46 | 47 | let character_name = "### Assistant"; 48 | let user_name = "### Human"; 49 | let persona = "A chat between a human and an assistant."; 50 | let history = format!( 51 | "{character_name}: Hello - How may I help you today?\n\ 52 | {user_name}: What is the capital of France?\n\ 53 | {character_name}: Paris is the capital of France." 54 | ); 55 | 56 | let inference_parameters = llm::InferenceParameters::default(); 57 | 58 | session 59 | .feed_prompt( 60 | model.as_ref(), 61 | format!("{persona}\n{history}").as_str(), 62 | &mut Default::default(), 63 | llm::feed_prompt_callback(|resp| match resp { 64 | llm::InferenceResponse::PromptToken(t) 65 | | llm::InferenceResponse::InferredToken(t) => { 66 | print_token(t); 67 | 68 | Ok::(llm::InferenceFeedback::Continue) 69 | } 70 | _ => Ok(llm::InferenceFeedback::Continue), 71 | }), 72 | ) 73 | .expect("Failed to ingest initial prompt."); 74 | 75 | let mut rl = rustyline::DefaultEditor::new().expect("Failed to create input reader"); 76 | 77 | let mut rng = rand::thread_rng(); 78 | let mut res = llm::InferenceStats::default(); 79 | 80 | loop { 81 | println!(); 82 | let readline = rl.readline(format!("{user_name}: ").as_str()); 83 | print!("{character_name}:"); 84 | match readline { 85 | Ok(line) => { 86 | let stats = session 87 | .infer::( 88 | model.as_ref(), 89 | &mut rng, 90 | &llm::InferenceRequest { 91 | prompt: format!("{user_name}: {line}\n{character_name}:") 92 | .as_str() 93 | .into(), 94 | parameters: &inference_parameters, 95 | play_back_previous_tokens: false, 96 | maximum_token_count: None, 97 | }, 98 | &mut Default::default(), 99 | conversation_inference_callback(&format!("{character_name}:"), print_token), 100 | ) 101 | .unwrap_or_else(|e| panic!("{e}")); 102 | 103 | res.feed_prompt_duration = res 104 | .feed_prompt_duration 105 | .saturating_add(stats.feed_prompt_duration); 106 | res.prompt_tokens += stats.prompt_tokens; 107 | res.predict_duration = res.predict_duration.saturating_add(stats.predict_duration); 108 | res.predict_tokens += stats.predict_tokens; 109 | } 110 | Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { 111 | break; 112 | } 113 | Err(err) => { 114 | println!("{err}"); 115 | } 116 | } 117 | } 118 | 119 | println!("\n\nInference stats:\n{res}"); 120 | } 121 | 122 | fn print_token(t: String) { 123 | print!("{t}"); 124 | std::io::stdout().flush().unwrap(); 125 | } 126 | -------------------------------------------------------------------------------- /crates/llm/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This crate provides a unified interface for loading and using 2 | //! Large Language Models (LLMs). The following models are supported: 3 | //! 4 | //! - [BLOOM](llm_bloom) 5 | //! - [GPT-2](llm_gpt2) 6 | //! - [GPT-J](llm_gptj) 7 | //! - [GPT-NeoX](llm_gptneox) 8 | //! - [LLaMA](llm_llama) 9 | //! - [MPT](llm_mpt) 10 | //! - Falcon (currently disabled due to incompleteness) 11 | //! 12 | //! At present, the only supported backend is [GGML](https://github.com/ggerganov/ggml), but this is expected to 13 | //! change in the future. 14 | //! 15 | //! # Example 16 | //! 17 | //! ```no_run 18 | //! use std::io::Write; 19 | //! use llm::Model; 20 | //! 21 | //! // load a GGML model from disk 22 | //! let llama = llm::load::( 23 | //! // path to GGML file 24 | //! std::path::Path::new("/path/to/model"), 25 | //! // llm::TokenizerSource 26 | //! llm::TokenizerSource::Embedded, 27 | //! // llm::ModelParameters 28 | //! Default::default(), 29 | //! // load progress callback 30 | //! llm::load_progress_callback_stdout 31 | //! ) 32 | //! .unwrap_or_else(|err| panic!("Failed to load model: {err}")); 33 | //! 34 | //! // use the model to generate text from a prompt 35 | //! let mut session = llama.start_session(Default::default()); 36 | //! let res = session.infer::( 37 | //! // model to use for text generation 38 | //! &llama, 39 | //! // randomness provider 40 | //! &mut rand::thread_rng(), 41 | //! // the prompt to use for text generation, as well as other 42 | //! // inference parameters 43 | //! &llm::InferenceRequest { 44 | //! prompt: "Rust is a cool programming language because".into(), 45 | //! parameters: &llm::InferenceParameters::default(), 46 | //! play_back_previous_tokens: false, 47 | //! maximum_token_count: None, 48 | //! }, 49 | //! // llm::OutputRequest 50 | //! &mut Default::default(), 51 | //! // output callback 52 | //! |r| match r { 53 | //! llm::InferenceResponse::PromptToken(t) | llm::InferenceResponse::InferredToken(t) => { 54 | //! print!("{t}"); 55 | //! std::io::stdout().flush().unwrap(); 56 | //! 57 | //! Ok(llm::InferenceFeedback::Continue) 58 | //! } 59 | //! _ => Ok(llm::InferenceFeedback::Continue), 60 | //! } 61 | //! ); 62 | //! 63 | //! match res { 64 | //! Ok(result) => println!("\n\nInference stats:\n{result}"), 65 | //! Err(err) => println!("\n{err}"), 66 | //! } 67 | //! ``` 68 | #![deny(missing_docs)] 69 | 70 | use std::{ 71 | error::Error, 72 | fmt::{Debug, Display}, 73 | path::Path, 74 | str::FromStr, 75 | }; 76 | 77 | // Try not to expose too many GGML details here. 78 | // This is the "user-facing" API, and GGML may not always be our backend. 79 | pub use llm_base::{ 80 | conversation_inference_callback, feed_prompt_callback, 81 | ggml::accelerator::get_accelerator as ggml_get_accelerator, 82 | ggml::accelerator::Accelerator as GgmlAccelerator, ggml::format as ggml_format, 83 | ggml::RoPEOverrides, load, load_progress_callback_stdout, quantize, samplers, ElementType, 84 | FileType, FileTypeFormat, FormatMagic, Hyperparameters, InferenceError, InferenceFeedback, 85 | InferenceParameters, InferenceRequest, InferenceResponse, InferenceSession, 86 | InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats, 87 | InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model, ModelKVMemoryType, 88 | ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, RewindError, 89 | SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, 90 | TokenizerSource, 91 | }; 92 | 93 | use serde::Serialize; 94 | 95 | macro_rules! define_models { 96 | ($(($model_lowercase:ident, $model_lowercase_str:literal, $model_pascalcase:ident, $krate_ident:ident, $display_name:literal)),*) => { 97 | /// All available models. 98 | pub mod models { 99 | $( 100 | #[cfg(feature = $model_lowercase_str)] 101 | pub use $krate_ident::{self as $model_lowercase, $model_pascalcase}; 102 | )* 103 | } 104 | 105 | #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize)] 106 | /// All available model architectures. 107 | pub enum ModelArchitecture { 108 | $( 109 | #[cfg(feature = $model_lowercase_str)] 110 | #[doc = concat!("[", $display_name, "](", stringify!($krate_ident), ")")] 111 | $model_pascalcase, 112 | )* 113 | } 114 | 115 | impl ModelArchitecture { 116 | /// All available model architectures 117 | pub const ALL: &'static [Self] = &[ 118 | $( 119 | #[cfg(feature = $model_lowercase_str)] 120 | Self::$model_pascalcase, 121 | )* 122 | ]; 123 | } 124 | 125 | impl ModelArchitecture { 126 | /// Use a visitor to dispatch some code based on the model architecture. 127 | pub fn visit(&self, visitor: &mut impl ModelArchitectureVisitor) -> R { 128 | match self { 129 | $( 130 | #[cfg(feature = $model_lowercase_str)] 131 | Self::$model_pascalcase => visitor.visit::(), 132 | )* 133 | } 134 | } 135 | } 136 | 137 | impl FromStr for ModelArchitecture { 138 | type Err = UnsupportedModelArchitecture; 139 | 140 | fn from_str(s: &str) -> Result { 141 | use ModelArchitecture::*; 142 | match s 143 | .to_lowercase() 144 | .chars() 145 | .filter(|c| c.is_alphanumeric()) 146 | .collect::() 147 | .as_str() 148 | { 149 | $( 150 | #[cfg(feature = $model_lowercase_str)] 151 | $model_lowercase_str => Ok($model_pascalcase), 152 | )* 153 | 154 | _ => Err(UnsupportedModelArchitecture(format!( 155 | "{s} is not one of supported model architectures: {:?}", ModelArchitecture::ALL 156 | ))), 157 | } 158 | } 159 | } 160 | 161 | impl Display for ModelArchitecture { 162 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 163 | match self { 164 | $( 165 | #[cfg(feature = $model_lowercase_str)] 166 | Self::$model_pascalcase => write!(f, $display_name), 167 | )* 168 | } 169 | } 170 | } 171 | }; 172 | } 173 | 174 | define_models!( 175 | (bloom, "bloom", Bloom, llm_bloom, "BLOOM"), 176 | (gpt2, "gpt2", Gpt2, llm_gpt2, "GPT-2"), 177 | (gptj, "gptj", GptJ, llm_gptj, "GPT-J"), 178 | (gptneox, "gptneox", GptNeoX, llm_gptneox, "GPT-NeoX"), 179 | (llama, "llama", Llama, llm_llama, "LLaMA"), 180 | (mpt, "mpt", Mpt, llm_mpt, "MPT"), 181 | (falcon, "falcon", Falcon, llm_falcon, "Falcon") 182 | ); 183 | 184 | /// Used to dispatch some code based on the model architecture. 185 | pub trait ModelArchitectureVisitor { 186 | /// Visit a model architecture. 187 | fn visit(&mut self) -> R; 188 | } 189 | 190 | /// An unsupported model architecture was specified. 191 | pub struct UnsupportedModelArchitecture(String); 192 | impl Display for UnsupportedModelArchitecture { 193 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 194 | write!(f, "{}", self.0) 195 | } 196 | } 197 | impl Error for UnsupportedModelArchitecture {} 198 | impl Debug for UnsupportedModelArchitecture { 199 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 200 | f.debug_tuple("UnsupportedModelArchitecture") 201 | .field(&self.0) 202 | .finish() 203 | } 204 | } 205 | 206 | /// A helper function that loads the specified model from disk using an architecture 207 | /// specified at runtime. If no architecture is specified, it will try to infer it 208 | /// from the model's metadata. 209 | /// 210 | /// This method returns a [`Box`], which means that the model will have single ownership. 211 | /// If you'd like to share ownership (i.e. to use the model in multiple threads), we 212 | /// suggest using [`Arc::from(Box)`](https://doc.rust-lang.org/std/sync/struct.Arc.html#impl-From%3CBox%3CT,+Global%3E%3E-for-Arc%3CT%3E) 213 | /// to convert the [`Box`] into an [`Arc`](std::sync::Arc) after loading. 214 | pub fn load_dynamic( 215 | architecture: Option, 216 | path: &Path, 217 | tokenizer_source: TokenizerSource, 218 | params: ModelParameters, 219 | load_progress_callback: impl FnMut(LoadProgress), 220 | ) -> Result, LoadError> { 221 | fn load_model( 222 | path: &Path, 223 | tokenizer_source: TokenizerSource, 224 | params: ModelParameters, 225 | load_progress_callback: impl FnMut(LoadProgress), 226 | ) -> Result, LoadError> { 227 | Ok(Box::new(load::( 228 | path, 229 | tokenizer_source, 230 | params, 231 | load_progress_callback, 232 | )?)) 233 | } 234 | 235 | let architecture = architecture.ok_or_else(|| LoadError::MissingModelArchitecture { 236 | path: path.to_owned(), 237 | })?; 238 | 239 | struct LoadVisitor<'a, F: FnMut(LoadProgress)> { 240 | path: &'a Path, 241 | tokenizer_source: TokenizerSource, 242 | params: ModelParameters, 243 | load_progress_callback: F, 244 | } 245 | impl<'a, F: FnMut(LoadProgress)> ModelArchitectureVisitor, LoadError>> 246 | for LoadVisitor<'a, F> 247 | { 248 | fn visit(&mut self) -> Result, LoadError> { 249 | load_model::( 250 | self.path, 251 | self.tokenizer_source.clone(), 252 | self.params.clone(), 253 | &mut self.load_progress_callback, 254 | ) 255 | } 256 | } 257 | 258 | architecture.visit(&mut LoadVisitor { 259 | path, 260 | tokenizer_source, 261 | params, 262 | load_progress_callback, 263 | }) 264 | } 265 | 266 | #[cfg(test)] 267 | mod tests { 268 | use super::*; 269 | 270 | #[test] 271 | fn test_model_architecture_from_str() { 272 | for arch in ModelArchitecture::ALL { 273 | assert_eq!( 274 | arch, 275 | &arch.to_string().parse::().unwrap() 276 | ); 277 | } 278 | } 279 | } 280 | -------------------------------------------------------------------------------- /crates/models/bloom/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llm-bloom" 3 | version = "0.2.0-dev" 4 | license = { workspace = true } 5 | repository = { workspace = true } 6 | description = "An implementation of BLOOM (BigScience Large Open-science Open-access Multilingual Language Model) for the `llm` ecosystem." 7 | edition = "2021" 8 | readme = "../../../README.md" 9 | 10 | [dependencies] 11 | llm-base = { path = "../../llm-base", version = "0.2.0-dev" } -------------------------------------------------------------------------------- /crates/models/falcon/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llm-falcon" 3 | version = "0.2.0-dev" 4 | license = { workspace = true } 5 | repository = { workspace = true } 6 | description = "An implementation of Falcon for the `llm` ecosystem." 7 | edition = "2021" 8 | readme = "../../../README.md" 9 | 10 | [dependencies] 11 | llm-base = { path = "../../llm-base", version = "0.2.0-dev" } -------------------------------------------------------------------------------- /crates/models/gpt2/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llm-gpt2" 3 | version = "0.2.0-dev" 4 | license = { workspace = true } 5 | repository = { workspace = true } 6 | description = "An implementation of GPT-2 for the `llm` ecosystem." 7 | edition = "2021" 8 | readme = "../../../README.md" 9 | 10 | [dependencies] 11 | llm-base = { path = "../../llm-base", version = "0.2.0-dev" } 12 | 13 | bytemuck = { workspace = true } 14 | -------------------------------------------------------------------------------- /crates/models/gptj/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llm-gptj" 3 | version = "0.2.0-dev" 4 | license = { workspace = true } 5 | repository = { workspace = true } 6 | description = "An implementation of GPT-J for the `llm` ecosystem." 7 | edition = "2021" 8 | readme = "../../../README.md" 9 | 10 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 11 | 12 | [dependencies] 13 | llm-base = { path = "../../llm-base", version = "0.2.0-dev" } -------------------------------------------------------------------------------- /crates/models/gptneox/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /Cargo.lock 3 | -------------------------------------------------------------------------------- /crates/models/gptneox/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llm-gptneox" 3 | version = "0.2.0-dev" 4 | license = { workspace = true } 5 | repository = { workspace = true } 6 | description = "An implementation of GPT-NeoX for the `llm` ecosystem." 7 | edition = "2021" 8 | readme = "../../../README.md" 9 | 10 | [dependencies] 11 | llm-base = { path = "../../llm-base", version = "0.2.0-dev" } 12 | -------------------------------------------------------------------------------- /crates/models/llama/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llm-llama" 3 | version = "0.2.0-dev" 4 | license = { workspace = true } 5 | repository = { workspace = true } 6 | description = "An implementation of LLaMA (Large Language Model Meta AI) for the `llm` ecosystem." 7 | edition = "2021" 8 | readme = "../../../README.md" 9 | 10 | [dependencies] 11 | llm-base = { path = "../../llm-base", version = "0.2.0-dev" } 12 | tracing = { version = "0.1", features = ["log"] } 13 | 14 | -------------------------------------------------------------------------------- /crates/models/mpt/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llm-mpt" 3 | version = "0.2.0-dev" 4 | license = { workspace = true } 5 | repository = { workspace = true } 6 | description = "An implementation of MosaicPretrainedTransformer (MPT) for the `llm` ecosystem." 7 | edition = "2021" 8 | readme = "../../../README.md" 9 | 10 | [dependencies] 11 | llm-base = { path = "../../llm-base", version = "0.2.0-dev" } -------------------------------------------------------------------------------- /doc/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributors Guide 2 | 3 | The purpose of this document is to make it easy for open-source community 4 | members to contribute to this project. We'd love to discuss your contributions 5 | with you via a GitHub [Issue](https://github.com/rustformers/llm/issues/new) or 6 | [Discussion](https://github.com/rustformers/llm/discussions/new?category=ideas), 7 | or on [Discord](https://discord.gg/YB9WaXYAWU)! 8 | 9 | ## Checking Changes 10 | 11 | This project uses a [GitHub workflow](../.github/workflows/rust.yml) to enforce 12 | code standards. 13 | 14 | The `rusty-hook` project is used to run a similar set of checks automatically before committing. 15 | If you would like to run these checks locally, use `cargo run -p precommit-check`. 16 | 17 | ## Regenerating GGML Bindings 18 | 19 | Follow these steps to update the GGML submodule and regenerate the Rust bindings 20 | (this is only necessary if your changes depend on new GGML features): 21 | 22 | ```shell 23 | git submodule update --remote 24 | cargo run --release --package generate-ggml-bindings 25 | ``` 26 | 27 | ## Debugging 28 | 29 | This repository includes a [`launch.json` file](../.vscode/launch.json) that can 30 | be used for 31 | [debugging with Visual Studio Code](https://code.visualstudio.com/docs/editor/debugging) - 32 | this file will need to be updated to reflect where models are stored on your 33 | system. Debugging with Visual Studio Code requires a 34 | [language extension](https://code.visualstudio.com/docs/languages/rust#_install-debugging-support) 35 | that depends on your operating system. Keep in mind that debugging text 36 | generation is extremely slow, but debugging model loading is not. 37 | 38 | ## LLM References 39 | 40 | Here are some tried-and-true references for learning more about large language 41 | models: 42 | 43 | - [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/) - an 44 | excellent technical description of how this seminal language model generates 45 | text 46 | - [Andrej Karpathy's "Neural Networks: Zero to Hero"](https://karpathy.ai/zero-to-hero.html) - 47 | a series of in-depth YouTube videos that guide the viewer through creating a 48 | neural network, a large language model, and a fully functioning chatbot, from 49 | scratch (in Python) 50 | - [rustygrad](https://github.com/Mathemmagician/rustygrad) - a native Rust 51 | implementation of Andrej Karpathy's micrograd 52 | - [Understanding Deep Learning](https://udlbook.github.io/udlbook/) (Chapter 12 53 | specifically) 54 | -------------------------------------------------------------------------------- /doc/acceleration-support.md: -------------------------------------------------------------------------------- 1 | # Acceleration Support 2 | 3 | The `llm` ecosystem of crates, including `llm`, `llm-base` and `ggml` support various acceleration backends, selectable via `--features` flags. The availability of supported backends varies by platform, and these crates can only be built with a single active acceleration backend at a time. If CuBLAS and CLBlast are both specified, CuBLAS is prioritized and CLBlast is ignored. 4 | 5 | | Platform/OS | `cublas` | `clblast` | `metal` | 6 | | ----------- | ------------------ | ------------------ | ------------------ | 7 | | Windows | :heavy_check_mark: | :heavy_check_mark: | :x: | 8 | | Linux | :heavy_check_mark: | :heavy_check_mark: | :x: | 9 | | MacOS | :x: | :x: | :heavy_check_mark: | 10 | 11 | ## Utilizing GPU Support 12 | 13 | To activate GPU support (assuming that you have enabled one of the features above), set the `use_gpu` attribute of the `ModelParameters` to `true`. 14 | 15 | - **CLI Users**: You can enable GPU support by adding the `--use-gpu` flag. 16 | 17 | - **Backend Consideration**: For users leveraging the `cublas` or `clblast` backends, you can specify the number of layers you wish to offload to your GPU with the `gpu_layers` parameter in the `ModelParameters`. By default, all layers are offloaded. 18 | 19 | However, if your model size exceeds your GPU's VRAM, you can specify a limit, like `20`, to offload only the first 20 layers. For CLI users, this can be achieved using the `--gpu-layers` parameter. 20 | 21 | **Example**: To run a `llama` model with CUDA acceleration and offload all its layers, your CLI command might resemble: 22 | 23 | ```bash 24 | cargo run --release --features cublas -- infer -a llama -m [path/to/model.bin] --use-gpu -p "Help a llama is standing in my garden!" 25 | ``` 26 | 27 | 💡 **Protip**: For those with ample VRAM using `cublas` or `clblast`, you can significantly reduce your prompt's feed time by increasing the batch size; for example, you can use `256` or `512` (default is `8`). 28 | 29 | - Programmatic users of `llm` can adjust this by setting the `n_batch` parameter in the `InferenceSessionConfig` when initializing a session. 30 | 31 | - CLI users can utilize the `--batch-size` parameter to achieve this. 32 | 33 | ## Supported Accelerated Models 34 | 35 | While specific accelerators only support certain model architectures, some unmarked architectures may function, but their performance is not guaranteed—it hinges on the operations used by the model's architecture. The table below lists models with confirmed compatibility for each accelerator: 36 | 37 | | Model/accelerator | `cublas` | `clblast` | `metal` | 38 | | ----------------- | -------- | --------- | ------- | 39 | | LLaMA | ✅ | ✅ | ✅ | 40 | | MPT | ❌ | ❌ | ❌ | 41 | | Falcon | ❌ | ❌ | ❌ | 42 | | GPT-NeoX | ❌ | ❌ | ❌ | 43 | | GPT-J | ✅ | ❌ | ❌ | 44 | | GPT-2 | ❌ | ❌ | ❌ | 45 | | BLOOM | ❌ | ❌ | ❌ | 46 | 47 | ## Pre-requisites for Building with Accelerated Support 48 | 49 | To build with acceleration support, certain dependencies must be installed. These dependencies are contingent upon your chosen platform and the specific acceleration backend you're working with. 50 | 51 | For developers aiming to distribute packages equipped with acceleration capabilities, our [CI/CD setup](../.github/workflows/rust.yml) serves as an exemplary foundation. 52 | 53 | ### Windows 54 | 55 | #### CuBLAS 56 | 57 | CUDA must be installed. You can download CUDA from the official [Nvidia site](https://developer.nvidia.com/cuda-downloads). 58 | 59 | #### CLBlast 60 | 61 | CLBlast can be installed via [vcpkg](https://vcpkg.io/en/getting-started.html) using the command `vcpkg install clblast`. After installation, the `OPENCL_PATH` and `CLBLAST_PATH` environment variables should be set to the `opencl_x64-windows` and `clblast_x64-windows` directories respectively. 62 | 63 | Here's an example of the required commands: 64 | 65 | ``` 66 | git clone https://github.com/Microsoft/vcpkg.git 67 | .\vcpkg\bootstrap-vcpkg.bat 68 | .\vcpkg\vcpkg install clblast 69 | set OPENCL_PATH=....\vcpkg\packages\opencl_x64-windows 70 | set CLBLAST_PATH=....\vcpkg\packages\clblast_x64-windows 71 | ``` 72 | 73 | ⚠️ When working with MSVC in a Windows environment, it is essential to set the `-Ctarget-feature=+crt-static` Rust flag. This flag is critical as it enables the static linking of the C runtime, which can be paramount for certain deployment scenarios or specific runtime environments. 74 | 75 | To set this flag, you can modify the .cargo\config file in your project directory. Please add the following configuration snippet: 76 | 77 | ``` 78 | [target.x86_64-pc-windows-msvc] 79 | rustflags = ["-Ctarget-feature=+crt-static"] 80 | ``` 81 | 82 | This will ensure the Rust flag is appropriately set for your compilation process. 83 | 84 | For a comprehensive guide on the usage of Rust flags, including other possible ways to set them, please refer to this detailed [StackOverflow discussion](https://stackoverflow.com/questions/38040327/how-to-pass-rustc-flags-to-cargo). Make sure to choose an option that best fits your project requirements and development environment. 85 | 86 | ⚠️ For `llm` to function properly, it requires the `clblast.dll` and `OpenCL.dll` files. These files can be found within the `bin` subdirectory of their respective vcpkg packages. There are two options to ensure `llm` can access these files: 87 | 88 | 1. Amend your `PATH` environment variable to include the `bin` directories of each respective package. 89 | 90 | 2. Manually copy the `clblast.dll` and `OpenCL.dll` files into the `./target/release` or `./target/debug` directories. The destination directory will depend on the profile that was active during the compilation process. 91 | 92 | Please choose the option that best suits your needs and environment configuration. 93 | 94 | ### Linux 95 | 96 | #### CuBLAS 97 | 98 | You need to have CUDA installed on your system. CUDA can be downloaded and installed from the official [Nvidia site](https://developer.nvidia.com/cuda-downloads). On Linux distributions that do not have `CUDA_PATH` set, the environment variables `CUDA_INCLUDE_PATH` and `CUDA_LIB_PATH` can be set to their corresponding paths. 99 | 100 | #### CLBlast 101 | 102 | CLBlast can be installed on Linux through various package managers. For example, using `apt` you can install it via `sudo apt install clblast`. After installation, make sure that the `OPENCL_PATH` and `CLBLAST_PATH` environment variables are correctly set. Additionally the environment variables `OPENCL_INCLUDE_PATH`/`OPENCL_LIB_PATH` & `CBLAST_INCLUDE_PATH`/`CLBLAST_LIB_PATH` can be used to specify the location of the files. All environment variables are supported by all listed operating systems. 103 | 104 | ### MacOS 105 | 106 | #### Metal 107 | 108 | Xcode and the associated command-line tools should be installed on your system, and you should be running a version of MacOS that supports Metal. For more detailed information, please consult the [official Metal documentation](https://developer.apple.com/metal/). 109 | 110 | To enable Metal using the CLI, ensure it was built successfully using `--features=metal` and then pass the `--use-gpu` flag. 111 | 112 | The current underlying implementation of Metal in GGML is still in flux and has some limitations: 113 | 114 | - Evaluating a model with more than one token at a time is not currently supported in GGML's Metal implementation. An `llm` inference session will fall back to the CPU implementation (typically during the 'feed prompt' phase) but will automatically use the GPU once a single token is passed per evaluation (typically after prompt feeding). 115 | - Not all model architectures will be equally stable when used with Metal due to ongoing work in the underlying implementation. Expect `llama` models to work fine though. 116 | - With Metal, it is possible but not required to use `mmap`. As buffers do not need to be copied to VRAM on M1, `mmap` is the most efficient however. 117 | - Debug messages may be logged by the underlying GGML Metal implementation. This will likely go away in the future for release builds of `llm`. 118 | -------------------------------------------------------------------------------- /doc/img/llm-crab-llama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rustformers/llm/b11ffb14fde039f068c808df8dcbd1257a01b8db/doc/img/llm-crab-llama.png -------------------------------------------------------------------------------- /doc/known-good-models.md: -------------------------------------------------------------------------------- 1 | # Known-good models 2 | 3 | The following models have been tested and are known to work with `llm`. 4 | 5 | Models are distributed as single files, but come in a variety of quantization levels. 6 | You will need to select the quantization level that is appropriate for your application. 7 | For more information, see [Getting Models](../README.md#getting-models) in the README. 8 | 9 | The LLaMA architecture is the most well-supported. 10 | 11 | ## LLaMA 12 | 13 | We have chosen not to include any models based on the original LLaMA model due to licensing concerns. 14 | However, the OpenLLaMA models are available under the Apache 2.0 license and are compatible with `llm`. 15 | 16 | - 17 | - 18 | - 19 | 20 | Models based on the original LLaMA model are also compatible, but you will need to find them yourselves 21 | due to their licensing. 22 | 23 | ## GPT-2 24 | 25 | - : note that this is `f16`-only and 26 | we recommend you quantize it using `llm` for best performance. 27 | - 28 | 29 | ## GPT-J 30 | 31 | - 32 | - 33 | 34 | ## MPT 35 | 36 | - 37 | 38 | ## GPT-NeoX/RedPajama 39 | 40 | - 41 | - 42 | - 43 | - 44 | 45 | ## BLOOM 46 | 47 | - 48 | - 49 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "inputs": { 5 | "systems": "systems" 6 | }, 7 | "locked": { 8 | "lastModified": 1681202837, 9 | "narHash": "sha256-H+Rh19JDwRtpVPAWp64F+rlEtxUWBAQW28eAi3SRSzg=", 10 | "owner": "numtide", 11 | "repo": "flake-utils", 12 | "rev": "cfacdce06f30d2b68473a46042957675eebb3401", 13 | "type": "github" 14 | }, 15 | "original": { 16 | "owner": "numtide", 17 | "repo": "flake-utils", 18 | "type": "github" 19 | } 20 | }, 21 | "naersk": { 22 | "inputs": { 23 | "nixpkgs": "nixpkgs" 24 | }, 25 | "locked": { 26 | "lastModified": 1679567394, 27 | "narHash": "sha256-ZvLuzPeARDLiQUt6zSZFGOs+HZmE+3g4QURc8mkBsfM=", 28 | "owner": "nix-community", 29 | "repo": "naersk", 30 | "rev": "88cd22380154a2c36799fe8098888f0f59861a15", 31 | "type": "github" 32 | }, 33 | "original": { 34 | "owner": "nix-community", 35 | "repo": "naersk", 36 | "type": "github" 37 | } 38 | }, 39 | "nixpkgs": { 40 | "locked": { 41 | "lastModified": 1683267615, 42 | "narHash": "sha256-A/zAy9YauwdPut90h6cYC1zgP/WmuW9zmJ+K/c5i6uc=", 43 | "owner": "NixOS", 44 | "repo": "nixpkgs", 45 | "rev": "0b6445b611472740f02eae9015150c07c5373340", 46 | "type": "github" 47 | }, 48 | "original": { 49 | "id": "nixpkgs", 50 | "type": "indirect" 51 | } 52 | }, 53 | "nixpkgs_2": { 54 | "locked": { 55 | "lastModified": 1683267615, 56 | "narHash": "sha256-A/zAy9YauwdPut90h6cYC1zgP/WmuW9zmJ+K/c5i6uc=", 57 | "owner": "nixos", 58 | "repo": "nixpkgs", 59 | "rev": "0b6445b611472740f02eae9015150c07c5373340", 60 | "type": "github" 61 | }, 62 | "original": { 63 | "owner": "nixos", 64 | "ref": "nixpkgs-unstable", 65 | "repo": "nixpkgs", 66 | "type": "github" 67 | } 68 | }, 69 | "root": { 70 | "inputs": { 71 | "flake-utils": "flake-utils", 72 | "naersk": "naersk", 73 | "nixpkgs": "nixpkgs_2" 74 | } 75 | }, 76 | "systems": { 77 | "locked": { 78 | "lastModified": 1681028828, 79 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 80 | "owner": "nix-systems", 81 | "repo": "default", 82 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 83 | "type": "github" 84 | }, 85 | "original": { 86 | "owner": "nix-systems", 87 | "repo": "default", 88 | "type": "github" 89 | } 90 | } 91 | }, 92 | "root": "root", 93 | "version": 7 94 | } 95 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Rust-based tool for inference of LLMs."; 3 | inputs = { 4 | nixpkgs.url = github:nixos/nixpkgs/nixpkgs-unstable; 5 | naersk.url = github:nix-community/naersk; 6 | flake-utils.url = github:numtide/flake-utils; 7 | }; 8 | 9 | outputs = { self, nixpkgs, naersk, flake-utils }: 10 | flake-utils.lib.eachDefaultSystem (system: 11 | let 12 | pkgs = nixpkgs.legacyPackages.${system}; 13 | naersk' = pkgs.callPackage naersk { }; 14 | llm = naersk'.buildPackage { 15 | src = ./.; 16 | }; 17 | in 18 | { 19 | formatter = pkgs.nixpkgs-fmt; 20 | packages.default = llm; 21 | apps.default = { 22 | type = "app"; 23 | program = "${llm}/bin/llm"; 24 | }; 25 | devShells.default = with pkgs; mkShell { 26 | packages = [ cargo rustc rust-analyzer rustfmt cmake ]; 27 | RUST_SRC_PATH = rustPlatform.rustLibSrc; 28 | }; 29 | } 30 | ); 31 | } 32 | -------------------------------------------------------------------------------- /utils/Dockerfile: -------------------------------------------------------------------------------- 1 | # Start with a rust alpine image 2 | FROM rust:alpine3.17 as builder 3 | # This is important, see https://github.com/rust-lang/docker-rust/issues/85 4 | ENV RUSTFLAGS="-C target-feature=-crt-static" 5 | # if needed, add additional dependencies here 6 | RUN apk add --no-cache musl-dev 7 | # set the workdir and copy the source into it 8 | WORKDIR /app 9 | COPY ./ /app 10 | # do a release build 11 | RUN cargo build --release --bin llm 12 | RUN strip target/release/llm 13 | 14 | # use a plain alpine image, the alpine version needs to match the builder 15 | FROM alpine:3.17 16 | # if needed, install additional dependencies here 17 | RUN apk add --no-cache libgcc 18 | # copy the binary into the final image 19 | COPY --from=builder /app/target/release/llm . 20 | # set the binary as entrypoint 21 | ENTRYPOINT ["/llm"] 22 | -------------------------------------------------------------------------------- /utils/prompts/alpaca.txt: -------------------------------------------------------------------------------- 1 | Below is an instruction that describes a task. Write a response that appropriately completes the request. 2 | 3 | ### Instruction: 4 | 5 | {{PROMPT}} 6 | 7 | ### Response: 8 | 9 | -------------------------------------------------------------------------------- /utils/prompts/pygmalion-message.txt: -------------------------------------------------------------------------------- 1 | You: -------------------------------------------------------------------------------- /utils/prompts/pygmalion-prelude.txt: -------------------------------------------------------------------------------- 1 | Assistant's Persona: Assistant is a highly intelligent language model trained to comply with user requests. 2 | 3 | Assistant: How may I help you? -------------------------------------------------------------------------------- /utils/prompts/vicuna-message.txt: -------------------------------------------------------------------------------- 1 | User: -------------------------------------------------------------------------------- /utils/prompts/vicuna-prelude.txt: -------------------------------------------------------------------------------- 1 | A chat between a human ("User") and an AI assistant ("Assistant"). The assistant gives helpful, detailed, and polite answers to the human's questions. 2 | 3 | Assistant: How may I help you? --------------------------------------------------------------------------------