├── .dockerignore ├── .github ├── dependabot.yml └── workflows │ ├── llama-cpp-rs-check.yml │ ├── publish-upon-release.yml │ ├── update-llama-cpp.yml │ └── update-toml-version.yaml ├── .gitignore ├── .gitmodules ├── Cargo.lock ├── Cargo.toml ├── LICENSE-APACHE ├── LISENCE-MIT ├── README.md ├── examples ├── embeddings │ ├── Cargo.toml │ └── src │ │ └── main.rs ├── reranker │ ├── Cargo.toml │ ├── README.md │ └── src │ │ └── main.rs ├── simple │ ├── Cargo.toml │ └── src │ │ └── main.rs └── usage.rs ├── llama-cpp-2 ├── Cargo.toml ├── README.md └── src │ ├── context.rs │ ├── context │ ├── kv_cache.rs │ ├── params.rs │ └── session.rs │ ├── grammar │ ├── arithmetic.gbnf │ ├── c.gbnf │ ├── chess.gbnf │ ├── japanese.gbnf │ ├── json.gbnf │ ├── json_arr.gbnf │ ├── list.gbnf │ └── tests.rs │ ├── lib.rs │ ├── llama_backend.rs │ ├── llama_batch.rs │ ├── log.rs │ ├── model.rs │ ├── model │ ├── params.rs │ └── params │ │ └── kv_overrides.rs │ ├── sampling.rs │ ├── timing.rs │ ├── token.rs │ ├── token │ ├── data.rs │ ├── data_array.rs │ └── logit_bias.rs │ └── token_type.rs ├── llama-cpp-sys-2 ├── Cargo.lock ├── Cargo.toml ├── README.md ├── build.rs ├── src │ └── lib.rs └── wrapper.h └── test-build.Dockerfile /.dockerignore: -------------------------------------------------------------------------------- 1 | target 2 | Dockerfile 3 | .dockerignore 4 | .gitignore -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "cargo" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | - package-ecosystem: github-actions 8 | directory: / 9 | schedule: 10 | interval: "weekly" 11 | -------------------------------------------------------------------------------- /.github/workflows/llama-cpp-rs-check.yml: -------------------------------------------------------------------------------- 1 | name: Llama Cpp Rs Check 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | workflow_dispatch: 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.ref }} 11 | cancel-in-progress: true 12 | 13 | permissions: read-all 14 | 15 | jobs: 16 | check: 17 | name: Run Tests on LLama Cpp Rs 18 | runs-on: ubuntu-latest 19 | steps: 20 | - name: Checkout 21 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 22 | with: 23 | submodules: recursive 24 | - name: Install Compile Deps 25 | env: 26 | DEBIAN_FRONTEND: noninteractive 27 | run: 28 | sudo apt-get update && sudo apt-get install -y build-essential curl libssl-dev libclang-dev pkg-config cmake git 29 | - uses: dtolnay/rust-toolchain@stable 30 | with: 31 | components: clippy, rustfmt 32 | - name: Clippy 33 | run: cargo clippy 34 | - name: Fmt 35 | run: cargo fmt 36 | - name: Test 37 | run: cargo test --features sampler 38 | arm64: 39 | name: Check that it builds on various targets 40 | runs-on: ubuntu-latest 41 | strategy: 42 | matrix: 43 | target: [ linux/arm64, linux/amd64 ] 44 | steps: 45 | - name: checkout 46 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 47 | - name: Setup QEMU 48 | uses: docker/setup-qemu-action@29109295f81e9208d7d86ff1c6c12d2833863392 49 | with: 50 | platforms: arm64,amd64 51 | - name: Set up Docker Buildx 52 | uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 53 | - name: Build 54 | uses: docker/build-push-action@v6 55 | with: 56 | file: test-build.Dockerfile 57 | target: base-cuda 58 | platforms: ${{ matrix.target }} 59 | mac: 60 | name: Check that it builds on mac 61 | runs-on: macos-latest 62 | steps: 63 | - name: checkout 64 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 65 | with: 66 | submodules: recursive 67 | - name: Setup Rust 68 | uses: dtolnay/rust-toolchain@stable 69 | - name: Build 70 | run: cargo build --features sampler 71 | windows: 72 | name: Check that it builds on windows 73 | runs-on: windows-latest 74 | steps: 75 | - name: checkout 76 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 77 | with: 78 | submodules: recursive 79 | - name: Setup Rust 80 | uses: dtolnay/rust-toolchain@stable 81 | - name: Build 82 | run: cargo build --features sampler 83 | - name: Test 84 | run: cargo test --features sampler -------------------------------------------------------------------------------- /.github/workflows/publish-upon-release.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Crates.io 2 | 3 | on: 4 | release: 5 | types: 6 | - created 7 | workflow_dispatch: 8 | 9 | permissions: 10 | contents: write 11 | 12 | jobs: 13 | publish: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 18 | with: 19 | submodules: recursive 20 | - name: Publish crates for llama-cpp-sys-2 21 | run: RUST_BACKTRACE=1 cargo publish --package llama-cpp-sys-2 --token ${{ secrets.CARGO_REGISTRY_TOKEN }} --verbose 22 | - name: Publish crates for llama-cpp-2 23 | run: RUST_BACKTRACE=1 cargo publish --package llama-cpp-2 --token ${{ secrets.CARGO_REGISTRY_TOKEN }} --verbose 24 | 25 | # Trigger the 'update-toml-version' workflow 26 | - name: Dispatch Update TOML Version Event 27 | if: success() # Ensure this runs only if the previous steps were successful 28 | run: | 29 | curl -X POST \ 30 | -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \ 31 | -H "Accept: application/vnd.github.everest-preview+json" \ 32 | "https://api.github.com/repos/${{ github.repository }}/dispatches" \ 33 | -d '{"event_type": "trigger-update-toml-version"}' 34 | -------------------------------------------------------------------------------- /.github/workflows/update-llama-cpp.yml: -------------------------------------------------------------------------------- 1 | name: Update llama cpp nightly 2 | on: 3 | schedule: 4 | - cron: '0 0 * * *' 5 | workflow_dispatch: { } 6 | 7 | permissions: 8 | pull-requests: write 9 | contents: write 10 | 11 | jobs: 12 | update: 13 | runs-on: ubuntu-latest 14 | name: Update llama cpp 15 | steps: 16 | - name: Set date 17 | run: echo "DATE=$(date -I)" >> $GITHUB_ENV 18 | - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 19 | name: Checkout latest 20 | with: 21 | submodules: recursive 22 | - name: Create branch 23 | run: git checkout -b update-llama-cpp-${{ env.DATE }} 24 | - name: Update submodules 25 | run: git submodule update --remote 26 | - name: Config git 27 | run: | 28 | git config --global user.email "marcus@utilityai.ca" 29 | git config --global user.name "Marcus Dunn" 30 | - name: Commit 31 | run: git commit -am "updated llama.cpp" 32 | - name: Push 33 | run: git push --set-upstream origin update-llama-cpp-${{ env.DATE }} --force 34 | - name: Close any outdated PRs 35 | env: 36 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 37 | run: | 38 | gh pr list --json number,title --jq '.[] | select(.title | contains("Updated llama-cpp (bot)")) | .number' | xargs -I {} gh pr close {} 39 | - name: Create open PR 40 | env: 41 | GITHUB_TOKEN: ${{ secrets.LLAMA_CPP_RS_UPDATE_LLAMA_CPP_ACTION}} 42 | run: | 43 | unset GITHUB_TOKEN 44 | echo ${{ secrets.LLAMA_CPP_RS_UPDATE_LLAMA_CPP_ACTION }} | gh auth login --with-token 45 | gh pr create --fill --head update-llama-cpp-${{ env.DATE }} --title "Updated llama-cpp (bot)" 46 | -------------------------------------------------------------------------------- /.github/workflows/update-toml-version.yaml: -------------------------------------------------------------------------------- 1 | name: Update version in TOML files 2 | 3 | on: 4 | repository_dispatch: 5 | types: [ trigger-update-toml-version ] 6 | workflow_dispatch: 7 | 8 | permissions: 9 | contents: write 10 | pull-requests: write 11 | 12 | jobs: 13 | modify_files: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - name: Checkout code 18 | uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 19 | with: 20 | submodules: recursive 21 | 22 | - name: Update version in TOML files 23 | env: 24 | GH_TOKEN: ${{ github.token }} 25 | run: | 26 | # Extract the current version from the TOML file 27 | CURRENT_VERSION=$(awk -F '"' '/^version/ {print $2}' llama-cpp-2/Cargo.toml) 28 | # Increment the version 29 | NEXT_VERSION=$(echo "$CURRENT_VERSION" | awk -F. -v OFS=. '{++$NF; print}') 30 | # Update version in llama-cpp-sys-2 Cargo.toml 31 | sed -i "s/^version = \".*\"/version = \"$NEXT_VERSION\"/g" llama-cpp-sys-2/Cargo.toml 32 | # Update version in llama-cpp-2 Cargo.toml 33 | sed -i "s/^version = \".*\"/version = \"$NEXT_VERSION\"/g" llama-cpp-2/Cargo.toml 34 | sed -i "s/^\(llama-cpp-sys-2 = { path = \"\.\.\/llama-cpp-sys-2\", version = \)\"$CURRENT_VERSION\"/\1\"$NEXT_VERSION\"/" llama-cpp-2/Cargo.toml 35 | # Update the version in the simple Cargo.toml 36 | sed -i "s/^version = \".*\"/version = \"$NEXT_VERSION\"/g" examples/simple/Cargo.toml 37 | sed -i "s/^\(llama-cpp-2 = { path = \"\.\.\/llama-cpp-2\", version = \)\"$CURRENT_VERSION\"/\1\"$NEXT_VERSION\"/" examples/simple/Cargo.toml 38 | # Update the version in the root embeddings Cargo.toml 39 | sed -i "s/^version = \".*\"/version = \"$NEXT_VERSION\"/g" examples/embeddings/Cargo.toml 40 | sed -i "s/^\(llama-cpp-2 = { path = \"\.\.\/llama-cpp-2\", version = \)\"$CURRENT_VERSION\"/\1\"$NEXT_VERSION\"/" examples/embeddings/Cargo.toml 41 | # Update Cargo.lock by running cargo check 42 | cargo check 43 | # Commit the changes 44 | git config --global user.email "actions@github.com" 45 | git config --global user.name "GitHub Actions" 46 | git add llama-cpp-sys-2/Cargo.toml llama-cpp-2/Cargo.toml examples/simple/Cargo.toml examples/embeddings/Cargo.toml Cargo.lock 47 | git commit -m "Bump version to $NEXT_VERSION [skip ci]" 48 | # Create a branch for the changes 49 | git checkout -b version-bump-$NEXT_VERSION 50 | # Push the changes and create a pull request 51 | git push origin version-bump-$NEXT_VERSION --force 52 | gh pr create --base main --head version-bump-$NEXT_VERSION --title "Bumped version to $NEXT_VERSION" --fill 53 | 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # These are backup files generated by rustfmt 7 | **/*.rs.bk 8 | 9 | # MSVC Windows builds of rustc generate these, which store debugging information 10 | *.pdb -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "llama-cpp-sys-2/llama.cpp"] 2 | path = llama-cpp-sys-2/llama.cpp 3 | url = https://github.com/ggml-org/llama.cpp 4 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | resolver = "2" 3 | members = [ 4 | "llama-cpp-sys-2", 5 | "llama-cpp-2", 6 | "examples/embeddings", 7 | "examples/simple", 8 | "examples/reranker", 9 | ] 10 | 11 | [workspace.dependencies] 12 | # core library deps 13 | thiserror = "1" 14 | tracing = "0.1" 15 | tracing-core = "0.1" 16 | 17 | # examples and benchmarks 18 | hf-hub = { version = "0.3.2" } 19 | criterion = "0.5.1" 20 | pprof = "0.13.0" 21 | bindgen = "0.69.5" 22 | cc = "1.2.25" 23 | anyhow = "1.0.98" 24 | clap = "4.5.39" 25 | encoding_rs = "0.8.35" 26 | tracing-subscriber = { version = "0.3", features = ["json"] } 27 | 28 | [workspace.lints.rust] 29 | missing_docs = { level = "warn" } 30 | missing_debug_implementations = { level = "warn" } 31 | 32 | [workspace.lints.clippy] 33 | pedantic = { level = "warn" } 34 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /LISENCE-MIT: -------------------------------------------------------------------------------- 1 | Copyright (c) Dial AI 2 | 3 | Permission is hereby granted, free of charge, to any 4 | person obtaining a copy of this software and associated 5 | documentation files (the "Software"), to deal in the 6 | Software without restriction, including without 7 | limitation the rights to use, copy, modify, merge, 8 | publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software 10 | is furnished to do so, subject to the following 11 | conditions: 12 | 13 | The above copyright notice and this permission notice 14 | shall be included in all copies or substantial portions 15 | of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 18 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 19 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 20 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 21 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 22 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 23 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 24 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 25 | DEALINGS IN THE SOFTWARE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🦙 [llama-cpp-rs][readme]   [![Docs]][docs.rs] [![Latest Version]][crates.io] [![Lisence]][crates.io] 2 | 3 | [Docs]: https://img.shields.io/docsrs/llama-cpp-2.svg 4 | 5 | [Latest Version]: https://img.shields.io/crates/v/llama-cpp-2.svg 6 | 7 | [crates.io]: https://crates.io/crates/llama-cpp-2 8 | 9 | [docs.rs]: https://docs.rs/llama-cpp-2 10 | 11 | [Lisence]: https://img.shields.io/crates/l/llama-cpp-2.svg 12 | 13 | [llama-cpp-sys]: https://crates.io/crates/llama-cpp-sys-2 14 | 15 | [utilityai]: https://utilityai.ca 16 | 17 | [readme]: https://github.com/utilityai/llama-cpp-rs/tree/main/llama-cpp-2 18 | 19 | This is the home for [llama-cpp-2][crates.io]. It also contains the [llama-cpp-sys] bindings which are updated semi-regularly 20 | and in sync with [llama-cpp-2][crates.io]. 21 | 22 | This project was created with the explict goal of staying as up to date as possible with llama.cpp, as a result it is 23 | dead simple, very close to raw bindings, and does not follow semver meaningfully. 24 | 25 | Check out the [docs.rs] for crate documentation or the [readme] for high level information about the project. 26 | 27 | ## Try it 28 | 29 | We maintain a super simple example of using the library: 30 | 31 | Clone the repo 32 | 33 | ```bash 34 | git clone --recursive https://github.com/utilityai/llama-cpp-rs 35 | cd llama-cpp-rs 36 | ``` 37 | 38 | Run the simple example (add `--featues cuda` if you have a cuda gpu) 39 | 40 | ```bash 41 | cargo run --release --bin simple -- --prompt "The way to kill a linux process is" hf-model TheBloke/Llama-2-7B-GGUF llama-2-7b.Q4_K_M.gguf 42 | ``` 43 | 44 |
45 | Output 46 |
 47 | ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
 48 | ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
 49 | ggml_init_cublas: found 1 CUDA devices:
 50 |   Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes
 51 | llama_model_params { n_gpu_layers: 1000, split_mode: 1, main_gpu: 0, tensor_split: 0x0, progress_callback: None, progress_callback_user_data: 0x0, kv_overrides: 0x0, vocab_only: false, use_mmap: true, use_mlock: false }
 52 | llama_model_loader: loaded meta data with 19 key-value pairs and 291 tensors from /home/marcus/.cache/huggingface/hub/models--TheBloke--Llama-2-7B-GGUF/snapshots/b4e04e128f421c93a5f1e34ac4d7ca9b0af47b80/llama-2-7b.Q4_K_M.gguf (version GGUF V2)
 53 | llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
 54 | llama_model_loader: - kv   0:                       general.architecture str              = llama
 55 | llama_model_loader: - kv   1:                               general.name str              = LLaMA v2
 56 | llama_model_loader: - kv   2:                       llama.context_length u32              = 4096
 57 | llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
 58 | llama_model_loader: - kv   4:                          llama.block_count u32              = 32
 59 | llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 11008
 60 | llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
 61 | llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 32
 62 | llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 32
 63 | llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
 64 | llama_model_loader: - kv  10:                          general.file_type u32              = 15
 65 | llama_model_loader: - kv  11:                       tokenizer.ggml.model str              = llama
 66 | llama_model_loader: - kv  12:                      tokenizer.ggml.tokens arr[str,32000]   = ["", "", "", "<0x00>", "<...
 67 | llama_model_loader: - kv  13:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
 68 | llama_model_loader: - kv  14:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
 69 | llama_model_loader: - kv  15:                tokenizer.ggml.bos_token_id u32              = 1
 70 | llama_model_loader: - kv  16:                tokenizer.ggml.eos_token_id u32              = 2
 71 | llama_model_loader: - kv  17:            tokenizer.ggml.unknown_token_id u32              = 0
 72 | llama_model_loader: - kv  18:               general.quantization_version u32              = 2
 73 | llama_model_loader: - type  f32:   65 tensors
 74 | llama_model_loader: - type q4_K:  193 tensors
 75 | llama_model_loader: - type q6_K:   33 tensors
 76 | llm_load_vocab: special tokens definition check successful ( 259/32000 ).
 77 | llm_load_print_meta: format           = GGUF V2
 78 | llm_load_print_meta: arch             = llama
 79 | llm_load_print_meta: vocab type       = SPM
 80 | llm_load_print_meta: n_vocab          = 32000
 81 | llm_load_print_meta: n_merges         = 0
 82 | llm_load_print_meta: n_ctx_train      = 4096
 83 | llm_load_print_meta: n_embd           = 4096
 84 | llm_load_print_meta: n_head           = 32
 85 | llm_load_print_meta: n_head_kv        = 32
 86 | llm_load_print_meta: n_layer          = 32
 87 | llm_load_print_meta: n_rot            = 128
 88 | llm_load_print_meta: n_embd_head_k    = 128
 89 | llm_load_print_meta: n_embd_head_v    = 128
 90 | llm_load_print_meta: n_gqa            = 1
 91 | llm_load_print_meta: n_embd_k_gqa     = 4096
 92 | llm_load_print_meta: n_embd_v_gqa     = 4096
 93 | llm_load_print_meta: f_norm_eps       = 0.0e+00
 94 | llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
 95 | llm_load_print_meta: f_clamp_kqv      = 0.0e+00
 96 | llm_load_print_meta: f_max_alibi_bias = 0.0e+00
 97 | llm_load_print_meta: n_ff             = 11008
 98 | llm_load_print_meta: n_expert         = 0
 99 | llm_load_print_meta: n_expert_used    = 0
100 | llm_load_print_meta: rope scaling     = linear
101 | llm_load_print_meta: freq_base_train  = 10000.0
102 | llm_load_print_meta: freq_scale_train = 1
103 | llm_load_print_meta: n_yarn_orig_ctx  = 4096
104 | llm_load_print_meta: rope_finetuned   = unknown
105 | llm_load_print_meta: model type       = 7B
106 | llm_load_print_meta: model ftype      = Q4_K - Medium
107 | llm_load_print_meta: model params     = 6.74 B
108 | llm_load_print_meta: model size       = 3.80 GiB (4.84 BPW) 
109 | llm_load_print_meta: general.name     = LLaMA v2
110 | llm_load_print_meta: BOS token        = 1 ''
111 | llm_load_print_meta: EOS token        = 2 ''
112 | llm_load_print_meta: UNK token        = 0 ''
113 | llm_load_print_meta: LF token         = 13 '<0x0A>'
114 | llm_load_tensors: ggml ctx size =    0.22 MiB
115 | llm_load_tensors: offloading 32 repeating layers to GPU
116 | llm_load_tensors: offloading non-repeating layers to GPU
117 | llm_load_tensors: offloaded 33/33 layers to GPU
118 | llm_load_tensors:      CUDA0 buffer size =  3820.94 MiB
119 | llm_load_tensors:        CPU buffer size =    70.31 MiB
120 | ..................................................................................................
121 | Loaded "/home/marcus/.cache/huggingface/hub/models--TheBloke--Llama-2-7B-GGUF/snapshots/b4e04e128f421c93a5f1e34ac4d7ca9b0af47b80/llama-2-7b.Q4_K_M.gguf"
122 | llama_new_context_with_model: n_ctx      = 2048
123 | llama_new_context_with_model: freq_base  = 10000.0
124 | llama_new_context_with_model: freq_scale = 1
125 | llama_kv_cache_init:      CUDA0 KV buffer size =  1024.00 MiB
126 | llama_new_context_with_model: KV self size  = 1024.00 MiB, K (f16):  512.00 MiB, V (f16):  512.00 MiB
127 | llama_new_context_with_model:  CUDA_Host input buffer size   =    13.02 MiB
128 | ggml_gallocr_reserve_n: reallocating CUDA0 buffer from size 0.00 MiB to 164.01 MiB
129 | ggml_gallocr_reserve_n: reallocating CUDA_Host buffer from size 0.00 MiB to 8.00 MiB
130 | llama_new_context_with_model:      CUDA0 compute buffer size =   164.01 MiB
131 | llama_new_context_with_model:  CUDA_Host compute buffer size =     8.00 MiB
132 | llama_new_context_with_model: graph splits (measure): 3
133 | n_len = 32, n_ctx = 2048, k_kv_req = 32
134 | 
135 | The way to kill a linux process is to send it a SIGKILL signal.
136 | The way to kill a windows process is to send it a S
137 | 
138 | decoded 24 tokens in 0.23 s, speed 105.65 t/s
139 | 
140 | load time = 727.50 ms
141 | sample time = 0.46 ms / 24 runs (0.02 ms per token, 51835.85 tokens per second)
142 | prompt eval time = 68.52 ms / 9 tokens (7.61 ms per token, 131.35 tokens per second)
143 | eval time = 225.70 ms / 24 runs (9.40 ms per token, 106.34 tokens per second)
144 | total time = 954.18 ms
145 | 
146 |
147 | 148 | ## Hacking 149 | 150 | Ensure that when you clone this project you also clone the submodules. This can be done with the following command: 151 | 152 | ```sh 153 | git clone --recursive https://github.com/utilityai/llama-cpp-rs 154 | ``` 155 | 156 | or if you have already cloned the project you can run: 157 | 158 | ```sh 159 | git submodule update --init --recursive 160 | ``` 161 | -------------------------------------------------------------------------------- /examples/embeddings/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "embeddings" 3 | version = "0.1.109" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.69" } 8 | hf-hub = { workspace = true } 9 | clap = { workspace = true, features = ["derive"] } 10 | anyhow = { workspace = true } 11 | 12 | [features] 13 | cuda = ["llama-cpp-2/cuda"] 14 | metal = ["llama-cpp-2/metal"] 15 | native = ["llama-cpp-2/native"] 16 | vulkan = ["llama-cpp-2/vulkan"] 17 | 18 | [lints] 19 | workspace = true 20 | -------------------------------------------------------------------------------- /examples/embeddings/src/main.rs: -------------------------------------------------------------------------------- 1 | //! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2. 2 | #![allow( 3 | clippy::cast_possible_wrap, 4 | clippy::cast_possible_truncation, 5 | clippy::cast_precision_loss, 6 | clippy::cast_sign_loss 7 | )] 8 | 9 | use std::io::Write; 10 | use std::path::PathBuf; 11 | use std::time::Duration; 12 | 13 | use anyhow::{bail, Context, Result}; 14 | use clap::Parser; 15 | use hf_hub::api::sync::ApiBuilder; 16 | 17 | use llama_cpp_2::context::params::LlamaContextParams; 18 | use llama_cpp_2::context::LlamaContext; 19 | use llama_cpp_2::ggml_time_us; 20 | use llama_cpp_2::llama_backend::LlamaBackend; 21 | use llama_cpp_2::llama_batch::LlamaBatch; 22 | use llama_cpp_2::model::params::LlamaModelParams; 23 | use llama_cpp_2::model::LlamaModel; 24 | use llama_cpp_2::model::{AddBos, Special}; 25 | 26 | #[derive(clap::Parser, Debug, Clone)] 27 | struct Args { 28 | /// The path to the model 29 | #[command(subcommand)] 30 | model: Model, 31 | /// The prompt 32 | #[clap(default_value = "Hello my name is")] 33 | prompt: String, 34 | /// Whether to normalise the produced embeddings 35 | #[clap(short)] 36 | normalise: bool, 37 | /// Disable offloading layers to the gpu 38 | #[cfg(any(feature = "cuda", feature = "vulkan"))] 39 | #[clap(long)] 40 | disable_gpu: bool, 41 | } 42 | 43 | #[derive(clap::Subcommand, Debug, Clone)] 44 | enum Model { 45 | /// Use an already downloaded model 46 | Local { 47 | /// The path to the model. e.g. `/home/marcus/.cache/huggingface/hub/models--TheBloke--Llama-2-7B-Chat-GGUF/blobs/08a5566d61d7cb6b420c3e4387a39e0078e1f2fe5f055f3a03887385304d4bfa` 48 | path: PathBuf, 49 | }, 50 | /// Download a model from huggingface (or use a cached version) 51 | #[clap(name = "hf-model")] 52 | HuggingFace { 53 | /// the repo containing the model. e.g. `BAAI/bge-small-en-v1.5` 54 | repo: String, 55 | /// the model name. e.g. `BAAI-bge-small-v1.5.Q4_K_M.gguf` 56 | model: String, 57 | }, 58 | } 59 | 60 | impl Model { 61 | /// Convert the model to a path - may download from huggingface 62 | fn get_or_load(self) -> Result { 63 | match self { 64 | Model::Local { path } => Ok(path), 65 | Model::HuggingFace { model, repo } => ApiBuilder::new() 66 | .with_progress(true) 67 | .build() 68 | .with_context(|| "unable to create huggingface api")? 69 | .model(repo) 70 | .get(&model) 71 | .with_context(|| "unable to download model"), 72 | } 73 | } 74 | } 75 | 76 | fn main() -> Result<()> { 77 | let Args { 78 | model, 79 | prompt, 80 | normalise, 81 | #[cfg(any(feature = "cuda", feature = "vulkan"))] 82 | disable_gpu, 83 | } = Args::parse(); 84 | 85 | // init LLM 86 | let backend = LlamaBackend::init()?; 87 | 88 | // offload all layers to the gpu 89 | let model_params = { 90 | #[cfg(any(feature = "cuda", feature = "vulkan"))] 91 | if !disable_gpu { 92 | LlamaModelParams::default().with_n_gpu_layers(1000) 93 | } else { 94 | LlamaModelParams::default() 95 | } 96 | #[cfg(not(any(feature = "cuda", feature = "vulkan")))] 97 | LlamaModelParams::default() 98 | }; 99 | 100 | let model_path = model 101 | .get_or_load() 102 | .with_context(|| "failed to get model from args")?; 103 | 104 | let model = LlamaModel::load_from_file(&backend, model_path, &model_params) 105 | .with_context(|| "unable to load model")?; 106 | 107 | // initialize the context 108 | let ctx_params = LlamaContextParams::default() 109 | .with_n_threads_batch(std::thread::available_parallelism()?.get().try_into()?) 110 | .with_embeddings(true); 111 | 112 | let mut ctx = model 113 | .new_context(&backend, ctx_params) 114 | .with_context(|| "unable to create the llama_context")?; 115 | 116 | // Split the prompt to display the batching functionality 117 | let prompt_lines = prompt.lines(); 118 | 119 | // tokenize the prompt 120 | let tokens_lines_list = prompt_lines 121 | .map(|line| model.str_to_token(line, AddBos::Always)) 122 | .collect::, _>>() 123 | .with_context(|| format!("failed to tokenize {prompt}"))?; 124 | 125 | let n_ctx = ctx.n_ctx() as usize; 126 | let n_ctx_train = model.n_ctx_train(); 127 | 128 | eprintln!("n_ctx = {n_ctx}, n_ctx_train = {n_ctx_train}"); 129 | 130 | if tokens_lines_list.iter().any(|tok| n_ctx < tok.len()) { 131 | bail!("One of the provided prompts exceeds the size of the context window"); 132 | } 133 | 134 | // print the prompt token-by-token 135 | eprintln!(); 136 | 137 | for (i, token_line) in tokens_lines_list.iter().enumerate() { 138 | eprintln!("Prompt {i}"); 139 | for token in token_line { 140 | // Attempt to convert token to string and print it; if it fails, print the token instead 141 | match model.token_to_str(*token, Special::Tokenize) { 142 | Ok(token_str) => eprintln!("{token} --> {token_str}"), 143 | Err(e) => { 144 | eprintln!("Failed to convert token to string, error: {e}"); 145 | eprintln!("Token value: {token}"); 146 | } 147 | } 148 | } 149 | eprintln!(); 150 | } 151 | 152 | std::io::stderr().flush()?; 153 | 154 | // create a llama_batch with the size of the context 155 | // we use this object to submit token data for decoding 156 | let mut batch = LlamaBatch::new(n_ctx, 1); 157 | 158 | let mut max_seq_id_batch = 0; 159 | let mut output = Vec::with_capacity(tokens_lines_list.len()); 160 | 161 | let t_main_start = ggml_time_us(); 162 | 163 | for tokens in &tokens_lines_list { 164 | // Flush the batch if the next prompt would exceed our batch size 165 | if (batch.n_tokens() as usize + tokens.len()) > n_ctx { 166 | batch_decode( 167 | &mut ctx, 168 | &mut batch, 169 | max_seq_id_batch, 170 | &mut output, 171 | normalise, 172 | )?; 173 | max_seq_id_batch = 0; 174 | } 175 | 176 | batch.add_sequence(tokens, max_seq_id_batch, false)?; 177 | max_seq_id_batch += 1; 178 | } 179 | // Handle final batch 180 | batch_decode( 181 | &mut ctx, 182 | &mut batch, 183 | max_seq_id_batch, 184 | &mut output, 185 | normalise, 186 | )?; 187 | 188 | let t_main_end = ggml_time_us(); 189 | 190 | for (i, embeddings) in output.iter().enumerate() { 191 | eprintln!("Embeddings {i}: {embeddings:?}"); 192 | eprintln!(); 193 | } 194 | 195 | let duration = Duration::from_micros((t_main_end - t_main_start) as u64); 196 | let total_tokens: usize = tokens_lines_list.iter().map(Vec::len).sum(); 197 | eprintln!( 198 | "Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n", 199 | total_tokens, 200 | duration.as_secs_f32(), 201 | total_tokens as f32 / duration.as_secs_f32() 202 | ); 203 | 204 | println!("{}", ctx.timings()); 205 | 206 | Ok(()) 207 | } 208 | 209 | fn batch_decode( 210 | ctx: &mut LlamaContext, 211 | batch: &mut LlamaBatch, 212 | s_batch: i32, 213 | output: &mut Vec>, 214 | normalise: bool, 215 | ) -> Result<()> { 216 | ctx.clear_kv_cache(); 217 | ctx.decode(batch).with_context(|| "llama_decode() failed")?; 218 | 219 | for i in 0..s_batch { 220 | let embedding = ctx 221 | .embeddings_seq_ith(i) 222 | .with_context(|| "Failed to get embeddings")?; 223 | let output_embeddings = if normalise { 224 | normalize(embedding) 225 | } else { 226 | embedding.to_vec() 227 | }; 228 | 229 | output.push(output_embeddings); 230 | } 231 | 232 | batch.clear(); 233 | 234 | Ok(()) 235 | } 236 | 237 | fn normalize(input: &[f32]) -> Vec { 238 | let magnitude = input 239 | .iter() 240 | .fold(0.0, |acc, &val| val.mul_add(val, acc)) 241 | .sqrt(); 242 | 243 | input.iter().map(|&val| val / magnitude).collect() 244 | } 245 | -------------------------------------------------------------------------------- /examples/reranker/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "reranker" 3 | version = "0.1.86" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.86" } 8 | hf-hub = { workspace = true } 9 | clap = { workspace = true, features = ["derive"] } 10 | anyhow = { workspace = true } 11 | encoding_rs = { workspace = true } 12 | 13 | [features] 14 | cuda = ["llama-cpp-2/cuda"] 15 | metal = ["llama-cpp-2/metal"] 16 | native = ["llama-cpp-2/native"] 17 | vulkan = ["llama-cpp-2/vulkan"] 18 | 19 | [lints] 20 | workspace = true -------------------------------------------------------------------------------- /examples/reranker/README.md: -------------------------------------------------------------------------------- 1 | # Rust Reranker Implementation 2 | 3 | A Rust implementation of cross-encoder based reranking using llama-cpp-2. Cross-encoder reranking is a more accurate way to determine similarity between queries and documents compared to traditional embedding-based approaches. 4 | 5 | ## Overview 6 | 7 | This implementation adds a new pooling type `LLAMA_POOLING_TYPE_RANK` which enables cross-encoder based reranking. Unlike traditional embedding approaches that encode query and document separately, this method: 8 | 9 | - Processes query and document pairs together in a single pass 10 | - Directly evaluates semantic relationships between the pairs 11 | - Outputs raw similarity scores indicating relevance 12 | 13 | ## Installation 14 | 15 | ```bash 16 | # Follow instructions to clone repo. 17 | # Navigate to examples reranker 18 | cd examples/reranker 19 | 20 | # Build the project 21 | cargo build --release 22 | ``` 23 | 24 | ## Usage 25 | 26 | ### Command Line Interface 27 | 28 | ```bash 29 | cargo run --release -- \  ✔ │ 5s │ 12:48:35 30 | --model-path "models/bge-reranker-v2-m3.gguf" \ 31 | --query "what is panda?" \ 32 | --documents "hi" \ 33 | --documents "it's a bear" \ 34 | --documents "The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." \ 35 | --pooling rank 36 | ``` 37 | Should output(with bge-reranker-v2-m3-Q5_0): 38 | rerank score 0: -6.551 39 | rerank score 1: -3.802 40 | rerank score 2: 4.522 41 | 42 | ### CLI Arguments 43 | 44 | - `--model-path`: Path to the GGUF model file 45 | - `--query`: The search query 46 | - `--documents`: One or more documents to rank against the query 47 | - `--pooling`: Pooling type (options: none, mean, rank) 48 | 49 | ### Pooling Types 50 | 51 | - `rank`: Performs cross-encoder reranking 52 | 53 | 54 | Note: The raw scores are not normalized through a sigmoid function. If you need scores between 0-1, you'll need to implement sigmoid normalization in your application code. 55 | 56 | # Additional notes 57 | 58 | - Query and documents are concatenated using the format queryanswer 59 | 60 | ## Supported Models 61 | 62 | Some tested models: 63 | 64 | - [BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3) 65 | - [jinaai/jina-reranker-v1-tiny-en](https://huggingface.co/jinaai/jina-reranker-v1-tiny-en) 66 | 67 | Not tested others, but anything supported by llama.cpp should work. 68 | 69 | ## Implementation Details 70 | 71 | This is a close Rust implementation of the reranker implementation discussed in [llama.cpp PR #9510](https://github.com/ggerganov/llama.cpp/pull/9510). 72 | 73 | ## Potential issues 74 | 75 | The bos, eos, sep tokens are being hardcoded. We need to ideally get it from the model and build out the prompts based on each specific model. -------------------------------------------------------------------------------- /examples/reranker/src/main.rs: -------------------------------------------------------------------------------- 1 | //! This is a translation of embedding.cpp in llama.cpp using llama-cpp-2. 2 | #![allow( 3 | clippy::cast_possible_wrap, 4 | clippy::cast_possible_truncation, 5 | clippy::cast_precision_loss, 6 | clippy::cast_sign_loss 7 | )] 8 | 9 | use std::io::Write; 10 | use std::path::PathBuf; 11 | use std::time::Duration; 12 | 13 | use anyhow::{bail, Context, Result}; 14 | use clap::Parser; 15 | use hf_hub::api::sync::ApiBuilder; 16 | 17 | use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType}; 18 | use llama_cpp_2::context::LlamaContext; 19 | use llama_cpp_2::ggml_time_us; 20 | use llama_cpp_2::llama_backend::LlamaBackend; 21 | use llama_cpp_2::llama_batch::LlamaBatch; 22 | use llama_cpp_2::model::params::LlamaModelParams; 23 | use llama_cpp_2::model::LlamaModel; 24 | use llama_cpp_2::model::{AddBos, Special}; 25 | 26 | #[derive(clap::Parser, Debug, Clone)] 27 | #[command(author, version, about, long_about = None)] 28 | struct Args { 29 | /// Path to the model file 30 | #[clap(long)] 31 | model_path: PathBuf, 32 | 33 | /// The query to embed 34 | #[clap(long)] 35 | query: String, 36 | 37 | /// The documents to embed and compare against 38 | #[clap(long, num_args = 1..)] 39 | documents: Vec, 40 | 41 | /// Pooling type (none, mean, or rank) 42 | #[clap(long, default_value = "none")] 43 | pooling: String, 44 | 45 | /// Whether to normalise the produced embeddings 46 | #[clap(long, default_value_t = true)] 47 | normalise: bool, 48 | } 49 | 50 | fn main() -> Result<()> { 51 | let Args { 52 | model_path, 53 | query, 54 | documents, 55 | pooling, 56 | normalise, 57 | } = Args::parse(); 58 | 59 | // init LLM 60 | let backend = LlamaBackend::init()?; 61 | 62 | // offload all layers to the gpu 63 | let model_params = { 64 | #[cfg(any(feature = "cuda", feature = "vulkan"))] 65 | if !disable_gpu { 66 | LlamaModelParams::default().with_n_gpu_layers(1000) 67 | } else { 68 | LlamaModelParams::default() 69 | } 70 | #[cfg(not(any(feature = "cuda", feature = "vulkan")))] 71 | LlamaModelParams::default() 72 | }; 73 | 74 | let model = LlamaModel::load_from_file(&backend, model_path, &model_params) 75 | .with_context(|| "unable to load model")?; 76 | // println!("pooling: {}", pooling); 77 | let pooling_type = match pooling.as_str() { 78 | "mean" => LlamaPoolingType::Mean, 79 | "none" => LlamaPoolingType::None, 80 | "rank" => LlamaPoolingType::Rank, 81 | _ => LlamaPoolingType::Unspecified, 82 | }; 83 | 84 | let ctx_params = LlamaContextParams::default() 85 | .with_n_threads_batch(std::thread::available_parallelism()?.get().try_into()?) 86 | .with_embeddings(true) 87 | .with_pooling_type(pooling_type); 88 | println!("ctx_params: {:?}", ctx_params); 89 | let mut ctx = model 90 | .new_context(&backend, ctx_params) 91 | .with_context(|| "unable to create the llama_context")?; 92 | 93 | let n_embd = model.n_embd(); 94 | 95 | let prompt_lines = { 96 | let mut lines = Vec::new(); 97 | for doc in documents { 98 | // Todo! update to get eos and sep from model instead of hardcoding 99 | lines.push(format!("{query}{eos}{sep}{doc}", sep = "", eos = "")); 100 | } 101 | lines 102 | }; 103 | 104 | println!("prompt_lines: {:?}", prompt_lines); 105 | // tokenize the prompt 106 | let tokens_lines_list = prompt_lines 107 | .iter() 108 | .map(|line| model.str_to_token(line, AddBos::Always)) 109 | .collect::, _>>() 110 | .with_context(|| format!("failed to tokenize {:?}", prompt_lines))?; 111 | 112 | let n_ctx = ctx.n_ctx() as usize; 113 | let n_ctx_train = model.n_ctx_train(); 114 | 115 | eprintln!("n_ctx = {n_ctx}, n_ctx_train = {n_ctx_train}"); 116 | 117 | if tokens_lines_list.iter().any(|tok| n_ctx < tok.len()) { 118 | bail!("One of the provided prompts exceeds the size of the context window"); 119 | } 120 | 121 | // print the prompt token-by-token 122 | eprintln!(); 123 | 124 | for (i, token_line) in tokens_lines_list.iter().enumerate() { 125 | eprintln!("Prompt {i} --> {}", prompt_lines[i]); 126 | eprintln!("Number of tokens: {}", token_line.len()); 127 | for token in token_line { 128 | // Attempt to convert token to string and print it; if it fails, print the token instead 129 | match model.token_to_str(*token, Special::Tokenize) { 130 | Ok(token_str) => eprintln!("{token} --> {token_str}"), 131 | Err(e) => { 132 | eprintln!("Failed to convert token to string, error: {e}"); 133 | eprintln!("Token value: {token}"); 134 | } 135 | } 136 | } 137 | eprintln!(); 138 | } 139 | 140 | std::io::stderr().flush()?; 141 | 142 | // create a llama_batch with the size of the context 143 | // we use this object to submit token data for decoding 144 | let mut batch = LlamaBatch::new(2048, 1); 145 | 146 | // Todo! update to get n_embd to init vector size for better memory management 147 | // let mut n_embd_count = if pooling == "none" { 148 | // tokens_lines_list.iter().map(|tokens| tokens.len()).sum() 149 | // } else { 150 | // tokens_lines_list.len() 151 | // }; 152 | let mut embeddings_stored = 0; 153 | let mut max_seq_id_batch = 0; 154 | let mut output = Vec::with_capacity(tokens_lines_list.len()); 155 | 156 | let t_main_start = ggml_time_us(); 157 | 158 | for tokens in &tokens_lines_list { 159 | // Flush the batch if the next prompt would exceed our batch size 160 | if (batch.n_tokens() as usize + tokens.len()) > 2048 { 161 | batch_decode( 162 | &mut ctx, 163 | &mut batch, 164 | max_seq_id_batch, 165 | n_embd, 166 | &mut output, 167 | normalise, 168 | pooling.clone(), 169 | )?; 170 | embeddings_stored += if pooling == "none" { 171 | batch.n_tokens() 172 | } else { 173 | max_seq_id_batch 174 | }; 175 | max_seq_id_batch = 0; 176 | batch.clear(); 177 | } 178 | 179 | batch.add_sequence(tokens, max_seq_id_batch, false)?; 180 | max_seq_id_batch += 1; 181 | } 182 | // Handle final batch 183 | batch_decode( 184 | &mut ctx, 185 | &mut batch, 186 | max_seq_id_batch, 187 | n_embd, 188 | &mut output, 189 | normalise, 190 | pooling.clone(), 191 | )?; 192 | 193 | let t_main_end = ggml_time_us(); 194 | 195 | for (j, embeddings) in output.iter().enumerate() { 196 | if pooling == "none" { 197 | eprintln!("embedding {j}: "); 198 | for i in 0..n_embd as usize { 199 | if !normalise { 200 | eprint!("{:6.5} ", embeddings[i]); 201 | } else { 202 | eprint!("{:9.6} ", embeddings[i]); 203 | } 204 | } 205 | eprintln!(); 206 | } else if pooling == "rank" { 207 | eprintln!("rerank score {j}: {:8.3}", embeddings[0]); 208 | } else { 209 | eprintln!("embedding {j}: "); 210 | for i in 0..n_embd as usize { 211 | if !normalise { 212 | eprint!("{:6.5} ", embeddings[i]); 213 | } else { 214 | eprint!("{:9.6} ", embeddings[i]); 215 | } 216 | } 217 | eprintln!(); 218 | } 219 | } 220 | 221 | let duration = Duration::from_micros((t_main_end - t_main_start) as u64); 222 | let total_tokens: usize = tokens_lines_list.iter().map(Vec::len).sum(); 223 | eprintln!( 224 | "Created embeddings for {} tokens in {:.2} s, speed {:.2} t/s\n", 225 | total_tokens, 226 | duration.as_secs_f32(), 227 | total_tokens as f32 / duration.as_secs_f32() 228 | ); 229 | 230 | println!("{}", ctx.timings()); 231 | 232 | Ok(()) 233 | } 234 | 235 | fn batch_decode( 236 | ctx: &mut LlamaContext, 237 | batch: &mut LlamaBatch, 238 | s_batch: i32, 239 | n_embd: i32, 240 | output: &mut Vec>, 241 | normalise: bool, 242 | pooling: String, 243 | ) -> Result<()> { 244 | eprintln!( 245 | "{}: n_tokens = {}, n_seq = {}", 246 | stringify!(batch_decode), 247 | batch.n_tokens(), 248 | s_batch 249 | ); 250 | 251 | // Clear previous kv_cache values 252 | ctx.clear_kv_cache(); 253 | 254 | ctx.decode(batch).with_context(|| "llama_decode() failed")?; 255 | 256 | for i in 0..s_batch { 257 | let embeddings = ctx 258 | .embeddings_seq_ith(i) 259 | .with_context(|| "Failed to get sequence embeddings")?; 260 | let normalized = if normalise { 261 | if pooling == "rank" { 262 | normalize_embeddings(&embeddings, -1) 263 | } else { 264 | normalize_embeddings(&embeddings, 2) 265 | } 266 | } else { 267 | embeddings.to_vec() 268 | }; 269 | output.push(normalized); 270 | } 271 | 272 | batch.clear(); 273 | 274 | Ok(()) 275 | } 276 | 277 | /// Normalizes embeddings based on different normalization strategies 278 | fn normalize_embeddings(input: &[f32], embd_norm: i32) -> Vec { 279 | let n = input.len(); 280 | let mut output = vec![0.0; n]; 281 | 282 | let sum = match embd_norm { 283 | -1 => 1.0, // no normalization 284 | 0 => { 285 | // max absolute 286 | let max_abs = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max) / 32760.0; 287 | max_abs as f64 288 | } 289 | 2 => { 290 | // euclidean norm 291 | input 292 | .iter() 293 | .map(|x| (*x as f64).powi(2)) 294 | .sum::() 295 | .sqrt() 296 | } 297 | p => { 298 | // p-norm 299 | let sum = input.iter().map(|x| (x.abs() as f64).powi(p)).sum::(); 300 | sum.powf(1.0 / p as f64) 301 | } 302 | }; 303 | 304 | let norm = if sum > 0.0 { 1.0 / sum } else { 0.0 }; 305 | 306 | for i in 0..n { 307 | output[i] = (input[i] as f64 * norm) as f32; 308 | } 309 | 310 | output 311 | } 312 | 313 | // /// Calculates cosine similarity between two embedding vectors 314 | // fn embedding_similarity_cos(embd1: &[f32], embd2: &[f32]) -> f32 { 315 | // assert_eq!(embd1.len(), embd2.len(), "Embedding vectors must be the same length"); 316 | 317 | // let (sum, sum1, sum2) = embd1.iter().zip(embd2.iter()).fold( 318 | // (0.0f64, 0.0f64, 0.0f64), 319 | // |(sum, sum1, sum2), (e1, e2)| { 320 | // let e1 = *e1 as f64; 321 | // let e2 = *e2 as f64; 322 | // ( 323 | // sum + e1 * e2, 324 | // sum1 + e1 * e1, 325 | // sum2 + e2 * e2 326 | // ) 327 | // } 328 | // ); 329 | 330 | // // Handle zero vectors 331 | // if sum1 == 0.0 || sum2 == 0.0 { 332 | // return if sum1 == 0.0 && sum2 == 0.0 { 333 | // 1.0 // two zero vectors are similar 334 | // } else { 335 | // 0.0 336 | // }; 337 | // } 338 | 339 | // (sum / (sum1.sqrt() * sum2.sqrt())) as f32 340 | // } 341 | -------------------------------------------------------------------------------- /examples/simple/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "simple" 3 | version = "0.1.109" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | llama-cpp-2 = { path = "../../llama-cpp-2", version = "0.1.69" } 10 | hf-hub = { workspace = true } 11 | clap = { workspace = true, features = ["derive"] } 12 | anyhow = { workspace = true } 13 | encoding_rs = { workspace = true } 14 | tracing-subscriber = { workspace = true } 15 | 16 | [features] 17 | cuda = ["llama-cpp-2/cuda"] 18 | metal = ["llama-cpp-2/metal"] 19 | native = ["llama-cpp-2/native"] 20 | vulkan = ["llama-cpp-2/vulkan"] 21 | 22 | [lints] 23 | workspace = true 24 | -------------------------------------------------------------------------------- /examples/simple/src/main.rs: -------------------------------------------------------------------------------- 1 | //! This is a translation of simple.cpp in llama.cpp using llama-cpp-2. 2 | #![allow( 3 | clippy::cast_possible_wrap, 4 | clippy::cast_possible_truncation, 5 | clippy::cast_precision_loss, 6 | clippy::cast_sign_loss 7 | )] 8 | 9 | use anyhow::{anyhow, bail, Context, Result}; 10 | use clap::Parser; 11 | use hf_hub::api::sync::ApiBuilder; 12 | use llama_cpp_2::context::params::LlamaContextParams; 13 | use llama_cpp_2::llama_backend::LlamaBackend; 14 | use llama_cpp_2::llama_batch::LlamaBatch; 15 | use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue; 16 | use llama_cpp_2::model::params::LlamaModelParams; 17 | use llama_cpp_2::model::LlamaModel; 18 | use llama_cpp_2::model::{AddBos, Special}; 19 | use llama_cpp_2::sampling::LlamaSampler; 20 | use llama_cpp_2::{ggml_time_us, send_logs_to_tracing, LogOptions}; 21 | 22 | use std::ffi::CString; 23 | use std::io::Write; 24 | use std::num::NonZeroU32; 25 | use std::path::PathBuf; 26 | use std::pin::pin; 27 | use std::str::FromStr; 28 | use std::time::Duration; 29 | 30 | #[derive(clap::Parser, Debug, Clone)] 31 | struct Args { 32 | /// The path to the model 33 | #[command(subcommand)] 34 | model: Model, 35 | /// The prompt 36 | #[clap(short = 'p', long)] 37 | prompt: Option, 38 | /// Read the prompt from a file 39 | #[clap(short = 'f', long, help = "prompt file to start generation")] 40 | file: Option, 41 | /// set the length of the prompt + output in tokens 42 | #[arg(long, default_value_t = 32)] 43 | n_len: i32, 44 | /// override some parameters of the model 45 | #[arg(short = 'o', value_parser = parse_key_val)] 46 | key_value_overrides: Vec<(String, ParamOverrideValue)>, 47 | /// Disable offloading layers to the gpu 48 | #[cfg(any(feature = "cuda", feature = "vulkan"))] 49 | #[clap(long)] 50 | disable_gpu: bool, 51 | #[arg(short = 's', long, help = "RNG seed (default: 1234)")] 52 | seed: Option, 53 | #[arg( 54 | short = 't', 55 | long, 56 | help = "number of threads to use during generation (default: use all available threads)" 57 | )] 58 | threads: Option, 59 | #[arg( 60 | long, 61 | help = "number of threads to use during batch and prompt processing (default: use all available threads)" 62 | )] 63 | threads_batch: Option, 64 | #[arg( 65 | short = 'c', 66 | long, 67 | help = "size of the prompt context (default: loaded from themodel)" 68 | )] 69 | ctx_size: Option, 70 | #[arg(short = 'v', long, help = "enable verbose llama.cpp logs")] 71 | verbose: bool, 72 | } 73 | 74 | /// Parse a single key-value pair 75 | fn parse_key_val(s: &str) -> Result<(String, ParamOverrideValue)> { 76 | let pos = s 77 | .find('=') 78 | .ok_or_else(|| anyhow!("invalid KEY=value: no `=` found in `{}`", s))?; 79 | let key = s[..pos].parse()?; 80 | let value: String = s[pos + 1..].parse()?; 81 | let value = i64::from_str(&value) 82 | .map(ParamOverrideValue::Int) 83 | .or_else(|_| f64::from_str(&value).map(ParamOverrideValue::Float)) 84 | .or_else(|_| bool::from_str(&value).map(ParamOverrideValue::Bool)) 85 | .map_err(|_| anyhow!("must be one of i64, f64, or bool"))?; 86 | 87 | Ok((key, value)) 88 | } 89 | 90 | #[derive(clap::Subcommand, Debug, Clone)] 91 | enum Model { 92 | /// Use an already downloaded model 93 | Local { 94 | /// The path to the model. e.g. `/home/marcus/.cache/huggingface/hub/models--TheBloke--Llama-2-7B-Chat-GGUF/blobs/08a5566d61d7cb6b420c3e4387a39e0078e1f2fe5f055f3a03887385304d4bfa` 95 | path: PathBuf, 96 | }, 97 | /// Download a model from huggingface (or use a cached version) 98 | #[clap(name = "hf-model")] 99 | HuggingFace { 100 | /// the repo containing the model. e.g. `TheBloke/Llama-2-7B-Chat-GGUF` 101 | repo: String, 102 | /// the model name. e.g. `llama-2-7b-chat.Q4_K_M.gguf` 103 | model: String, 104 | }, 105 | } 106 | 107 | impl Model { 108 | /// Convert the model to a path - may download from huggingface 109 | fn get_or_load(self) -> Result { 110 | match self { 111 | Model::Local { path } => Ok(path), 112 | Model::HuggingFace { model, repo } => ApiBuilder::new() 113 | .with_progress(true) 114 | .build() 115 | .with_context(|| "unable to create huggingface api")? 116 | .model(repo) 117 | .get(&model) 118 | .with_context(|| "unable to download model"), 119 | } 120 | } 121 | } 122 | 123 | #[allow(clippy::too_many_lines)] 124 | fn main() -> Result<()> { 125 | let Args { 126 | n_len, 127 | model, 128 | prompt, 129 | file, 130 | #[cfg(any(feature = "cuda", feature = "vulkan"))] 131 | disable_gpu, 132 | key_value_overrides, 133 | seed, 134 | threads, 135 | threads_batch, 136 | ctx_size, 137 | verbose, 138 | } = Args::parse(); 139 | 140 | if verbose { 141 | tracing_subscriber::fmt().init(); 142 | } 143 | send_logs_to_tracing(LogOptions::default().with_logs_enabled(verbose)); 144 | 145 | // init LLM 146 | let backend = LlamaBackend::init()?; 147 | 148 | // offload all layers to the gpu 149 | let model_params = { 150 | #[cfg(any(feature = "cuda", feature = "vulkan"))] 151 | if !disable_gpu { 152 | LlamaModelParams::default().with_n_gpu_layers(1000) 153 | } else { 154 | LlamaModelParams::default() 155 | } 156 | #[cfg(not(any(feature = "cuda", feature = "vulkan")))] 157 | LlamaModelParams::default() 158 | }; 159 | 160 | let prompt = if let Some(str) = prompt { 161 | if file.is_some() { 162 | bail!("either prompt or file must be specified, but not both") 163 | } 164 | str 165 | } else if let Some(file) = file { 166 | std::fs::read_to_string(&file).with_context(|| format!("unable to read {file}"))? 167 | } else { 168 | "Hello my name is".to_string() 169 | }; 170 | 171 | let mut model_params = pin!(model_params); 172 | 173 | for (k, v) in &key_value_overrides { 174 | let k = CString::new(k.as_bytes()).with_context(|| format!("invalid key: {k}"))?; 175 | model_params.as_mut().append_kv_override(k.as_c_str(), *v); 176 | } 177 | 178 | let model_path = model 179 | .get_or_load() 180 | .with_context(|| "failed to get model from args")?; 181 | 182 | let model = LlamaModel::load_from_file(&backend, model_path, &model_params) 183 | .with_context(|| "unable to load model")?; 184 | 185 | // initialize the context 186 | let mut ctx_params = 187 | LlamaContextParams::default().with_n_ctx(ctx_size.or(Some(NonZeroU32::new(2048).unwrap()))); 188 | 189 | if let Some(threads) = threads { 190 | ctx_params = ctx_params.with_n_threads(threads); 191 | } 192 | if let Some(threads_batch) = threads_batch.or(threads) { 193 | ctx_params = ctx_params.with_n_threads_batch(threads_batch); 194 | } 195 | 196 | let mut ctx = model 197 | .new_context(&backend, ctx_params) 198 | .with_context(|| "unable to create the llama_context")?; 199 | 200 | // tokenize the prompt 201 | 202 | let tokens_list = model 203 | .str_to_token(&prompt, AddBos::Always) 204 | .with_context(|| format!("failed to tokenize {prompt}"))?; 205 | 206 | let n_cxt = ctx.n_ctx() as i32; 207 | let n_kv_req = tokens_list.len() as i32 + (n_len - tokens_list.len() as i32); 208 | 209 | eprintln!("n_len = {n_len}, n_ctx = {n_cxt}, k_kv_req = {n_kv_req}"); 210 | 211 | // make sure the KV cache is big enough to hold all the prompt and generated tokens 212 | if n_kv_req > n_cxt { 213 | bail!( 214 | "n_kv_req > n_ctx, the required kv cache size is not big enough 215 | either reduce n_len or increase n_ctx" 216 | ) 217 | } 218 | 219 | if tokens_list.len() >= usize::try_from(n_len)? { 220 | bail!("the prompt is too long, it has more tokens than n_len") 221 | } 222 | 223 | // print the prompt token-by-token 224 | eprintln!(); 225 | 226 | for token in &tokens_list { 227 | eprint!("{}", model.token_to_str(*token, Special::Tokenize)?); 228 | } 229 | 230 | std::io::stderr().flush()?; 231 | 232 | // create a llama_batch with size 512 233 | // we use this object to submit token data for decoding 234 | let mut batch = LlamaBatch::new(512, 1); 235 | 236 | let last_index: i32 = (tokens_list.len() - 1) as i32; 237 | for (i, token) in (0_i32..).zip(tokens_list.into_iter()) { 238 | // llama_decode will output logits only for the last token of the prompt 239 | let is_last = i == last_index; 240 | batch.add(token, i, &[0], is_last)?; 241 | } 242 | 243 | ctx.decode(&mut batch) 244 | .with_context(|| "llama_decode() failed")?; 245 | 246 | // main loop 247 | 248 | let mut n_cur = batch.n_tokens(); 249 | let mut n_decode = 0; 250 | 251 | let t_main_start = ggml_time_us(); 252 | 253 | // The `Decoder` 254 | let mut decoder = encoding_rs::UTF_8.new_decoder(); 255 | 256 | let mut sampler = LlamaSampler::chain_simple([ 257 | LlamaSampler::dist(seed.unwrap_or(1234)), 258 | LlamaSampler::greedy(), 259 | ]); 260 | 261 | while n_cur <= n_len { 262 | // sample the next token 263 | { 264 | let token = sampler.sample(&ctx, batch.n_tokens() - 1); 265 | 266 | sampler.accept(token); 267 | 268 | // is it an end of stream? 269 | if model.is_eog_token(token) { 270 | eprintln!(); 271 | break; 272 | } 273 | 274 | let output_bytes = model.token_to_bytes(token, Special::Tokenize)?; 275 | // use `Decoder.decode_to_string()` to avoid the intermediate buffer 276 | let mut output_string = String::with_capacity(32); 277 | let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false); 278 | print!("{output_string}"); 279 | std::io::stdout().flush()?; 280 | 281 | batch.clear(); 282 | batch.add(token, n_cur, &[0], true)?; 283 | } 284 | 285 | n_cur += 1; 286 | 287 | ctx.decode(&mut batch).with_context(|| "failed to eval")?; 288 | 289 | n_decode += 1; 290 | } 291 | 292 | eprintln!("\n"); 293 | 294 | let t_main_end = ggml_time_us(); 295 | 296 | let duration = Duration::from_micros((t_main_end - t_main_start) as u64); 297 | 298 | eprintln!( 299 | "decoded {} tokens in {:.2} s, speed {:.2} t/s\n", 300 | n_decode, 301 | duration.as_secs_f32(), 302 | n_decode as f32 / duration.as_secs_f32() 303 | ); 304 | 305 | println!("{}", ctx.timings()); 306 | 307 | Ok(()) 308 | } 309 | -------------------------------------------------------------------------------- /examples/usage.rs: -------------------------------------------------------------------------------- 1 | //! # Usage 2 | //! 3 | //! This is just about the smallest possible way to do inference. To fetch a model from hugging face: 4 | //! 5 | //! ```console 6 | //! git clone --recursive https://github.com/utilityai/llama-cpp-rs 7 | //! cd llama-cpp-rs/examples/usage 8 | //! wget https://huggingface.co/Qwen/Qwen2-1.5B-Instruct-GGUF/resolve/main/qwen2-1_5b-instruct-q4_0.gguf 9 | //! cargo run --example usage -- qwen2-1_5b-instruct-q4_0.gguf 10 | //! ``` 11 | use llama_cpp_2::context::params::LlamaContextParams; 12 | use llama_cpp_2::llama_backend::LlamaBackend; 13 | use llama_cpp_2::llama_batch::LlamaBatch; 14 | use llama_cpp_2::model::params::LlamaModelParams; 15 | use llama_cpp_2::model::LlamaModel; 16 | use llama_cpp_2::model::{AddBos, Special}; 17 | use llama_cpp_2::sampling::LlamaSampler; 18 | use std::io::Write; 19 | 20 | #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)] 21 | fn main() { 22 | let model_path = std::env::args().nth(1).expect("Please specify model path"); 23 | let backend = LlamaBackend::init().unwrap(); 24 | let params = LlamaModelParams::default(); 25 | 26 | let prompt = 27 | "<|im_start|>user\nHello! how are you?<|im_end|>\n<|im_start|>assistant\n".to_string(); 28 | LlamaContextParams::default(); 29 | let model = 30 | LlamaModel::load_from_file(&backend, model_path, ¶ms).expect("unable to load model"); 31 | let ctx_params = LlamaContextParams::default(); 32 | let mut ctx = model 33 | .new_context(&backend, ctx_params) 34 | .expect("unable to create the llama_context"); 35 | let tokens_list = model 36 | .str_to_token(&prompt, AddBos::Always) 37 | .unwrap_or_else(|_| panic!("failed to tokenize {prompt}")); 38 | let n_len = 64; 39 | 40 | // create a llama_batch with size 512 41 | // we use this object to submit token data for decoding 42 | let mut batch = LlamaBatch::new(512, 1); 43 | 44 | let last_index = tokens_list.len() as i32 - 1; 45 | for (i, token) in (0_i32..).zip(tokens_list.into_iter()) { 46 | // llama_decode will output logits only for the last token of the prompt 47 | let is_last = i == last_index; 48 | batch.add(token, i, &[0], is_last).unwrap(); 49 | } 50 | ctx.decode(&mut batch).expect("llama_decode() failed"); 51 | 52 | let mut n_cur = batch.n_tokens(); 53 | 54 | // The `Decoder` 55 | let mut decoder = encoding_rs::UTF_8.new_decoder(); 56 | let mut sampler = LlamaSampler::greedy(); 57 | 58 | while n_cur <= n_len { 59 | // sample the next token 60 | { 61 | let token = sampler.sample(&ctx, batch.n_tokens() - 1); 62 | 63 | sampler.accept(token); 64 | 65 | // is it an end of stream? 66 | if token == model.token_eos() { 67 | eprintln!(); 68 | break; 69 | } 70 | 71 | let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap(); 72 | // use `Decoder.decode_to_string()` to avoid the intermediate buffer 73 | let mut output_string = String::with_capacity(32); 74 | let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false); 75 | print!("{output_string}"); 76 | std::io::stdout().flush().unwrap(); 77 | 78 | batch.clear(); 79 | batch.add(token, n_cur, &[0], true).unwrap(); 80 | } 81 | 82 | n_cur += 1; 83 | 84 | ctx.decode(&mut batch).expect("failed to eval"); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /llama-cpp-2/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llama-cpp-2" 3 | description = "llama.cpp bindings for Rust" 4 | version = "0.1.109" 5 | edition = "2021" 6 | license = "MIT OR Apache-2.0" 7 | repository = "https://github.com/utilityai/llama-cpp-rs" 8 | 9 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 10 | 11 | [dependencies] 12 | enumflags2 = "0.7.11" 13 | llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.69" } 14 | thiserror = { workspace = true } 15 | tracing = { workspace = true } 16 | tracing-core = { workspace = true } 17 | 18 | [dev-dependencies] 19 | encoding_rs = { workspace = true } 20 | 21 | [features] 22 | default = ["openmp", "android-shared-stdcxx"] 23 | cuda = ["llama-cpp-sys-2/cuda"] 24 | cuda-no-vmm = ["cuda", "llama-cpp-sys-2/cuda-no-vmm"] 25 | metal = ["llama-cpp-sys-2/metal"] 26 | dynamic-link = ["llama-cpp-sys-2/dynamic-link"] 27 | vulkan = ["llama-cpp-sys-2/vulkan"] 28 | native = ["llama-cpp-sys-2/native"] 29 | openmp = ["llama-cpp-sys-2/openmp"] 30 | sampler = [] 31 | # Only has an impact on Android. 32 | android-shared-stdcxx = ["llama-cpp-sys-2/shared-stdcxx"] 33 | 34 | 35 | [target.'cfg(all(target_os = "macos", any(target_arch = "aarch64", target_arch = "arm64")))'.dependencies] 36 | llama-cpp-sys-2 = { path = "../llama-cpp-sys-2", version = "0.1.69", features = [ 37 | "metal", 38 | ] } 39 | 40 | [lints] 41 | workspace = true 42 | 43 | [package.metadata.docs.rs] 44 | features = ["sampler"] 45 | 46 | [[example]] 47 | name = "usage" 48 | path = "../examples/usage.rs" 49 | -------------------------------------------------------------------------------- /llama-cpp-2/README.md: -------------------------------------------------------------------------------- 1 | # llama-cpp-rs-2 2 | 3 | [utilityai]: https://utilityai.ca 4 | 5 | A wrapper around the [llama-cpp](https://github.com/ggerganov/llama.cpp/) library for rust. 6 | 7 | # Info 8 | 9 | This is part of the project powering all the LLMs at [utilityai], it is tightly coupled llama.cpp and mimics its API as 10 | closly as possible while being safe in order to stay up to date. 11 | 12 | # Dependencies 13 | 14 | This uses bindgen to build the bindings to llama.cpp. This means that you need to have clang installed on your system. 15 | 16 | If this is a problem for you, open an issue, and we can look into including the bindings. 17 | 18 | See [bindgen](https://rust-lang.github.io/rust-bindgen/requirements.html) for more information. 19 | 20 | # Disclaimer 21 | 22 | This crate is *not safe*. There is absolutly ways to misuse the llama.cpp API provided to create UB, please create an issue if you spot one. Do not use this code for tasks where UB is not acceptable. 23 | 24 | This is not a simple library to use. In an ideal world a nice abstraction would be written on top of this crate to 25 | provide an ergonomic API - the benefits of this crate over raw bindings is safety (and not much of it as that) and not much else. 26 | 27 | We compensate for this shortcoming (we hope) by providing lots of examples and good documentation. Testing is a work in 28 | progress. 29 | 30 | # Contributing 31 | 32 | Contributions are welcome. Please open an issue before starting work on a non-trivial PR. 33 | -------------------------------------------------------------------------------- /llama-cpp-2/src/context.rs: -------------------------------------------------------------------------------- 1 | //! Safe wrapper around `llama_context`. 2 | 3 | use std::fmt::{Debug, Formatter}; 4 | use std::num::NonZeroI32; 5 | use std::ptr::NonNull; 6 | use std::slice; 7 | 8 | use crate::llama_batch::LlamaBatch; 9 | use crate::model::{LlamaLoraAdapter, LlamaModel}; 10 | use crate::timing::LlamaTimings; 11 | use crate::token::data::LlamaTokenData; 12 | use crate::token::data_array::LlamaTokenDataArray; 13 | use crate::token::LlamaToken; 14 | use crate::{ 15 | DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError, 16 | LlamaLoraAdapterSetError, 17 | }; 18 | 19 | pub mod kv_cache; 20 | pub mod params; 21 | pub mod session; 22 | 23 | /// Safe wrapper around `llama_context`. 24 | #[allow(clippy::module_name_repetitions)] 25 | pub struct LlamaContext<'a> { 26 | pub(crate) context: NonNull, 27 | /// a reference to the contexts model. 28 | pub model: &'a LlamaModel, 29 | initialized_logits: Vec, 30 | embeddings_enabled: bool, 31 | } 32 | 33 | impl Debug for LlamaContext<'_> { 34 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 35 | f.debug_struct("LlamaContext") 36 | .field("context", &self.context) 37 | .finish() 38 | } 39 | } 40 | 41 | impl<'model> LlamaContext<'model> { 42 | pub(crate) fn new( 43 | llama_model: &'model LlamaModel, 44 | llama_context: NonNull, 45 | embeddings_enabled: bool, 46 | ) -> Self { 47 | Self { 48 | context: llama_context, 49 | model: llama_model, 50 | initialized_logits: Vec::new(), 51 | embeddings_enabled, 52 | } 53 | } 54 | 55 | /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`]. 56 | #[must_use] 57 | pub fn n_batch(&self) -> u32 { 58 | unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) } 59 | } 60 | 61 | /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`]. 62 | #[must_use] 63 | pub fn n_ubatch(&self) -> u32 { 64 | unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) } 65 | } 66 | 67 | /// Gets the size of the context. 68 | #[must_use] 69 | pub fn n_ctx(&self) -> u32 { 70 | unsafe { llama_cpp_sys_2::llama_n_ctx(self.context.as_ptr()) } 71 | } 72 | 73 | /// Decodes the batch. 74 | /// 75 | /// # Errors 76 | /// 77 | /// - `DecodeError` if the decoding failed. 78 | /// 79 | /// # Panics 80 | /// 81 | /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems) 82 | pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> { 83 | let result = 84 | unsafe { llama_cpp_sys_2::llama_decode(self.context.as_ptr(), batch.llama_batch) }; 85 | 86 | match NonZeroI32::new(result) { 87 | None => { 88 | self.initialized_logits 89 | .clone_from(&batch.initialized_logits); 90 | Ok(()) 91 | } 92 | Some(error) => Err(DecodeError::from(error)), 93 | } 94 | } 95 | 96 | /// Encodes the batch. 97 | /// 98 | /// # Errors 99 | /// 100 | /// - `EncodeError` if the decoding failed. 101 | /// 102 | /// # Panics 103 | /// 104 | /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems) 105 | pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> { 106 | let result = 107 | unsafe { llama_cpp_sys_2::llama_encode(self.context.as_ptr(), batch.llama_batch) }; 108 | 109 | match NonZeroI32::new(result) { 110 | None => { 111 | self.initialized_logits 112 | .clone_from(&batch.initialized_logits); 113 | Ok(()) 114 | } 115 | Some(error) => Err(EncodeError::from(error)), 116 | } 117 | } 118 | 119 | /// Get the embeddings for the `i`th sequence in the current context. 120 | /// 121 | /// # Returns 122 | /// 123 | /// A slice containing the embeddings for the last decoded batch. 124 | /// The size corresponds to the `n_embd` parameter of the context's model. 125 | /// 126 | /// # Errors 127 | /// 128 | /// - When the current context was constructed without enabling embeddings. 129 | /// - If the current model had a pooling type of [`llama_cpp_sys_2::LLAMA_POOLING_TYPE_NONE`] 130 | /// - If the given sequence index exceeds the max sequence id. 131 | /// 132 | /// # Panics 133 | /// 134 | /// * `n_embd` does not fit into a usize 135 | pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> { 136 | if !self.embeddings_enabled { 137 | return Err(EmbeddingsError::NotEnabled); 138 | } 139 | 140 | let n_embd = 141 | usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize"); 142 | 143 | unsafe { 144 | let embedding = llama_cpp_sys_2::llama_get_embeddings_seq(self.context.as_ptr(), i); 145 | 146 | // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here. 147 | if embedding.is_null() { 148 | Err(EmbeddingsError::NonePoolType) 149 | } else { 150 | Ok(slice::from_raw_parts(embedding, n_embd)) 151 | } 152 | } 153 | } 154 | 155 | /// Get the embeddings for the `i`th token in the current context. 156 | /// 157 | /// # Returns 158 | /// 159 | /// A slice containing the embeddings for the last decoded batch of the given token. 160 | /// The size corresponds to the `n_embd` parameter of the context's model. 161 | /// 162 | /// # Errors 163 | /// 164 | /// - When the current context was constructed without enabling embeddings. 165 | /// - When the given token didn't have logits enabled when it was passed. 166 | /// - If the given token index exceeds the max token id. 167 | /// 168 | /// # Panics 169 | /// 170 | /// * `n_embd` does not fit into a usize 171 | pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> { 172 | if !self.embeddings_enabled { 173 | return Err(EmbeddingsError::NotEnabled); 174 | } 175 | 176 | let n_embd = 177 | usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize"); 178 | 179 | unsafe { 180 | let embedding = llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i); 181 | // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here. 182 | if embedding.is_null() { 183 | Err(EmbeddingsError::LogitsNotEnabled) 184 | } else { 185 | Ok(slice::from_raw_parts(embedding, n_embd)) 186 | } 187 | } 188 | } 189 | 190 | /// Get the logits for the last token in the context. 191 | /// 192 | /// # Returns 193 | /// An iterator over unsorted `LlamaTokenData` containing the 194 | /// logits for the last token in the context. 195 | /// 196 | /// # Panics 197 | /// 198 | /// - underlying logits data is null 199 | pub fn candidates(&self) -> impl Iterator + '_ { 200 | (0_i32..).zip(self.get_logits()).map(|(i, logit)| { 201 | let token = LlamaToken::new(i); 202 | LlamaTokenData::new(token, *logit, 0_f32) 203 | }) 204 | } 205 | 206 | /// Get the token data array for the last token in the context. 207 | /// 208 | /// This is a convience method that implements: 209 | /// ```ignore 210 | /// LlamaTokenDataArray::from_iter(ctx.candidates(), false) 211 | /// ``` 212 | /// 213 | /// # Panics 214 | /// 215 | /// - underlying logits data is null 216 | #[must_use] 217 | pub fn token_data_array(&self) -> LlamaTokenDataArray { 218 | LlamaTokenDataArray::from_iter(self.candidates(), false) 219 | } 220 | 221 | /// Token logits obtained from the last call to `decode()`. 222 | /// The logits for which `batch.logits[i] != 0` are stored contiguously 223 | /// in the order they have appeared in the batch. 224 | /// Rows: number of tokens for which `batch.logits[i] != 0` 225 | /// Cols: `n_vocab` 226 | /// 227 | /// # Returns 228 | /// 229 | /// A slice containing the logits for the last decoded token. 230 | /// The size corresponds to the `n_vocab` parameter of the context's model. 231 | /// 232 | /// # Panics 233 | /// 234 | /// - `n_vocab` does not fit into a usize 235 | /// - token data returned is null 236 | #[must_use] 237 | pub fn get_logits(&self) -> &[f32] { 238 | let data = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) }; 239 | assert!(!data.is_null(), "logits data for last token is null"); 240 | let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize"); 241 | 242 | unsafe { slice::from_raw_parts(data, len) } 243 | } 244 | 245 | /// Get the logits for the ith token in the context. 246 | /// 247 | /// # Panics 248 | /// 249 | /// - logit `i` is not initialized. 250 | pub fn candidates_ith(&self, i: i32) -> impl Iterator + '_ { 251 | (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| { 252 | let token = LlamaToken::new(i); 253 | LlamaTokenData::new(token, *logit, 0_f32) 254 | }) 255 | } 256 | 257 | /// Get the token data array for the ith token in the context. 258 | /// 259 | /// This is a convience method that implements: 260 | /// ```ignore 261 | /// LlamaTokenDataArray::from_iter(ctx.candidates_ith(i), false) 262 | /// ``` 263 | /// 264 | /// # Panics 265 | /// 266 | /// - logit `i` is not initialized. 267 | #[must_use] 268 | pub fn token_data_array_ith(&self, i: i32) -> LlamaTokenDataArray { 269 | LlamaTokenDataArray::from_iter(self.candidates_ith(i), false) 270 | } 271 | 272 | /// Get the logits for the ith token in the context. 273 | /// 274 | /// # Panics 275 | /// 276 | /// - `i` is greater than `n_ctx` 277 | /// - `n_vocab` does not fit into a usize 278 | /// - logit `i` is not initialized. 279 | #[must_use] 280 | pub fn get_logits_ith(&self, i: i32) -> &[f32] { 281 | assert!( 282 | self.initialized_logits.contains(&i), 283 | "logit {i} is not initialized. only {:?} is", 284 | self.initialized_logits 285 | ); 286 | assert!( 287 | self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"), 288 | "n_ctx ({}) must be greater than i ({})", 289 | self.n_ctx(), 290 | i 291 | ); 292 | 293 | let data = unsafe { llama_cpp_sys_2::llama_get_logits_ith(self.context.as_ptr(), i) }; 294 | let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize"); 295 | 296 | unsafe { slice::from_raw_parts(data, len) } 297 | } 298 | 299 | /// Reset the timings for the context. 300 | pub fn reset_timings(&mut self) { 301 | unsafe { llama_cpp_sys_2::llama_perf_context_reset(self.context.as_ptr()) } 302 | } 303 | 304 | /// Returns the timings for the context. 305 | pub fn timings(&mut self) -> LlamaTimings { 306 | let timings = unsafe { llama_cpp_sys_2::llama_perf_context(self.context.as_ptr()) }; 307 | LlamaTimings { timings } 308 | } 309 | 310 | /// Sets a lora adapter. 311 | /// 312 | /// # Errors 313 | /// 314 | /// See [`LlamaLoraAdapterSetError`] for more information. 315 | pub fn lora_adapter_set( 316 | &self, 317 | adapter: &mut LlamaLoraAdapter, 318 | scale: f32, 319 | ) -> Result<(), LlamaLoraAdapterSetError> { 320 | let err_code = unsafe { 321 | llama_cpp_sys_2::llama_set_adapter_lora( 322 | self.context.as_ptr(), 323 | adapter.lora_adapter.as_ptr(), 324 | scale, 325 | ) 326 | }; 327 | if err_code != 0 { 328 | return Err(LlamaLoraAdapterSetError::ErrorResult(err_code)); 329 | } 330 | 331 | tracing::debug!("Set lora adapter"); 332 | Ok(()) 333 | } 334 | 335 | /// Remove a lora adapter. 336 | /// 337 | /// # Errors 338 | /// 339 | /// See [`LlamaLoraAdapterRemoveError`] for more information. 340 | pub fn lora_adapter_remove( 341 | &self, 342 | adapter: &mut LlamaLoraAdapter, 343 | ) -> Result<(), LlamaLoraAdapterRemoveError> { 344 | let err_code = unsafe { 345 | llama_cpp_sys_2::llama_rm_adapter_lora( 346 | self.context.as_ptr(), 347 | adapter.lora_adapter.as_ptr(), 348 | ) 349 | }; 350 | if err_code != 0 { 351 | return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code)); 352 | } 353 | 354 | tracing::debug!("Remove lora adapter"); 355 | Ok(()) 356 | } 357 | } 358 | 359 | impl Drop for LlamaContext<'_> { 360 | fn drop(&mut self) { 361 | unsafe { llama_cpp_sys_2::llama_free(self.context.as_ptr()) } 362 | } 363 | } 364 | -------------------------------------------------------------------------------- /llama-cpp-2/src/context/kv_cache.rs: -------------------------------------------------------------------------------- 1 | //! utilities for working with the kv cache 2 | 3 | use crate::context::LlamaContext; 4 | use std::ffi::c_int; 5 | use std::num::{NonZeroU8, TryFromIntError}; 6 | 7 | /// Errors that can occur when attempting to prepare values for the kv cache 8 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 9 | #[allow(clippy::module_name_repetitions)] 10 | pub enum KvCacheConversionError { 11 | /// Sequence id conversion to i32 failed 12 | #[error("Provided sequence id is too large for a i32")] 13 | SeqIdTooLarge(#[source] TryFromIntError), 14 | /// Position 0 conversion to i32 failed 15 | #[error("Provided start position is too large for a i32")] 16 | P0TooLarge(#[source] TryFromIntError), 17 | /// Position 1 conversion to i32 failed 18 | #[error("Provided end position is too large for a i32")] 19 | P1TooLarge(#[source] TryFromIntError), 20 | } 21 | 22 | impl LlamaContext<'_> { 23 | /// Copy the cache from one sequence to another. 24 | /// 25 | /// # Parameters 26 | /// 27 | /// * `src` - The sequence id to copy the cache from. 28 | /// * `dest` - The sequence id to copy the cache to. 29 | /// * `size` - The size of the cache to copy. 30 | pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) { 31 | unsafe { llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, 0, size) } 32 | } 33 | 34 | /// Copy the cache from one sequence to another. 35 | /// 36 | /// # Returns 37 | /// A `Result` indicating whether the operation was successful. 38 | /// 39 | /// # Parameters 40 | /// * `src` - The sequence id to copy the cache from. 41 | /// * `dest` - The sequence id to copy the cache to. 42 | /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`. 43 | /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`. 44 | /// 45 | /// # Errors 46 | /// If either position exceeds [`i32::MAX`]. 47 | pub fn copy_kv_cache_seq( 48 | &mut self, 49 | src: i32, 50 | dest: i32, 51 | p0: Option, 52 | p1: Option, 53 | ) -> Result<(), KvCacheConversionError> { 54 | let p0 = p0 55 | .map_or(Ok(-1), i32::try_from) 56 | .map_err(KvCacheConversionError::P0TooLarge)?; 57 | let p1 = p1 58 | .map_or(Ok(-1), i32::try_from) 59 | .map_err(KvCacheConversionError::P1TooLarge)?; 60 | unsafe { 61 | llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, p0, p1); 62 | } 63 | Ok(()) 64 | } 65 | 66 | /// Clear the kv cache for the given sequence within the specified range `[p0, p1)` 67 | /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed. 68 | /// 69 | /// # Returns 70 | /// A `Result` indicating whether the operation was successful. If the sequence id or 71 | /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned. 72 | /// 73 | /// # Parameters 74 | /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences 75 | /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`. 76 | /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`. 77 | /// 78 | /// # Errors 79 | /// If the sequence id or either position exceeds [`i32::MAX`]. 80 | pub fn clear_kv_cache_seq( 81 | &mut self, 82 | src: Option, 83 | p0: Option, 84 | p1: Option, 85 | ) -> Result { 86 | let src = src 87 | .map_or(Ok(-1), i32::try_from) 88 | .map_err(KvCacheConversionError::SeqIdTooLarge)?; 89 | let p0 = p0 90 | .map_or(Ok(-1), i32::try_from) 91 | .map_err(KvCacheConversionError::P0TooLarge)?; 92 | let p1 = p1 93 | .map_or(Ok(-1), i32::try_from) 94 | .map_err(KvCacheConversionError::P1TooLarge)?; 95 | Ok(unsafe { llama_cpp_sys_2::llama_kv_self_seq_rm(self.context.as_ptr(), src, p0, p1) }) 96 | } 97 | 98 | /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) 99 | #[must_use] 100 | pub fn get_kv_cache_used_cells(&self) -> i32 { 101 | unsafe { llama_cpp_sys_2::llama_kv_self_used_cells(self.context.as_ptr()) } 102 | } 103 | 104 | /// Clear the KV cache 105 | pub fn clear_kv_cache(&mut self) { 106 | unsafe { llama_cpp_sys_2::llama_kv_self_clear(self.context.as_ptr()) } 107 | } 108 | 109 | /// Removes all tokens that do not belong to the specified sequence 110 | /// 111 | /// # Parameters 112 | /// 113 | /// * `seq_id` - The sequence id to keep 114 | pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) { 115 | unsafe { llama_cpp_sys_2::llama_kv_self_seq_keep(self.context.as_ptr(), seq_id) } 116 | } 117 | 118 | #[allow(clippy::doc_markdown)] 119 | /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)` 120 | /// If the KV cache is RoPEd, the KV data is updated accordingly: 121 | /// - lazily on next [`LlamaContext::decode`] 122 | /// - explicitly with [`Self::kv_cache_update`] 123 | /// 124 | /// # Returns 125 | /// A `Result` indicating whether the operation was successful. 126 | /// 127 | /// # Parameters 128 | /// 129 | /// * `seq_id` - The sequence id to update 130 | /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`. 131 | /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`. 132 | /// * `delta` - The relative position to add to the tokens 133 | /// 134 | /// # Errors 135 | /// If either position exceeds [`i32::MAX`]. 136 | pub fn kv_cache_seq_add( 137 | &mut self, 138 | seq_id: i32, 139 | p0: Option, 140 | p1: Option, 141 | delta: i32, 142 | ) -> Result<(), KvCacheConversionError> { 143 | let p0 = p0 144 | .map_or(Ok(-1), i32::try_from) 145 | .map_err(KvCacheConversionError::P0TooLarge)?; 146 | let p1 = p1 147 | .map_or(Ok(-1), i32::try_from) 148 | .map_err(KvCacheConversionError::P1TooLarge)?; 149 | unsafe { 150 | llama_cpp_sys_2::llama_kv_self_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta); 151 | } 152 | Ok(()) 153 | } 154 | 155 | /// Integer division of the positions by factor of `d > 1` 156 | /// If the KV cache is `RoPEd`, the KV data is updated accordingly: 157 | /// - lazily on next [`LlamaContext::decode`] 158 | /// - explicitly with [`Self::kv_cache_update`] 159 | /// 160 | /// # Returns 161 | /// A `Result` indicating whether the operation was successful. 162 | /// 163 | /// # Parameters 164 | /// 165 | /// * `seq_id` - The sequence id to update 166 | /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`. 167 | /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`. 168 | /// * `d` - The factor to divide the positions by 169 | /// 170 | /// # Errors 171 | /// If either position exceeds [`i32::MAX`]. 172 | pub fn kv_cache_seq_div( 173 | &mut self, 174 | seq_id: i32, 175 | p0: Option, 176 | p1: Option, 177 | d: NonZeroU8, 178 | ) -> Result<(), KvCacheConversionError> { 179 | let p0 = p0 180 | .map_or(Ok(-1), i32::try_from) 181 | .map_err(KvCacheConversionError::P0TooLarge)?; 182 | let p1 = p1 183 | .map_or(Ok(-1), i32::try_from) 184 | .map_err(KvCacheConversionError::P1TooLarge)?; 185 | let d = c_int::from(d.get()); 186 | unsafe { llama_cpp_sys_2::llama_kv_self_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) } 187 | Ok(()) 188 | } 189 | 190 | /// Returns the largest position present in the KV cache for the specified sequence 191 | /// 192 | /// # Parameters 193 | /// 194 | /// * `seq_id` - The sequence id to get the max position for 195 | #[must_use] 196 | pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 { 197 | unsafe { llama_cpp_sys_2::llama_kv_self_seq_pos_max(self.context.as_ptr(), seq_id) } 198 | } 199 | 200 | /// Defragment the KV cache 201 | /// This will be applied: 202 | /// - lazily on next [`LlamaContext::decode`] 203 | /// - explicitly with [`Self::kv_cache_update`] 204 | pub fn kv_cache_defrag(&mut self) { 205 | unsafe { llama_cpp_sys_2::llama_kv_self_defrag(self.context.as_ptr()) } 206 | } 207 | 208 | /// Apply the KV cache updates (such as K-shifts, defragmentation, etc.) 209 | pub fn kv_cache_update(&mut self) { 210 | unsafe { llama_cpp_sys_2::llama_kv_self_update(self.context.as_ptr()) } 211 | } 212 | } 213 | -------------------------------------------------------------------------------- /llama-cpp-2/src/context/params.rs: -------------------------------------------------------------------------------- 1 | //! A safe wrapper around `llama_context_params`. 2 | use std::fmt::Debug; 3 | use std::num::NonZeroU32; 4 | 5 | /// A rusty wrapper around `rope_scaling_type`. 6 | #[repr(i8)] 7 | #[derive(Copy, Clone, Debug, PartialEq, Eq)] 8 | pub enum RopeScalingType { 9 | /// The scaling type is unspecified 10 | Unspecified = -1, 11 | /// No scaling 12 | None = 0, 13 | /// Linear scaling 14 | Linear = 1, 15 | /// Yarn scaling 16 | Yarn = 2, 17 | } 18 | 19 | /// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if 20 | /// the value is not recognized. 21 | impl From for RopeScalingType { 22 | fn from(value: i32) -> Self { 23 | match value { 24 | 0 => Self::None, 25 | 1 => Self::Linear, 26 | 2 => Self::Yarn, 27 | _ => Self::Unspecified, 28 | } 29 | } 30 | } 31 | 32 | /// Create a `c_int` from a `RopeScalingType`. 33 | impl From for i32 { 34 | fn from(value: RopeScalingType) -> Self { 35 | match value { 36 | RopeScalingType::None => 0, 37 | RopeScalingType::Linear => 1, 38 | RopeScalingType::Yarn => 2, 39 | RopeScalingType::Unspecified => -1, 40 | } 41 | } 42 | } 43 | 44 | /// A rusty wrapper around `LLAMA_POOLING_TYPE`. 45 | #[repr(i8)] 46 | #[derive(Copy, Clone, Debug, PartialEq, Eq)] 47 | pub enum LlamaPoolingType { 48 | /// The pooling type is unspecified 49 | Unspecified = -1, 50 | /// No pooling 51 | None = 0, 52 | /// Mean pooling 53 | Mean = 1, 54 | /// CLS pooling 55 | Cls = 2, 56 | /// Last pooling 57 | Last = 3, 58 | /// Rank pooling 59 | Rank = 4, 60 | } 61 | 62 | /// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if 63 | /// the value is not recognized. 64 | impl From for LlamaPoolingType { 65 | fn from(value: i32) -> Self { 66 | match value { 67 | 0 => Self::None, 68 | 1 => Self::Mean, 69 | 2 => Self::Cls, 70 | 3 => Self::Last, 71 | 4 => Self::Rank, 72 | _ => Self::Unspecified, 73 | } 74 | } 75 | } 76 | 77 | /// Create a `c_int` from a `LlamaPoolingType`. 78 | impl From for i32 { 79 | fn from(value: LlamaPoolingType) -> Self { 80 | match value { 81 | LlamaPoolingType::None => 0, 82 | LlamaPoolingType::Mean => 1, 83 | LlamaPoolingType::Cls => 2, 84 | LlamaPoolingType::Last => 3, 85 | LlamaPoolingType::Rank => 4, 86 | LlamaPoolingType::Unspecified => -1, 87 | } 88 | } 89 | } 90 | 91 | /// A safe wrapper around `llama_context_params`. 92 | /// 93 | /// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods. 94 | /// 95 | /// # Examples 96 | /// 97 | /// ```rust 98 | /// # use std::num::NonZeroU32; 99 | /// use llama_cpp_2::context::params::LlamaContextParams; 100 | /// 101 | ///let ctx_params = LlamaContextParams::default() 102 | /// .with_n_ctx(NonZeroU32::new(2048)); 103 | /// 104 | /// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048)); 105 | /// ``` 106 | #[derive(Debug, Clone)] 107 | #[allow( 108 | missing_docs, 109 | clippy::struct_excessive_bools, 110 | clippy::module_name_repetitions 111 | )] 112 | pub struct LlamaContextParams { 113 | pub(crate) context_params: llama_cpp_sys_2::llama_context_params, 114 | } 115 | 116 | /// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync. 117 | unsafe impl Send for LlamaContextParams {} 118 | unsafe impl Sync for LlamaContextParams {} 119 | 120 | impl LlamaContextParams { 121 | /// Set the side of the context 122 | /// 123 | /// # Examples 124 | /// 125 | /// ```rust 126 | /// # use std::num::NonZeroU32; 127 | /// use llama_cpp_2::context::params::LlamaContextParams; 128 | /// let params = LlamaContextParams::default(); 129 | /// let params = params.with_n_ctx(NonZeroU32::new(2048)); 130 | /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048)); 131 | /// ``` 132 | #[must_use] 133 | pub fn with_n_ctx(mut self, n_ctx: Option) -> Self { 134 | self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get); 135 | self 136 | } 137 | 138 | /// Get the size of the context. 139 | /// 140 | /// [`None`] if the context size is specified by the model and not the context. 141 | /// 142 | /// # Examples 143 | /// 144 | /// ```rust 145 | /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); 146 | /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512)); 147 | #[must_use] 148 | pub fn n_ctx(&self) -> Option { 149 | NonZeroU32::new(self.context_params.n_ctx) 150 | } 151 | 152 | /// Set the `n_batch` 153 | /// 154 | /// # Examples 155 | /// 156 | /// ```rust 157 | /// # use std::num::NonZeroU32; 158 | /// use llama_cpp_2::context::params::LlamaContextParams; 159 | /// let params = LlamaContextParams::default() 160 | /// .with_n_batch(2048); 161 | /// assert_eq!(params.n_batch(), 2048); 162 | /// ``` 163 | #[must_use] 164 | pub fn with_n_batch(mut self, n_batch: u32) -> Self { 165 | self.context_params.n_batch = n_batch; 166 | self 167 | } 168 | 169 | /// Get the `n_batch` 170 | /// 171 | /// # Examples 172 | /// 173 | /// ```rust 174 | /// use llama_cpp_2::context::params::LlamaContextParams; 175 | /// let params = LlamaContextParams::default(); 176 | /// assert_eq!(params.n_batch(), 2048); 177 | /// ``` 178 | #[must_use] 179 | pub fn n_batch(&self) -> u32 { 180 | self.context_params.n_batch 181 | } 182 | 183 | /// Set the `n_ubatch` 184 | /// 185 | /// # Examples 186 | /// 187 | /// ```rust 188 | /// # use std::num::NonZeroU32; 189 | /// use llama_cpp_2::context::params::LlamaContextParams; 190 | /// let params = LlamaContextParams::default() 191 | /// .with_n_ubatch(512); 192 | /// assert_eq!(params.n_ubatch(), 512); 193 | /// ``` 194 | #[must_use] 195 | pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self { 196 | self.context_params.n_ubatch = n_ubatch; 197 | self 198 | } 199 | 200 | /// Get the `n_ubatch` 201 | /// 202 | /// # Examples 203 | /// 204 | /// ```rust 205 | /// use llama_cpp_2::context::params::LlamaContextParams; 206 | /// let params = LlamaContextParams::default(); 207 | /// assert_eq!(params.n_ubatch(), 512); 208 | /// ``` 209 | #[must_use] 210 | pub fn n_ubatch(&self) -> u32 { 211 | self.context_params.n_ubatch 212 | } 213 | 214 | /// Set the `flash_attention` parameter 215 | /// 216 | /// # Examples 217 | /// 218 | /// ```rust 219 | /// use llama_cpp_2::context::params::LlamaContextParams; 220 | /// let params = LlamaContextParams::default() 221 | /// .with_flash_attention(true); 222 | /// assert_eq!(params.flash_attention(), true); 223 | /// ``` 224 | #[must_use] 225 | pub fn with_flash_attention(mut self, enabled: bool) -> Self { 226 | self.context_params.flash_attn = enabled; 227 | self 228 | } 229 | 230 | /// Get the `flash_attention` parameter 231 | /// 232 | /// # Examples 233 | /// 234 | /// ```rust 235 | /// use llama_cpp_2::context::params::LlamaContextParams; 236 | /// let params = LlamaContextParams::default(); 237 | /// assert_eq!(params.flash_attention(), false); 238 | /// ``` 239 | #[must_use] 240 | pub fn flash_attention(&self) -> bool { 241 | self.context_params.flash_attn 242 | } 243 | 244 | /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU 245 | /// 246 | /// # Examples 247 | /// 248 | /// ```rust 249 | /// use llama_cpp_2::context::params::LlamaContextParams; 250 | /// let params = LlamaContextParams::default() 251 | /// .with_offload_kqv(false); 252 | /// assert_eq!(params.offload_kqv(), false); 253 | /// ``` 254 | #[must_use] 255 | pub fn with_offload_kqv(mut self, enabled: bool) -> Self { 256 | self.context_params.offload_kqv = enabled; 257 | self 258 | } 259 | 260 | /// Get the `offload_kqv` parameter 261 | /// 262 | /// # Examples 263 | /// 264 | /// ```rust 265 | /// use llama_cpp_2::context::params::LlamaContextParams; 266 | /// let params = LlamaContextParams::default(); 267 | /// assert_eq!(params.offload_kqv(), true); 268 | /// ``` 269 | #[must_use] 270 | pub fn offload_kqv(&self) -> bool { 271 | self.context_params.offload_kqv 272 | } 273 | 274 | /// Set the type of rope scaling. 275 | /// 276 | /// # Examples 277 | /// 278 | /// ```rust 279 | /// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType}; 280 | /// let params = LlamaContextParams::default() 281 | /// .with_rope_scaling_type(RopeScalingType::Linear); 282 | /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear); 283 | /// ``` 284 | #[must_use] 285 | pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self { 286 | self.context_params.rope_scaling_type = i32::from(rope_scaling_type); 287 | self 288 | } 289 | 290 | /// Get the type of rope scaling. 291 | /// 292 | /// # Examples 293 | /// 294 | /// ```rust 295 | /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); 296 | /// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified); 297 | /// ``` 298 | #[must_use] 299 | pub fn rope_scaling_type(&self) -> RopeScalingType { 300 | RopeScalingType::from(self.context_params.rope_scaling_type) 301 | } 302 | 303 | /// Set the rope frequency base. 304 | /// 305 | /// # Examples 306 | /// 307 | /// ```rust 308 | /// use llama_cpp_2::context::params::LlamaContextParams; 309 | /// let params = LlamaContextParams::default() 310 | /// .with_rope_freq_base(0.5); 311 | /// assert_eq!(params.rope_freq_base(), 0.5); 312 | /// ``` 313 | #[must_use] 314 | pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self { 315 | self.context_params.rope_freq_base = rope_freq_base; 316 | self 317 | } 318 | 319 | /// Get the rope frequency base. 320 | /// 321 | /// # Examples 322 | /// 323 | /// ```rust 324 | /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); 325 | /// assert_eq!(params.rope_freq_base(), 0.0); 326 | /// ``` 327 | #[must_use] 328 | pub fn rope_freq_base(&self) -> f32 { 329 | self.context_params.rope_freq_base 330 | } 331 | 332 | /// Set the rope frequency scale. 333 | /// 334 | /// # Examples 335 | /// 336 | /// ```rust 337 | /// use llama_cpp_2::context::params::LlamaContextParams; 338 | /// let params = LlamaContextParams::default() 339 | /// .with_rope_freq_scale(0.5); 340 | /// assert_eq!(params.rope_freq_scale(), 0.5); 341 | /// ``` 342 | #[must_use] 343 | pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self { 344 | self.context_params.rope_freq_scale = rope_freq_scale; 345 | self 346 | } 347 | 348 | /// Get the rope frequency scale. 349 | /// 350 | /// # Examples 351 | /// 352 | /// ```rust 353 | /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); 354 | /// assert_eq!(params.rope_freq_scale(), 0.0); 355 | /// ``` 356 | #[must_use] 357 | pub fn rope_freq_scale(&self) -> f32 { 358 | self.context_params.rope_freq_scale 359 | } 360 | 361 | /// Get the number of threads. 362 | /// 363 | /// # Examples 364 | /// 365 | /// ```rust 366 | /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); 367 | /// assert_eq!(params.n_threads(), 4); 368 | /// ``` 369 | #[must_use] 370 | pub fn n_threads(&self) -> i32 { 371 | self.context_params.n_threads 372 | } 373 | 374 | /// Get the number of threads allocated for batches. 375 | /// 376 | /// # Examples 377 | /// 378 | /// ```rust 379 | /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); 380 | /// assert_eq!(params.n_threads_batch(), 4); 381 | /// ``` 382 | #[must_use] 383 | pub fn n_threads_batch(&self) -> i32 { 384 | self.context_params.n_threads_batch 385 | } 386 | 387 | /// Set the number of threads. 388 | /// 389 | /// # Examples 390 | /// 391 | /// ```rust 392 | /// use llama_cpp_2::context::params::LlamaContextParams; 393 | /// let params = LlamaContextParams::default() 394 | /// .with_n_threads(8); 395 | /// assert_eq!(params.n_threads(), 8); 396 | /// ``` 397 | #[must_use] 398 | pub fn with_n_threads(mut self, n_threads: i32) -> Self { 399 | self.context_params.n_threads = n_threads; 400 | self 401 | } 402 | 403 | /// Set the number of threads allocated for batches. 404 | /// 405 | /// # Examples 406 | /// 407 | /// ```rust 408 | /// use llama_cpp_2::context::params::LlamaContextParams; 409 | /// let params = LlamaContextParams::default() 410 | /// .with_n_threads_batch(8); 411 | /// assert_eq!(params.n_threads_batch(), 8); 412 | /// ``` 413 | #[must_use] 414 | pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self { 415 | self.context_params.n_threads_batch = n_threads; 416 | self 417 | } 418 | 419 | /// Check whether embeddings are enabled 420 | /// 421 | /// # Examples 422 | /// 423 | /// ```rust 424 | /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); 425 | /// assert!(!params.embeddings()); 426 | /// ``` 427 | #[must_use] 428 | pub fn embeddings(&self) -> bool { 429 | self.context_params.embeddings 430 | } 431 | 432 | /// Enable the use of embeddings 433 | /// 434 | /// # Examples 435 | /// 436 | /// ```rust 437 | /// use llama_cpp_2::context::params::LlamaContextParams; 438 | /// let params = LlamaContextParams::default() 439 | /// .with_embeddings(true); 440 | /// assert!(params.embeddings()); 441 | /// ``` 442 | #[must_use] 443 | pub fn with_embeddings(mut self, embedding: bool) -> Self { 444 | self.context_params.embeddings = embedding; 445 | self 446 | } 447 | 448 | /// Set the evaluation callback. 449 | /// 450 | /// # Examples 451 | /// 452 | /// ```no_run 453 | /// extern "C" fn cb_eval_fn( 454 | /// t: *mut llama_cpp_sys_2::ggml_tensor, 455 | /// ask: bool, 456 | /// user_data: *mut std::ffi::c_void, 457 | /// ) -> bool { 458 | /// false 459 | /// } 460 | /// 461 | /// use llama_cpp_2::context::params::LlamaContextParams; 462 | /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn)); 463 | /// ``` 464 | #[must_use] 465 | pub fn with_cb_eval( 466 | mut self, 467 | cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback, 468 | ) -> Self { 469 | self.context_params.cb_eval = cb_eval; 470 | self 471 | } 472 | 473 | /// Set the evaluation callback user data. 474 | /// 475 | /// # Examples 476 | /// 477 | /// ```no_run 478 | /// use llama_cpp_2::context::params::LlamaContextParams; 479 | /// let params = LlamaContextParams::default(); 480 | /// let user_data = std::ptr::null_mut(); 481 | /// let params = params.with_cb_eval_user_data(user_data); 482 | /// ``` 483 | #[must_use] 484 | pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self { 485 | self.context_params.cb_eval_user_data = cb_eval_user_data; 486 | self 487 | } 488 | 489 | /// Set the type of pooling. 490 | /// 491 | /// # Examples 492 | /// 493 | /// ```rust 494 | /// use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType}; 495 | /// let params = LlamaContextParams::default() 496 | /// .with_pooling_type(LlamaPoolingType::Last); 497 | /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last); 498 | /// ``` 499 | #[must_use] 500 | pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self { 501 | self.context_params.pooling_type = i32::from(pooling_type); 502 | self 503 | } 504 | 505 | /// Get the type of pooling. 506 | /// 507 | /// # Examples 508 | /// 509 | /// ```rust 510 | /// let params = llama_cpp_2::context::params::LlamaContextParams::default(); 511 | /// assert_eq!(params.pooling_type(), llama_cpp_2::context::params::LlamaPoolingType::Unspecified); 512 | /// ``` 513 | #[must_use] 514 | pub fn pooling_type(&self) -> LlamaPoolingType { 515 | LlamaPoolingType::from(self.context_params.pooling_type) 516 | } 517 | } 518 | 519 | /// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`) 520 | /// ``` 521 | /// # use std::num::NonZeroU32; 522 | /// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType}; 523 | /// let params = LlamaContextParams::default(); 524 | /// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512"); 525 | /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified); 526 | /// ``` 527 | impl Default for LlamaContextParams { 528 | fn default() -> Self { 529 | let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() }; 530 | Self { context_params } 531 | } 532 | } 533 | -------------------------------------------------------------------------------- /llama-cpp-2/src/context/session.rs: -------------------------------------------------------------------------------- 1 | //! utilities for working with session files 2 | 3 | use crate::context::LlamaContext; 4 | use crate::token::LlamaToken; 5 | use std::ffi::{CString, NulError}; 6 | use std::path::{Path, PathBuf}; 7 | 8 | /// Failed to save a Session file 9 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 10 | pub enum SaveSessionError { 11 | /// llama.cpp failed to save the session file 12 | #[error("Failed to save session file")] 13 | FailedToSave, 14 | 15 | /// null byte in string 16 | #[error("null byte in string {0}")] 17 | NullError(#[from] NulError), 18 | 19 | /// failed to convert path to str 20 | #[error("failed to convert path {0} to str")] 21 | PathToStrError(PathBuf), 22 | } 23 | 24 | /// Failed to load a Session file 25 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 26 | pub enum LoadSessionError { 27 | /// llama.cpp failed to load the session file 28 | #[error("Failed to load session file")] 29 | FailedToLoad, 30 | 31 | /// null byte in string 32 | #[error("null byte in string {0}")] 33 | NullError(#[from] NulError), 34 | 35 | /// failed to convert path to str 36 | #[error("failed to convert path {0} to str")] 37 | PathToStrError(PathBuf), 38 | 39 | /// Insufficient max length 40 | #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")] 41 | InsufficientMaxLength { 42 | /// The length of the session file 43 | n_out: usize, 44 | /// The maximum length 45 | max_tokens: usize, 46 | }, 47 | } 48 | 49 | impl LlamaContext<'_> { 50 | /// Save the current session to a file. 51 | /// 52 | /// # Parameters 53 | /// 54 | /// * `path_session` - The file to save to. 55 | /// * `tokens` - The tokens to associate the session with. This should be a prefix of a sequence of tokens that the context has processed, so that the relevant KV caches are already filled. 56 | /// 57 | /// # Errors 58 | /// 59 | /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save the session file. 60 | pub fn save_session_file( 61 | &self, 62 | path_session: impl AsRef, 63 | tokens: &[LlamaToken], 64 | ) -> Result<(), SaveSessionError> { 65 | let path = path_session.as_ref(); 66 | let path = path 67 | .to_str() 68 | .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?; 69 | 70 | let cstr = CString::new(path)?; 71 | 72 | if unsafe { 73 | llama_cpp_sys_2::llama_save_session_file( 74 | self.context.as_ptr(), 75 | cstr.as_ptr(), 76 | tokens.as_ptr().cast::(), 77 | tokens.len(), 78 | ) 79 | } { 80 | Ok(()) 81 | } else { 82 | Err(SaveSessionError::FailedToSave) 83 | } 84 | } 85 | /// Load a session file into the current context. 86 | /// 87 | /// You still need to pass the returned tokens to the context for inference to work. What this function buys you is that the KV caches are already filled with the relevant data. 88 | /// 89 | /// # Parameters 90 | /// 91 | /// * `path_session` - The file to load from. It must be a session file from a compatible context, otherwise the function will error. 92 | /// * `max_tokens` - The maximum token length of the loaded session. If the session was saved with a longer length, the function will error. 93 | /// 94 | /// # Errors 95 | /// 96 | /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load the session file. (e.g. the file does not exist, is not a session file, etc.) 97 | pub fn load_session_file( 98 | &mut self, 99 | path_session: impl AsRef, 100 | max_tokens: usize, 101 | ) -> Result, LoadSessionError> { 102 | let path = path_session.as_ref(); 103 | let path = path 104 | .to_str() 105 | .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?; 106 | 107 | let cstr = CString::new(path)?; 108 | let mut tokens: Vec = Vec::with_capacity(max_tokens); 109 | let mut n_out = 0; 110 | 111 | // SAFETY: cast is valid as LlamaToken is repr(transparent) 112 | let tokens_out = tokens.as_mut_ptr().cast::(); 113 | 114 | let load_session_success = unsafe { 115 | llama_cpp_sys_2::llama_load_session_file( 116 | self.context.as_ptr(), 117 | cstr.as_ptr(), 118 | tokens_out, 119 | max_tokens, 120 | &mut n_out, 121 | ) 122 | }; 123 | if load_session_success { 124 | if n_out > max_tokens { 125 | return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens }); 126 | } 127 | // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written 128 | unsafe { 129 | tokens.set_len(n_out); 130 | } 131 | Ok(tokens) 132 | } else { 133 | Err(LoadSessionError::FailedToLoad) 134 | } 135 | } 136 | 137 | /// Returns the maximum size in bytes of the state (rng, logits, embedding 138 | /// and `kv_cache`) - will often be smaller after compacting tokens 139 | #[must_use] 140 | pub fn get_state_size(&self) -> usize { 141 | unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) } 142 | } 143 | 144 | /// Copies the state to the specified destination address. 145 | /// 146 | /// Returns the number of bytes copied 147 | /// 148 | /// # Safety 149 | /// 150 | /// Destination needs to have allocated enough memory. 151 | pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize { 152 | unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) } 153 | } 154 | 155 | /// Set the state reading from the specified address 156 | /// Returns the number of bytes read 157 | /// 158 | /// # Safety 159 | /// 160 | /// help wanted: not entirely sure what the safety requirements are here. 161 | pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize { 162 | unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) } 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /llama-cpp-2/src/grammar/arithmetic.gbnf: -------------------------------------------------------------------------------- 1 | root ::= (expr "=" ws term "\n")+ 2 | expr ::= term ([-+*/] term)* 3 | term ::= ident | num | "(" ws expr ")" ws 4 | ident ::= [a-z] [a-z0-9_]* ws 5 | num ::= [0-9]+ ws 6 | ws ::= [ \t\n]* 7 | -------------------------------------------------------------------------------- /llama-cpp-2/src/grammar/c.gbnf: -------------------------------------------------------------------------------- 1 | root ::= (declaration)* 2 | 3 | declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}" 4 | 5 | dataType ::= "int" ws | "float" ws | "char" ws 6 | identifier ::= [a-zA-Z_] [a-zA-Z_0-9]* 7 | 8 | parameter ::= dataType identifier 9 | 10 | statement ::= 11 | ( dataType identifier ws "=" ws expression ";" ) | 12 | ( identifier ws "=" ws expression ";" ) | 13 | ( identifier ws "(" argList? ")" ";" ) | 14 | ( "return" ws expression ";" ) | 15 | ( "while" "(" condition ")" "{" statement* "}" ) | 16 | ( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) | 17 | ( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) | 18 | ( singleLineComment ) | 19 | ( multiLineComment ) 20 | 21 | forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression 22 | forUpdate ::= identifier ws "=" ws expression 23 | 24 | condition ::= expression relationOperator expression 25 | relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">") 26 | 27 | expression ::= term (("+" | "-") term)* 28 | term ::= factor(("*" | "/") factor)* 29 | 30 | factor ::= identifier | number | unaryTerm | funcCall | parenExpression 31 | unaryTerm ::= "-" factor 32 | funcCall ::= identifier "(" argList? ")" 33 | parenExpression ::= "(" ws expression ws ")" 34 | 35 | argList ::= expression ("," ws expression)* 36 | 37 | number ::= [0-9]+ 38 | 39 | singleLineComment ::= "//" [^\n]* "\n" 40 | multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/" 41 | 42 | ws ::= ([ \t\n]+) 43 | -------------------------------------------------------------------------------- /llama-cpp-2/src/grammar/chess.gbnf: -------------------------------------------------------------------------------- 1 | # Specifies chess moves as a list in algebraic notation, using PGN conventions 2 | 3 | # Force first move to "1. ", then any 1-2 digit number after, relying on model to follow the pattern 4 | root ::= "1. " move " " move "\n" ([1-9] [0-9]? ". " move " " move "\n")+ 5 | move ::= (pawn | nonpawn | castle) [+#]? 6 | 7 | # piece type, optional file/rank, optional capture, dest file & rank 8 | nonpawn ::= [NBKQR] [a-h]? [1-8]? "x"? [a-h] [1-8] 9 | 10 | # optional file & capture, dest file & rank, optional promotion 11 | pawn ::= ([a-h] "x")? [a-h] [1-8] ("=" [NBKQR])? 12 | 13 | castle ::= "O-O" "-O"? 14 | -------------------------------------------------------------------------------- /llama-cpp-2/src/grammar/japanese.gbnf: -------------------------------------------------------------------------------- 1 | # A probably incorrect grammar for Japanese 2 | root ::= jp-char+ ([ \t\n] jp-char+)* 3 | jp-char ::= hiragana | katakana | punctuation | cjk 4 | hiragana ::= [ぁ-ゟ] 5 | katakana ::= [ァ-ヿ] 6 | punctuation ::= [、-〾] 7 | cjk ::= [一-鿿] 8 | -------------------------------------------------------------------------------- /llama-cpp-2/src/grammar/json.gbnf: -------------------------------------------------------------------------------- 1 | root ::= object 2 | value ::= object | array | string | number | ("true" | "false" | "null") 3 | 4 | object ::= 5 | "{" ( 6 | string ":" value 7 | ("," string ":" value)* 8 | )? "}" 9 | 10 | array ::= 11 | "[" ( 12 | value 13 | ("," value)* 14 | )? "]" 15 | 16 | string ::= 17 | "\"" ( 18 | [^"\\] | 19 | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes 20 | )* "\"" 21 | 22 | number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? 23 | -------------------------------------------------------------------------------- /llama-cpp-2/src/grammar/json_arr.gbnf: -------------------------------------------------------------------------------- 1 | # This is the same as json.gbnf but we restrict whitespaces at the end of the root array 2 | # Useful for generating JSON arrays 3 | 4 | root ::= arr 5 | value ::= object | array | string | number | ("true" | "false" | "null") ws 6 | 7 | arr ::= 8 | "[\n" ws ( 9 | value 10 | (",\n" ws value)* 11 | )? "]" 12 | 13 | object ::= 14 | "{" ws ( 15 | string ":" ws value 16 | ("," ws string ":" ws value)* 17 | )? "}" ws 18 | 19 | array ::= 20 | "[" ws ( 21 | value 22 | ("," ws value)* 23 | )? "]" ws 24 | 25 | string ::= 26 | "\"" ( 27 | [^"\\] | 28 | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes 29 | )* "\"" ws 30 | 31 | number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws 32 | 33 | # Optional space: by convention, applied in this grammar after literal chars when allowed 34 | ws ::= ([ \t\n] ws)? 35 | -------------------------------------------------------------------------------- /llama-cpp-2/src/grammar/list.gbnf: -------------------------------------------------------------------------------- 1 | root ::= item* 2 | 3 | # comment 4 | item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n" 5 | -------------------------------------------------------------------------------- /llama-cpp-2/src/grammar/tests.rs: -------------------------------------------------------------------------------- 1 | use std::fs::File; 2 | use std::io::BufReader; 3 | use std::path::Path; 4 | 5 | use super::*; 6 | 7 | #[test] 8 | fn check_parse() { 9 | let dir = Path::new("src/grammar"); 10 | let files = std::fs::read_dir(dir) 11 | .expect("Failed to read grammar directory") 12 | .filter_map(Result::ok) 13 | .map(|os_str| os_str.path()) 14 | .filter(|p| p.is_file()) 15 | .filter(|f| f.extension().unwrap_or_default() == "gbnf") 16 | .map(File::open) 17 | .collect::>(); 18 | assert!( 19 | !files.is_empty(), 20 | "No grammar files found in {}", 21 | dir.canonicalize().unwrap().display() 22 | ); 23 | for file in files { 24 | let reader = BufReader::new(file.unwrap()); 25 | let file = std::io::read_to_string(reader).unwrap(); 26 | LlamaGrammar::from_str(&file).unwrap(); 27 | } 28 | } 29 | 30 | #[test] 31 | fn check_parse_simple() { 32 | let parse_state = ParseState::from_str(r#"root ::= "cat""#).unwrap(); 33 | assert_eq!( 34 | ParseState { 35 | symbol_ids: BTreeMap::from([("root".to_string(), 0),]), 36 | rules: vec![vec![ 37 | llama_grammar_element { 38 | type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR, 39 | value: 'c' as u32, 40 | }, 41 | llama_grammar_element { 42 | type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR, 43 | value: 'a' as u32, 44 | }, 45 | llama_grammar_element { 46 | type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR, 47 | value: 't' as u32, 48 | }, 49 | llama_grammar_element { 50 | type_: llama_cpp_sys_2::LLAMA_GRETYPE_END, 51 | value: 0, 52 | } 53 | ]], 54 | }, 55 | parse_state 56 | ); 57 | } 58 | 59 | #[test] 60 | fn check_parse_char_range() { 61 | let parse_state = ParseState::from_str(r#"root ::= [a-zA-Z]"#).unwrap(); 62 | assert_eq!( 63 | ParseState { 64 | symbol_ids: BTreeMap::from([("root".to_string(), 0),]), 65 | rules: vec![vec![ 66 | llama_grammar_element { 67 | type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR, 68 | value: 'a' as u32 69 | }, 70 | llama_grammar_element { 71 | type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR_RNG_UPPER, 72 | value: 'z' as u32 73 | }, 74 | llama_grammar_element { 75 | type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR_ALT, 76 | value: 'A' as u32 77 | }, 78 | llama_grammar_element { 79 | type_: llama_cpp_sys_2::LLAMA_GRETYPE_CHAR_RNG_UPPER, 80 | value: 'Z' as u32 81 | }, 82 | llama_grammar_element { 83 | type_: llama_cpp_sys_2::LLAMA_GRETYPE_END, 84 | value: 0 85 | } 86 | ]] 87 | }, 88 | parse_state 89 | ); 90 | } 91 | -------------------------------------------------------------------------------- /llama-cpp-2/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! Bindings to the llama.cpp library. 2 | //! 3 | //! As llama.cpp is a very fast moving target, this crate does not attempt to create a stable API 4 | //! with all the rust idioms. Instead it provided safe wrappers around nearly direct bindings to 5 | //! llama.cpp. This makes it easier to keep up with the changes in llama.cpp, but does mean that 6 | //! the API is not as nice as it could be. 7 | //! 8 | //! # Examples 9 | //! 10 | //! - [simple](https://github.com/utilityai/llama-cpp-rs/tree/main/simple) 11 | //! 12 | //! # Feature Flags 13 | //! 14 | //! - `cuda` enables CUDA gpu support. 15 | //! - `sampler` adds the [`context::sample::sampler`] struct for a more rusty way of sampling. 16 | use std::ffi::NulError; 17 | use std::fmt::Debug; 18 | use std::num::NonZeroI32; 19 | 20 | use crate::llama_batch::BatchAddError; 21 | use std::os::raw::c_int; 22 | use std::path::PathBuf; 23 | use std::string::FromUtf8Error; 24 | 25 | pub mod context; 26 | pub mod llama_backend; 27 | pub mod llama_batch; 28 | mod log; 29 | pub mod model; 30 | pub mod sampling; 31 | pub mod timing; 32 | pub mod token; 33 | pub mod token_type; 34 | 35 | /// A failable result from a llama.cpp function. 36 | pub type Result = std::result::Result; 37 | 38 | /// All errors that can occur in the llama-cpp crate. 39 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 40 | pub enum LLamaCppError { 41 | /// The backend was already initialized. This can generally be ignored as initializing the backend 42 | /// is idempotent. 43 | #[error("BackendAlreadyInitialized")] 44 | BackendAlreadyInitialized, 45 | /// There was an error while get the chat template from model. 46 | #[error("{0}")] 47 | ChatTemplateError(#[from] ChatTemplateError), 48 | /// There was an error while decoding a batch. 49 | #[error("{0}")] 50 | DecodeError(#[from] DecodeError), 51 | /// There was an error while encoding a batch. 52 | #[error("{0}")] 53 | EncodeError(#[from] EncodeError), 54 | /// There was an error loading a model. 55 | #[error("{0}")] 56 | LlamaModelLoadError(#[from] LlamaModelLoadError), 57 | /// There was an error creating a new model context. 58 | #[error("{0}")] 59 | LlamaContextLoadError(#[from] LlamaContextLoadError), 60 | /// There was an error adding a token to a batch. 61 | #[error["{0}"]] 62 | BatchAddError(#[from] BatchAddError), 63 | /// see [`EmbeddingsError`] 64 | #[error(transparent)] 65 | EmbeddingError(#[from] EmbeddingsError), 66 | // See [`LlamaSamplerError`] 67 | } 68 | 69 | /// There was an error while getting the chat template from a model. 70 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 71 | pub enum ChatTemplateError { 72 | /// gguf has no chat template (by that name) 73 | #[error("chat template not found - returned null pointer")] 74 | MissingTemplate, 75 | 76 | /// chat template contained a null byte 77 | #[error("null byte in string {0}")] 78 | NullError(#[from] NulError), 79 | 80 | /// The chat template was not valid utf8. 81 | #[error(transparent)] 82 | Utf8Error(#[from] std::str::Utf8Error), 83 | } 84 | 85 | /// Failed fetching metadata value 86 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 87 | pub enum MetaValError { 88 | /// The provided string contains an unexpected null-byte 89 | #[error("null byte in string {0}")] 90 | NullError(#[from] NulError), 91 | 92 | /// The returned data contains invalid UTF8 data 93 | #[error("FromUtf8Error {0}")] 94 | FromUtf8Error(#[from] FromUtf8Error), 95 | 96 | /// Got negative return value. This happens if the key or index queried does not exist. 97 | #[error("Negative return value. Likely due to a missing index or key. Got return value: {0}")] 98 | NegativeReturn(i32), 99 | } 100 | 101 | /// Failed to Load context 102 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 103 | pub enum LlamaContextLoadError { 104 | /// llama.cpp returned null 105 | #[error("null reference from llama.cpp")] 106 | NullReturn, 107 | } 108 | 109 | /// Failed to decode a batch. 110 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 111 | pub enum DecodeError { 112 | /// No kv cache slot was available. 113 | #[error("Decode Error 1: NoKvCacheSlot")] 114 | NoKvCacheSlot, 115 | /// The number of tokens in the batch was 0. 116 | #[error("Decode Error -1: n_tokens == 0")] 117 | NTokensZero, 118 | /// An unknown error occurred. 119 | #[error("Decode Error {0}: unknown")] 120 | Unknown(c_int), 121 | } 122 | 123 | /// Failed to decode a batch. 124 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 125 | pub enum EncodeError { 126 | /// No kv cache slot was available. 127 | #[error("Encode Error 1: NoKvCacheSlot")] 128 | NoKvCacheSlot, 129 | /// The number of tokens in the batch was 0. 130 | #[error("Encode Error -1: n_tokens == 0")] 131 | NTokensZero, 132 | /// An unknown error occurred. 133 | #[error("Encode Error {0}: unknown")] 134 | Unknown(c_int), 135 | } 136 | 137 | /// When embedding related functions fail 138 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 139 | pub enum EmbeddingsError { 140 | /// Embeddings weren't enabled in the context options 141 | #[error("Embeddings weren't enabled in the context options")] 142 | NotEnabled, 143 | /// Logits weren't enabled for the given token 144 | #[error("Logits were not enabled for the given token")] 145 | LogitsNotEnabled, 146 | /// The given sequence index exceeds the max sequence id 147 | #[error("Can't use sequence embeddings with a model supporting only LLAMA_POOLING_TYPE_NONE")] 148 | NonePoolType, 149 | } 150 | 151 | /// Decode a error from llama.cpp into a [`DecodeError`]. 152 | impl From for DecodeError { 153 | fn from(value: NonZeroI32) -> Self { 154 | match value.get() { 155 | 1 => DecodeError::NoKvCacheSlot, 156 | -1 => DecodeError::NTokensZero, 157 | i => DecodeError::Unknown(i), 158 | } 159 | } 160 | } 161 | 162 | /// Encode a error from llama.cpp into a [`EncodeError`]. 163 | impl From for EncodeError { 164 | fn from(value: NonZeroI32) -> Self { 165 | match value.get() { 166 | 1 => EncodeError::NoKvCacheSlot, 167 | -1 => EncodeError::NTokensZero, 168 | i => EncodeError::Unknown(i), 169 | } 170 | } 171 | } 172 | 173 | /// An error that can occur when loading a model. 174 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 175 | pub enum LlamaModelLoadError { 176 | /// There was a null byte in a provided string and thus it could not be converted to a C string. 177 | #[error("null byte in string {0}")] 178 | NullError(#[from] NulError), 179 | /// llama.cpp returned a nullptr - this could be many different causes. 180 | #[error("null result from llama cpp")] 181 | NullResult, 182 | /// Failed to convert the path to a rust str. This means the path was not valid unicode 183 | #[error("failed to convert path {0} to str")] 184 | PathToStrError(PathBuf), 185 | } 186 | 187 | /// An error that can occur when loading a model. 188 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 189 | pub enum LlamaLoraAdapterInitError { 190 | /// There was a null byte in a provided string and thus it could not be converted to a C string. 191 | #[error("null byte in string {0}")] 192 | NullError(#[from] NulError), 193 | /// llama.cpp returned a nullptr - this could be many different causes. 194 | #[error("null result from llama cpp")] 195 | NullResult, 196 | /// Failed to convert the path to a rust str. This means the path was not valid unicode 197 | #[error("failed to convert path {0} to str")] 198 | PathToStrError(PathBuf), 199 | } 200 | 201 | /// An error that can occur when loading a model. 202 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 203 | pub enum LlamaLoraAdapterSetError { 204 | /// llama.cpp returned a non-zero error code. 205 | #[error("error code from llama cpp")] 206 | ErrorResult(i32), 207 | } 208 | 209 | /// An error that can occur when loading a model. 210 | #[derive(Debug, Eq, PartialEq, thiserror::Error)] 211 | pub enum LlamaLoraAdapterRemoveError { 212 | /// llama.cpp returned a non-zero error code. 213 | #[error("error code from llama cpp")] 214 | ErrorResult(i32), 215 | } 216 | 217 | /// get the time (in microseconds) according to llama.cpp 218 | /// ``` 219 | /// # use llama_cpp_2::llama_time_us; 220 | /// # use llama_cpp_2::llama_backend::LlamaBackend; 221 | /// let backend = LlamaBackend::init().unwrap(); 222 | /// let time = llama_time_us(); 223 | /// assert!(time > 0); 224 | /// ``` 225 | #[must_use] 226 | pub fn llama_time_us() -> i64 { 227 | unsafe { llama_cpp_sys_2::llama_time_us() } 228 | } 229 | 230 | /// get the max number of devices according to llama.cpp (this is generally cuda devices) 231 | /// ``` 232 | /// # use llama_cpp_2::max_devices; 233 | /// let max_devices = max_devices(); 234 | /// assert!(max_devices >= 0); 235 | /// ``` 236 | #[must_use] 237 | pub fn max_devices() -> usize { 238 | unsafe { llama_cpp_sys_2::llama_max_devices() } 239 | } 240 | 241 | /// is memory mapping supported according to llama.cpp 242 | /// ``` 243 | /// # use llama_cpp_2::mmap_supported; 244 | /// let mmap_supported = mmap_supported(); 245 | /// if mmap_supported { 246 | /// println!("mmap_supported!"); 247 | /// } 248 | /// ``` 249 | #[must_use] 250 | pub fn mmap_supported() -> bool { 251 | unsafe { llama_cpp_sys_2::llama_supports_mmap() } 252 | } 253 | 254 | /// is memory locking supported according to llama.cpp 255 | /// ``` 256 | /// # use llama_cpp_2::mlock_supported; 257 | /// let mlock_supported = mlock_supported(); 258 | /// if mlock_supported { 259 | /// println!("mlock_supported!"); 260 | /// } 261 | /// ``` 262 | #[must_use] 263 | pub fn mlock_supported() -> bool { 264 | unsafe { llama_cpp_sys_2::llama_supports_mlock() } 265 | } 266 | 267 | /// An error that can occur when converting a token to a string. 268 | #[derive(Debug, thiserror::Error, Clone)] 269 | #[non_exhaustive] 270 | pub enum TokenToStringError { 271 | /// the token type was unknown 272 | #[error("Unknown Token Type")] 273 | UnknownTokenType, 274 | /// There was insufficient buffer space to convert the token to a string. 275 | #[error("Insufficient Buffer Space {0}")] 276 | InsufficientBufferSpace(c_int), 277 | /// The token was not valid utf8. 278 | #[error("FromUtf8Error {0}")] 279 | FromUtf8Error(#[from] FromUtf8Error), 280 | } 281 | 282 | /// Failed to convert a string to a token sequence. 283 | #[derive(Debug, thiserror::Error)] 284 | pub enum StringToTokenError { 285 | /// the string contained a null byte and thus could not be converted to a c string. 286 | #[error("{0}")] 287 | NulError(#[from] NulError), 288 | #[error("{0}")] 289 | /// Failed to convert a provided integer to a [`c_int`]. 290 | CIntConversionError(#[from] std::num::TryFromIntError), 291 | } 292 | 293 | /// Failed to apply model chat template. 294 | #[derive(Debug, thiserror::Error)] 295 | pub enum NewLlamaChatMessageError { 296 | /// the string contained a null byte and thus could not be converted to a c string. 297 | #[error("{0}")] 298 | NulError(#[from] NulError), 299 | } 300 | 301 | /// Failed to apply model chat template. 302 | #[derive(Debug, thiserror::Error)] 303 | pub enum ApplyChatTemplateError { 304 | /// the string contained a null byte and thus could not be converted to a c string. 305 | #[error("{0}")] 306 | NulError(#[from] NulError), 307 | /// the string could not be converted to utf8. 308 | #[error("{0}")] 309 | FromUtf8Error(#[from] FromUtf8Error), 310 | } 311 | 312 | /// Get the time in microseconds according to ggml 313 | /// 314 | /// ``` 315 | /// # use std::time::Duration; 316 | /// # use llama_cpp_2::llama_backend::LlamaBackend; 317 | /// let backend = LlamaBackend::init().unwrap(); 318 | /// use llama_cpp_2::ggml_time_us; 319 | /// 320 | /// let start = ggml_time_us(); 321 | /// 322 | /// std::thread::sleep(Duration::from_micros(10)); 323 | /// 324 | /// let end = ggml_time_us(); 325 | /// 326 | /// let elapsed = end - start; 327 | /// 328 | /// assert!(elapsed >= 10) 329 | #[must_use] 330 | pub fn ggml_time_us() -> i64 { 331 | unsafe { llama_cpp_sys_2::ggml_time_us() } 332 | } 333 | 334 | /// checks if mlock is supported 335 | /// 336 | /// ``` 337 | /// # use llama_cpp_2::llama_supports_mlock; 338 | /// 339 | /// if llama_supports_mlock() { 340 | /// println!("mlock is supported!"); 341 | /// } else { 342 | /// println!("mlock is not supported!"); 343 | /// } 344 | /// ``` 345 | #[must_use] 346 | pub fn llama_supports_mlock() -> bool { 347 | unsafe { llama_cpp_sys_2::llama_supports_mlock() } 348 | } 349 | 350 | /// Options to configure how llama.cpp logs are intercepted. 351 | #[derive(Default, Debug, Clone)] 352 | pub struct LogOptions { 353 | disabled: bool, 354 | } 355 | 356 | impl LogOptions { 357 | /// If enabled, logs are sent to tracing. If disabled, all logs are suppressed. Default is for 358 | /// logs to be sent to tracing. 359 | pub fn with_logs_enabled(mut self, enabled: bool) -> Self { 360 | self.disabled = !enabled; 361 | self 362 | } 363 | } 364 | 365 | extern "C" fn logs_to_trace( 366 | level: llama_cpp_sys_2::ggml_log_level, 367 | text: *const ::std::os::raw::c_char, 368 | data: *mut ::std::os::raw::c_void, 369 | ) { 370 | // In the "fast-path" (i.e. the vast majority of logs) we want to avoid needing to take the log state 371 | // lock at all. Similarly, we try to avoid any heap allocations within this function. This is accomplished 372 | // by being a dummy pass-through to tracing in the normal case of DEBUG/INFO/WARN/ERROR logs that are 373 | // newline terminated and limiting the slow-path of locks and/or heap allocations for other cases. 374 | use std::borrow::Borrow; 375 | 376 | let log_state = unsafe { &*(data as *const log::State) }; 377 | 378 | let text = unsafe { std::ffi::CStr::from_ptr(text) }; 379 | let text = text.to_string_lossy(); 380 | let text: &str = text.borrow(); 381 | 382 | if log_state.options.disabled { 383 | return; 384 | } 385 | 386 | // As best I can tell llama.cpp / ggml require all log format strings at call sites to have the '\n'. 387 | // If it's missing, it means that you expect more logs via CONT (or there's a typo in the codebase). To 388 | // distinguish typo from intentional support for CONT, we have to buffer until the next message comes in 389 | // to know how to flush it. 390 | 391 | if level == llama_cpp_sys_2::GGML_LOG_LEVEL_CONT { 392 | log_state.cont_buffered_log(text); 393 | } else if text.ends_with('\n') { 394 | log_state.emit_non_cont_line(level, text); 395 | } else { 396 | log_state.buffer_non_cont(level, text); 397 | } 398 | } 399 | 400 | /// Redirect llama.cpp logs into tracing. 401 | pub fn send_logs_to_tracing(options: LogOptions) { 402 | // TODO: Reinitialize the state to support calling send_logs_to_tracing multiple times. 403 | 404 | // We set up separate log states for llama.cpp and ggml to make sure that CONT logs between the two 405 | // can't possibly interfere with each other. In other words, if llama.cpp emits a log without a trailing 406 | // newline and calls a GGML function, the logs won't be weirdly intermixed and instead we'll llama.cpp logs 407 | // will CONT previous llama.cpp logs and GGML logs will CONT previous ggml logs. 408 | let llama_heap_state = Box::as_ref( 409 | log::LLAMA_STATE 410 | .get_or_init(|| Box::new(log::State::new(log::Module::LlamaCpp, options.clone()))), 411 | ) as *const _; 412 | let ggml_heap_state = Box::as_ref( 413 | log::GGML_STATE.get_or_init(|| Box::new(log::State::new(log::Module::GGML, options))), 414 | ) as *const _; 415 | 416 | unsafe { 417 | // GGML has to be set after llama since setting llama sets ggml as well. 418 | llama_cpp_sys_2::llama_log_set(Some(logs_to_trace), llama_heap_state as *mut _); 419 | llama_cpp_sys_2::ggml_log_set(Some(logs_to_trace), ggml_heap_state as *mut _); 420 | } 421 | } 422 | -------------------------------------------------------------------------------- /llama-cpp-2/src/llama_backend.rs: -------------------------------------------------------------------------------- 1 | //! Representation of an initialized llama backend 2 | 3 | use crate::LLamaCppError; 4 | use llama_cpp_sys_2::ggml_log_level; 5 | use std::sync::atomic::AtomicBool; 6 | use std::sync::atomic::Ordering::SeqCst; 7 | 8 | /// Representation of an initialized llama backend 9 | /// This is required as a parameter for most llama functions as the backend must be initialized 10 | /// before any llama functions are called. This type is proof of initialization. 11 | #[derive(Eq, PartialEq, Debug)] 12 | pub struct LlamaBackend {} 13 | 14 | static LLAMA_BACKEND_INITIALIZED: AtomicBool = AtomicBool::new(false); 15 | 16 | impl LlamaBackend { 17 | /// Mark the llama backend as initialized 18 | fn mark_init() -> crate::Result<()> { 19 | match LLAMA_BACKEND_INITIALIZED.compare_exchange(false, true, SeqCst, SeqCst) { 20 | Ok(_) => Ok(()), 21 | Err(_) => Err(LLamaCppError::BackendAlreadyInitialized), 22 | } 23 | } 24 | 25 | /// Initialize the llama backend (without numa). 26 | /// 27 | /// # Examples 28 | /// 29 | /// ``` 30 | ///# use llama_cpp_2::llama_backend::LlamaBackend; 31 | ///# use llama_cpp_2::LLamaCppError; 32 | ///# use std::error::Error; 33 | /// 34 | ///# fn main() -> Result<(), Box> { 35 | /// 36 | /// 37 | /// let backend = LlamaBackend::init()?; 38 | /// // the llama backend can only be initialized once 39 | /// assert_eq!(Err(LLamaCppError::BackendAlreadyInitialized), LlamaBackend::init()); 40 | /// 41 | ///# Ok(()) 42 | ///# } 43 | /// ``` 44 | #[tracing::instrument(skip_all)] 45 | pub fn init() -> crate::Result { 46 | Self::mark_init()?; 47 | unsafe { llama_cpp_sys_2::llama_backend_init() } 48 | Ok(LlamaBackend {}) 49 | } 50 | 51 | /// Initialize the llama backend (with numa). 52 | /// ``` 53 | ///# use llama_cpp_2::llama_backend::LlamaBackend; 54 | ///# use std::error::Error; 55 | ///# use llama_cpp_2::llama_backend::NumaStrategy; 56 | /// 57 | ///# fn main() -> Result<(), Box> { 58 | /// 59 | /// let llama_backend = LlamaBackend::init_numa(NumaStrategy::MIRROR)?; 60 | /// 61 | ///# Ok(()) 62 | ///# } 63 | /// ``` 64 | #[tracing::instrument(skip_all)] 65 | pub fn init_numa(strategy: NumaStrategy) -> crate::Result { 66 | Self::mark_init()?; 67 | unsafe { 68 | llama_cpp_sys_2::llama_numa_init(llama_cpp_sys_2::ggml_numa_strategy::from(strategy)); 69 | } 70 | Ok(LlamaBackend {}) 71 | } 72 | 73 | /// Was the code built for a GPU backend & is a supported one available. 74 | pub fn supports_gpu_offload(&self) -> bool { 75 | unsafe { llama_cpp_sys_2::llama_supports_gpu_offload() } 76 | } 77 | 78 | /// Does this platform support loading the model via mmap. 79 | pub fn supports_mmap(&self) -> bool { 80 | unsafe { llama_cpp_sys_2::llama_supports_mmap() } 81 | } 82 | 83 | /// Does this platform support locking the model in RAM. 84 | pub fn supports_mlock(&self) -> bool { 85 | unsafe { llama_cpp_sys_2::llama_supports_mlock() } 86 | } 87 | 88 | /// Change the output of llama.cpp's logging to be voided instead of pushed to `stderr`. 89 | pub fn void_logs(&mut self) { 90 | unsafe extern "C" fn void_log( 91 | _level: ggml_log_level, 92 | _text: *const ::std::os::raw::c_char, 93 | _user_data: *mut ::std::os::raw::c_void, 94 | ) { 95 | } 96 | 97 | unsafe { 98 | llama_cpp_sys_2::llama_log_set(Some(void_log), std::ptr::null_mut()); 99 | } 100 | } 101 | } 102 | 103 | /// A rusty wrapper around `numa_strategy`. 104 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] 105 | pub enum NumaStrategy { 106 | /// The numa strategy is disabled. 107 | DISABLED, 108 | /// help wanted: what does this do? 109 | DISTRIBUTE, 110 | /// help wanted: what does this do? 111 | ISOLATE, 112 | /// help wanted: what does this do? 113 | NUMACTL, 114 | /// help wanted: what does this do? 115 | MIRROR, 116 | /// help wanted: what does this do? 117 | COUNT, 118 | } 119 | 120 | /// An invalid numa strategy was provided. 121 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] 122 | pub struct InvalidNumaStrategy( 123 | /// The invalid numa strategy that was provided. 124 | pub llama_cpp_sys_2::ggml_numa_strategy, 125 | ); 126 | 127 | impl TryFrom for NumaStrategy { 128 | type Error = InvalidNumaStrategy; 129 | 130 | fn try_from(value: llama_cpp_sys_2::ggml_numa_strategy) -> Result { 131 | match value { 132 | llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED => Ok(Self::DISABLED), 133 | llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE => Ok(Self::DISTRIBUTE), 134 | llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE => Ok(Self::ISOLATE), 135 | llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL => Ok(Self::NUMACTL), 136 | llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR => Ok(Self::MIRROR), 137 | llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT => Ok(Self::COUNT), 138 | value => Err(InvalidNumaStrategy(value)), 139 | } 140 | } 141 | } 142 | 143 | impl From for llama_cpp_sys_2::ggml_numa_strategy { 144 | fn from(value: NumaStrategy) -> Self { 145 | match value { 146 | NumaStrategy::DISABLED => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED, 147 | NumaStrategy::DISTRIBUTE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE, 148 | NumaStrategy::ISOLATE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE, 149 | NumaStrategy::NUMACTL => llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL, 150 | NumaStrategy::MIRROR => llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR, 151 | NumaStrategy::COUNT => llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT, 152 | } 153 | } 154 | } 155 | 156 | /// Drops the llama backend. 157 | /// ``` 158 | /// 159 | ///# use llama_cpp_2::llama_backend::LlamaBackend; 160 | ///# use std::error::Error; 161 | /// 162 | ///# fn main() -> Result<(), Box> { 163 | /// let backend = LlamaBackend::init()?; 164 | /// drop(backend); 165 | /// // can be initialized again after being dropped 166 | /// let backend = LlamaBackend::init()?; 167 | ///# Ok(()) 168 | ///# } 169 | /// 170 | /// ``` 171 | impl Drop for LlamaBackend { 172 | fn drop(&mut self) { 173 | match LLAMA_BACKEND_INITIALIZED.compare_exchange(true, false, SeqCst, SeqCst) { 174 | Ok(_) => {} 175 | Err(_) => { 176 | unreachable!("This should not be reachable as the only ways to obtain a llama backend involve marking the backend as initialized.") 177 | } 178 | } 179 | unsafe { llama_cpp_sys_2::llama_backend_free() } 180 | } 181 | } 182 | 183 | #[cfg(test)] 184 | mod tests { 185 | use super::*; 186 | 187 | #[test] 188 | fn numa_from_and_to() { 189 | let numas = [ 190 | NumaStrategy::DISABLED, 191 | NumaStrategy::DISTRIBUTE, 192 | NumaStrategy::ISOLATE, 193 | NumaStrategy::NUMACTL, 194 | NumaStrategy::MIRROR, 195 | NumaStrategy::COUNT, 196 | ]; 197 | 198 | for numa in &numas { 199 | let from = llama_cpp_sys_2::ggml_numa_strategy::from(*numa); 200 | let to = NumaStrategy::try_from(from).expect("Failed to convert from and to"); 201 | assert_eq!(*numa, to); 202 | } 203 | } 204 | 205 | #[test] 206 | fn check_invalid_numa() { 207 | let invalid = 800; 208 | let invalid = NumaStrategy::try_from(invalid); 209 | assert_eq!(invalid, Err(InvalidNumaStrategy(invalid.unwrap_err().0))); 210 | } 211 | } 212 | -------------------------------------------------------------------------------- /llama-cpp-2/src/llama_batch.rs: -------------------------------------------------------------------------------- 1 | //! Safe wrapper around `llama_batch`. 2 | 3 | use crate::token::LlamaToken; 4 | use llama_cpp_sys_2::{llama_batch, llama_batch_free, llama_batch_init, llama_pos, llama_seq_id}; 5 | 6 | /// A safe wrapper around `llama_batch`. 7 | #[derive(Debug)] 8 | pub struct LlamaBatch { 9 | /// The number of tokens the batch was allocated with. they are safe to write to - but not necessarily read from as they are not necessarily initialized 10 | allocated: usize, 11 | /// The logits that are initialized. Used by [`LlamaContext`] to ensure that only initialized logits are accessed. 12 | pub(crate) initialized_logits: Vec, 13 | #[allow(clippy::doc_markdown)] 14 | /// The llama_cpp batch. always initialize by `llama_cpp_sys_2::llama_batch_init(allocated, , )` 15 | pub(crate) llama_batch: llama_batch, 16 | } 17 | 18 | /// Errors that can occur when adding a token to a batch. 19 | #[derive(thiserror::Error, Debug, PartialEq, Eq)] 20 | pub enum BatchAddError { 21 | /// There was not enough space in the batch to add the token. 22 | #[error("Insufficient Space of {0}")] 23 | InsufficientSpace(usize), 24 | /// Empty buffer is provided for [`LlamaBatch::get_one`] 25 | #[error("Empty buffer")] 26 | EmptyBuffer, 27 | } 28 | 29 | impl LlamaBatch { 30 | /// Clear the batch. This does not free the memory associated with the batch, but it does reset 31 | /// the number of tokens to 0. 32 | pub fn clear(&mut self) { 33 | self.llama_batch.n_tokens = 0; 34 | self.initialized_logits.clear(); 35 | } 36 | 37 | /// add a token to the batch for sequences `seq_ids` at position `pos`. If `logits` is true, the 38 | /// token will be initialized and can be read from after the next decode. 39 | /// 40 | /// # Panics 41 | /// 42 | /// - [`self.llama_batch.n_tokens`] does not fit into a usize 43 | /// - [`seq_ids.len()`] does not fit into a [`llama_seq_id`] 44 | /// 45 | /// # Errors 46 | /// 47 | /// returns a error if there is insufficient space in the buffer 48 | pub fn add( 49 | &mut self, 50 | LlamaToken(id): LlamaToken, 51 | pos: llama_pos, 52 | seq_ids: &[i32], 53 | logits: bool, 54 | ) -> Result<(), BatchAddError> { 55 | if self.allocated 56 | < usize::try_from(self.n_tokens() + 1).expect("cannot fit n_tokens into a usize") 57 | { 58 | return Err(BatchAddError::InsufficientSpace(self.allocated)); 59 | } 60 | let offset = self.llama_batch.n_tokens; 61 | let offset_usize = usize::try_from(offset).expect("cannot fit n_tokens into a usize"); 62 | unsafe { 63 | // batch.token [batch.n_tokens] = id; 64 | self.llama_batch.token.add(offset_usize).write(id); 65 | // batch.pos [batch.n_tokens] = pos, 66 | self.llama_batch.pos.add(offset_usize).write(pos); 67 | // batch.n_seq_id[batch.n_tokens] = seq_ids.size(); 68 | self.llama_batch.n_seq_id.add(offset_usize).write( 69 | llama_seq_id::try_from(seq_ids.len()) 70 | .expect("cannot fit seq_ids.len() into a llama_seq_id"), 71 | ); 72 | // for (size_t i = 0; i < seq_ids.size(); ++i) { 73 | // batch.seq_id[batch.n_tokens][i] = seq_ids[i]; 74 | // } 75 | for (i, seq_id) in seq_ids.iter().enumerate() { 76 | let tmp = *self.llama_batch.seq_id.add(offset_usize); 77 | tmp.add(i).write(*seq_id); 78 | } 79 | // batch.logits [batch.n_tokens] = logits; 80 | self.llama_batch 81 | .logits 82 | .add(offset_usize) 83 | .write(i8::from(logits)); 84 | } 85 | 86 | if logits { 87 | self.initialized_logits.push(offset); 88 | } else { 89 | self.initialized_logits.retain(|l| l != &offset); 90 | } 91 | 92 | // batch.n_tokens++; 93 | self.llama_batch.n_tokens += 1; 94 | 95 | Ok(()) 96 | } 97 | 98 | /// Add a sequence of tokens to the batch for the given sequence id. If `logits_all` is true, the 99 | /// tokens will be initialized and can be read from after the next decode. 100 | /// 101 | /// Either way the last token in the sequence will have its logits set to `true`. 102 | /// 103 | /// # Errors 104 | /// 105 | /// Returns an error if there is insufficient space in the buffer 106 | /// 107 | /// # Panics 108 | /// 109 | /// - [`self.llama_batch.n_tokens`] does not fit into a [`usize`] 110 | /// - [`n_tokens - 1`] does not fit into a [`llama_pos`] 111 | pub fn add_sequence( 112 | &mut self, 113 | tokens: &[LlamaToken], 114 | seq_id: i32, 115 | logits_all: bool, 116 | ) -> Result<(), BatchAddError> { 117 | let n_tokens_0 = 118 | usize::try_from(self.llama_batch.n_tokens).expect("cannot fit n_tokens into a usize"); 119 | let n_tokens = tokens.len(); 120 | 121 | if self.allocated < n_tokens_0 + n_tokens { 122 | return Err(BatchAddError::InsufficientSpace(self.allocated)); 123 | } 124 | 125 | let last_index = llama_pos::try_from(n_tokens.saturating_sub(1)) 126 | .expect("cannot fit n_tokens into a llama_pos"); 127 | for (i, token) in (0..).zip(tokens.iter()) { 128 | self.add(*token, i, &[seq_id], logits_all || i == last_index)?; 129 | } 130 | 131 | Ok(()) 132 | } 133 | 134 | /// Create a new `LlamaBatch` that can contain up to `n_tokens` tokens. 135 | /// 136 | /// # Arguments 137 | /// 138 | /// - `n_tokens`: the maximum number of tokens that can be added to the batch 139 | /// - `n_seq_max`: the maximum number of sequences that can be added to the batch (generally 1 unless you know what you are doing) 140 | /// 141 | /// # Panics 142 | /// 143 | /// Panics if `n_tokens` is greater than `i32::MAX`. 144 | #[must_use] 145 | pub fn new(n_tokens: usize, n_seq_max: i32) -> Self { 146 | let n_tokens_i32 = i32::try_from(n_tokens).expect("cannot fit n_tokens into a i32"); 147 | let batch = unsafe { llama_batch_init(n_tokens_i32, 0, n_seq_max) }; 148 | 149 | LlamaBatch { 150 | allocated: n_tokens, 151 | initialized_logits: vec![], 152 | llama_batch: batch, 153 | } 154 | } 155 | 156 | /// ``llama_batch_get_one`` 157 | /// Return batch for single sequence of tokens 158 | /// 159 | /// NOTE: this is a helper function to facilitate transition to the new batch API 160 | /// 161 | /// # Errors 162 | /// If the provided token buffer is empty. 163 | /// 164 | /// # Panics 165 | /// If the number of tokens in ``tokens`` exceeds [`i32::MAX`]. 166 | pub fn get_one(tokens: &[LlamaToken]) -> Result { 167 | if tokens.is_empty() { 168 | return Err(BatchAddError::EmptyBuffer); 169 | } 170 | let batch = unsafe { 171 | let ptr = tokens.as_ptr() as *mut i32; 172 | llama_cpp_sys_2::llama_batch_get_one( 173 | ptr, 174 | tokens 175 | .len() 176 | .try_into() 177 | .expect("number of tokens exceeds i32::MAX"), 178 | ) 179 | }; 180 | let batch = Self { 181 | allocated: 0, 182 | initialized_logits: vec![(tokens.len() - 1) 183 | .try_into() 184 | .expect("number of tokens exceeds i32::MAX + 1")], 185 | llama_batch: batch, 186 | }; 187 | Ok(batch) 188 | } 189 | 190 | /// Returns the number of tokens in the batch. 191 | #[must_use] 192 | pub fn n_tokens(&self) -> i32 { 193 | self.llama_batch.n_tokens 194 | } 195 | } 196 | 197 | impl Drop for LlamaBatch { 198 | /// Drops the `LlamaBatch`. 199 | /// 200 | /// ``` 201 | /// # use llama_cpp_2::llama_batch::LlamaBatch; 202 | /// # use std::error::Error; 203 | /// # fn main() -> Result<(), Box> { 204 | /// let batch = LlamaBatch::new(512, 1); 205 | /// // frees the memory associated with the batch. (allocated by llama.cpp) 206 | /// drop(batch); 207 | /// # Ok(()) 208 | /// # } 209 | fn drop(&mut self) { 210 | unsafe { 211 | if self.allocated > 0 { 212 | llama_batch_free(self.llama_batch); 213 | } 214 | } 215 | } 216 | } 217 | -------------------------------------------------------------------------------- /llama-cpp-2/src/log.rs: -------------------------------------------------------------------------------- 1 | use super::LogOptions; 2 | use std::sync::OnceLock; 3 | use tracing_core::{callsite, field, identify_callsite, Interest, Kind, Metadata}; 4 | 5 | static FIELD_NAMES: &[&str] = &["message", "module"]; 6 | 7 | struct OverridableFields { 8 | message: tracing::field::Field, 9 | target: tracing::field::Field, 10 | } 11 | 12 | macro_rules! log_cs { 13 | ($level:expr, $cs:ident, $meta:ident, $fields:ident, $ty:ident) => { 14 | struct $ty; 15 | static $cs: $ty = $ty; 16 | static $meta: Metadata<'static> = Metadata::new( 17 | "log event", 18 | "llama-cpp-2", 19 | $level, 20 | ::core::option::Option::None, 21 | ::core::option::Option::None, 22 | ::core::option::Option::None, 23 | field::FieldSet::new(FIELD_NAMES, identify_callsite!(&$cs)), 24 | Kind::EVENT, 25 | ); 26 | static $fields: std::sync::LazyLock = std::sync::LazyLock::new(|| { 27 | let fields = $meta.fields(); 28 | OverridableFields { 29 | message: fields.field("message").unwrap(), 30 | target: fields.field("module").unwrap(), 31 | } 32 | }); 33 | 34 | impl callsite::Callsite for $ty { 35 | fn set_interest(&self, _: Interest) {} 36 | fn metadata(&self) -> &'static Metadata<'static> { 37 | &$meta 38 | } 39 | } 40 | }; 41 | } 42 | log_cs!( 43 | tracing_core::Level::DEBUG, 44 | DEBUG_CS, 45 | DEBUG_META, 46 | DEBUG_FIELDS, 47 | DebugCallsite 48 | ); 49 | log_cs!( 50 | tracing_core::Level::INFO, 51 | INFO_CS, 52 | INFO_META, 53 | INFO_FIELDS, 54 | InfoCallsite 55 | ); 56 | log_cs!( 57 | tracing_core::Level::WARN, 58 | WARN_CS, 59 | WARN_META, 60 | WARN_FIELDS, 61 | WarnCallsite 62 | ); 63 | log_cs!( 64 | tracing_core::Level::ERROR, 65 | ERROR_CS, 66 | ERROR_META, 67 | ERROR_FIELDS, 68 | ErrorCallsite 69 | ); 70 | 71 | #[derive(Clone, Copy)] 72 | pub(super) enum Module { 73 | GGML, 74 | LlamaCpp, 75 | } 76 | 77 | impl Module { 78 | const fn name(&self) -> &'static str { 79 | match self { 80 | Module::GGML => "ggml", 81 | Module::LlamaCpp => "llama.cpp", 82 | } 83 | } 84 | } 85 | 86 | fn meta_for_level( 87 | level: llama_cpp_sys_2::ggml_log_level, 88 | ) -> (&'static Metadata<'static>, &'static OverridableFields) { 89 | match level { 90 | llama_cpp_sys_2::GGML_LOG_LEVEL_DEBUG => (&DEBUG_META, &DEBUG_FIELDS), 91 | llama_cpp_sys_2::GGML_LOG_LEVEL_INFO => (&INFO_META, &INFO_FIELDS), 92 | llama_cpp_sys_2::GGML_LOG_LEVEL_WARN => (&WARN_META, &WARN_FIELDS), 93 | llama_cpp_sys_2::GGML_LOG_LEVEL_ERROR => (&ERROR_META, &ERROR_FIELDS), 94 | _ => { 95 | unreachable!("Illegal log level to be called here") 96 | } 97 | } 98 | } 99 | 100 | pub(super) struct State { 101 | pub(super) options: LogOptions, 102 | module: Module, 103 | buffered: std::sync::Mutex>, 104 | previous_level: std::sync::atomic::AtomicI32, 105 | is_buffering: std::sync::atomic::AtomicBool, 106 | } 107 | 108 | impl State { 109 | pub(super) fn new(module: Module, options: LogOptions) -> Self { 110 | Self { 111 | options, 112 | module, 113 | buffered: Default::default(), 114 | previous_level: Default::default(), 115 | is_buffering: Default::default(), 116 | } 117 | } 118 | 119 | fn generate_log(target: Module, level: llama_cpp_sys_2::ggml_log_level, text: &str) { 120 | // Annoying but tracing requires that the provided target name is a string literal and 121 | // even &'static str isn't enough so we have to duplicate the generation AND we can't even 122 | // extract the interrior module within llama.cpp/ggml to be able to propagate it forward. 123 | // This happens because the target is part of a static variable injected by the macro that's 124 | // initialized with said target. 125 | 126 | let (module, text) = text 127 | .char_indices() 128 | .take_while(|(_, c)| c.is_ascii_lowercase() || *c == '_') 129 | .last() 130 | .and_then(|(pos, _)| { 131 | let next_two = text.get(pos + 1..pos + 3); 132 | if next_two == Some(": ") { 133 | let (sub_module, text) = text.split_at(pos + 1); 134 | let text = text.split_at(2).1; 135 | Some((Some(format!("{}::{sub_module}", target.name())), text)) 136 | } else { 137 | None 138 | } 139 | }) 140 | .unwrap_or((None, text)); 141 | 142 | let (meta, fields) = meta_for_level(level); 143 | 144 | tracing::dispatcher::get_default(|dispatcher| { 145 | if dispatcher.enabled(meta) { 146 | dispatcher.event(&tracing::Event::new( 147 | meta, 148 | &meta.fields().value_set(&[ 149 | (&fields.message, Some(&text as &dyn tracing::field::Value)), 150 | ( 151 | &fields.target, 152 | module.as_ref().map(|s| s as &dyn tracing::field::Value), 153 | ), 154 | ]), 155 | )); 156 | } 157 | }); 158 | } 159 | 160 | /// Append more text to the previously buffered log. The text may or may not end with a newline. 161 | pub(super) fn cont_buffered_log(&self, text: &str) { 162 | let mut lock = self.buffered.lock().unwrap(); 163 | 164 | if let Some((previous_log_level, mut buffer)) = lock.take() { 165 | buffer.push_str(text); 166 | if buffer.ends_with('\n') { 167 | self.is_buffering 168 | .store(false, std::sync::atomic::Ordering::Release); 169 | Self::generate_log(self.module, previous_log_level, buffer.as_str()); 170 | } else { 171 | *lock = Some((previous_log_level, buffer)); 172 | } 173 | } else { 174 | let level = self 175 | .previous_level 176 | .load(std::sync::atomic::Ordering::Acquire) 177 | as llama_cpp_sys_2::ggml_log_level; 178 | tracing::warn!( 179 | inferred_level = level, 180 | text = text, 181 | origin = "crate", 182 | "llma.cpp sent out a CONT log without any previously buffered message" 183 | ); 184 | *lock = Some((level, text.to_string())); 185 | } 186 | } 187 | 188 | /// Start buffering a message. Not the CONT log level and text is missing a newline. 189 | pub(super) fn buffer_non_cont(&self, level: llama_cpp_sys_2::ggml_log_level, text: &str) { 190 | debug_assert!(!text.ends_with('\n')); 191 | debug_assert_ne!(level, llama_cpp_sys_2::GGML_LOG_LEVEL_CONT); 192 | 193 | if let Some((previous_log_level, buffer)) = self 194 | .buffered 195 | .lock() 196 | .unwrap() 197 | .replace((level, text.to_string())) 198 | { 199 | tracing::warn!( 200 | level = previous_log_level, 201 | text = &buffer, 202 | origin = "crate", 203 | "Message buffered unnnecessarily due to missing newline and not followed by a CONT" 204 | ); 205 | Self::generate_log(self.module, previous_log_level, buffer.as_str()) 206 | } 207 | 208 | self.is_buffering 209 | .store(true, std::sync::atomic::Ordering::Release); 210 | self.previous_level 211 | .store(level as i32, std::sync::atomic::Ordering::Release); 212 | } 213 | 214 | // Emit a normal unbuffered log message (not the CONT log level and the text ends with a newline). 215 | pub(super) fn emit_non_cont_line(&self, level: llama_cpp_sys_2::ggml_log_level, text: &str) { 216 | debug_assert!(text.ends_with('\n')); 217 | debug_assert_ne!(level, llama_cpp_sys_2::GGML_LOG_LEVEL_CONT); 218 | 219 | if self 220 | .is_buffering 221 | .swap(false, std::sync::atomic::Ordering::Acquire) 222 | { 223 | if let Some((buf_level, buf_text)) = self.buffered.lock().unwrap().take() { 224 | // This warning indicates a bug within llama.cpp 225 | tracing::warn!(level = buf_level, text = buf_text, origin = "crate", "llama.cpp message buffered spuriously due to missing \\n and being followed by a non-CONT message!"); 226 | Self::generate_log(self.module, buf_level, buf_text.as_str()); 227 | } 228 | } 229 | 230 | self.previous_level 231 | .store(level as i32, std::sync::atomic::Ordering::Release); 232 | 233 | let (text, newline) = text.split_at(text.len() - 1); 234 | debug_assert_eq!(newline, "\n"); 235 | 236 | match level { 237 | llama_cpp_sys_2::GGML_LOG_LEVEL_NONE => { 238 | // TODO: Support logging this to stdout directly via options? 239 | tracing::info!(no_log_level = true, text); 240 | } 241 | llama_cpp_sys_2::GGML_LOG_LEVEL_DEBUG 242 | | llama_cpp_sys_2::GGML_LOG_LEVEL_INFO 243 | | llama_cpp_sys_2::GGML_LOG_LEVEL_WARN 244 | | llama_cpp_sys_2::GGML_LOG_LEVEL_ERROR => Self::generate_log(self.module, level, text), 245 | llama_cpp_sys_2::GGML_LOG_LEVEL_CONT => unreachable!(), 246 | _ => { 247 | tracing::warn!( 248 | level = level, 249 | text = text, 250 | origin = "crate", 251 | "Unknown llama.cpp log level" 252 | ) 253 | } 254 | } 255 | } 256 | } 257 | 258 | pub(super) static LLAMA_STATE: OnceLock> = OnceLock::new(); 259 | pub(super) static GGML_STATE: OnceLock> = OnceLock::new(); 260 | -------------------------------------------------------------------------------- /llama-cpp-2/src/model/params.rs: -------------------------------------------------------------------------------- 1 | //! A safe wrapper around `llama_model_params`. 2 | 3 | use crate::model::params::kv_overrides::KvOverrides; 4 | use std::ffi::{c_char, CStr}; 5 | use std::fmt::{Debug, Formatter}; 6 | use std::pin::Pin; 7 | use std::ptr::null; 8 | 9 | pub mod kv_overrides; 10 | 11 | /// A safe wrapper around `llama_model_params`. 12 | #[allow(clippy::module_name_repetitions)] 13 | pub struct LlamaModelParams { 14 | pub(crate) params: llama_cpp_sys_2::llama_model_params, 15 | kv_overrides: Vec, 16 | } 17 | 18 | impl Debug for LlamaModelParams { 19 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 20 | f.debug_struct("LlamaModelParams") 21 | .field("n_gpu_layers", &self.params.n_gpu_layers) 22 | .field("main_gpu", &self.params.main_gpu) 23 | .field("vocab_only", &self.params.vocab_only) 24 | .field("use_mmap", &self.params.use_mmap) 25 | .field("use_mlock", &self.params.use_mlock) 26 | .field("kv_overrides", &"vec of kv_overrides") 27 | .finish() 28 | } 29 | } 30 | 31 | impl LlamaModelParams { 32 | /// See [`KvOverrides`] 33 | /// 34 | /// # Examples 35 | /// 36 | /// ```rust 37 | /// # use llama_cpp_2::model::params::LlamaModelParams; 38 | /// let params = Box::pin(LlamaModelParams::default()); 39 | /// let kv_overrides = params.kv_overrides(); 40 | /// let count = kv_overrides.into_iter().count(); 41 | /// assert_eq!(count, 0); 42 | /// ``` 43 | #[must_use] 44 | pub fn kv_overrides(&self) -> KvOverrides { 45 | KvOverrides::new(self) 46 | } 47 | 48 | /// Appends a key-value override to the model parameters. It must be pinned as this creates a self-referential struct. 49 | /// 50 | /// # Examples 51 | /// 52 | /// ```rust 53 | /// # use std::ffi::{CStr, CString}; 54 | /// use std::pin::pin; 55 | /// # use llama_cpp_2::model::params::LlamaModelParams; 56 | /// # use llama_cpp_2::model::params::kv_overrides::ParamOverrideValue; 57 | /// let mut params = pin!(LlamaModelParams::default()); 58 | /// let key = CString::new("key").expect("CString::new failed"); 59 | /// params.as_mut().append_kv_override(&key, ParamOverrideValue::Int(50)); 60 | /// 61 | /// let kv_overrides = params.kv_overrides().into_iter().collect::>(); 62 | /// assert_eq!(kv_overrides.len(), 1); 63 | /// 64 | /// let (k, v) = &kv_overrides[0]; 65 | /// assert_eq!(v, &ParamOverrideValue::Int(50)); 66 | /// 67 | /// assert_eq!(k.to_bytes(), b"key", "expected key to be 'key', was {:?}", k); 68 | /// ``` 69 | #[allow(clippy::missing_panics_doc)] // panics are just to enforce internal invariants, not user errors 70 | pub fn append_kv_override( 71 | mut self: Pin<&mut Self>, 72 | key: &CStr, 73 | value: kv_overrides::ParamOverrideValue, 74 | ) { 75 | let kv_override = self 76 | .kv_overrides 77 | .get_mut(0) 78 | .expect("kv_overrides did not have a next allocated"); 79 | 80 | assert_eq!(kv_override.key[0], 0, "last kv_override was not empty"); 81 | 82 | // There should be some way to do this without iterating over everything. 83 | for (i, &c) in key.to_bytes_with_nul().iter().enumerate() { 84 | kv_override.key[i] = c_char::try_from(c).expect("invalid character in key"); 85 | } 86 | 87 | kv_override.tag = value.tag(); 88 | kv_override.__bindgen_anon_1 = value.value(); 89 | 90 | // set to null pointer for panic safety (as push may move the vector, invalidating the pointer) 91 | self.params.kv_overrides = null(); 92 | 93 | // push the next one to ensure we maintain the iterator invariant of ending with a 0 94 | self.kv_overrides 95 | .push(llama_cpp_sys_2::llama_model_kv_override { 96 | key: [0; 128], 97 | tag: 0, 98 | __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { 99 | val_i64: 0, 100 | }, 101 | }); 102 | 103 | // set the pointer to the (potentially) new vector 104 | self.params.kv_overrides = self.kv_overrides.as_ptr(); 105 | 106 | eprintln!("saved ptr: {:?}", self.params.kv_overrides); 107 | } 108 | } 109 | 110 | impl LlamaModelParams { 111 | /// Get the number of layers to offload to the GPU. 112 | #[must_use] 113 | pub fn n_gpu_layers(&self) -> i32 { 114 | self.params.n_gpu_layers 115 | } 116 | 117 | /// The GPU that is used for scratch and small tensors 118 | #[must_use] 119 | pub fn main_gpu(&self) -> i32 { 120 | self.params.main_gpu 121 | } 122 | 123 | /// only load the vocabulary, no weights 124 | #[must_use] 125 | pub fn vocab_only(&self) -> bool { 126 | self.params.vocab_only 127 | } 128 | 129 | /// use mmap if possible 130 | #[must_use] 131 | pub fn use_mmap(&self) -> bool { 132 | self.params.use_mmap 133 | } 134 | 135 | /// force system to keep model in RAM 136 | #[must_use] 137 | pub fn use_mlock(&self) -> bool { 138 | self.params.use_mlock 139 | } 140 | 141 | /// sets the number of gpu layers to offload to the GPU. 142 | /// ``` 143 | /// # use llama_cpp_2::model::params::LlamaModelParams; 144 | /// let params = LlamaModelParams::default(); 145 | /// let params = params.with_n_gpu_layers(1); 146 | /// assert_eq!(params.n_gpu_layers(), 1); 147 | /// ``` 148 | #[must_use] 149 | pub fn with_n_gpu_layers(mut self, n_gpu_layers: u32) -> Self { 150 | // The only way this conversion can fail is if u32 overflows the i32 - in which case we set 151 | // to MAX 152 | let n_gpu_layers = i32::try_from(n_gpu_layers).unwrap_or(i32::MAX); 153 | self.params.n_gpu_layers = n_gpu_layers; 154 | self 155 | } 156 | 157 | /// sets the main GPU 158 | #[must_use] 159 | pub fn with_main_gpu(mut self, main_gpu: i32) -> Self { 160 | self.params.main_gpu = main_gpu; 161 | self 162 | } 163 | 164 | /// sets `vocab_only` 165 | #[must_use] 166 | pub fn with_vocab_only(mut self, vocab_only: bool) -> Self { 167 | self.params.vocab_only = vocab_only; 168 | self 169 | } 170 | 171 | /// sets `use_mlock` 172 | #[must_use] 173 | pub fn with_use_mlock(mut self, use_mlock: bool) -> Self { 174 | self.params.use_mlock = use_mlock; 175 | self 176 | } 177 | } 178 | 179 | /// Default parameters for `LlamaModel`. (as defined in llama.cpp by `llama_model_default_params`) 180 | /// ``` 181 | /// # use llama_cpp_2::model::params::LlamaModelParams; 182 | /// let params = LlamaModelParams::default(); 183 | /// #[cfg(not(target_os = "macos"))] 184 | /// assert_eq!(params.n_gpu_layers(), 0, "n_gpu_layers should be 0"); 185 | /// #[cfg(target_os = "macos")] 186 | /// assert_eq!(params.n_gpu_layers(), 999, "n_gpu_layers should be 999"); 187 | /// assert_eq!(params.main_gpu(), 0, "main_gpu should be 0"); 188 | /// assert_eq!(params.vocab_only(), false, "vocab_only should be false"); 189 | /// assert_eq!(params.use_mmap(), true, "use_mmap should be true"); 190 | /// assert_eq!(params.use_mlock(), false, "use_mlock should be false"); 191 | /// ``` 192 | impl Default for LlamaModelParams { 193 | fn default() -> Self { 194 | let default_params = unsafe { llama_cpp_sys_2::llama_model_default_params() }; 195 | LlamaModelParams { 196 | params: default_params, 197 | // push the next one to ensure we maintain the iterator invariant of ending with a 0 198 | kv_overrides: vec![llama_cpp_sys_2::llama_model_kv_override { 199 | key: [0; 128], 200 | tag: 0, 201 | __bindgen_anon_1: llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { 202 | val_i64: 0, 203 | }, 204 | }], 205 | } 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /llama-cpp-2/src/model/params/kv_overrides.rs: -------------------------------------------------------------------------------- 1 | //! Key-value overrides for a model. 2 | 3 | use crate::model::params::LlamaModelParams; 4 | use std::ffi::{CStr, CString}; 5 | use std::fmt::Debug; 6 | 7 | /// An override value for a model parameter. 8 | #[derive(Debug, Clone, Copy, PartialEq)] 9 | pub enum ParamOverrideValue { 10 | /// A string value 11 | Bool(bool), 12 | /// A float value 13 | Float(f64), 14 | /// A integer value 15 | Int(i64), 16 | /// A string value 17 | Str([std::os::raw::c_char; 128]), 18 | } 19 | 20 | impl ParamOverrideValue { 21 | pub(crate) fn tag(&self) -> llama_cpp_sys_2::llama_model_kv_override_type { 22 | match self { 23 | ParamOverrideValue::Bool(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL, 24 | ParamOverrideValue::Float(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT, 25 | ParamOverrideValue::Int(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT, 26 | ParamOverrideValue::Str(_) => llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_STR, 27 | } 28 | } 29 | 30 | pub(crate) fn value(&self) -> llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { 31 | match self { 32 | ParamOverrideValue::Bool(value) => { 33 | llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_bool: *value } 34 | } 35 | ParamOverrideValue::Float(value) => { 36 | llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_f64: *value } 37 | } 38 | ParamOverrideValue::Int(value) => { 39 | llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_i64: *value } 40 | } 41 | ParamOverrideValue::Str(c_string) => { 42 | llama_cpp_sys_2::llama_model_kv_override__bindgen_ty_1 { val_str: *c_string } 43 | } 44 | } 45 | } 46 | } 47 | 48 | impl From<&llama_cpp_sys_2::llama_model_kv_override> for ParamOverrideValue { 49 | fn from( 50 | llama_cpp_sys_2::llama_model_kv_override { 51 | key: _, 52 | tag, 53 | __bindgen_anon_1, 54 | }: &llama_cpp_sys_2::llama_model_kv_override, 55 | ) -> Self { 56 | match *tag { 57 | llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_INT => { 58 | ParamOverrideValue::Int(unsafe { __bindgen_anon_1.val_i64 }) 59 | } 60 | llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_FLOAT => { 61 | ParamOverrideValue::Float(unsafe { __bindgen_anon_1.val_f64 }) 62 | } 63 | llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_BOOL => { 64 | ParamOverrideValue::Bool(unsafe { __bindgen_anon_1.val_bool }) 65 | } 66 | llama_cpp_sys_2::LLAMA_KV_OVERRIDE_TYPE_STR => { 67 | ParamOverrideValue::Str(unsafe { __bindgen_anon_1.val_str }) 68 | } 69 | _ => unreachable!("Unknown tag of {tag}"), 70 | } 71 | } 72 | } 73 | 74 | /// A struct implementing [`IntoIterator`] over the key-value overrides for a model. 75 | #[derive(Debug)] 76 | pub struct KvOverrides<'a> { 77 | model_params: &'a LlamaModelParams, 78 | } 79 | 80 | impl KvOverrides<'_> { 81 | pub(super) fn new(model_params: &LlamaModelParams) -> KvOverrides { 82 | KvOverrides { model_params } 83 | } 84 | } 85 | 86 | impl<'a> IntoIterator for KvOverrides<'a> { 87 | // I'm fairly certain this could be written returning by reference, but I'm not sure how to do it safely. I do not 88 | // expect this to be a performance bottleneck so the copy should be fine. (let me know if it's not fine!) 89 | type Item = (CString, ParamOverrideValue); 90 | type IntoIter = KvOverrideValueIterator<'a>; 91 | 92 | fn into_iter(self) -> Self::IntoIter { 93 | KvOverrideValueIterator { 94 | model_params: self.model_params, 95 | current: 0, 96 | } 97 | } 98 | } 99 | 100 | /// An iterator over the key-value overrides for a model. 101 | #[derive(Debug)] 102 | pub struct KvOverrideValueIterator<'a> { 103 | model_params: &'a LlamaModelParams, 104 | current: usize, 105 | } 106 | 107 | impl Iterator for KvOverrideValueIterator<'_> { 108 | type Item = (CString, ParamOverrideValue); 109 | 110 | fn next(&mut self) -> Option { 111 | let overrides = self.model_params.params.kv_overrides; 112 | if overrides.is_null() { 113 | return None; 114 | } 115 | 116 | // SAFETY: llama.cpp seems to guarantee that the last element contains an empty key or is valid. We've checked 117 | // the prev one in the last iteration, the next one should be valid or 0 (and thus safe to deref) 118 | let current = unsafe { *overrides.add(self.current) }; 119 | 120 | if current.key[0] == 0 { 121 | return None; 122 | } 123 | 124 | let value = ParamOverrideValue::from(¤t); 125 | 126 | let key = unsafe { CStr::from_ptr(current.key.as_ptr()).to_owned() }; 127 | 128 | self.current += 1; 129 | Some((key, value)) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /llama-cpp-2/src/timing.rs: -------------------------------------------------------------------------------- 1 | //! Safe wrapper around `llama_timings`. 2 | use std::fmt::{Debug, Display, Formatter}; 3 | 4 | /// A wrapper around `llama_timings`. 5 | #[derive(Clone, Copy, Debug)] 6 | pub struct LlamaTimings { 7 | pub(crate) timings: llama_cpp_sys_2::llama_perf_context_data, 8 | } 9 | 10 | impl LlamaTimings { 11 | /// Create a new `LlamaTimings`. 12 | /// ``` 13 | /// # use llama_cpp_2::timing::LlamaTimings; 14 | /// let timings = LlamaTimings::new(1.0, 2.0, 3.0, 4.0, 5, 6); 15 | /// let timings_str = "load time = 2.00 ms 16 | /// prompt eval time = 3.00 ms / 5 tokens (0.60 ms per token, 1666.67 tokens per second) 17 | /// eval time = 4.00 ms / 6 runs (0.67 ms per token, 1500.00 tokens per second)\n"; 18 | /// assert_eq!(timings_str, format!("{}", timings)); 19 | /// ``` 20 | #[allow(clippy::too_many_arguments)] 21 | #[must_use] 22 | pub fn new( 23 | t_start_ms: f64, 24 | t_load_ms: f64, 25 | t_p_eval_ms: f64, 26 | t_eval_ms: f64, 27 | n_p_eval: i32, 28 | n_eval: i32, 29 | ) -> Self { 30 | Self { 31 | timings: llama_cpp_sys_2::llama_perf_context_data { 32 | t_start_ms, 33 | t_load_ms, 34 | t_p_eval_ms, 35 | t_eval_ms, 36 | n_p_eval, 37 | n_eval, 38 | }, 39 | } 40 | } 41 | 42 | /// Get the start time in milliseconds. 43 | #[must_use] 44 | pub fn t_start_ms(&self) -> f64 { 45 | self.timings.t_start_ms 46 | } 47 | 48 | /// Get the load time in milliseconds. 49 | #[must_use] 50 | pub fn t_load_ms(&self) -> f64 { 51 | self.timings.t_load_ms 52 | } 53 | 54 | /// Get the prompt evaluation time in milliseconds. 55 | #[must_use] 56 | pub fn t_p_eval_ms(&self) -> f64 { 57 | self.timings.t_p_eval_ms 58 | } 59 | 60 | /// Get the evaluation time in milliseconds. 61 | #[must_use] 62 | pub fn t_eval_ms(&self) -> f64 { 63 | self.timings.t_eval_ms 64 | } 65 | 66 | /// Get the number of prompt evaluations. 67 | #[must_use] 68 | pub fn n_p_eval(&self) -> i32 { 69 | self.timings.n_p_eval 70 | } 71 | 72 | /// Get the number of evaluations. 73 | #[must_use] 74 | pub fn n_eval(&self) -> i32 { 75 | self.timings.n_eval 76 | } 77 | 78 | /// Set the start time in milliseconds. 79 | pub fn set_t_start_ms(&mut self, t_start_ms: f64) { 80 | self.timings.t_start_ms = t_start_ms; 81 | } 82 | 83 | /// Set the load time in milliseconds. 84 | pub fn set_t_load_ms(&mut self, t_load_ms: f64) { 85 | self.timings.t_load_ms = t_load_ms; 86 | } 87 | 88 | /// Set the prompt evaluation time in milliseconds. 89 | pub fn set_t_p_eval_ms(&mut self, t_p_eval_ms: f64) { 90 | self.timings.t_p_eval_ms = t_p_eval_ms; 91 | } 92 | 93 | /// Set the evaluation time in milliseconds. 94 | pub fn set_t_eval_ms(&mut self, t_eval_ms: f64) { 95 | self.timings.t_eval_ms = t_eval_ms; 96 | } 97 | 98 | /// Set the number of prompt evaluations. 99 | pub fn set_n_p_eval(&mut self, n_p_eval: i32) { 100 | self.timings.n_p_eval = n_p_eval; 101 | } 102 | 103 | /// Set the number of evaluations. 104 | pub fn set_n_eval(&mut self, n_eval: i32) { 105 | self.timings.n_eval = n_eval; 106 | } 107 | } 108 | 109 | impl Display for LlamaTimings { 110 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 111 | writeln!(f, "load time = {:.2} ms", self.t_load_ms())?; 112 | writeln!( 113 | f, 114 | "prompt eval time = {:.2} ms / {} tokens ({:.2} ms per token, {:.2} tokens per second)", 115 | self.t_p_eval_ms(), 116 | self.n_p_eval(), 117 | self.t_p_eval_ms() / f64::from(self.n_p_eval()), 118 | 1e3 / self.t_p_eval_ms() * f64::from(self.n_p_eval()) 119 | )?; 120 | writeln!( 121 | f, 122 | "eval time = {:.2} ms / {} runs ({:.2} ms per token, {:.2} tokens per second)", 123 | self.t_eval_ms(), 124 | self.n_eval(), 125 | self.t_eval_ms() / f64::from(self.n_eval()), 126 | 1e3 / self.t_eval_ms() * f64::from(self.n_eval()) 127 | )?; 128 | Ok(()) 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /llama-cpp-2/src/token.rs: -------------------------------------------------------------------------------- 1 | //! Safe wrappers around `llama_token_data` and `llama_token_data_array`. 2 | 3 | use std::fmt::Debug; 4 | use std::fmt::Display; 5 | 6 | pub mod data; 7 | pub mod data_array; 8 | pub mod logit_bias; 9 | 10 | /// A safe wrapper for `llama_token`. 11 | #[repr(transparent)] 12 | #[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)] 13 | #[allow(clippy::module_name_repetitions)] 14 | pub struct LlamaToken(pub llama_cpp_sys_2::llama_token); 15 | 16 | impl Display for LlamaToken { 17 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 18 | write!(f, "{}", self.0) 19 | } 20 | } 21 | 22 | impl LlamaToken { 23 | /// Create a new `LlamaToken` from a i32. 24 | /// 25 | /// ``` 26 | /// # use llama_cpp_2::token::LlamaToken; 27 | /// let token = LlamaToken::new(0); 28 | /// assert_eq!(token, LlamaToken(0)); 29 | /// ``` 30 | #[must_use] 31 | pub fn new(token_id: i32) -> Self { 32 | Self(token_id) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /llama-cpp-2/src/token/data.rs: -------------------------------------------------------------------------------- 1 | //! Safe wrapper around `llama_token_data`. 2 | use crate::token::LlamaToken; 3 | 4 | /// A transparent wrapper around `llama_token_data`. 5 | /// 6 | /// Do not rely on `repr(transparent)` for this type. It should be considered an implementation 7 | /// detail and may change across minor versions. 8 | #[derive(Clone, Copy, Debug, PartialEq)] 9 | #[repr(transparent)] 10 | #[allow(clippy::module_name_repetitions)] 11 | pub struct LlamaTokenData { 12 | data: llama_cpp_sys_2::llama_token_data, 13 | } 14 | 15 | impl LlamaTokenData { 16 | /// Create a new token data from a token, logit, and probability. 17 | /// ``` 18 | /// # use llama_cpp_2::token::LlamaToken; 19 | /// # use llama_cpp_2::token::data::LlamaTokenData; 20 | /// let token = LlamaToken::new(1); 21 | /// let token_data = LlamaTokenData::new(token, 1.0, 1.0); 22 | #[must_use] 23 | pub fn new(LlamaToken(id): LlamaToken, logit: f32, p: f32) -> Self { 24 | LlamaTokenData { 25 | data: llama_cpp_sys_2::llama_token_data { id, logit, p }, 26 | } 27 | } 28 | /// Get the token's id 29 | /// ``` 30 | /// # use llama_cpp_2::token::LlamaToken; 31 | /// # use llama_cpp_2::token::data::LlamaTokenData; 32 | /// let token = LlamaToken::new(1); 33 | /// let token_data = LlamaTokenData::new(token, 1.0, 1.0); 34 | /// assert_eq!(token_data.id(), token); 35 | /// ``` 36 | #[must_use] 37 | pub fn id(&self) -> LlamaToken { 38 | LlamaToken(self.data.id) 39 | } 40 | 41 | /// Get the token's logit 42 | /// ``` 43 | /// # use llama_cpp_2::token::LlamaToken; 44 | /// # use llama_cpp_2::token::data::LlamaTokenData; 45 | /// let token = LlamaToken::new(1); 46 | /// let token_data = LlamaTokenData::new(token, 1.0, 1.0); 47 | /// assert_eq!(token_data.logit(), 1.0); 48 | /// ``` 49 | #[must_use] 50 | pub fn logit(&self) -> f32 { 51 | self.data.logit 52 | } 53 | 54 | /// Get the token's probability 55 | /// ``` 56 | /// # use llama_cpp_2::token::LlamaToken; 57 | /// # use llama_cpp_2::token::data::LlamaTokenData; 58 | /// let token = LlamaToken::new(1); 59 | /// let token_data = LlamaTokenData::new(token, 1.0, 1.0); 60 | /// assert_eq!(token_data.p(), 1.0); 61 | /// ``` 62 | #[must_use] 63 | pub fn p(&self) -> f32 { 64 | self.data.p 65 | } 66 | 67 | /// Set the token's id 68 | /// ``` 69 | /// # use llama_cpp_2::token::LlamaToken; 70 | /// # use llama_cpp_2::token::data::LlamaTokenData; 71 | /// let token = LlamaToken::new(1); 72 | /// let mut token_data = LlamaTokenData::new(token, 1.0, 1.0); 73 | /// token_data.set_id(LlamaToken::new(2)); 74 | /// assert_eq!(token_data.id(), LlamaToken::new(2)); 75 | /// ``` 76 | pub fn set_id(&mut self, id: LlamaToken) { 77 | self.data.id = id.0; 78 | } 79 | 80 | /// Set the token's logit 81 | /// ``` 82 | /// # use llama_cpp_2::token::LlamaToken; 83 | /// # use llama_cpp_2::token::data::LlamaTokenData; 84 | /// let token = LlamaToken::new(1); 85 | /// let mut token_data = LlamaTokenData::new(token, 1.0, 1.0); 86 | /// token_data.set_logit(2.0); 87 | /// assert_eq!(token_data.logit(), 2.0); 88 | /// ``` 89 | pub fn set_logit(&mut self, logit: f32) { 90 | self.data.logit = logit; 91 | } 92 | 93 | /// Set the token's probability 94 | /// ``` 95 | /// # use llama_cpp_2::token::LlamaToken; 96 | /// # use llama_cpp_2::token::data::LlamaTokenData; 97 | /// let token = LlamaToken::new(1); 98 | /// let mut token_data = LlamaTokenData::new(token, 1.0, 1.0); 99 | /// token_data.set_p(2.0); 100 | /// assert_eq!(token_data.p(), 2.0); 101 | /// ``` 102 | pub fn set_p(&mut self, p: f32) { 103 | self.data.p = p; 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /llama-cpp-2/src/token/data_array.rs: -------------------------------------------------------------------------------- 1 | //! an rusty equivalent of `llama_token_data_array`. 2 | use std::ptr; 3 | 4 | use crate::{sampling::LlamaSampler, token::data::LlamaTokenData}; 5 | 6 | use super::LlamaToken; 7 | 8 | /// a safe wrapper around `llama_token_data_array`. 9 | #[derive(Debug, Clone, PartialEq)] 10 | #[allow(clippy::module_name_repetitions)] 11 | pub struct LlamaTokenDataArray { 12 | /// the underlying data 13 | pub data: Vec, 14 | /// the index of the selected token in ``data`` 15 | pub selected: Option, 16 | /// is the data sorted? 17 | pub sorted: bool, 18 | } 19 | 20 | impl LlamaTokenDataArray { 21 | /// Create a new `LlamaTokenDataArray` from a vector and whether or not the data is sorted. 22 | /// 23 | /// ``` 24 | /// # use llama_cpp_2::token::data::LlamaTokenData; 25 | /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; 26 | /// # use llama_cpp_2::token::LlamaToken; 27 | /// let array = LlamaTokenDataArray::new(vec![ 28 | /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0), 29 | /// LlamaTokenData::new(LlamaToken(1), 0.1, 0.1) 30 | /// ], false); 31 | /// assert_eq!(array.data.len(), 2); 32 | /// assert_eq!(array.sorted, false); 33 | /// ``` 34 | #[must_use] 35 | pub fn new(data: Vec, sorted: bool) -> Self { 36 | Self { 37 | data, 38 | selected: None, 39 | sorted, 40 | } 41 | } 42 | 43 | /// Create a new `LlamaTokenDataArray` from an iterator and whether or not the data is sorted. 44 | /// ``` 45 | /// # use llama_cpp_2::token::data::LlamaTokenData; 46 | /// # use llama_cpp_2::token::data_array::LlamaTokenDataArray; 47 | /// # use llama_cpp_2::token::LlamaToken; 48 | /// let array = LlamaTokenDataArray::from_iter([ 49 | /// LlamaTokenData::new(LlamaToken(0), 0.0, 0.0), 50 | /// LlamaTokenData::new(LlamaToken(1), 0.1, 0.1) 51 | /// ], false); 52 | /// assert_eq!(array.data.len(), 2); 53 | /// assert_eq!(array.sorted, false); 54 | pub fn from_iter(data: T, sorted: bool) -> LlamaTokenDataArray 55 | where 56 | T: IntoIterator, 57 | { 58 | Self::new(data.into_iter().collect(), sorted) 59 | } 60 | 61 | /// Returns the current selected token, if one exists. 62 | #[must_use] 63 | pub fn selected_token(&self) -> Option { 64 | self.data.get(self.selected?).map(LlamaTokenData::id) 65 | } 66 | } 67 | 68 | impl LlamaTokenDataArray { 69 | /// Modify the underlying data as a `llama_token_data_array`. and reconstruct the `LlamaTokenDataArray`. 70 | /// 71 | /// # Panics 72 | /// 73 | /// Panics if some of the safety conditions are not met. (we cannot check all of them at 74 | /// runtime so breaking them is UB) 75 | /// 76 | /// SAFETY: 77 | /// The returned array formed by the data pointer and the length must entirely consist of 78 | /// initialized token data and the length must be less than the capacity of this array's data 79 | /// buffer. 80 | /// if the data is not sorted, sorted must be false. 81 | pub(crate) unsafe fn modify_as_c_llama_token_data_array( 82 | &mut self, 83 | modify: impl FnOnce(&mut llama_cpp_sys_2::llama_token_data_array) -> T, 84 | ) -> T { 85 | let size = self.data.len(); 86 | let data = self 87 | .data 88 | .as_mut_ptr() 89 | .cast::(); 90 | 91 | let mut c_llama_token_data_array = llama_cpp_sys_2::llama_token_data_array { 92 | data, 93 | size, 94 | selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1), 95 | sorted: self.sorted, 96 | }; 97 | 98 | let result = modify(&mut c_llama_token_data_array); 99 | 100 | assert!( 101 | c_llama_token_data_array.size <= self.data.capacity(), 102 | "Size of the returned array exceeds the data buffer's capacity!" 103 | ); 104 | if !ptr::eq(c_llama_token_data_array.data, data) { 105 | ptr::copy( 106 | c_llama_token_data_array.data, 107 | data, 108 | c_llama_token_data_array.size, 109 | ); 110 | } 111 | self.data.set_len(c_llama_token_data_array.size); 112 | 113 | self.sorted = c_llama_token_data_array.sorted; 114 | self.selected = c_llama_token_data_array 115 | .selected 116 | .try_into() 117 | .ok() 118 | .filter(|&s| s < self.data.len()); 119 | 120 | result 121 | } 122 | 123 | /// Modifies the data array by applying a sampler to it 124 | pub fn apply_sampler(&mut self, sampler: &LlamaSampler) { 125 | unsafe { 126 | self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| { 127 | llama_cpp_sys_2::llama_sampler_apply(sampler.sampler, c_llama_token_data_array); 128 | }); 129 | } 130 | } 131 | 132 | /// Modifies the data array by applying a sampler to it 133 | #[must_use] 134 | pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self { 135 | self.apply_sampler(sampler); 136 | self 137 | } 138 | 139 | /// Randomly selects a token from the candidates based on their probabilities. 140 | /// 141 | /// # Panics 142 | /// If the internal llama.cpp sampler fails to select a token. 143 | pub fn sample_token(&mut self, seed: u32) -> LlamaToken { 144 | self.apply_sampler(&LlamaSampler::dist(seed)); 145 | self.selected_token() 146 | .expect("Dist sampler failed to select a token!") 147 | } 148 | 149 | /// Selects the token with the highest probability. 150 | /// 151 | /// # Panics 152 | /// If the internal llama.cpp sampler fails to select a token. 153 | pub fn sample_token_greedy(&mut self) -> LlamaToken { 154 | self.apply_sampler(&LlamaSampler::greedy()); 155 | self.selected_token() 156 | .expect("Greedy sampler failed to select a token!") 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /llama-cpp-2/src/token/logit_bias.rs: -------------------------------------------------------------------------------- 1 | //! Safe wrapper around `llama_logit_bias`. 2 | use crate::token::LlamaToken; 3 | 4 | /// A transparent wrapper around `llama_logit_bias`. 5 | /// 6 | /// Represents a bias to be applied to a specific token during text generation. 7 | /// The bias modifies the likelihood of the token being selected. 8 | /// 9 | /// Do not rely on `repr(transparent)` for this type. It should be considered an implementation 10 | /// detail and may change across minor versions. 11 | #[derive(Clone, Copy, Debug, PartialEq)] 12 | #[repr(transparent)] 13 | #[allow(clippy::module_name_repetitions)] 14 | pub struct LlamaLogitBias { 15 | logit_bias: llama_cpp_sys_2::llama_logit_bias, 16 | } 17 | 18 | impl LlamaLogitBias { 19 | /// Creates a new logit bias for a specific token with the given bias value. 20 | /// 21 | /// # Examples 22 | /// ``` 23 | /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; 24 | /// let token = LlamaToken::new(1); 25 | /// let bias = LlamaLogitBias::new(token, 1.5); 26 | /// ``` 27 | #[must_use] 28 | pub fn new(LlamaToken(token): LlamaToken, bias: f32) -> Self { 29 | Self { 30 | logit_bias: llama_cpp_sys_2::llama_logit_bias { 31 | token, 32 | bias, 33 | }, 34 | } 35 | } 36 | 37 | /// Gets the token this bias applies to. 38 | /// 39 | /// # Examples 40 | /// ``` 41 | /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; 42 | /// let token = LlamaToken::new(1); 43 | /// let bias = LlamaLogitBias::new(token, 1.5); 44 | /// assert_eq!(bias.token(), token); 45 | /// ``` 46 | #[must_use] 47 | pub fn token(&self) -> LlamaToken { 48 | LlamaToken(self.logit_bias.token) 49 | } 50 | 51 | /// Gets the bias value. 52 | /// 53 | /// # Examples 54 | /// ``` 55 | /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; 56 | /// let token = LlamaToken::new(1); 57 | /// let bias = LlamaLogitBias::new(token, 1.5); 58 | /// assert_eq!(bias.bias(), 1.5); 59 | /// ``` 60 | #[must_use] 61 | pub fn bias(&self) -> f32 { 62 | self.logit_bias.bias 63 | } 64 | 65 | /// Sets the token this bias applies to. 66 | /// 67 | /// # Examples 68 | /// ``` 69 | /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; 70 | /// let token = LlamaToken::new(1); 71 | /// let mut bias = LlamaLogitBias::new(token, 1.5); 72 | /// let new_token = LlamaToken::new(2); 73 | /// bias.set_token(new_token); 74 | /// assert_eq!(bias.token(), new_token); 75 | /// ``` 76 | pub fn set_token(&mut self, token: LlamaToken) { 77 | self.logit_bias.token = token.0; 78 | } 79 | 80 | /// Sets the bias value. 81 | /// 82 | /// # Examples 83 | /// ``` 84 | /// # use llama_cpp_2::token::{LlamaToken, logit_bias::LlamaLogitBias}; 85 | /// let token = LlamaToken::new(1); 86 | /// let mut bias = LlamaLogitBias::new(token, 1.5); 87 | /// bias.set_bias(2.0); 88 | /// assert_eq!(bias.bias(), 2.0); 89 | /// ``` 90 | pub fn set_bias(&mut self, bias: f32) { 91 | self.logit_bias.bias = bias; 92 | } 93 | } -------------------------------------------------------------------------------- /llama-cpp-2/src/token_type.rs: -------------------------------------------------------------------------------- 1 | //! Utilities for working with `llama_token_type` values. 2 | use enumflags2::{bitflags, BitFlags}; 3 | use std::ops::{Deref, DerefMut}; 4 | 5 | /// A rust flavored equivalent of `llama_token_type`. 6 | #[derive(Eq, PartialEq, Debug, Clone, Copy)] 7 | #[bitflags] 8 | #[repr(u32)] 9 | #[allow(clippy::module_name_repetitions, missing_docs)] 10 | pub enum LlamaTokenAttr { 11 | Unknown = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_UNKNOWN as _, 12 | Unused = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_UNUSED as _, 13 | Normal = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_NORMAL as _, 14 | Control = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_CONTROL as _, 15 | UserDefined = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_USER_DEFINED as _, 16 | Byte = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_BYTE as _, 17 | Normalized = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_NORMALIZED as _, 18 | LStrip = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_LSTRIP as _, 19 | RStrip = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_RSTRIP as _, 20 | SingleWord = llama_cpp_sys_2::LLAMA_TOKEN_ATTR_SINGLE_WORD as _, 21 | } 22 | 23 | /// A set of `LlamaTokenAttrs` 24 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 25 | pub struct LlamaTokenAttrs(pub BitFlags); 26 | 27 | impl Deref for LlamaTokenAttrs { 28 | type Target = BitFlags; 29 | 30 | fn deref(&self) -> &Self::Target { 31 | &self.0 32 | } 33 | } 34 | 35 | impl DerefMut for LlamaTokenAttrs { 36 | fn deref_mut(&mut self) -> &mut Self::Target { 37 | &mut self.0 38 | } 39 | } 40 | 41 | impl TryFrom for LlamaTokenAttrs { 42 | type Error = LlamaTokenTypeFromIntError; 43 | 44 | fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result { 45 | Ok(Self(BitFlags::from_bits(value as _).map_err(|e| { 46 | LlamaTokenTypeFromIntError::UnknownValue(e.invalid_bits()) 47 | })?)) 48 | } 49 | } 50 | 51 | /// An error type for `LlamaTokenType::try_from`. 52 | #[derive(thiserror::Error, Debug, Eq, PartialEq)] 53 | pub enum LlamaTokenTypeFromIntError { 54 | /// The value is not a valid `llama_token_type`. 55 | #[error("Unknown Value {0}")] 56 | UnknownValue(std::ffi::c_uint), 57 | } 58 | -------------------------------------------------------------------------------- /llama-cpp-sys-2/Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "llama-cpp-sys" 7 | version = "0.1.0" 8 | -------------------------------------------------------------------------------- /llama-cpp-sys-2/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llama-cpp-sys-2" 3 | description = "Low Level Bindings to llama.cpp" 4 | version = "0.1.109" 5 | edition = "2021" 6 | license = "MIT OR Apache-2.0" 7 | repository = "https://github.com/utilityai/llama-cpp-rs" 8 | links = "llama" 9 | 10 | include = [ 11 | "wrapper.h", 12 | "build.rs", 13 | "/src", 14 | 15 | "/llama.cpp/common/**/*.h", 16 | "/llama.cpp/common/**/*.hpp", 17 | "/llama.cpp/common/**/*.cpp", 18 | "/llama.cpp/ggml/include/*.h", 19 | "/llama.cpp/ggml/src/*.h", 20 | "/llama.cpp/ggml/src/*.c", 21 | "/llama.cpp/ggml/src/*.cpp", 22 | "/llama.cpp/src/*.h", 23 | "/llama.cpp/src/*.cpp", 24 | 25 | "/llama.cpp/convert_hf_to_gguf.py", # Yes, it's required 26 | 27 | # Erroneously the llama.cpp code currently generates the build-info.cpp 28 | # into the source directory of the build instead of into the target directory 29 | # as it should. Will try submitting something upstream to clean this up as 30 | # well but for now explictly exclude this from the build. Previously this was 31 | # implicitly excluded because the llama.cpp code was copied wholesale into the 32 | # target directory for building which is why this problem wasn't visible before 33 | # (i.e. we'd package the llama.cpp source from the submodule & thus this build-info.cpp 34 | # generated file would still be ignored because it would only exist in the separate 35 | # copy within the target directory. An alternative, if we do want to capture build-info.cpp 36 | # within the package would be to change the CI task to add `--allow-dirty` to the package 37 | # command. 38 | "!/llama.cpp/common/build-info.cpp", 39 | "/llama.cpp/common/build-info.cpp.in", 40 | 41 | "/llama.cpp/ggml/src/ggml-cuda.cu", 42 | "/llama.cpp/ggml/src/ggml-metal.m", 43 | "/llama.cpp/ggml/src/ggml-metal.metal", 44 | 45 | "/llama.cpp/include/llama.h", 46 | "/llama.cpp/include/llama-cpp.h", 47 | 48 | "/llama.cpp/ggml/src/ggml-cpu/**/*", 49 | "/llama.cpp/ggml/src/ggml-cuda/**/*", 50 | "/llama.cpp/ggml/src/ggml-metal/**/*", 51 | "/llama.cpp/ggml/src/ggml-vulkan/**/*", 52 | 53 | "/llama.cpp/ggml/src/llamafile/sgemm.h", 54 | "/llama.cpp/ggml/src/llamafile/sgemm.cpp", 55 | 56 | "/llama.cpp/pocs", 57 | 58 | "/llama.cpp/CMakeLists.txt", 59 | "/llama.cpp/common/CMakeLists.txt", 60 | "/llama.cpp/ggml/CMakeLists.txt", 61 | "/llama.cpp/ggml/src/CMakeLists.txt", 62 | "/llama.cpp/src/CMakeLists.txt", 63 | 64 | "/llama.cpp/cmake", 65 | "/llama.cpp/ggml/cmake", 66 | "/llama.cpp/common/cmake", 67 | ] 68 | 69 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 70 | 71 | [dependencies] 72 | 73 | [build-dependencies] 74 | bindgen = { workspace = true } 75 | cc = { workspace = true, features = ["parallel"] } 76 | cmake = "0.1" 77 | find_cuda_helper = "0.2.0" 78 | glob = "0.3.2" 79 | walkdir = "2" 80 | 81 | [features] 82 | cuda = [] 83 | # Disables the need to dynamically link against libcuda.so / cuda.dll 84 | cuda-no-vmm = ["cuda"] 85 | metal = [] 86 | dynamic-link = [] 87 | vulkan = [] 88 | native = [] 89 | openmp = [] 90 | # Only has an impact on Android. 91 | shared-stdcxx = [] 92 | -------------------------------------------------------------------------------- /llama-cpp-sys-2/README.md: -------------------------------------------------------------------------------- 1 | # llama-cpp-sys 2 | 3 | Raw bindings to llama.cpp with cuda support. 4 | 5 | See [llama-cpp-2](https://crates.io/crates/llama-cpp-2) for a safe API. 6 | -------------------------------------------------------------------------------- /llama-cpp-sys-2/build.rs: -------------------------------------------------------------------------------- 1 | use cmake::Config; 2 | use glob::glob; 3 | use std::env; 4 | use std::path::{Path, PathBuf}; 5 | use std::process::Command; 6 | use walkdir::DirEntry; 7 | 8 | enum WindowsVariant { 9 | Msvc, 10 | Other, 11 | } 12 | 13 | enum AppleVariant { 14 | MacOS, 15 | Other, 16 | } 17 | 18 | enum TargetOs { 19 | Windows(WindowsVariant), 20 | Apple(AppleVariant), 21 | Linux, 22 | Android, 23 | } 24 | 25 | macro_rules! debug_log { 26 | ($($arg:tt)*) => { 27 | if std::env::var("BUILD_DEBUG").is_ok() { 28 | println!("cargo:warning=[DEBUG] {}", format!($($arg)*)); 29 | } 30 | }; 31 | } 32 | 33 | fn parse_target_os() -> Result<(TargetOs, String), String> { 34 | let target = env::var("TARGET").unwrap(); 35 | 36 | if target.contains("windows") { 37 | if target.ends_with("-windows-msvc") { 38 | Ok((TargetOs::Windows(WindowsVariant::Msvc), target)) 39 | } else { 40 | Ok((TargetOs::Windows(WindowsVariant::Other), target)) 41 | } 42 | } else if target.contains("apple") { 43 | if target.ends_with("-apple-darwin") { 44 | Ok((TargetOs::Apple(AppleVariant::MacOS), target)) 45 | } else { 46 | Ok((TargetOs::Apple(AppleVariant::Other), target)) 47 | } 48 | } else if target.contains("android") { 49 | Ok((TargetOs::Android, target)) 50 | } else if target.contains("linux") { 51 | Ok((TargetOs::Linux, target)) 52 | } else { 53 | Err(target) 54 | } 55 | } 56 | 57 | fn get_cargo_target_dir() -> Result> { 58 | let out_dir = env::var("OUT_DIR")?; 59 | let path = PathBuf::from(out_dir); 60 | let target_dir = path 61 | .ancestors() 62 | .nth(3) 63 | .ok_or("OUT_DIR is not deep enough")?; 64 | Ok(target_dir.to_path_buf()) 65 | } 66 | 67 | fn extract_lib_names(out_dir: &Path, build_shared_libs: bool) -> Vec { 68 | let lib_pattern = if cfg!(windows) { 69 | "*.lib" 70 | } else if cfg!(target_os = "macos") { 71 | if build_shared_libs { 72 | "*.dylib" 73 | } else { 74 | "*.a" 75 | } 76 | } else if build_shared_libs { 77 | "*.so" 78 | } else { 79 | "*.a" 80 | }; 81 | let libs_dir = out_dir.join("lib*"); 82 | let pattern = libs_dir.join(lib_pattern); 83 | debug_log!("Extract libs {}", pattern.display()); 84 | 85 | let mut lib_names: Vec = Vec::new(); 86 | 87 | // Process the libraries based on the pattern 88 | for entry in glob(pattern.to_str().unwrap()).unwrap() { 89 | match entry { 90 | Ok(path) => { 91 | let stem = path.file_stem().unwrap(); 92 | let stem_str = stem.to_str().unwrap(); 93 | 94 | // Remove the "lib" prefix if present 95 | let lib_name = if stem_str.starts_with("lib") { 96 | stem_str.strip_prefix("lib").unwrap_or(stem_str) 97 | } else { 98 | if path.extension() == Some(std::ffi::OsStr::new("a")) { 99 | let target = path.parent().unwrap().join(format!("lib{}.a", stem_str)); 100 | std::fs::rename(&path, &target).unwrap_or_else(|e| { 101 | panic!("Failed to rename {path:?} to {target:?}: {e:?}"); 102 | }) 103 | } 104 | stem_str 105 | }; 106 | lib_names.push(lib_name.to_string()); 107 | } 108 | Err(e) => println!("cargo:warning=error={}", e), 109 | } 110 | } 111 | lib_names 112 | } 113 | 114 | fn extract_lib_assets(out_dir: &Path) -> Vec { 115 | let shared_lib_pattern = if cfg!(windows) { 116 | "*.dll" 117 | } else if cfg!(target_os = "macos") { 118 | "*.dylib" 119 | } else { 120 | "*.so" 121 | }; 122 | 123 | let shared_libs_dir = if cfg!(windows) { "bin" } else { "lib" }; 124 | let libs_dir = out_dir.join(shared_libs_dir); 125 | let pattern = libs_dir.join(shared_lib_pattern); 126 | debug_log!("Extract lib assets {}", pattern.display()); 127 | let mut files = Vec::new(); 128 | 129 | for entry in glob(pattern.to_str().unwrap()).unwrap() { 130 | match entry { 131 | Ok(path) => { 132 | files.push(path); 133 | } 134 | Err(e) => eprintln!("cargo:warning=error={}", e), 135 | } 136 | } 137 | 138 | files 139 | } 140 | 141 | fn macos_link_search_path() -> Option { 142 | let output = Command::new("clang") 143 | .arg("--print-search-dirs") 144 | .output() 145 | .ok()?; 146 | if !output.status.success() { 147 | println!( 148 | "failed to run 'clang --print-search-dirs', continuing without a link search path" 149 | ); 150 | return None; 151 | } 152 | 153 | let stdout = String::from_utf8_lossy(&output.stdout); 154 | for line in stdout.lines() { 155 | if line.contains("libraries: =") { 156 | let path = line.split('=').nth(1)?; 157 | return Some(format!("{}/lib/darwin", path)); 158 | } 159 | } 160 | 161 | println!("failed to determine link search path, continuing without it"); 162 | None 163 | } 164 | 165 | fn is_hidden(e: &DirEntry) -> bool { 166 | e.file_name() 167 | .to_str() 168 | .map(|s| s.starts_with('.')) 169 | .unwrap_or_default() 170 | } 171 | 172 | fn main() { 173 | println!("cargo:rerun-if-changed=build.rs"); 174 | 175 | let (target_os, target_triple) = 176 | parse_target_os().unwrap_or_else(|t| panic!("Failed to parse target os {t}")); 177 | let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); 178 | 179 | let target_dir = get_cargo_target_dir().unwrap(); 180 | let manifest_dir = env::var("CARGO_MANIFEST_DIR").expect("Failed to get CARGO_MANIFEST_DIR"); 181 | let llama_src = Path::new(&manifest_dir).join("llama.cpp"); 182 | let build_shared_libs = cfg!(feature = "dynamic-link"); 183 | 184 | let build_shared_libs = std::env::var("LLAMA_BUILD_SHARED_LIBS") 185 | .map(|v| v == "1") 186 | .unwrap_or(build_shared_libs); 187 | let profile = env::var("LLAMA_LIB_PROFILE").unwrap_or("Release".to_string()); 188 | let static_crt = env::var("LLAMA_STATIC_CRT") 189 | .map(|v| v == "1") 190 | .unwrap_or(false); 191 | 192 | println!("cargo:rerun-if-env-changed=LLAMA_LIB_PROFILE"); 193 | println!("cargo:rerun-if-env-changed=LLAMA_BUILD_SHARED_LIBS"); 194 | println!("cargo:rerun-if-env-changed=LLAMA_STATIC_CRT"); 195 | 196 | debug_log!("TARGET: {}", target_triple); 197 | debug_log!("CARGO_MANIFEST_DIR: {}", manifest_dir); 198 | debug_log!("TARGET_DIR: {}", target_dir.display()); 199 | debug_log!("OUT_DIR: {}", out_dir.display()); 200 | debug_log!("BUILD_SHARED: {}", build_shared_libs); 201 | 202 | // Make sure that changes to the llama.cpp project trigger a rebuild. 203 | let rebuild_on_children_of = [ 204 | llama_src.join("src"), 205 | llama_src.join("ggml/src"), 206 | llama_src.join("common"), 207 | ]; 208 | for entry in walkdir::WalkDir::new(&llama_src) 209 | .into_iter() 210 | .filter_entry(|e| !is_hidden(e)) 211 | { 212 | let entry = entry.expect("Failed to obtain entry"); 213 | let rebuild = entry 214 | .file_name() 215 | .to_str() 216 | .map(|f| f.starts_with("CMake")) 217 | .unwrap_or_default() 218 | || rebuild_on_children_of 219 | .iter() 220 | .any(|src_folder| entry.path().starts_with(src_folder)); 221 | if rebuild { 222 | println!("cargo:rerun-if-changed={}", entry.path().display()); 223 | } 224 | } 225 | 226 | // Speed up build 227 | env::set_var( 228 | "CMAKE_BUILD_PARALLEL_LEVEL", 229 | std::thread::available_parallelism() 230 | .unwrap() 231 | .get() 232 | .to_string(), 233 | ); 234 | 235 | // Bindings 236 | let bindings = bindgen::Builder::default() 237 | .header("wrapper.h") 238 | .clang_arg(format!("-I{}", llama_src.join("include").display())) 239 | .clang_arg(format!("-I{}", llama_src.join("ggml/include").display())) 240 | .clang_arg(format!("--target={}", target_triple)) 241 | .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) 242 | .derive_partialeq(true) 243 | .allowlist_function("ggml_.*") 244 | .allowlist_type("ggml_.*") 245 | .allowlist_function("llama_.*") 246 | .allowlist_type("llama_.*") 247 | .prepend_enum_name(false) 248 | .generate() 249 | .expect("Failed to generate bindings"); 250 | 251 | // Write the generated bindings to an output file 252 | let bindings_path = out_dir.join("bindings.rs"); 253 | bindings 254 | .write_to_file(bindings_path) 255 | .expect("Failed to write bindings"); 256 | 257 | println!("cargo:rerun-if-changed=wrapper.h"); 258 | 259 | debug_log!("Bindings Created"); 260 | 261 | // Build with Cmake 262 | 263 | let mut config = Config::new(&llama_src); 264 | 265 | // Would require extra source files to pointlessly 266 | // be included in what's uploaded to and downloaded from 267 | // crates.io, so deactivating these instead 268 | config.define("LLAMA_BUILD_TESTS", "OFF"); 269 | config.define("LLAMA_BUILD_EXAMPLES", "OFF"); 270 | config.define("LLAMA_BUILD_SERVER", "OFF"); 271 | config.define("LLAMA_BUILD_TOOLS", "OFF"); 272 | config.define("LLAMA_CURL", "OFF"); 273 | 274 | config.define( 275 | "BUILD_SHARED_LIBS", 276 | if build_shared_libs { "ON" } else { "OFF" }, 277 | ); 278 | 279 | if matches!(target_os, TargetOs::Apple(_)) { 280 | config.define("GGML_BLAS", "OFF"); 281 | } 282 | 283 | if (matches!(target_os, TargetOs::Windows(WindowsVariant::Msvc)) 284 | && matches!( 285 | profile.as_str(), 286 | "Release" | "RelWithDebInfo" | "MinSizeRel" 287 | )) 288 | { 289 | // Debug Rust builds under MSVC turn off optimization even though we're ideally building the release profile of llama.cpp. 290 | // Looks like an upstream bug: 291 | // https://github.com/rust-lang/cmake-rs/issues/240 292 | // For now explicitly reinject the optimization flags that a CMake Release build is expected to have on in this scenario. 293 | // This fixes CPU inference performance when part of a Rust debug build. 294 | for flag in &["/O2", "/DNDEBUG", "/Ob2"] { 295 | config.cflag(flag); 296 | config.cxxflag(flag); 297 | } 298 | } 299 | 300 | config.static_crt(static_crt); 301 | 302 | if matches!(target_os, TargetOs::Android) { 303 | // build flags for android taken from this doc 304 | // https://github.com/ggerganov/llama.cpp/blob/master/docs/android.md 305 | let android_ndk = env::var("ANDROID_NDK") 306 | .expect("Please install Android NDK and ensure that ANDROID_NDK env variable is set"); 307 | 308 | println!("cargo::rerun-if-env-changed=ANDROID_NDK"); 309 | 310 | config.define( 311 | "CMAKE_TOOLCHAIN_FILE", 312 | format!("{android_ndk}/build/cmake/android.toolchain.cmake"), 313 | ); 314 | if env::var("ANDROID_PLATFORM").is_ok() { 315 | println!("cargo::rerun-if-env-changed=ANDROID_PLATFORM"); 316 | } else { 317 | config.define("ANDROID_PLATFORM", "android-28"); 318 | } 319 | if target_triple.contains("aarch64") || target_triple.contains("armv7") { 320 | config.cflag("-march=armv8.7a"); 321 | config.cxxflag("-march=armv8.7a"); 322 | } else if target_triple.contains("x86_64") { 323 | config.cflag("-march=x86-64"); 324 | config.cxxflag("-march=x86-64"); 325 | } else if target_triple.contains("i686") { 326 | config.cflag("-march=i686"); 327 | config.cxxflag("-march=i686"); 328 | } else { 329 | // Rather than guessing just fail. 330 | panic!("Unsupported Android target {target_triple}"); 331 | } 332 | config.define("GGML_LLAMAFILE", "OFF"); 333 | if cfg!(feature = "shared-stdcxx") { 334 | println!("cargo:rustc-link-lib=dylib=stdc++"); 335 | println!("cargo:rustc-link-lib=c++_shared"); 336 | } 337 | } 338 | 339 | if matches!(target_os, TargetOs::Linux) 340 | && target_triple.contains("aarch64") 341 | && !env::var(format!("CARGO_FEATURE_{}", "native".to_uppercase())).is_ok() 342 | { 343 | // If the native feature is not enabled, we take off the native ARM64 support. 344 | // It is useful in docker environments where the native feature is not enabled. 345 | config.define("GGML_NATIVE", "OFF"); 346 | config.define("GGML_CPU_ARM_ARCH", "armv8-a"); 347 | } 348 | 349 | if cfg!(feature = "vulkan") { 350 | config.define("GGML_VULKAN", "ON"); 351 | match target_os { 352 | TargetOs::Windows(_) => { 353 | let vulkan_path = env::var("VULKAN_SDK").expect( 354 | "Please install Vulkan SDK and ensure that VULKAN_SDK env variable is set", 355 | ); 356 | let vulkan_lib_path = Path::new(&vulkan_path).join("Lib"); 357 | println!("cargo:rustc-link-search={}", vulkan_lib_path.display()); 358 | println!("cargo:rustc-link-lib=vulkan-1"); 359 | } 360 | TargetOs::Linux => { 361 | println!("cargo:rustc-link-lib=vulkan"); 362 | } 363 | _ => (), 364 | } 365 | } 366 | 367 | if cfg!(feature = "cuda") { 368 | config.define("GGML_CUDA", "ON"); 369 | 370 | if cfg!(feature = "cuda-no-vmm") { 371 | config.define("GGML_CUDA_NO_VMM", "ON"); 372 | } 373 | } 374 | 375 | // Android doesn't have OpenMP support AFAICT and openmp is a default feature. Do this here 376 | // rather than modifying the defaults in Cargo.toml just in case someone enables the OpenMP feature 377 | // and tries to build for Android anyway. 378 | if cfg!(feature = "openmp") && !matches!(target_os, TargetOs::Android) { 379 | config.define("GGML_OPENMP", "ON"); 380 | } else { 381 | config.define("GGML_OPENMP", "OFF"); 382 | } 383 | 384 | // General 385 | config 386 | .profile(&profile) 387 | .very_verbose(std::env::var("CMAKE_VERBOSE").is_ok()) // Not verbose by default 388 | .always_configure(false); 389 | 390 | let build_dir = config.build(); 391 | let build_info_src = llama_src.join("common/build-info.cpp"); 392 | let build_info_target = build_dir.join("build-info.cpp"); 393 | std::fs::rename(&build_info_src,&build_info_target).unwrap_or_else(|move_e| { 394 | // Rename may fail if the target directory is on a different filesystem/disk from the source. 395 | // Fall back to copy + delete to achieve the same effect in this case. 396 | std::fs::copy(&build_info_src, &build_info_target).unwrap_or_else(|copy_e| { 397 | panic!("Failed to rename {build_info_src:?} to {build_info_target:?}. Move failed with {move_e:?} and copy failed with {copy_e:?}"); 398 | }); 399 | std::fs::remove_file(&build_info_src).unwrap_or_else(|e| { 400 | panic!("Failed to delete {build_info_src:?} after copying to {build_info_target:?}: {e:?} (move failed because {move_e:?})"); 401 | }); 402 | }); 403 | 404 | // Search paths 405 | println!("cargo:rustc-link-search={}", out_dir.join("lib").display()); 406 | println!( 407 | "cargo:rustc-link-search={}", 408 | out_dir.join("lib64").display() 409 | ); 410 | println!("cargo:rustc-link-search={}", build_dir.display()); 411 | 412 | if cfg!(feature = "cuda") && !build_shared_libs { 413 | println!("cargo:rerun-if-env-changed=CUDA_PATH"); 414 | 415 | for lib_dir in find_cuda_helper::find_cuda_lib_dirs() { 416 | println!("cargo:rustc-link-search=native={}", lib_dir.display()); 417 | } 418 | 419 | // Logic from ggml-cuda/CMakeLists.txt 420 | println!("cargo:rustc-link-lib=static=cudart_static"); 421 | if matches!(target_os, TargetOs::Windows(_)) { 422 | println!("cargo:rustc-link-lib=static=cublas"); 423 | println!("cargo:rustc-link-lib=static=cublasLt"); 424 | } else { 425 | println!("cargo:rustc-link-lib=static=cublas_static"); 426 | println!("cargo:rustc-link-lib=static=cublasLt_static"); 427 | } 428 | 429 | // Need to link against libcuda.so unless GGML_CUDA_NO_VMM is defined. 430 | if !cfg!(feature = "cuda-no-vmm") { 431 | println!("cargo:rustc-link-lib=cuda"); 432 | } 433 | 434 | println!("cargo:rustc-link-lib=static=culibos"); 435 | } 436 | 437 | // Link libraries 438 | let llama_libs_kind = if build_shared_libs { "dylib" } else { "static" }; 439 | let llama_libs = extract_lib_names(&out_dir, build_shared_libs); 440 | assert_ne!(llama_libs.len(), 0); 441 | 442 | for lib in llama_libs { 443 | let link = format!("cargo:rustc-link-lib={}={}", llama_libs_kind, lib); 444 | debug_log!("LINK {link}",); 445 | println!("{link}",); 446 | } 447 | 448 | // OpenMP 449 | if cfg!(feature = "openmp") && target_triple.contains("gnu") { 450 | println!("cargo:rustc-link-lib=gomp"); 451 | } 452 | 453 | match target_os { 454 | TargetOs::Windows(WindowsVariant::Msvc) => { 455 | if cfg!(debug_assertions) { 456 | println!("cargo:rustc-link-lib=dylib=msvcrtd"); 457 | } 458 | } 459 | TargetOs::Linux => { 460 | println!("cargo:rustc-link-lib=dylib=stdc++"); 461 | } 462 | TargetOs::Apple(variant) => { 463 | println!("cargo:rustc-link-lib=framework=Foundation"); 464 | println!("cargo:rustc-link-lib=framework=Metal"); 465 | println!("cargo:rustc-link-lib=framework=MetalKit"); 466 | println!("cargo:rustc-link-lib=framework=Accelerate"); 467 | println!("cargo:rustc-link-lib=c++"); 468 | 469 | match variant { 470 | AppleVariant::MacOS => { 471 | // On (older) OSX we need to link against the clang runtime, 472 | // which is hidden in some non-default path. 473 | // 474 | // More details at https://github.com/alexcrichton/curl-rust/issues/279. 475 | if let Some(path) = macos_link_search_path() { 476 | println!("cargo:rustc-link-lib=clang_rt.osx"); 477 | println!("cargo:rustc-link-search={}", path); 478 | } 479 | } 480 | AppleVariant::Other => (), 481 | } 482 | } 483 | _ => (), 484 | } 485 | 486 | // copy DLLs to target 487 | if build_shared_libs { 488 | let libs_assets = extract_lib_assets(&out_dir); 489 | for asset in libs_assets { 490 | let asset_clone = asset.clone(); 491 | let filename = asset_clone.file_name().unwrap(); 492 | let filename = filename.to_str().unwrap(); 493 | let dst = target_dir.join(filename); 494 | debug_log!("HARD LINK {} TO {}", asset.display(), dst.display()); 495 | if !dst.exists() { 496 | std::fs::hard_link(asset.clone(), dst).unwrap(); 497 | } 498 | 499 | // Copy DLLs to examples as well 500 | if target_dir.join("examples").exists() { 501 | let dst = target_dir.join("examples").join(filename); 502 | debug_log!("HARD LINK {} TO {}", asset.display(), dst.display()); 503 | if !dst.exists() { 504 | std::fs::hard_link(asset.clone(), dst).unwrap(); 505 | } 506 | } 507 | 508 | // Copy DLLs to target/profile/deps as well for tests 509 | let dst = target_dir.join("deps").join(filename); 510 | debug_log!("HARD LINK {} TO {}", asset.display(), dst.display()); 511 | if !dst.exists() { 512 | std::fs::hard_link(asset.clone(), dst).unwrap(); 513 | } 514 | } 515 | } 516 | } 517 | -------------------------------------------------------------------------------- /llama-cpp-sys-2/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! See [llama-cpp-2](https://crates.io/crates/llama-cpp-2) for a documented and safe API. 2 | 3 | #![allow(non_upper_case_globals)] 4 | #![allow(non_camel_case_types)] 5 | #![allow(non_snake_case)] 6 | 7 | include!(concat!(env!("OUT_DIR"), "/bindings.rs")); 8 | -------------------------------------------------------------------------------- /llama-cpp-sys-2/wrapper.h: -------------------------------------------------------------------------------- 1 | #include "llama.cpp/include/llama.h" -------------------------------------------------------------------------------- /test-build.Dockerfile: -------------------------------------------------------------------------------- 1 | ARG CUDA_VERSION=12.3.1 2 | ARG UBUNTU_VERSION=22.04 3 | FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION} AS base-cuda 4 | 5 | # Install requirements for rustup install + bindgen: https://rust-lang.github.io/rust-bindgen/requirements.html 6 | RUN DEBIAN_FRONTEND=noninteractive apt update -y && apt install -y curl llvm-dev libclang-dev clang pkg-config libssl-dev cmake git 7 | RUN curl https://sh.rustup.rs -sSf | bash -s -- -y 8 | ENV PATH=/root/.cargo/bin:$PATH 9 | 10 | COPY . . 11 | RUN cargo build --bin simple --features cuda 12 | 13 | FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} AS base-cuda-runtime 14 | 15 | COPY --from=base-cuda /target/debug/simple /usr/local/bin/simple 16 | 17 | ENTRYPOINT ["/usr/local/bin/simple"] 18 | --------------------------------------------------------------------------------