├── .cargo └── config ├── .docker ├── cpu.dockerfile ├── nvidia.dockerfile └── nvidia.md ├── .dockerignore ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── LICENSE.third_parties ├── README.md ├── build.rs ├── examples └── api_hello_world.py ├── proto └── sentencepiece_model.proto ├── rllama.gif └── src ├── benches └── benchmark.rs ├── data_source.rs ├── embedding.rs ├── huggingface_loader.rs ├── lib.rs ├── main.rs ├── model_params.rs ├── protomodels ├── mod.rs └── sentencepiece_model.rs ├── rllama_main.rs ├── semaphore.rs ├── simd_support.rs ├── tensor.rs ├── tensor_opencl_support.rs ├── token_sampler.rs ├── tokenizer.rs ├── transformer.rs ├── unpickler.rs └── weight_compression.rs /.cargo/config: -------------------------------------------------------------------------------- 1 | [build] 2 | rustflags = ["-C", "target-feature=+avx2,+avx,+sse,+fma"] 3 | -------------------------------------------------------------------------------- /.docker/cpu.dockerfile: -------------------------------------------------------------------------------- 1 | FROM debian:bookworm 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | RUN apt update -y 5 | RUN apt install -y curl \ 6 | apt-utils \ 7 | unzip \ 8 | tar \ 9 | curl \ 10 | xz-utils \ 11 | build-essential \ 12 | gcc 13 | 14 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs > /rustup.sh 15 | RUN chmod +x /rustup.sh 16 | RUN /rustup.sh -y 17 | 18 | RUN bash -c 'export LD_LIBRARY_PATH=/usr/lib:/lib:/usr/lib64:/lib64; export PATH="$PATH:$HOME/.cargo/bin";rustup default nightly' 19 | 20 | COPY . /opt/rllama 21 | RUN bash -c 'export PATH="$PATH:$HOME/.cargo/bin";cd /opt/rllama;RUSTFLAGS="-C target-feature=+sse2,+avx,+fma,+avx2" cargo build --release --features server' 22 | RUN ln -s /opt/rllama/target/release/rllama /usr/bin 23 | -------------------------------------------------------------------------------- /.docker/nvidia.dockerfile: -------------------------------------------------------------------------------- 1 | FROM debian:bookworm 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | RUN apt update -y 5 | RUN apt install -y curl \ 6 | apt-utils \ 7 | unzip \ 8 | tar \ 9 | curl \ 10 | xz-utils \ 11 | ocl-icd-libopencl1 \ 12 | opencl-headers \ 13 | clinfo \ 14 | build-essential \ 15 | gcc 16 | 17 | RUN mkdir -p /etc/OpenCL/vendors && \ 18 | echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd 19 | ENV NVIDIA_VISIBLE_DEVICES all 20 | ENV NVIDIA_DRIVER_CAPABILITIES compute,utility 21 | 22 | RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs > /rustup.sh 23 | RUN chmod +x /rustup.sh 24 | RUN /rustup.sh -y 25 | 26 | RUN apt install -y opencl-dev 27 | 28 | RUN bash -c 'export PATH="$PATH:$HOME/.cargo/bin";rustup default nightly' 29 | 30 | COPY . /opt/rllama 31 | RUN bash -c 'export PATH="$PATH:$HOME/.cargo/bin";cd /opt/rllama;RUSTFLAGS="-C target-feature=+sse2,+avx,+fma,+avx2" cargo build --release --features server,opencl' 32 | RUN ln -s /opt/rllama/target/release/rllama /usr/bin 33 | -------------------------------------------------------------------------------- /.docker/nvidia.md: -------------------------------------------------------------------------------- 1 | #rllama docker on nvidia 2 | 3 | ## Getting OpenCL to work inside docker. 4 | Please note that this also requires some packages and modifications on your host system in order to allow the containers to use nvidia GPU features such as **compute**. 5 | 6 | 7 | For each of the described distro / distro-family you could follow the instructions at the given links below. 8 | 9 | **Note**: You also need an upto-date version of docker/docker-ce so be sure to follow the instructions to install docker for your distro from the [docker website](https://docs.docker.com/engine/install). 10 | 11 | **Note2**: I have only personally tested the instructions on fedora/nobara and hence, cannot guarantee the accuracy of the instructions for other distros. 12 | 13 | ### Fedora / Fedora-based 14 | **[https://gist.github.com/JuanM04/fcbed16d0f4405a286adebee5fd31cb2](https://gist.github.com/JuanM04/fcbed16d0f4405a286adebee5fd31cb2)** 15 | 16 | 17 | ### Debian / Debian-based / Ubuntu / Ubuntu-based 18 | **[https://www.howtogeek.com/devops/how-to-use-an-nvidia-gpu-with-docker-containers/](https://www.howtogeek.com/devops/how-to-use-an-nvidia-gpu-with-docker-containers/)** 19 | 20 | 21 | ### Arch / Arch-based 22 | **[https://wiki.archlinux.org/title/Docker#Run_GPU_accelerated_Docker_containers_with_NVIDIA_GPUs](https://wiki.archlinux.org/title/Docker#Run_GPU_accelerated_Docker_containers_with_NVIDIA_GPUs)** 23 | 24 | Feel free to contribute/improve the instructions for existing and other distros. 25 | 26 | ## Usage 27 | 1. 28 | ```bash 29 | docker build -f ./.docker/nvidia.dockerfile -t rllama:nvidia . 30 | ``` 31 | 2. 32 | ```bash 33 | docker run --rm --gpus all --privileged -v /models/LLaMA:/models:z -it rllama:nvidia \ 34 | rllama --model-path /models/7B \ 35 | --param-path /models/7B/params.json \ 36 | --tokenizer-path /models/tokenizer.model \ 37 | --prompt "hi I like cheese" 38 | ``` 39 | 40 | Replace `/models/LLaMA` with the directory you've downloaded your models to. The `:z` in `-v` flag may or may not be needed depending on your distribution (I needed it on Fedora Linux) -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | target 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "rllama" 3 | version = "0.3.0" 4 | edition = "2021" 5 | authors = ["Mikko Juola"] 6 | description = "Pure Rust implementation of LLaMA-family of models, executable" 7 | documentation = "https://github.com/Noeda/rllama" 8 | homepage = "https://github.com/Noeda/rllama" 9 | repository = "https://github.com/Noeda/rllama" 10 | license = "AGPL-3.0" 11 | keywords = ["llama", "machine-learning"] 12 | categories = ["command-line-utilities"] 13 | 14 | [lib] 15 | path = "src/lib.rs" 16 | 17 | [[bin]] 18 | name = "rllama" 19 | path = "src/main.rs" 20 | 21 | [dependencies] 22 | protobuf = "3.2" 23 | thiserror = "1.0" 24 | half = "2.2" 25 | num-complex = "0.4" 26 | embedded-profiling = "0.3" 27 | rand = "0.8" 28 | approx = "0.5" 29 | rayon = "1.7" 30 | clap = { version = "4.1", features = ["derive"] } 31 | indicatif = "0.17" 32 | colored = "2" 33 | serde = { version = "1", features = ["derive"] } 34 | serde_json = "1" 35 | mimalloc = "0.1" 36 | ocl = { version = "0.19", optional = true } 37 | rocket = { version = "0.4", features = ["sse"], optional = true } 38 | lazy_static = "1.4" 39 | zip = "0.6" 40 | ouroboros = "0.15" 41 | 42 | [features] 43 | opencl = ["ocl"] 44 | server = ["rocket"] 45 | 46 | # We need protobuf compiler 47 | [build-dependencies] 48 | protobuf-codegen = "3.2" 49 | protobuf-parse = "3.2" 50 | 51 | [dev-dependencies] 52 | criterion = "0.4" 53 | 54 | [profile.release] 55 | panic = 'abort' 56 | debug = true 57 | 58 | [[bench]] 59 | path = "src/benches/benchmark.rs" 60 | name = "benchmark" 61 | harness = false 62 | -------------------------------------------------------------------------------- /LICENSE.third_parties: -------------------------------------------------------------------------------- 1 | proto/ directory contains a protobuf file from Google's 2 | https://github.com/google/sentencepiece repository. 3 | 4 | Here is their license: (note rllama as a whole is AGPL3) 5 | ----- 6 | 7 | 8 | Apache License 9 | Version 2.0, January 2004 10 | http://www.apache.org/licenses/ 11 | 12 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 13 | 14 | 1. Definitions. 15 | 16 | "License" shall mean the terms and conditions for use, reproduction, 17 | and distribution as defined by Sections 1 through 9 of this document. 18 | 19 | "Licensor" shall mean the copyright owner or entity authorized by 20 | the copyright owner that is granting the License. 21 | 22 | "Legal Entity" shall mean the union of the acting entity and all 23 | other entities that control, are controlled by, or are under common 24 | control with that entity. For the purposes of this definition, 25 | "control" means (i) the power, direct or indirect, to cause the 26 | direction or management of such entity, whether by contract or 27 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 28 | outstanding shares, or (iii) beneficial ownership of such entity. 29 | 30 | "You" (or "Your") shall mean an individual or Legal Entity 31 | exercising permissions granted by this License. 32 | 33 | "Source" form shall mean the preferred form for making modifications, 34 | including but not limited to software source code, documentation 35 | source, and configuration files. 36 | 37 | "Object" form shall mean any form resulting from mechanical 38 | transformation or translation of a Source form, including but 39 | not limited to compiled object code, generated documentation, 40 | and conversions to other media types. 41 | 42 | "Work" shall mean the work of authorship, whether in Source or 43 | Object form, made available under the License, as indicated by a 44 | copyright notice that is included in or attached to the work 45 | (an example is provided in the Appendix below). 46 | 47 | "Derivative Works" shall mean any work, whether in Source or Object 48 | form, that is based on (or derived from) the Work and for which the 49 | editorial revisions, annotations, elaborations, or other modifications 50 | represent, as a whole, an original work of authorship. For the purposes 51 | of this License, Derivative Works shall not include works that remain 52 | separable from, or merely link (or bind by name) to the interfaces of, 53 | the Work and Derivative Works thereof. 54 | 55 | "Contribution" shall mean any work of authorship, including 56 | the original version of the Work and any modifications or additions 57 | to that Work or Derivative Works thereof, that is intentionally 58 | submitted to Licensor for inclusion in the Work by the copyright owner 59 | or by an individual or Legal Entity authorized to submit on behalf of 60 | the copyright owner. For the purposes of this definition, "submitted" 61 | means any form of electronic, verbal, or written communication sent 62 | to the Licensor or its representatives, including but not limited to 63 | communication on electronic mailing lists, source code control systems, 64 | and issue tracking systems that are managed by, or on behalf of, the 65 | Licensor for the purpose of discussing and improving the Work, but 66 | excluding communication that is conspicuously marked or otherwise 67 | designated in writing by the copyright owner as "Not a Contribution." 68 | 69 | "Contributor" shall mean Licensor and any individual or Legal Entity 70 | on behalf of whom a Contribution has been received by Licensor and 71 | subsequently incorporated within the Work. 72 | 73 | 2. Grant of Copyright License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | copyright license to reproduce, prepare Derivative Works of, 77 | publicly display, publicly perform, sublicense, and distribute the 78 | Work and such Derivative Works in Source or Object form. 79 | 80 | 3. Grant of Patent License. Subject to the terms and conditions of 81 | this License, each Contributor hereby grants to You a perpetual, 82 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 83 | (except as stated in this section) patent license to make, have made, 84 | use, offer to sell, sell, import, and otherwise transfer the Work, 85 | where such license applies only to those patent claims licensable 86 | by such Contributor that are necessarily infringed by their 87 | Contribution(s) alone or by combination of their Contribution(s) 88 | with the Work to which such Contribution(s) was submitted. If You 89 | institute patent litigation against any entity (including a 90 | cross-claim or counterclaim in a lawsuit) alleging that the Work 91 | or a Contribution incorporated within the Work constitutes direct 92 | or contributory patent infringement, then any patent licenses 93 | granted to You under this License for that Work shall terminate 94 | as of the date such litigation is filed. 95 | 96 | 4. Redistribution. You may reproduce and distribute copies of the 97 | Work or Derivative Works thereof in any medium, with or without 98 | modifications, and in Source or Object form, provided that You 99 | meet the following conditions: 100 | 101 | (a) You must give any other recipients of the Work or 102 | Derivative Works a copy of this License; and 103 | 104 | (b) You must cause any modified files to carry prominent notices 105 | stating that You changed the files; and 106 | 107 | (c) You must retain, in the Source form of any Derivative Works 108 | that You distribute, all copyright, patent, trademark, and 109 | attribution notices from the Source form of the Work, 110 | excluding those notices that do not pertain to any part of 111 | the Derivative Works; and 112 | 113 | (d) If the Work includes a "NOTICE" text file as part of its 114 | distribution, then any Derivative Works that You distribute must 115 | include a readable copy of the attribution notices contained 116 | within such NOTICE file, excluding those notices that do not 117 | pertain to any part of the Derivative Works, in at least one 118 | of the following places: within a NOTICE text file distributed 119 | as part of the Derivative Works; within the Source form or 120 | documentation, if provided along with the Derivative Works; or, 121 | within a display generated by the Derivative Works, if and 122 | wherever such third-party notices normally appear. The contents 123 | of the NOTICE file are for informational purposes only and 124 | do not modify the License. You may add Your own attribution 125 | notices within Derivative Works that You distribute, alongside 126 | or as an addendum to the NOTICE text from the Work, provided 127 | that such additional attribution notices cannot be construed 128 | as modifying the License. 129 | 130 | You may add Your own copyright statement to Your modifications and 131 | may provide additional or different license terms and conditions 132 | for use, reproduction, or distribution of Your modifications, or 133 | for any such Derivative Works as a whole, provided Your use, 134 | reproduction, and distribution of the Work otherwise complies with 135 | the conditions stated in this License. 136 | 137 | 5. Submission of Contributions. Unless You explicitly state otherwise, 138 | any Contribution intentionally submitted for inclusion in the Work 139 | by You to the Licensor shall be under the terms and conditions of 140 | this License, without any additional terms or conditions. 141 | Notwithstanding the above, nothing herein shall supersede or modify 142 | the terms of any separate license agreement you may have executed 143 | with Licensor regarding such Contributions. 144 | 145 | 6. Trademarks. This License does not grant permission to use the trade 146 | names, trademarks, service marks, or product names of the Licensor, 147 | except as required for reasonable and customary use in describing the 148 | origin of the Work and reproducing the content of the NOTICE file. 149 | 150 | 7. Disclaimer of Warranty. Unless required by applicable law or 151 | agreed to in writing, Licensor provides the Work (and each 152 | Contributor provides its Contributions) on an "AS IS" BASIS, 153 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 154 | implied, including, without limitation, any warranties or conditions 155 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 156 | PARTICULAR PURPOSE. You are solely responsible for determining the 157 | appropriateness of using or redistributing the Work and assume any 158 | risks associated with Your exercise of permissions under this License. 159 | 160 | 8. Limitation of Liability. In no event and under no legal theory, 161 | whether in tort (including negligence), contract, or otherwise, 162 | unless required by applicable law (such as deliberate and grossly 163 | negligent acts) or agreed to in writing, shall any Contributor be 164 | liable to You for damages, including any direct, indirect, special, 165 | incidental, or consequential damages of any character arising as a 166 | result of this License or out of the use or inability to use the 167 | Work (including but not limited to damages for loss of goodwill, 168 | work stoppage, computer failure or malfunction, or any and all 169 | other commercial damages or losses), even if such Contributor 170 | has been advised of the possibility of such damages. 171 | 172 | 9. Accepting Warranty or Additional Liability. While redistributing 173 | the Work or Derivative Works thereof, You may choose to offer, 174 | and charge a fee for, acceptance of support, warranty, indemnity, 175 | or other liability obligations and/or rights consistent with this 176 | License. However, in accepting such obligations, You may act only 177 | on Your own behalf and on Your sole responsibility, not on behalf 178 | of any other Contributor, and only if You agree to indemnify, 179 | defend, and hold each Contributor harmless for any liability 180 | incurred by, or claims asserted against, such Contributor by reason 181 | of your accepting any such warranty or additional liability. 182 | 183 | END OF TERMS AND CONDITIONS 184 | 185 | APPENDIX: How to apply the Apache License to your work. 186 | 187 | To apply the Apache License to your work, attach the following 188 | boilerplate notice, with the fields enclosed by brackets "[]" 189 | replaced with your own identifying information. (Don't include 190 | the brackets!) The text should be enclosed in the appropriate 191 | comment syntax for the file format. We also recommend that a 192 | file or class name and description of purpose be included on the 193 | same "printed page" as the copyright notice for easier 194 | identification within third-party archives. 195 | 196 | Copyright [yyyy] [name of copyright owner] 197 | 198 | Licensed under the Apache License, Version 2.0 (the "License"); 199 | you may not use this file except in compliance with the License. 200 | You may obtain a copy of the License at 201 | 202 | http://www.apache.org/licenses/LICENSE-2.0 203 | 204 | Unless required by applicable law or agreed to in writing, software 205 | distributed under the License is distributed on an "AS IS" BASIS, 206 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 207 | See the License for the specific language governing permissions and 208 | limitations under the License. 209 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RLLaMA 2 | 3 | RLLaMA is a pure Rust implementation of [LLaMA large language model inference.](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/). 4 | 5 | ## Supported features 6 | 7 | * Uses either `f16` and `f32` weights. 8 | * LLaMA-7B, LLaMA-13B, LLaMA-30B, LLaMA-65B all confirmed working 9 | * Hand-optimized AVX2 implementation 10 | * OpenCL support for GPU inference. 11 | * Load model only partially to GPU with `--percentage-to-gpu` command line switch to run hybrid-GPU-CPU inference. 12 | * Simple HTTP API support, with the possibility of doing token sampling on 13 | client side 14 | * It can load `Vicuna-13B` instruct-finetuned model (although currently there is no nice UX). 15 | 16 | ## Performance 17 | 18 | The current performance is as follows: 19 | 20 | ``` 21 | Pure Rust implementations: 22 | 23 | LLaMA-7B: AMD Ryzen 3950X: 552ms / token f16 (pure Rust) 24 | LLaMA-7B: AMD Ryzen 3950X: 1008ms / token f32 (pure Rust) 25 | LLaMA-13B: AMD Ryzen 3950X: 1029ms / token f16 (pure Rust) 26 | LLaMA-13B: AMD Ryzen 3950X: 1930ms / token f32 (pure Rust) 27 | LLaMA-30B: AMD Ryzen 5950X: 2112ms / token f16 (pure Rust) 28 | LLaMA-65B: AMD Ryzen 5950X: 4186ms / token f16 (pure Rust) 29 | 30 | OpenCL (all use f16): 31 | 32 | LLaMA-7B: AMD Ryzen 3950X + OpenCL RTX 3090 Ti: 216ms / token (OpenCL on GPU) 33 | LLaMA-7B: AMD Ryzen 3950X + OpenCL Ryzen 3950X: 680ms / token (OpenCL on CPU) 34 | LLaMA-13B: AMD Ryzen 3950X + OpenCL RTX 3090 Ti: 420ms / token (OpenCL on GPU) 35 | LLaMA-13B: AMD Ryzen 3950X + OpenCL Ryzen 3950X: 1232ms / token (OpenCL on CPU) 36 | LLaMA-30B: AMD Ryzen 5950X + OpenCL Ryzen 5950X: 4098ms / token (OpenCL on CPU) 37 | ``` 38 | 39 | Scroll to the bottom of this README.md to see benchmarks over time. 40 | 41 | ## Screenshot 42 | 43 | ![Screenshot of RLLaMA in action](rllama.gif) 44 | 45 | ## Install 46 | 47 | You can install with `cargo` tool. RLLaMA uses intrinsics extensively and you 48 | likely need to enable them to install the executable. 49 | 50 | ``` 51 | RUSTFLAGS="-C target-feature=+sse2,+avx,+fma,+avx2" cargo install rllama 52 | ``` 53 | 54 | There is a `.cargo/config.toml` inside this repository that will enable these 55 | features if you install manually from this Git repository instead. 56 | 57 | ## Install (Docker path) 58 | 59 | There is a Dockerfile you can use if you'd rather just get started quickly and 60 | you are familiar with `docker`. You still need to download the models yourself. 61 | 62 | 63 | ### For CPU-only docker support: 64 | ``` 65 | docker build -f ./.docker/cpu.dockerfile -t rllama . 66 | ``` 67 | 68 | ``` 69 | docker run -v /models/LLaMA:/models:z -it rllama \ 70 | rllama --model-path /models/7B \ 71 | --param-path /models/7B/params.json \ 72 | --tokenizer-path /models/tokenizer.model \ 73 | --prompt "hi I like cheese" 74 | ``` 75 | 76 | Replace `/models/LLaMA` with the directory you've downloaded your models to. 77 | The `:z` in `-v` flag may or may not be needed depending on your distribution 78 | (I needed it on Fedora Linux) 79 | 80 | ### For GPU-enabled docker support with nvidia: 81 | Follow the instructions [here](.docker/nvidia.md). 82 | 83 | ## LLaMA weights 84 | 85 | Refer to https://github.com/facebookresearch/llama/ As of now, you need to be 86 | approved to get weights. 87 | 88 | For LLaMA-7B make sure, you got these files: 89 | 90 | ```shell 91 | * 7B/consolidated.00.pth 92 | * 7B/params.json 93 | * tokenizer.model 94 | ``` 95 | 96 | The `consolidated.00.pth` is actually a zip file. You need to unzip it: 97 | 98 | ```shell 99 | $ cd 7B 100 | $ unzip consolidated.00.pth 101 | $ mv consolidated consolidated.00 102 | ``` 103 | 104 | If you are using a larger model like LLaMA-13B, then you can skip the last step 105 | of renaming the `consolidated` directory. 106 | 107 | You should now be ready to generate some text. 108 | 109 | ## Example 110 | 111 | Run LLaMA-7B with some weights casted to 16-bit floats: 112 | 113 | ```shell 114 | rllama --tokenizer-path /path/to/tokenizer.model \ 115 | --model-path /path/to/LLaMA/7B \ 116 | --param-path /path/to/LLaMA/7B/params.json \ 117 | --f16 \ 118 | --prompt "The meaning of life is" 119 | ``` 120 | 121 | Use `rllama --help` to see all the options. 122 | 123 | ## Partially load model to GPU 124 | 125 | `rllama` can load only some of the transformer blocks to GPU. There is a 126 | command line argument: 127 | 128 | `--percentage-to-gpu ` 129 | 130 | 1 means 100% and 0 means 0%. Values in-between load the model partially to GPU. 131 | 132 | You can use this to load LLaMA-13B or Vicuna-13B on a consumer GPU of 24 133 | gigabytes at around `--percentage-to-gpu 0.9` before it fails to out-of-memory 134 | error (if there are no competing programs on the computer that use GPU memory). 135 | 136 | ## Interactive mode 137 | 138 | There is a simple experimental interactive mode to try force a type of 139 | back-and-forth discussion with the model. 140 | 141 | ```shell 142 | rllama ... --start-interactive \ 143 | --interactive-system-prompt "Helpful assistant helps curious human." \ # (optional) 144 | --interactive-prompt-postfix " ###Assistant:" \ # (optional) 145 | --interactive-stop "###Human: " # (optional) 146 | ``` 147 | 148 | In this mode, you need to type your prompt before the AI starts doing its work. 149 | If the AI outputs token sequence given in `--interactive-stop` (defaults to 150 | `###Human:`) then it will ask for another input. 151 | 152 | The defaults match Vicuna-13B model: 153 | 154 | ``` 155 | --interactive-system-prompt "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." 156 | --interactive-prompt-postfix " ###Assissant:" 157 | --interactive-prompt-prefix " " 158 | --interactive-stop "###Human:" 159 | ``` 160 | 161 | `--interactive-prompt-postfix` is appended automatically to your typed text and 162 | `--interactive-prompt-prefix` is appended to the start of your typed text.Here 163 | is an example of interactive mode command line with the default settings: 164 | 165 | ```shell 166 | rllama --f16 \ 167 | --param-path /models/vicuna13b/params.json \ 168 | --model-path /models/vicuna13b \ 169 | --tokenizer-path /stonks/LLaMA/tokenizer.model \ 170 | --start-interactive 171 | ``` 172 | 173 | As of writing of this, the output is not formatted prettily for chat and there 174 | is no visual indication of when you are supposed to be typing. That will come 175 | later. 176 | 177 | ## Inference server 178 | 179 | `rllama` can run in an inference server mode with a simple HTTP JSON API. You 180 | need to enable `server` features for this. 181 | 182 | ``` 183 | cargo build --release --features server 184 | ``` 185 | 186 | The command line flags for this are: 187 | 188 | * `--inference-server` using this will turn on the inference server. 189 | * `--inference-server-port` sets the port. Default port is 8080. 190 | * `--inference-server-host` sets the host. The default host is 127.0.0.1. 191 | * `--inference-server-max-concurrent-inferences` sets how many concurrent 192 | requests are allowed to be actively doing inference at the same time. The 193 | default is 5. 194 | * `--inference-server-api-path` sets which path servers the API requests. The 195 | default path is `/rllama/v1/inference` 196 | * `--inference-server-prompt-cache-size` sets how many previous prompt 197 | calculations should be cached. Default is 50. This speeds up token 198 | generation for prompts that were already requested before, however it also 199 | increases memory use as the cache gets more full. 200 | * `--inference-server-exit-after-one-query` will make the server exit with 201 | exit code 0 after it has served one HTTP query. This is used for 202 | troubleshooting and experiments. 203 | 204 | Prompts and flags related to token sampling are all ignored in inference server 205 | mode. Instead, they are obtained from each HTTP JSON API request. 206 | 207 | ### Inference server API 208 | 209 | There is an `examples/api_hello_world.py` for a minimal API use example. 210 | 211 | ``` 212 | POST /rllama/v1/inference 213 | ``` 214 | 215 | Expects a JSON body and `Accept: application/json` or `Accept: text/jsonl`. 216 | 217 | The expected JSON is as follows: 218 | 219 | ``` 220 | { 221 | "temperature": 222 | "top_k": 223 | "top_p": 224 | "repetition_penalty": 225 | "stop_at_end_token": 226 | "max_seq_len": 228 | "max_new_tokens": 229 | "no_token_sampling": 230 | "prompt": 231 | } 232 | ``` 233 | 234 | The form of the response depends on if `no_token_sampling` is set to true or false. The 235 | response is in JSONL, i.e. multiple JSON dictionaries, separated by newlines. 236 | 237 | `no_token_sampling` can turn off `rllama`'s own token sampling. In this case, 238 | the probabilities for every token are returned instead. 239 | 240 | When no\_token\_sampling = false: 241 | 242 | ``` 243 | {: {"p": , "is_end_token": bool, might not be present}} 244 | ``` 245 | 246 | * `token` contains the new token to be appended to output. It does not 247 | include string you fed to the system originally. 248 | * `p` is the probability that this token was chosen. For example, if this 249 | value is 0.1, it means that this particular token had 10% chance of being 250 | selected with the current token sampling settings. 251 | * `is_end_token` is `true` is the given token signifies end of output. This 252 | field is not present otherwise. 253 | 254 | When no\_token\_sampling = true: 255 | 256 | ``` 257 | {: {"p": , "is_end_token": bool, might not be present} \ 258 | ,: {"p": , "is_end_token": bool, might not be present} \ 259 | ,...} 260 | ``` 261 | 262 | If you want to implement your own token sampling, you may want to set 263 | `max_new_tokens=1` and `stop_at_end_token=false` to suppress rllama's own 264 | sampling behavior entirely. 265 | 266 | `rllama` internally caches recently queried prompts and the intermediate 267 | computations so that it's able to continue off quickly if you issue a query 268 | that is either the same as a previous query or a continuation of one. 269 | 270 | ## How to turn on OpenCL 271 | 272 | Use `opencl` Cargo feature. 273 | 274 | ``` 275 | RUSTFLAGS="-C target-feature=+sse2,+avx,+fma,+avx2" cargo install rllama --features opencl 276 | ``` 277 | 278 | ``` 279 | rllama --tokenizer-path /path/to/tokenizer.model \ 280 | --model-path /path/to/LLaMA/7B \ 281 | --param-path /path/to/LLaMA/7B/params.json \ 282 | --opencl-device 0 \ 283 | --prompt "The meaning of life is" 284 | ``` 285 | 286 | With `opencl` feature, there is also another argument, `--opencl-device` that 287 | takes a number. That number selects Nth OpenCL device found on the system. You 288 | can see the devices in the output when you run the program (e.g. see the 289 | screenshot below). 290 | 291 | Weights are always cast to 16-bit floats for OpenCL. 292 | 293 | ## Notes and future plans 294 | 295 | This is a hobby thing for me so don't expect updates or help. 296 | 297 | * There are various BLAS libraries like CLBlast to speed up matrix 298 | multiplication that probably outperform my handwritten code. 299 | * I've heard there is some thing called Tensor Cores on nVidia GPUs. Not 300 | accessible with OpenCL. But might be accessible on Vulkan with a an 301 | extension. Or with cuBLAS. 302 | 303 | ## Benchmarks 304 | 305 | I'm trying to track that I'm making this faster and not slower. 306 | 307 | For 50-length sequence generation: 308 | 309 | ``` 310 | cargo run --release -- 311 | --model-path /LLaMA/13B \ 312 | --param-path /LLaMA/13B/params.json \ 313 | --tokenizer-path /LLaMA/tokenizer.model \ 314 | --prompt "Computers are pretty complica" --max-seq-len 50 315 | 316 | # commit c9c861d199bd2d87d7e883e3087661c1e287f6c4 (13 March 2023) 317 | 318 | LLaMA-7B: AMD Ryzen 3950X: 1058ms / token 319 | LLaMA-13B: AMD Ryzen 3950X: 2005ms / token 320 | 321 | # commit 63d27dba9091823f8ba11a270ab5790d6f597311 (13 March 2023) 322 | # This one has one part of the transformer moved to GPU as a type of smoke test 323 | 324 | LLaMA-7B: AMD Ryzen 3950X + OpenCL RTX 3090 Ti: 567ms / token 325 | LLaMA-7B: AMD Ryzen 3950X + OpenCL Ryzen 3950X: 956ms / token 326 | LLaMA-13B: AMD Ryzen 3950X + OpenCL RTX 3090 Ti: 987ms / token 327 | LLaMA-13B: AMD Ryzen 3950X + OpenCL Ryzen 3950X: 1706ms / token 328 | 329 | # commit 35b0c372a87192761e17beb421699ea5ad4ac1ce (13 March 2023) 330 | # I moved some attention stuff to OpenCL too. 331 | 332 | LLaMA-7B: AMD Ryzen 3950X + OpenCL RTX 3090 Ti: 283ms / token 333 | LLaMA-7B: AMD Ryzen 3950X + OpenCL Ryzen 3950X: 679ms / token 334 | LLaMA-13B: AMD Ryzen 3950X + OpenCL RTX 3090 Ti: 335 | LLaMA-13B: AMD Ryzen 3950X + OpenCL Ryzen 3950X: 1226ms / token 336 | 337 | # commit de5dd592777b3a4f5a9e8c93c8aeef25b9294364 (15 March 2023) 338 | # The matrix multiplication on GPU is now much faster. It didn't have that much 339 | # effect overall though, but I got modest improvement on LLaMA-7B GPU. 340 | 341 | LLaMA-7B: AMD Ryzen 3950X + OpenCL RTX 3090 Ti: 247ms / token 342 | LLaMA-7B: AMD Ryzen 3950X + OpenCL Ryzen 3950X: 680ms / token 343 | LLaMA-13B: AMD Ryzen 3950X + OpenCL RTX 3090 Ti: 344 | LLaMA-13B: AMD Ryzen 3950X + OpenCL Ryzen 3950X: 1232ms / token 345 | LLaMA-30B: AMD Ryzen 5950X + OpenCL Ryzen 5950X: 4098ms / token 346 | 347 | # commit 3d0afcf24309f28ec540ed7645c35400a865ad6f (17 March 2023) 348 | # I've been focusing on making the ordinary non-OpenCL CPU implementation 349 | # faster and I got some gains, most importantly from multithreading. 350 | # There is Float16 support now, so I've added f16/f32 to these tables: 351 | # 352 | # I also managed to run LLaMA-65B for the first time. 353 | 354 | LLaMA-7B: AMD Ryzen 3950X: 552ms / token f16 355 | LLaMA-7B: AMD Ryzen 3950X: 1008ms / token f32 356 | LLaMA-13B: AMD Ryzen 3950X: 1029ms / token f16 357 | LLaMA-13B: AMD Ryzen 3950X: 1930ms / token f32 358 | LLaMA-30B: AMD Ryzen 5950X: 2112ms / token f16 359 | LLaMA-65B: AMD Ryzen 5950X: 4186ms / token f16 360 | 361 | # commit f5328ab5bd62fe9bd930539382b13e9033434a0b (5 April 2023) 362 | # I've worked on making Vicuna-13B runnable and added an option to only 363 | # partially use GPU. Improved one of the OpenCL kernels: 364 | 365 | LLaMA-7B: AMD Ryzen 3950X + OpenCL RTX 3090 Ti: 420ms (at 90%/10% GPU/CPU split) 366 | LLaMA-13B: AMD Ryzen 3950X + OpenCL RTX 3090 Ti: 216ms (at 100% GPU) 367 | ``` 368 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | protobuf_codegen::Codegen::new() 3 | .pure() 4 | .out_dir("src/protomodels") 5 | .include("proto") 6 | .input("proto/sentencepiece_model.proto") 7 | .run() 8 | .unwrap(); 9 | } 10 | -------------------------------------------------------------------------------- /examples/api_hello_world.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | This script uses the rllama API to generate tokens. 5 | 6 | It does not print the tokens nicely. 7 | """ 8 | 9 | import requests 10 | 11 | def main(): 12 | url = 'http://127.0.0.1:8080/rllama/v1/inference' 13 | req = { 14 | 'prompt': 'Hello world!', 15 | 'max_seq_len': 1024, 16 | 'max_new_tokens': 200, 17 | 'no_token_sampling': False 18 | } 19 | res = requests.post(url, json=req, stream=True) 20 | for line in res.iter_lines(): 21 | print(line.decode('utf-8')) 22 | 23 | 24 | if __name__ == '__main__': 25 | main() 26 | -------------------------------------------------------------------------------- /proto/sentencepiece_model.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License.! 14 | 15 | syntax = "proto2"; 16 | 17 | // TODO(taku): Needs to use LITE RUNTIME in OSS release. 18 | option optimize_for = LITE_RUNTIME; 19 | 20 | package sentencepiece; 21 | 22 | // TrainerSpec encodes a various parameters for SentencePiece training. 23 | // Next id: 53 24 | message TrainerSpec { 25 | /////////////////////////////////////////////////////////////////// 26 | // General parameters 27 | // 28 | // Input corpus files. 29 | // Trainer accepts the following two formats: 30 | // A) Monolingual: plain text, one sentence per line. 31 | // B) Bilingual: TSV, source sentence target sentence 32 | // When bilingual data is passed, shared vocabulary model is built. 33 | // Note that the input file must be raw corpus, not a preprocessed corpus. 34 | // Trainer only loads the first `input_sentence_size` sentences specified 35 | // with this parameter. 36 | repeated string input = 1; 37 | 38 | // Input corpus format: 39 | // "text": one-sentence-per-line text format (default) 40 | // "tsv": sentence freq 41 | optional string input_format = 7; 42 | 43 | // Output model file prefix. 44 | // .model and .vocab are generated. 45 | optional string model_prefix = 2; 46 | 47 | // Model type. only have UNIGRAM now. 48 | enum ModelType { 49 | UNIGRAM = 1; // Unigram language model with dynamic algorithm 50 | BPE = 2; // Byte Pair Encoding 51 | WORD = 3; // Delimitered by whitespace. 52 | CHAR = 4; // tokenizes into character sequence 53 | } 54 | optional ModelType model_type = 3 [default = UNIGRAM]; 55 | 56 | // Vocabulary size. 8k is the default size. 57 | optional int32 vocab_size = 4 [default = 8000]; 58 | 59 | // List of the languages this model can accept. 60 | // Since the model is language-agnostic, this field is used as a reference. 61 | repeated string accept_language = 5; 62 | 63 | // Size of self-test samples, which are encoded in the model file. 64 | optional int32 self_test_sample_size = 6 [default = 0]; 65 | 66 | // Whether to use DP version of sentencepiece. Use it with TSV input format 67 | // (requires precomputed word tab counts to work). 68 | optional bool enable_differential_privacy = 50 [default = false]; 69 | // Set these parameters if you need DP version of sentencepiece. 70 | // std of noise to add. 71 | optional float differential_privacy_noise_level = 51 [default = 0.0]; 72 | // Clipping threshold to apply after adding noise. All the words with 73 | // frequency less than this value are dropped. 74 | optional uint64 differential_privacy_clipping_threshold = 52 [default = 0]; 75 | 76 | /////////////////////////////////////////////////////////////////// 77 | // Training parameters. 78 | // 79 | // Uses characters which cover the corpus with the ratio of `chars_coverage`. 80 | // This parameter determines the set of basic Alphabet of sentence piece. 81 | // 1.0 - `chars_coverage` characters are treated as UNK. 82 | // See also required_chars field. 83 | optional float character_coverage = 10 [default = 0.9995]; 84 | 85 | // Maximum size of sentences the trainer loads from `input` parameter. 86 | // Trainer simply loads the `input` files in sequence. 87 | // It is better to shuffle the input corpus randomly. 88 | optional uint64 input_sentence_size = 11 [default = 0]; 89 | optional bool shuffle_input_sentence = 19 [default = true]; 90 | 91 | // Maximum size of sentences to make seed sentence pieces. 92 | // Extended suffix array is constructed to extract frequent 93 | // sub-strings from the corpus. This uses 20N working space, 94 | // where N is the size of corpus. 95 | optional int32 mining_sentence_size = 12 [deprecated = true]; 96 | 97 | // Maximum size of sentences to train sentence pieces. 98 | optional int32 training_sentence_size = 13 [deprecated = true]; 99 | 100 | // The size of seed sentencepieces. 101 | // `seed_sentencepiece_size` must be larger than `vocab_size`. 102 | optional int32 seed_sentencepiece_size = 14 [default = 1000000]; 103 | 104 | // In every EM sub-iterations, keeps top 105 | // `shrinking_factor` * `current sentencepieces size` with respect to 106 | // the loss of the sentence piece. This value should be smaller than 1.0. 107 | optional float shrinking_factor = 15 [default = 0.75]; 108 | 109 | // The maximum sentence length in byte. The sentences with the length 110 | // larger than `max_sentence_length` is simply ignored. 111 | // Longer input tends to bring the following risks: 112 | // * Overflow during EM training (unigram language model only) 113 | // * Performance drop because of O(n log n) cost in BPE. 114 | optional int32 max_sentence_length = 18 [default = 4192]; 115 | 116 | // Number of threads in the training. 117 | optional int32 num_threads = 16 [default = 16]; 118 | 119 | // Number of EM sub iterations. 120 | optional int32 num_sub_iterations = 17 [default = 2]; 121 | 122 | /////////////////////////////////////////////////////////////////// 123 | // SentencePiece parameters which control the shapes of sentence piece. 124 | // 125 | // Maximum length of sentencepiece. 126 | optional int32 max_sentencepiece_length = 20 [default = 16]; 127 | 128 | // Uses Unicode script to split sentence pieces. 129 | // When `split_by_unicode_script` is true, we do not allow sentence piece to 130 | // include multiple Unicode scripts, e.g. "F1" is not a valid piece. 131 | // Exception: CJ characters (Hiragana/Katakana/Han) are all handled 132 | // as one script type, since Japanese word can consist of multiple scripts. 133 | // This exception is always applied regardless of the accept-language 134 | // parameter. 135 | optional bool split_by_unicode_script = 21 [default = true]; 136 | 137 | // When `split_by_number` is true, put a boundary between number and 138 | // non-number transition. If we want to treat "F1" is one token, set this flag 139 | // to be false. 140 | optional bool split_by_number = 23 [default = true]; 141 | 142 | // Use a white space to split sentence pieces. 143 | // When `split_by_whitespace` is false, we may have the piece containing 144 | // a white space in the middle. e.g., "in_the". 145 | optional bool split_by_whitespace = 22 [default = true]; 146 | 147 | // Adds whitespace symbol (_) as a suffix instead of prefix. e.g., _hello => 148 | // hello_. When `treat_whitespace_as_suffix` is true, 149 | // NormalizerSpec::add_dummy_prefix will add the dummy whitespace to the end 150 | // of sentence. 151 | optional bool treat_whitespace_as_suffix = 24 [default = false]; 152 | 153 | // Allows pieces that only contain whitespaces instead of appearing only as 154 | // prefix or suffix of other pieces. 155 | optional bool allow_whitespace_only_pieces = 26 [default = false]; 156 | 157 | // Split all digits (0-9) into separate pieces. 158 | optional bool split_digits = 25 [default = false]; 159 | 160 | /////////////////////////////////////////////////////////////////// 161 | // Vocabulary management 162 | // 163 | // Defines control symbols used as an indicator to 164 | // change the behavior of the decoder. and are pre-defined. 165 | // We can use this field to encode various meta information, 166 | // including language indicator in multilingual model. 167 | // These symbols are not visible to users, but visible to 168 | // the decoder. Note that when the input sentence contains control symbols, 169 | // they are not treated as one token, but segmented into normal pieces. 170 | // Control symbols must be inserted independently from the segmentation. 171 | repeated string control_symbols = 30; 172 | 173 | // Defines user defined symbols. 174 | // These symbols are added with extremely high score 175 | // so they are always treated as one unique symbol in any context. 176 | // Typical usage of user_defined_symbols is placeholder for named entities. 177 | repeated string user_defined_symbols = 31; 178 | 179 | // Defines required characters. Each UTF8 character in this string is included 180 | // in the character set regardless of character_coverage value. Unlike 181 | // user_defined_symbols, these characters have scores based on the frequency 182 | // on input sentences, and the model can form subwords using characters 183 | // in this field. 184 | optional string required_chars = 36; 185 | 186 | // Decomposes unknown pieces into UTF-8 bytes. 187 | optional bool byte_fallback = 35 [default = false]; 188 | 189 | // When creating the vocabulary file, defines whether or not to additionally 190 | // output the score for each piece. 191 | optional bool vocabulary_output_piece_score = 32 [default = true]; 192 | 193 | // `vocab_size` is treated as hard limit. Crash if 194 | // the model can not produce the vocab of size `vocab_size`, 195 | // When `hard_vocab_limit` is false, vocab_size is treated 196 | // as soft limit. Note that when model_type=char, 197 | // always assumes hard_vocab_limit = false. 198 | optional bool hard_vocab_limit = 33 [default = true]; 199 | 200 | // use all symbols for vocab extraction. This flag is valid 201 | // if model type is either CHAR or WORD 202 | optional bool use_all_vocab = 34 [default = false]; 203 | 204 | /////////////////////////////////////////////////////////////////// 205 | // Reserved special meta tokens. 206 | // * -1 is not used. 207 | // * unk_id must not be -1. 208 | // Id must starts with 0 and be contigous. 209 | optional int32 unk_id = 40 [default = 0]; // 210 | optional int32 bos_id = 41 [default = 1]; // 211 | optional int32 eos_id = 42 [default = 2]; // 212 | optional int32 pad_id = 43 [default = -1]; // (padding) 213 | optional string unk_piece = 45 [default = ""]; 214 | optional string bos_piece = 46 [default = ""]; 215 | optional string eos_piece = 47 [default = ""]; 216 | optional string pad_piece = 48 [default = ""]; 217 | 218 | // Encodes into U+2047 (DOUBLE QUESTION MARK), 219 | // since this character can be useful both for user and 220 | // developer. We can easily figure out that is emitted. 221 | optional string unk_surface = 44 [default = " \xE2\x81\x87 "]; 222 | 223 | // Increase bit depth to allow unigram model training on large 224 | // (>10M sentences) corpora. A Side-effect of enabling this flag 225 | // is increased memory usage. 226 | optional bool train_extremely_large_corpus = 49 [default = false]; 227 | 228 | // Customized extensions: the range of field numbers 229 | // are open to third-party extensions. 230 | extensions 200 to max; 231 | } 232 | 233 | // NormalizerSpec encodes a various parameters for string normalizaiton 234 | message NormalizerSpec { 235 | // name of normalization rule. 236 | optional string name = 1; 237 | 238 | // Pre-compiled normalization rule created by 239 | // Builder::GetPrecompiledCharsMap() or Builder::CompileCharsMap() method. 240 | // Usually this field is set by Builder::GetNormalizerSpec() method. 241 | optional bytes precompiled_charsmap = 2; 242 | 243 | // Adds dummy whitespace at the beginning of text in order to 244 | // treat "world" in "world" and "hello world" in the same way. 245 | optional bool add_dummy_prefix = 3 [default = true]; 246 | 247 | // Removes leading, trailing, and duplicate internal whitespace. 248 | optional bool remove_extra_whitespaces = 4 [default = true]; 249 | 250 | // Replaces whitespace with meta symbol. 251 | // This field must be true to train sentence piece model. 252 | optional bool escape_whitespaces = 5 [default = true]; 253 | 254 | // Custom normalization rule file in TSV format. 255 | // https://github.com/google/sentencepiece/blob/master/doc/normalization.md 256 | // This field is only used in SentencePieceTrainer::Train() method, which 257 | // compiles the rule into the binary rule stored in `precompiled_charsmap`. 258 | optional string normalization_rule_tsv = 6; 259 | 260 | // Customized extensions: the range of field numbers 261 | // are open to third-party extensions. 262 | extensions 200 to max; 263 | } 264 | 265 | // Proto to store samples for self-testing. 266 | message SelfTestData { 267 | message Sample { 268 | optional string input = 1; 269 | optional string expected = 2; 270 | } 271 | repeated Sample samples = 1; 272 | 273 | // Customized extensions: the range of field numbers 274 | // are open to third-party extensions. 275 | extensions 200 to max; 276 | } 277 | 278 | // ModelProto stores model parameters. 279 | // SentencePieceProcessor is supposed to be self-contained. 280 | // All settings/parameters which may change the behavior must be encoded 281 | // in ModelProto. 282 | message ModelProto { 283 | message SentencePiece { 284 | enum Type { 285 | NORMAL = 1; // normal symbol 286 | UNKNOWN = 2; // unknown symbol. only for now. 287 | CONTROL = 3; // control symbols. , , <2ja> etc. 288 | USER_DEFINED = 4; // user defined symbols. 289 | // Typical usage of USER_DEFINED symbol 290 | // is placeholder. 291 | BYTE = 6; // byte symbols. Used when `byte_fallback` is true. 292 | UNUSED = 5; // this piece is not used. 293 | } 294 | optional string piece = 1; // piece must not be empty. 295 | optional float score = 2; 296 | optional Type type = 3 [default = NORMAL]; 297 | 298 | // Customized extensions: the range of field numbers 299 | // are open to third-party extensions. 300 | extensions 200 to max; 301 | } 302 | 303 | // Sentence pieces with scores. 304 | repeated SentencePiece pieces = 1; 305 | 306 | // Spec used to generate this model file. 307 | optional TrainerSpec trainer_spec = 2; 308 | 309 | // Spec for text normalization. 310 | optional NormalizerSpec normalizer_spec = 3; 311 | 312 | // Stores sample input and its expected segmentation to verify the model. 313 | optional SelfTestData self_test_data = 4; 314 | 315 | // Spec for text de-normalization. 316 | optional NormalizerSpec denormalizer_spec = 5; 317 | 318 | // Customized extensions: the range of field numbers 319 | // are open to third-party extensions. 320 | extensions 200 to max; 321 | } 322 | -------------------------------------------------------------------------------- /rllama.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Noeda/rllama/1e1131faaaf7013ed19639ad96f252458efdb45b/rllama.gif -------------------------------------------------------------------------------- /src/benches/benchmark.rs: -------------------------------------------------------------------------------- 1 | extern crate rllama; 2 | #[cfg(feature = "opencl")] 3 | use rllama::tensor_opencl_support::OpenCL; 4 | 5 | use rllama::tensor::{Tensor, TensorDType}; 6 | 7 | use criterion::{black_box, criterion_group, criterion_main, Criterion}; 8 | 9 | #[cfg(feature = "opencl")] 10 | pub fn opencl_benchmarks(c: &mut Criterion) { 11 | let mut orig1 = Tensor::random(1, 1, TensorDType::Float16); 12 | let mut orig16 = Tensor::random(1024, 1024, TensorDType::Float16); 13 | let mut orig32 = Tensor::random(4096, 4096, TensorDType::Float16); 14 | let cl = OpenCL::new(false, 0).unwrap(); 15 | 16 | let mut mul_left = Tensor::random(1024, 1024, TensorDType::Float16); 17 | mul_left.to_gpu_inplace(&cl).unwrap(); 18 | let mut mul_right = Tensor::random(1024, 1024, TensorDType::Float16); 19 | mul_right.to_gpu_inplace(&cl).unwrap(); 20 | let mut mul_target = Tensor::zeros(1024, 1024, TensorDType::Float16); 21 | mul_target.to_gpu_inplace(&cl).unwrap(); 22 | 23 | let mut mul_left_cpu = Tensor::random(1024, 1024, TensorDType::Float32); 24 | let mut mul_right_cpu = Tensor::random(1024, 1024, TensorDType::Float32); 25 | let mut mul_target_cpu = Tensor::random(1024, 1024, TensorDType::Float32); 26 | 27 | let mut mul_left1 = Tensor::random(4096, 11000, TensorDType::Float16); 28 | let mut mul_right1 = Tensor::random(1, 11000, TensorDType::Float16); 29 | let mut mul_target1 = Tensor::zeros(4096, 1, TensorDType::Float16); 30 | let mut mul_target2 = Tensor::zeros(1, 4096, TensorDType::Float16); 31 | mul_left1.to_gpu_inplace(&cl).unwrap(); 32 | mul_right1.to_gpu_inplace(&cl).unwrap(); 33 | mul_target1.to_gpu_inplace(&cl).unwrap(); 34 | mul_target2.to_gpu_inplace(&cl).unwrap(); 35 | 36 | c.bench_function( 37 | "1x11000 to 4096x11000 matrix multiplication transposed on OpenCL", 38 | |b| { 39 | b.iter(|| { 40 | mul_target2 41 | .matrix_mul_inplace_transposed(black_box(&mul_right1), black_box(&mul_left1)); 42 | mul_target2.finish(); 43 | }) 44 | }, 45 | ); 46 | 47 | c.bench_function( 48 | "4096x11000 to 1x11000 matrix multiplication transposed on OpenCL", 49 | |b| { 50 | b.iter(|| { 51 | mul_target1 52 | .matrix_mul_inplace_transposed(black_box(&mul_left1), black_box(&mul_right1)); 53 | mul_target1.finish(); 54 | }) 55 | }, 56 | ); 57 | 58 | c.bench_function( 59 | "1024x1024 matrix multiplication transposed on OpenCL", 60 | |b| { 61 | b.iter(|| { 62 | mul_target 63 | .matrix_mul_inplace_transposed(black_box(&mul_left), black_box(&mul_right)); 64 | mul_target.finish(); 65 | }) 66 | }, 67 | ); 68 | 69 | c.bench_function("1024x1024 matrix multiplication transposed on CPU", |b| { 70 | b.iter(|| { 71 | let _ = mul_target_cpu.matrix_mul_inplace_transposed(&mul_left_cpu, &mul_right_cpu); 72 | }) 73 | }); 74 | 75 | c.bench_function("1x1 matrix from CPU to OpenCL device and back", |b| { 76 | b.iter(|| { 77 | let _ = orig1.to_gpu_inplace(&cl).unwrap(); 78 | let _ = orig1.to_cpu_inplace(); 79 | orig1.finish(); 80 | }) 81 | }); 82 | 83 | c.bench_function("1024x1024 matrix from CPU to OpenCL device and back", |b| { 84 | b.iter(|| { 85 | let _ = orig16.to_gpu_inplace(&cl).unwrap(); 86 | let _ = orig16.to_cpu_inplace(); 87 | orig16.finish(); 88 | }) 89 | }); 90 | 91 | c.bench_function("4096x4096 matrix from CPU to OpenCL device and back", |b| { 92 | b.iter(|| { 93 | let _ = orig32.to_gpu_inplace(&cl).unwrap(); 94 | let _ = orig32.to_cpu_inplace(); 95 | orig32.finish(); 96 | }) 97 | }); 98 | } 99 | 100 | pub fn tensor_benchmarks(c: &mut Criterion) { 101 | let orig16_1 = Tensor::full(16, 32, TensorDType::Float16, 3.0); 102 | let orig16_2 = Tensor::full(32, 512, TensorDType::Float16, -1.33); 103 | 104 | let orig32_1 = Tensor::full(16, 32, TensorDType::Float32, 3.0); 105 | let orig32_2 = Tensor::full(32, 512, TensorDType::Float32, -1.33); 106 | let orig32_2_transposed = orig32_2.transpose(); 107 | 108 | let mut result_16 = Tensor::zeros(16, 512, TensorDType::Float16); 109 | let mut result_32 = Tensor::zeros(16, 512, TensorDType::Float32); 110 | 111 | let orig_84096_1 = Tensor::zeros(8, 4096, TensorDType::Float32); 112 | let orig_84096_2 = Tensor::zeros(4096, 4096, TensorDType::Float32); 113 | let mut result_84096 = Tensor::zeros(8, 4096, TensorDType::Float32); 114 | 115 | let orig_84096_1_f16 = Tensor::zeros(8, 4096, TensorDType::Float16); 116 | let orig_84096_2_f16 = Tensor::zeros(4096, 4096, TensorDType::Float16); 117 | let mut result_84096_f16 = Tensor::zeros(8, 4096, TensorDType::Float16); 118 | 119 | let orig_f32 = Tensor::zeros(1024, 1024, TensorDType::Float32); 120 | let orig_f16 = Tensor::zeros(1024, 1024, TensorDType::Float16); 121 | 122 | let m1 = Tensor::random(1024, 128, TensorDType::Float32); 123 | let m2 = Tensor::random(1, 128, TensorDType::Float32); 124 | let m1_f16 = m1.to_f16(); 125 | let m2_f16 = m2.to_f16(); 126 | 127 | c.bench_function( 128 | "1024x128 * 1x128 matrix vector transposed multiplication, f32", 129 | |b| { 130 | b.iter(|| { 131 | let _ = m1.matrix_vector_mul_transposed(black_box(&m2)); 132 | }) 133 | }, 134 | ); 135 | 136 | c.bench_function( 137 | "1024x128 * 1x128 matrix vector transposed multiplication, f16", 138 | |b| { 139 | b.iter(|| { 140 | let _ = m1_f16.matrix_vector_mul_transposed(black_box(&m2_f16)); 141 | }) 142 | }, 143 | ); 144 | 145 | c.bench_function( 146 | "matrix multiplication 8x4096 @ 4096x4096 f16 in-place, transposed", 147 | |b| { 148 | b.iter(|| { 149 | let _ = result_84096_f16.matrix_mul_inplace_transposed( 150 | black_box(&orig_84096_1_f16), 151 | black_box(&orig_84096_2_f16), 152 | ); 153 | }) 154 | }, 155 | ); 156 | 157 | c.bench_function( 158 | "matrix multiplication 8x4096 @ 4096x4096 f32 in-place, transposed", 159 | |b| { 160 | b.iter(|| { 161 | let _ = result_84096.matrix_mul_inplace_transposed( 162 | black_box(&orig_84096_1), 163 | black_box(&orig_84096_2), 164 | ); 165 | }) 166 | }, 167 | ); 168 | 169 | c.bench_function( 170 | "matrix multiplication 8x4096 @ 4096x4096 f32 in-place", 171 | |b| { 172 | b.iter(|| { 173 | let _ = result_84096 174 | .matrix_mul_inplace(black_box(&orig_84096_1), black_box(&orig_84096_2)); 175 | }) 176 | }, 177 | ); 178 | 179 | c.bench_function("1024x1024 matrix from f32->f16", |b| { 180 | b.iter(|| { 181 | let _ = black_box(&orig_f32).to_f16(); 182 | }) 183 | }); 184 | 185 | c.bench_function("1024x1024 matrix from f16->f32", |b| { 186 | b.iter(|| { 187 | let _ = black_box(&orig_f16).to_f32(); 188 | }) 189 | }); 190 | 191 | c.bench_function("matrix multiplication f32 not in-place", |b| { 192 | b.iter(|| { 193 | let _ = black_box(&orig32_1).matrix_mul(black_box(&orig32_2)); 194 | }) 195 | }); 196 | c.bench_function("matrix multiplication f32 naive", |b| { 197 | b.iter(|| { 198 | let _ = black_box(&orig32_1).matrix_mul_naive(black_box(&orig32_2)); 199 | }) 200 | }); 201 | c.bench_function("matrix multiplication f16 not in-place", |b| { 202 | b.iter(|| { 203 | let _ = black_box(&orig16_1).matrix_mul(black_box(&orig16_2)); 204 | }) 205 | }); 206 | c.bench_function("matrix multiplication f16 naive", |b| { 207 | b.iter(|| { 208 | let _ = black_box(&orig16_1).matrix_mul_naive(black_box(&orig16_2)); 209 | }) 210 | }); 211 | c.bench_function("matrix multiplication f16 in-place", |b| { 212 | b.iter(|| { 213 | let _ = result_16.matrix_mul_inplace(black_box(&orig16_1), black_box(&orig16_2)); 214 | }) 215 | }); 216 | c.bench_function("matrix multiplication f32 in-place", |b| { 217 | b.iter(|| { 218 | let _ = result_32.matrix_mul_inplace(black_box(&orig32_1), black_box(&orig32_2)); 219 | }) 220 | }); 221 | c.bench_function("matrix multiplication f32 in-place, transposed", |b| { 222 | b.iter(|| { 223 | let _ = result_32.matrix_mul_inplace_transposed( 224 | black_box(&orig32_1), 225 | black_box(&orig32_2_transposed), 226 | ); 227 | }) 228 | }); 229 | } 230 | 231 | #[cfg(feature = "opencl")] 232 | criterion_group!(benches, opencl_benchmarks, tensor_benchmarks); 233 | #[cfg(not(feature = "opencl"))] 234 | criterion_group!(benches, tensor_benchmarks); 235 | criterion_main!(benches); 236 | -------------------------------------------------------------------------------- /src/data_source.rs: -------------------------------------------------------------------------------- 1 | use crate::huggingface_loader; 2 | use crate::huggingface_loader::HugginfaceModel; 3 | use crate::unpickler; 4 | use crate::unpickler::Value; 5 | use ouroboros::self_referencing; 6 | use std::io::{Read, Seek}; 7 | use std::path::{Path, PathBuf}; 8 | use std::sync::{Arc}; 9 | use thiserror::Error; 10 | 11 | #[derive(Error, Debug)] 12 | pub enum DataSourceError { 13 | #[error("IO error: {0}")] 14 | IOError(#[from] std::io::Error), 15 | #[error("Unpickling error: {0}")] 16 | UnpicklingError(#[from] unpickler::UnpicklingError), 17 | #[error("HuggingFace error: {0}")] 18 | HuggingFaceError(#[from] crate::huggingface_loader::HugginfaceModelError), 19 | #[error("Unknown source")] 20 | UnknownSource, 21 | } 22 | 23 | // This is cloned a lot in transformers.rs, keep it cheap to clone 24 | #[derive(Clone)] 25 | pub enum DataSource { 26 | // The format used by original LLaMA release, unzipped manually as per rllama README.md 27 | // instructions 28 | LLaMASource(PathBuf, Arc>), 29 | // The huggingface format used by Vicuna-13B 30 | VicunaSource(PathBuf, Arc, Arc>), 31 | } 32 | 33 | pub struct DataSourceFile { 34 | reader: Box, 35 | } 36 | 37 | trait ReadSeek: Read + Seek {} 38 | 39 | impl ReadSeek for std::fs::File {} 40 | impl ReadSeek for ZipFileSeekWrap {} 41 | 42 | #[self_referencing] 43 | struct ZipFileSeekWrap { 44 | zipfile: PathBuf, 45 | name: String, 46 | archive: zip::ZipArchive>, 47 | #[borrows(mut archive)] 48 | #[not_covariant] 49 | reader: zip::read::ZipFile<'this>, 50 | } 51 | 52 | impl Read for ZipFileSeekWrap { 53 | fn read(&mut self, buf: &mut [u8]) -> std::io::Result { 54 | self.with_mut(|s| s.reader.read(buf)) 55 | } 56 | } 57 | 58 | impl Seek for ZipFileSeekWrap { 59 | fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { 60 | self.with_mut(|mut s| { 61 | let reader = &mut s.reader; 62 | match pos { 63 | std::io::SeekFrom::Start(_pos) => { 64 | unimplemented!(); 65 | } 66 | std::io::SeekFrom::End(_pos) => { 67 | unimplemented!(); 68 | } 69 | std::io::SeekFrom::Current(pos) => { 70 | std::io::copy(&mut reader.by_ref().take(pos as u64), &mut std::io::sink()) 71 | } 72 | } 73 | }) 74 | } 75 | } 76 | 77 | impl Read for DataSourceFile { 78 | fn read(&mut self, buf: &mut [u8]) -> std::io::Result { 79 | self.reader.read(buf) 80 | } 81 | } 82 | 83 | impl Seek for DataSourceFile { 84 | fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { 85 | self.reader.seek(pos) 86 | } 87 | } 88 | 89 | impl DataSource { 90 | pub fn unpickled(&self) -> &[unpickler::Value] { 91 | match self { 92 | DataSource::LLaMASource(_path, unpickled) => unpickled, 93 | DataSource::VicunaSource(_path, _model, unpickled) => unpickled, 94 | } 95 | } 96 | 97 | pub fn open, P: AsRef>( 98 | &self, 99 | name: P, 100 | tensor_name: S, 101 | shard: usize, 102 | ) -> Result { 103 | let name: &Path = name.as_ref(); 104 | match self { 105 | DataSource::LLaMASource(path, _) => { 106 | let base = PathBuf::from(format!("consolidated.{:02}", shard)); 107 | let path = path.join(base).join(name); 108 | let reader = std::fs::File::open(path)?; 109 | Ok(DataSourceFile { 110 | reader: Box::new(reader), 111 | }) 112 | } 113 | DataSource::VicunaSource(path, model, _) => { 114 | if shard != 0 { 115 | panic!("Vicuna loader does not support shards"); 116 | } 117 | // TODO: this can potentially open the same zip file repeatedly, and decompress the 118 | // same data, if multiple tensors are in the same file. 119 | // 120 | // Also the zip has no real Seek so we "emulate" it by decompressing. Ugh. Whatever 121 | // it works. 122 | for (zipfile_name, contents, tensors) in model.zip_file_contents.iter() { 123 | let name_str: &str = name.to_str().unwrap(); 124 | if contents.contains(name_str) && tensors.contains(tensor_name.as_ref()) { 125 | let reader = std::io::BufReader::new(std::fs::File::open(zipfile_name)?); 126 | let mut archive = zip::ZipArchive::new(reader)?; 127 | let archive_len = archive.len(); 128 | let mut idx: usize = archive_len; 129 | for i in 0..archive_len { 130 | let file = archive.by_index(i)?; 131 | let file = huggingface_loader::remove_first_directory(file.name()); 132 | if file == name { 133 | idx = i; 134 | break; 135 | } 136 | } 137 | if idx == archive_len { 138 | return Err(std::io::Error::new( 139 | std::io::ErrorKind::NotFound, 140 | format!("file not found: {:?}", name), 141 | )); 142 | } 143 | return Ok(DataSourceFile { 144 | reader: Box::new( 145 | ZipFileSeekWrapBuilder { 146 | zipfile: zipfile_name.clone(), 147 | name: name.to_str().unwrap().to_string(), 148 | archive, 149 | reader_builder: move |archive| { 150 | archive.by_index(idx).unwrap() 151 | }, 152 | } 153 | .build(), 154 | ), 155 | }); 156 | } 157 | } 158 | return Err(std::io::Error::new( 159 | std::io::ErrorKind::NotFound, 160 | format!("file not found: {:?}", path), 161 | )); 162 | } 163 | } 164 | } 165 | 166 | pub fn from_llama_source>(path: P) -> Result { 167 | let path = path.as_ref(); 168 | let mut unpickle_results: Vec = vec![]; 169 | let mut part: usize = 0; 170 | loop { 171 | let model_path: PathBuf = path.clone().into(); 172 | let base_path = model_path.join(format!("consolidated.{:02}", part)); 173 | // The data file is in consolidated.XX/data.pkl where XX is the part number. 174 | let full_path = base_path.join("data.pkl"); 175 | let mut fs = match std::fs::File::open(&full_path) { 176 | Ok(fs) => fs, 177 | Err(err) => { 178 | if err.kind() == std::io::ErrorKind::NotFound { 179 | break; 180 | } else { 181 | return Err(err.into()); 182 | } 183 | } 184 | }; 185 | let mut bs = Vec::new(); 186 | fs.read_to_end(&mut bs)?; 187 | std::mem::drop(fs); 188 | let result = unpickler::unpickle(&bs)?; 189 | unpickle_results.push(result); 190 | part += 1; 191 | } 192 | Ok(Self::LLaMASource( 193 | path.to_path_buf(), 194 | Arc::new(unpickle_results), 195 | )) 196 | } 197 | 198 | pub fn from_inferred_source>(path: P) -> Result { 199 | // LLaMA source has a params.json and Vicuna/Huggingfac has a pytorch_model.bin.index.json 200 | let path = path.as_ref(); 201 | let params_path = path.join("params.json"); 202 | let pytorch_model_path = path.join("pytorch_model.bin.index.json"); 203 | if params_path.exists() { 204 | Self::from_llama_source(path) 205 | } else if pytorch_model_path.exists() { 206 | Self::from_vicuna_source(path) 207 | } else { 208 | Err(DataSourceError::UnknownSource) 209 | } 210 | } 211 | 212 | pub fn from_vicuna_source>(path: P) -> Result { 213 | let path = path.as_ref(); 214 | let model = HugginfaceModel::unpickle(path)?; 215 | let unpickled: Vec = vec![model.unpickles_flattened.clone()]; 216 | Ok(DataSource::VicunaSource( 217 | path.to_path_buf(), 218 | Arc::new(model), 219 | Arc::new(unpickled), 220 | )) 221 | } 222 | 223 | pub fn need_to_do_antitranspose(&self) -> bool { 224 | match self { 225 | Self::LLaMASource(_, _) => false, 226 | Self::VicunaSource(_, _, _) => true, 227 | } 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /src/embedding.rs: -------------------------------------------------------------------------------- 1 | use crate::data_source::DataSource; 2 | use crate::tensor::{FromPiecesDirection, Tensor, TensorBuilder}; 3 | 4 | use crate::unpickler::*; 5 | use std::collections::BTreeMap; 6 | 7 | 8 | pub struct Embedding { 9 | wgts: BTreeMap, 10 | } 11 | 12 | impl Embedding { 13 | pub fn from_unpickled(data_source: DataSource) -> Result { 14 | let mut builders: Vec = vec![]; 15 | let unpickled = data_source.unpickled(); 16 | for unpickle in unpickled.iter() { 17 | let (name, val) = 18 | match unpickle.get_str_key2("tok_embeddings.weight", "model.embed_tokens.weight") { 19 | Some(val) => val, 20 | None => { 21 | return Err(UnpicklingError::MissingField( 22 | "tok_embeddings.weight/model.embed_tokens.weight".to_string(), 23 | )) 24 | } 25 | }; 26 | builders.push( 27 | val.to_tensor_builder(name) 28 | .ok_or(UnpicklingError::InvalidTensorData)?, 29 | ); 30 | } 31 | 32 | let tensor = TensorBuilder::load_from_pieces2( 33 | &builders, 34 | "tok_embeddings.weight", 35 | "model.embed_tokens.weight", 36 | data_source.clone(), 37 | FromPiecesDirection::Cols, 38 | )?; 39 | let num_embeddings = tensor.rows(); 40 | 41 | let mut table: BTreeMap = BTreeMap::new(); 42 | for key in 0..num_embeddings { 43 | let row = tensor.row(key); 44 | table.insert(key as usize, row); 45 | } 46 | 47 | Ok(Self { wgts: table }) 48 | } 49 | 50 | pub fn get_embedding(&self, idx: usize) -> &Tensor { 51 | self.wgts.get(&idx).unwrap() 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/huggingface_loader.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * Understands HuggingFace format for models, or well at least as much as we need to. 3 | */ 4 | 5 | use crate::unpickler; 6 | use serde::{Deserialize, Serialize}; 7 | use std::collections::{BTreeMap, BTreeSet}; 8 | use std::io::Read; 9 | use std::path::{Path, PathBuf}; 10 | use thiserror::Error; 11 | 12 | #[derive(Error, Debug)] 13 | pub enum HugginfaceModelError { 14 | #[error("Error parsing JSON: {0}")] 15 | JSONError(#[from] serde_json::Error), 16 | #[error("IO error: {0}")] 17 | IOError(#[from] std::io::Error), 18 | #[error("ZIP error: {0}")] 19 | ZIPError(#[from] zip::result::ZipError), 20 | #[error("Unpickler error: {0}")] 21 | UnpicklingError(#[from] unpickler::UnpicklingError), 22 | } 23 | 24 | #[allow(dead_code)] 25 | pub struct HugginfaceModel { 26 | pub(crate) unpickles: Vec<(unpickler::Value, PathBuf)>, 27 | // (path, files, tensors) 28 | pub(crate) zip_file_contents: Vec<(PathBuf, BTreeSet, BTreeSet)>, 29 | pub(crate) unpickles_flattened: unpickler::Value, 30 | pub(crate) index: HugginfaceIndex, 31 | } 32 | 33 | #[derive(Serialize, Deserialize, Clone, Debug)] 34 | pub struct HugginfaceConfig { 35 | vocab_size: usize, 36 | hidden_size: usize, 37 | intermediate_size: usize, 38 | num_hidden_layers: usize, 39 | num_attention_heads: usize, 40 | max_position_embeddings: usize, 41 | rms_norm_eps: f32, 42 | architectures: Vec, 43 | 44 | bos_token_id: usize, 45 | eos_token_id: usize, 46 | 47 | torch_dtype: String, 48 | } 49 | 50 | #[derive(Serialize, Deserialize, Clone, Debug)] 51 | pub struct HugginfaceIndex { 52 | metadata: HugginfaceIndexMetadata, 53 | weight_map: BTreeMap, 54 | } 55 | 56 | #[derive(Serialize, Deserialize, Clone, Debug)] 57 | pub struct HugginfaceIndexMetadata { 58 | total_size: usize, 59 | } 60 | 61 | impl HugginfaceModel { 62 | pub fn unpickle>(path: P) -> Result { 63 | let path: &Path = path.as_ref(); 64 | 65 | let mut unpickles = vec![]; 66 | 67 | // Read config,json 68 | let config_json_path: PathBuf = path.join("config.json"); 69 | let config_json = std::fs::read_to_string(config_json_path)?; 70 | let _config: HugginfaceConfig = serde_json::from_str(&config_json)?; 71 | 72 | let index_json_path: PathBuf = path.join("pytorch_model.bin.index.json"); 73 | let index_json = std::fs::read_to_string(index_json_path)?; 74 | let index: HugginfaceIndex = serde_json::from_str(&index_json)?; 75 | 76 | // List all .bin files that contain the weights. 77 | let mut weight_files: Vec = vec![]; 78 | for entry in std::fs::read_dir(path)? { 79 | let entry = entry?; 80 | let path = entry.path(); 81 | if path.extension().unwrap_or_default() == "bin" { 82 | weight_files.push(path); 83 | } 84 | } 85 | 86 | // List all files in said zips 87 | let mut unpickles2 = vec![]; 88 | let mut zip_file_contents = vec![]; 89 | for file in weight_files.iter() { 90 | let mut files_in_zip = BTreeSet::new(); 91 | let mut tensors_in_zip = BTreeSet::new(); 92 | let reader = std::io::BufReader::new(std::fs::File::open(file)?); 93 | let mut archive = zip::ZipArchive::new(reader)?; 94 | for i in 0..archive.len() { 95 | let mut file = archive.by_index(i)?; 96 | // Remove the first directory. 97 | let file2 = remove_first_directory(file.name()); 98 | files_in_zip.insert(file2.to_str().unwrap().to_string()); 99 | // data.pkl 100 | if file.name().ends_with("data.pkl") { 101 | let mut data_unzipped: Vec = vec![]; 102 | file.read_to_end(&mut data_unzipped)?; 103 | let unpickled = unpickler::unpickle(&data_unzipped)?; 104 | for tensor in unpickled.keys() { 105 | tensors_in_zip.insert(tensor.to_string()); 106 | } 107 | unpickles2.push(unpickled.clone()); 108 | unpickles.push((unpickled, file.name().to_string().into())) 109 | } 110 | } 111 | zip_file_contents.push((file.clone(), files_in_zip, tensors_in_zip)); 112 | } 113 | // Flatten unpickles. 114 | let unpickles_flattened = crate::unpickler::Value::merge_dicts(&unpickles2); 115 | 116 | Ok(HugginfaceModel { 117 | unpickles, 118 | unpickles_flattened, 119 | zip_file_contents, 120 | index, 121 | }) 122 | } 123 | } 124 | 125 | pub fn remove_first_directory>(path: P) -> PathBuf { 126 | let path = path.as_ref(); 127 | let mut components = vec![]; 128 | for component in path.components().skip(1) { 129 | components.push(component); 130 | } 131 | PathBuf::from(components.iter().collect::()) 132 | } 133 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![feature(stdsimd)] 2 | #![feature(decl_macro)] 3 | 4 | pub mod data_source; 5 | pub mod embedding; 6 | pub mod huggingface_loader; 7 | pub mod model_params; 8 | pub mod protomodels; 9 | pub mod rllama_main; 10 | pub mod semaphore; 11 | pub mod simd_support; 12 | pub mod tensor; 13 | #[cfg(feature = "opencl")] 14 | pub mod tensor_opencl_support; 15 | pub mod token_sampler; 16 | pub mod tokenizer; 17 | pub mod transformer; 18 | pub mod unpickler; 19 | pub mod weight_compression; 20 | #[cfg(feature = "server")] 21 | #[macro_use] 22 | extern crate rocket; 23 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | #[cfg(not(target_feature = "avx2"))] 2 | compile_error!("This library assumes availability of AVX and must be compiled with -C target-feature=+sse2,+avx,+fma,+avx2"); 3 | #[cfg(not(target_feature = "sse2"))] 4 | compile_error!("This library assumes availability of AVX and must be compiled with -C target-feature=+sse2,+avx,+fma,+avx2"); 5 | #[cfg(not(target_feature = "fma"))] 6 | compile_error!("This library assumes availability of AVX and must be compiled with -C target-feature=+sse2,+avx,+fma,+avx2"); 7 | #[cfg(not(target_feature = "avx"))] 8 | compile_error!("This library assumes availability of AVX and must be compiled with -C target-feature=+sse2,+avx,+fma,+avx2"); 9 | 10 | use mimalloc::MiMalloc; 11 | 12 | #[global_allocator] 13 | static GLOBAL: MiMalloc = MiMalloc; 14 | 15 | pub fn main() -> Result<(), Box> { 16 | rllama::rllama_main::main() 17 | } 18 | -------------------------------------------------------------------------------- /src/model_params.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | #[derive(Clone, Serialize, Deserialize)] 4 | pub struct ModelParams { 5 | #[serde(alias = "hidden_size")] 6 | pub dim: usize, 7 | #[serde(alias = "num_attention_heads")] 8 | pub n_heads: usize, 9 | #[serde(alias = "num_hidden_layers")] 10 | pub n_layers: usize, 11 | #[serde(alias = "rms_norm_eps")] 12 | pub norm_eps: f64, 13 | pub vocab_size: i64, 14 | } 15 | -------------------------------------------------------------------------------- /src/protomodels/mod.rs: -------------------------------------------------------------------------------- 1 | // @generated 2 | 3 | pub mod sentencepiece_model; 4 | -------------------------------------------------------------------------------- /src/rllama_main.rs: -------------------------------------------------------------------------------- 1 | use crate::data_source::DataSource; 2 | use crate::embedding::Embedding; 3 | use crate::model_params::ModelParams; 4 | 5 | #[cfg(feature = "opencl")] 6 | use crate::tensor_opencl_support::OpenCL; 7 | use crate::token_sampler::TokenSampler; 8 | use crate::tokenizer::{TokenId, Tokenizer}; 9 | use crate::transformer::{DataSettings, Transformer}; 10 | 11 | #[cfg(feature = "server")] 12 | use crate::semaphore::Semaphore; 13 | #[cfg(feature = "server")] 14 | use crate::transformer::TransformerCaches; 15 | use clap::Parser; 16 | use colored::Colorize; 17 | #[cfg(feature = "server")] 18 | use rocket::{response::status, response::Stream, Data, State}; 19 | use serde::{Deserialize, Serialize}; 20 | #[cfg(feature = "server")] 21 | use std::collections::BTreeMap; 22 | use std::io::{Read, Write}; 23 | use std::sync::Arc; 24 | #[cfg(feature = "server")] 25 | use std::sync::RwLock; 26 | 27 | // Refer to README.md to see what all these options mean. 28 | #[derive(Parser, Clone)] 29 | #[command(author, version, about, long_about = None)] 30 | struct Cli { 31 | #[arg(long)] 32 | model_path: String, 33 | #[arg(long)] 34 | tokenizer_path: String, 35 | #[arg(long)] 36 | param_path: String, 37 | 38 | #[arg(short, long, action)] 39 | quiet: bool, 40 | 41 | #[arg(long)] 42 | prompt: Option, 43 | #[arg(long)] 44 | prompt_file: Option, 45 | 46 | #[arg(long)] 47 | interactive_system_prompt: Option, 48 | #[arg(long)] 49 | interactive_stop: Vec, 50 | #[arg(long)] 51 | interactive_prompt_postfix: Option, 52 | #[arg(long)] 53 | interactive_prompt_prefix: Option, 54 | #[arg(long, action)] 55 | start_interactive: bool, 56 | 57 | #[arg(long)] 58 | max_seq_len: Option, 59 | 60 | #[arg(long)] 61 | temperature: Option, 62 | #[arg(long)] 63 | top_p: Option, 64 | #[arg(long)] 65 | top_k: Option, 66 | #[arg(long)] 67 | repetition_penalty: Option, 68 | 69 | #[arg(long)] 70 | max_threads: Option, 71 | 72 | #[arg(long, action)] 73 | f16: bool, 74 | 75 | #[cfg(feature = "opencl")] 76 | #[arg(long)] 77 | opencl_device: Option, 78 | 79 | #[cfg(feature = "opencl")] 80 | #[arg(long)] 81 | percentage_to_gpu: Option, 82 | 83 | #[arg(long, action)] 84 | inference_server: bool, 85 | 86 | #[arg(long)] 87 | inference_server_port: Option, 88 | 89 | #[arg(long)] 90 | inference_server_host: Option, 91 | 92 | #[arg(long)] 93 | inference_server_max_concurrent_inferences: Option, 94 | 95 | #[arg(long)] 96 | inference_server_api_path: Option, 97 | 98 | #[arg(long)] 99 | inference_server_prompt_cache_size: Option, 100 | 101 | #[arg(long, action)] 102 | inference_server_exit_after_one_query: bool, 103 | } 104 | 105 | pub fn main() -> Result<(), Box> { 106 | let cli = Cli::parse(); 107 | let model_path = cli.model_path.clone(); 108 | let tokenizer_path = cli.tokenizer_path.clone(); 109 | let param_path = cli.param_path.clone(); 110 | let interactive_system_prompt = cli.interactive_system_prompt.clone().unwrap_or("A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, terse answers to the human's questions.### Human:".to_string()); 111 | let mut interactive_stop = cli.interactive_stop.clone(); 112 | if interactive_stop.is_empty() { 113 | // Desperado to catch all weird variants of ###Human the model might spit out. 114 | interactive_stop = vec![ 115 | "### Human:".to_string(), 116 | "###Human:".to_string(), 117 | "### Human: ".to_string(), 118 | "###Human: ".to_string(), 119 | " ### Human:".to_string(), 120 | " ###Human:".to_string(), 121 | " ### Human: ".to_string(), 122 | " ###Human: ".to_string(), 123 | "\n### Human:".to_string(), 124 | "\n###Human:".to_string(), 125 | "\n### Human: ".to_string(), 126 | "\n###Human: ".to_string(), 127 | "\n ### Human:".to_string(), 128 | "\n ###Human:".to_string(), 129 | "\n ### Human: ".to_string(), 130 | "\n ###Human: ".to_string(), 131 | ]; 132 | } 133 | let interactive_prompt_prefix = cli 134 | .interactive_prompt_prefix 135 | .clone() 136 | .unwrap_or(" ".to_string()); 137 | let interactive_prompt_postfix = cli 138 | .interactive_prompt_postfix 139 | .clone() 140 | .unwrap_or("### Assistant:".to_string()); 141 | let start_interactive = cli.start_interactive; 142 | #[cfg(not(feature = "server"))] 143 | if cli.inference_server { 144 | eprintln!("Inference server is not enabled in this build."); 145 | return Err("Inference server is not enabled in this build.".into()); 146 | } 147 | 148 | let max_threads: usize = match cli.max_threads { 149 | None => rayon::current_num_threads(), 150 | Some(max_threads) => { 151 | rayon::ThreadPoolBuilder::new() 152 | .num_threads(max_threads) 153 | .build_global() 154 | .unwrap(); 155 | max_threads 156 | } 157 | }; 158 | 159 | #[cfg(feature = "opencl")] 160 | let percentage_to_gpu: f32 = cli.percentage_to_gpu.unwrap_or(1.0); 161 | 162 | let mut be_quiet: bool = false; 163 | if !colored::control::SHOULD_COLORIZE.should_colorize() { 164 | be_quiet = true; 165 | } 166 | if cli.quiet { 167 | be_quiet = true; 168 | } 169 | if be_quiet { 170 | colored::control::SHOULD_COLORIZE.set_override(false); 171 | } 172 | 173 | // Custom println-like macro that respects be_quiet 174 | macro_rules! pln { 175 | ($($arg:tt)*) => { 176 | if !be_quiet { 177 | std::println!($($arg)*); 178 | } 179 | }; 180 | } 181 | 182 | #[cfg(feature = "opencl")] 183 | let opencl: Option = { 184 | let opencl_device = cli.opencl_device.unwrap_or(0); 185 | match OpenCL::new(!be_quiet, opencl_device) { 186 | Err(openclerr) => { 187 | eprintln!("OpenCL error: {}", openclerr); 188 | eprintln!("OpenCL is disabled because it failed to initialize."); 189 | None 190 | } 191 | Ok(opencl) => { 192 | println!("OpenCL initialized."); 193 | Some(opencl) 194 | } 195 | } 196 | }; 197 | 198 | #[cfg(feature = "opencl")] 199 | let has_opencl = opencl.is_some(); 200 | 201 | // Read ModelParams from param_path, we expect it to be JSON 202 | let mut fs = std::fs::File::open(¶m_path)?; 203 | let mut bs = Vec::new(); 204 | fs.read_to_end(&mut bs)?; 205 | std::mem::drop(fs); 206 | 207 | let prompt: String = match (&cli.prompt, &cli.prompt_file, start_interactive) { 208 | (Some(ref prompt), None, _) => { 209 | pln!("Using prompt: {}", prompt); 210 | prompt.clone() 211 | } 212 | (None, Some(ref prompt_file), _) => { 213 | pln!("Using prompt file: {}", prompt_file); 214 | let mut fs = std::fs::File::open(prompt_file)?; 215 | let mut bs = Vec::new(); 216 | fs.read_to_end(&mut bs)?; 217 | std::mem::drop(fs); 218 | String::from_utf8(bs)? 219 | } 220 | (_, _, false) => { 221 | if cli.inference_server { 222 | "".to_string() 223 | } else { 224 | eprintln!("Please provide either a prompt or a prompt file."); 225 | return Err("Please provide either a prompt or a prompt file.".into()); 226 | } 227 | } 228 | (None, None, true) => "".to_string(), 229 | (_, _, true) => { 230 | eprintln!("Please provide either a prompt or a prompt file."); 231 | return Err("Please provide either a prompt or a prompt file.".into()); 232 | } 233 | }; 234 | 235 | pln!("Starting up. Loading tokenizer from {}...", tokenizer_path); 236 | let tok = Tokenizer::load(tokenizer_path.as_str())?; 237 | pln!("Tokenizer loaded. Loading model from {}...", model_path); 238 | 239 | let model_data_source = DataSource::from_inferred_source(model_path.clone())?; 240 | 241 | let params: ModelParams = serde_json::from_slice(&bs)?; 242 | pln!("Loaded model parameters from {}.", param_path); 243 | 244 | pln!("Loading embeddings from {}...", model_path); 245 | let emb = Embedding::from_unpickled(model_data_source.clone())?; 246 | 247 | let max_seq_len = cli.max_seq_len.unwrap_or(1024); 248 | 249 | let mut data_settings = { 250 | #[cfg(feature = "opencl")] 251 | { 252 | if let Some(opencl) = opencl { 253 | let ds = DataSettings::new(Some(opencl)); 254 | ds.percentage_to_gpu(percentage_to_gpu).use_opencl() 255 | } else { 256 | DataSettings::new(None) 257 | } 258 | } 259 | #[cfg(not(feature = "opencl"))] 260 | DataSettings::new() 261 | }; 262 | 263 | #[cfg(feature = "opencl")] 264 | if cli.f16 || has_opencl { 265 | data_settings = data_settings.force_f16(); 266 | } 267 | #[cfg(not(feature = "opencl"))] 268 | if cli.f16 { 269 | data_settings = data_settings.force_f16(); 270 | } 271 | 272 | pln!("Loading transformer weights from {}...", model_path); 273 | let tr = Transformer::from_unpickled( 274 | emb, 275 | params.dim, 276 | params.n_layers, 277 | params.n_heads, 278 | max_seq_len, 279 | params.norm_eps, 280 | data_settings, 281 | model_data_source, 282 | )?; 283 | pln!("All is loaded. Starting inference."); 284 | 285 | let tr: Arc = Arc::new(tr); 286 | let tok: Arc = Arc::new(tok); 287 | 288 | if cli.inference_server { 289 | #[cfg(feature = "server")] 290 | { 291 | server_inference(cli, tr, tok, be_quiet, max_seq_len, params, max_threads) 292 | } 293 | #[cfg(not(feature = "server"))] 294 | { 295 | eprintln!("The inference server feature is not enabled."); 296 | eprintln!("Please enable it with the \"inference-server\" feature."); 297 | Err("The inference server feature is not enabled.".into()) 298 | } 299 | } else { 300 | command_line_inference( 301 | cli.clone(), 302 | tr.clone(), 303 | tok.clone(), 304 | prompt.clone(), 305 | interactive_stop.clone(), 306 | interactive_system_prompt.clone(), 307 | interactive_prompt_prefix.clone(), 308 | interactive_prompt_postfix.clone(), 309 | start_interactive, 310 | be_quiet, 311 | max_seq_len, 312 | params.clone(), 313 | max_threads, 314 | ) 315 | } 316 | } 317 | 318 | #[cfg(feature = "server")] 319 | fn server_inference( 320 | cli: Cli, 321 | tr: Arc, 322 | tok: Arc, 323 | be_quiet: bool, 324 | max_seq_len: usize, 325 | _params: ModelParams, 326 | _max_threads: usize, 327 | ) -> Result<(), Box> { 328 | macro_rules! pln { 329 | ($($arg:tt)*) => { 330 | if !be_quiet { 331 | std::println!($($arg)*); 332 | } 333 | }; 334 | } 335 | 336 | let inference_server_port = cli.inference_server_port.unwrap_or(8080); 337 | let inference_server_host = cli 338 | .inference_server_host 339 | .clone() 340 | .unwrap_or("127.0.0.1".to_string()); 341 | let inference_server_max_concurrent_inferences = 342 | cli.inference_server_max_concurrent_inferences.unwrap_or(5); 343 | let inference_server_api_path = cli 344 | .inference_server_api_path 345 | .clone() 346 | .unwrap_or("/rllama/v1/inference".to_string()); 347 | let inference_server_prompt_cache_size = cli.inference_server_prompt_cache_size.unwrap_or(50); 348 | 349 | pln!( 350 | "Maximum concurrent inferences: {}", 351 | inference_server_max_concurrent_inferences 352 | ); 353 | pln!("Prompt cache size: {}", inference_server_prompt_cache_size); 354 | pln!("Maximum sequence length: {}", max_seq_len); 355 | pln!( 356 | "--- Starting HTTP server on {}:{}, answering to requests at {} ---", 357 | inference_server_host, 358 | inference_server_port, 359 | inference_server_api_path 360 | ); 361 | 362 | // If there are too many connections, they will hang until they get their turn. 363 | // Maybe can later implement return 503 slow down or something similar. 364 | let concurrent_requests_semaphore = Semaphore::new(inference_server_max_concurrent_inferences); 365 | 366 | let rocket_conf = rocket::Config::build(rocket::config::Environment::Production) 367 | .address(inference_server_host) 368 | .port(inference_server_port) 369 | .finalize() 370 | .unwrap(); 371 | 372 | let app = rocket::custom(rocket_conf) 373 | .mount(&inference_server_api_path, routes![handle_request]) 374 | .manage(InferenceServerState { 375 | transformer: tr, 376 | tokenizer: tok, 377 | max_seq_len, 378 | concurrent_requests_semaphore, 379 | attention_cache_repository: Arc::new(RwLock::new(AttentionCacheRepository::empty( 380 | inference_server_prompt_cache_size, 381 | ))), 382 | exit_after_one_query: cli.inference_server_exit_after_one_query, 383 | }); 384 | 385 | app.launch(); 386 | panic!("Starting web server failed."); 387 | } 388 | 389 | #[cfg(feature = "server")] 390 | fn is_false(b: &bool) -> bool { 391 | !b 392 | } 393 | 394 | #[derive(Serialize, Deserialize, Clone, Debug)] 395 | struct InferenceRequest { 396 | temperature: Option, 397 | top_k: Option, 398 | top_p: Option, 399 | repetition_penalty: Option, 400 | max_seq_len: Option, 401 | max_new_tokens: Option, 402 | no_token_sampling: Option, 403 | stop_at_end_token: Option, 404 | prompt: String, 405 | } 406 | 407 | #[cfg(feature = "server")] 408 | #[derive(Serialize, Deserialize, Clone, Debug)] 409 | struct PredResult { 410 | p: f32, 411 | #[serde(skip_serializing_if = "is_false")] 412 | is_end_token: bool, 413 | } 414 | 415 | #[cfg(feature = "server")] 416 | struct GeneratingSession { 417 | transformer: Arc, 418 | token_sampler: TokenSampler, 419 | tokenizer: Arc, 420 | attention_cache_repository: Arc>, 421 | tokens: Vec, 422 | req_max_seq_len: usize, 423 | req_max_new_tokens: usize, 424 | new_tokens_generated: usize, 425 | prev_pos: usize, 426 | no_token_sampling: bool, 427 | stop_at_end_token: bool, 428 | sent_stuff_last_time: bool, 429 | exit_after_one_query: bool, 430 | result: Vec, // stores JSONL lines to be returned from read() 431 | } 432 | 433 | #[cfg(feature = "server")] 434 | impl GeneratingSession { 435 | fn read_from_result(&mut self, buf: &mut [u8]) -> usize { 436 | if !self.result.is_empty() { 437 | if self.result.len() <= buf.len() { 438 | for idx in 0..self.result.len() { 439 | buf[idx] = self.result[idx]; 440 | } 441 | let len = self.result.len(); 442 | self.sent_stuff_last_time = true; 443 | self.result.truncate(0); 444 | return len; 445 | } else { 446 | for idx in 0..buf.len() { 447 | buf[idx] = self.result[idx]; 448 | } 449 | self.result = self.result[buf.len()..].to_vec(); 450 | self.sent_stuff_last_time = true; 451 | return buf.len(); 452 | } 453 | } 454 | return 0; 455 | } 456 | } 457 | 458 | #[cfg(feature = "server")] 459 | impl Read for GeneratingSession { 460 | fn read(&mut self, buf: &mut [u8]) -> std::io::Result { 461 | if self.sent_stuff_last_time && self.result.is_empty() { 462 | // If we return WouldBlock every time we send something, it'll cause Rocket to 463 | // flush available data. 464 | self.sent_stuff_last_time = false; 465 | return Err(std::io::Error::new( 466 | std::io::ErrorKind::WouldBlock, 467 | "WouldBlock", 468 | )); 469 | } 470 | 471 | // Push more data to the upstream if we have something stored. 472 | let bytes_read = self.read_from_result(buf); 473 | if bytes_read > 0 { 474 | return Ok(bytes_read); 475 | } 476 | if self.tokens.len() >= self.req_max_seq_len { 477 | if self.exit_after_one_query { 478 | std::process::exit(0); 479 | } 480 | return Ok(0); 481 | } 482 | if self.new_tokens_generated >= self.req_max_new_tokens { 483 | if self.exit_after_one_query { 484 | std::process::exit(0); 485 | } 486 | return Ok(0); 487 | } 488 | 489 | let (mut caches, update_pos) = { 490 | let mut ac = self.attention_cache_repository.write().unwrap(); 491 | match ac.get(&self.tokens) { 492 | Some((c, pos)) if pos >= self.prev_pos => (c.true_clone(), pos), 493 | Some(_) => { 494 | std::mem::drop(ac); 495 | (self.transformer.make_caches(), 0) 496 | } 497 | None => { 498 | let caches = self.transformer.make_caches(); 499 | ac.put(self.tokens.clone(), caches.true_clone(), self.prev_pos); 500 | (caches, self.prev_pos) 501 | } 502 | } 503 | }; 504 | if update_pos > self.prev_pos { 505 | self.prev_pos = update_pos; 506 | } 507 | 508 | assert!(self.result.is_empty()); 509 | let predictions = 510 | self.transformer 511 | .forward(&self.tokens[self.prev_pos..], self.prev_pos, &mut caches); 512 | self.prev_pos = self.tokens.len(); 513 | let (highest_pred_idx, token_prob) = 514 | self.token_sampler 515 | .sample(&predictions, self.tokenizer.as_ref(), &self.tokens); 516 | self.tokens.push(highest_pred_idx as TokenId); 517 | { 518 | let mut ac = self.attention_cache_repository.write().unwrap(); 519 | ac.put(self.tokens.clone(), caches, self.prev_pos); 520 | } 521 | self.new_tokens_generated += 1; 522 | let token: &str = self.tokenizer.id_to_str(highest_pred_idx as TokenId); 523 | let mut is_end_token: bool = false; 524 | if token == "" && self.stop_at_end_token { 525 | self.new_tokens_generated = self.req_max_new_tokens; 526 | is_end_token = true; 527 | } 528 | 529 | let mut result: BTreeMap = BTreeMap::new(); 530 | if self.no_token_sampling { 531 | // All predictions go the line. 532 | let probs = self 533 | .token_sampler 534 | .logits_to_btreemap(&predictions, self.tokenizer.as_ref()); 535 | for (k, v) in probs.into_iter() { 536 | let mut is_end_token: bool = false; 537 | if k == "" { 538 | is_end_token = true; 539 | } 540 | result.insert( 541 | k, 542 | PredResult { 543 | p: v, 544 | is_end_token: is_end_token, 545 | }, 546 | ); 547 | } 548 | // Convert to JSON 549 | let json = serde_json::to_string(&result).unwrap(); 550 | self.result.extend(json.as_bytes()); 551 | self.result.push(b'\n'); 552 | return Ok(self.read_from_result(buf)); 553 | } else { 554 | result.insert( 555 | token.to_string(), 556 | PredResult { 557 | p: token_prob, 558 | is_end_token, 559 | }, 560 | ); 561 | let json = serde_json::to_string(&result).unwrap(); 562 | self.result.extend(json.as_bytes()); 563 | self.result.push(b'\n'); 564 | return Ok(self.read_from_result(buf)); 565 | } 566 | } 567 | } 568 | 569 | #[cfg(feature = "server")] 570 | struct AttentionCacheRepository { 571 | caches: BTreeMap, (TransformerCaches, usize, std::time::Instant)>, 572 | max_sz: usize, 573 | } 574 | 575 | #[cfg(feature = "server")] 576 | impl AttentionCacheRepository { 577 | fn empty(max_size: usize) -> AttentionCacheRepository { 578 | AttentionCacheRepository { 579 | caches: BTreeMap::new(), 580 | max_sz: max_size, 581 | } 582 | } 583 | 584 | /// Makes sure the cache repository is not larger than sz, evicts any older items. 585 | fn limit_size(&mut self, sz: usize) { 586 | if sz == 0 { 587 | self.caches = BTreeMap::new(); 588 | return; 589 | } 590 | // Slow algorithm but I guess our cache will never be unimaginably large so it's probably 591 | // fine 592 | while self.caches.len() > sz { 593 | let mut oldest_time = None; 594 | let mut oldest_key: Option<&Vec> = None; 595 | for (k, (_, _, time)) in self.caches.iter() { 596 | if oldest_time.is_none() || time < oldest_time.unwrap() { 597 | oldest_time = Some(time); 598 | oldest_key = Some(k); 599 | } 600 | } 601 | let oldest_key = oldest_key.unwrap().clone(); 602 | self.caches.remove(&oldest_key); 603 | } 604 | } 605 | 606 | fn get(&self, tokens: &[TokenId]) -> Option<(&TransformerCaches, usize)> { 607 | if let Some((caches, pos, _)) = self.caches.get(tokens) { 608 | Some((caches, *pos)) 609 | } else { 610 | None 611 | } 612 | } 613 | 614 | fn put(&mut self, tokens: Vec, caches: TransformerCaches, prev_pos: usize) { 615 | self.caches 616 | .insert(tokens, (caches, prev_pos, std::time::Instant::now())); 617 | self.limit_size(self.max_sz); 618 | } 619 | } 620 | 621 | #[cfg(feature = "server")] 622 | #[derive(Clone)] 623 | struct InferenceServerState { 624 | transformer: Arc, 625 | tokenizer: Arc, 626 | max_seq_len: usize, 627 | concurrent_requests_semaphore: Semaphore, 628 | attention_cache_repository: Arc>, 629 | exit_after_one_query: bool, 630 | } 631 | 632 | #[cfg(feature = "server")] 633 | #[post("/", data = "")] 634 | fn handle_request( 635 | state: State, 636 | input: Data, 637 | ) -> Result, status::BadRequest> { 638 | let _lock = state.concurrent_requests_semaphore.acquire(); 639 | let tr = state.transformer.clone(); 640 | let tok = state.tokenizer.clone(); 641 | 642 | let mut data = input.open(); 643 | let mut databuf: Vec = Vec::new(); 644 | data.read_to_end(&mut databuf).unwrap(); 645 | 646 | // Parse the JSON out of the request 647 | let request: InferenceRequest = match serde_json::from_slice(&databuf) { 648 | Err(_e) => { 649 | return Err(status::BadRequest(Some("Invalid JSON.".to_string()))); 650 | } 651 | Ok(ir) => ir, 652 | }; 653 | 654 | let stop_at_end_token = request.stop_at_end_token.unwrap_or(true); 655 | let temperature = request.temperature.unwrap_or(1.0); 656 | let top_k = request.top_k.unwrap_or(20); 657 | let top_p = request.top_p.unwrap_or(1.0); 658 | let repetition_penalty = request.repetition_penalty.unwrap_or(1.0); 659 | let mut req_max_seq_len = request.max_seq_len.unwrap_or(state.max_seq_len); 660 | if req_max_seq_len > state.max_seq_len { 661 | req_max_seq_len = state.max_seq_len; 662 | } 663 | let req_max_new_tokens = request.max_new_tokens.unwrap_or(20); 664 | let no_token_sampling = request.no_token_sampling.unwrap_or(false); 665 | let prompt = request.prompt; 666 | 667 | if temperature.is_nan() { 668 | return Err(status::BadRequest(Some( 669 | "Temperature must be a number.".to_string(), 670 | ))); 671 | } 672 | if top_k == 0 { 673 | return Err(status::BadRequest(Some( 674 | "Top-k must be greater than 0.".to_string(), 675 | ))); 676 | } 677 | if top_p.is_nan() { 678 | return Err(status::BadRequest(Some( 679 | "Top-p must be a number.".to_string(), 680 | ))); 681 | } 682 | if repetition_penalty.is_nan() { 683 | return Err(status::BadRequest(Some( 684 | "Repetition penalty must be a number.".to_string(), 685 | ))); 686 | } 687 | 688 | let token_sampler = TokenSampler::new() 689 | .temperature(temperature) 690 | .top_p(top_p) 691 | .top_k(top_k) 692 | .repetition_penalty(repetition_penalty); 693 | let toks_id: Vec = tok.tokenize_to_ids(prompt.clone()); 694 | let gsession = GeneratingSession { 695 | transformer: tr, 696 | tokenizer: tok, 697 | attention_cache_repository: state.attention_cache_repository.clone(), 698 | token_sampler: token_sampler, 699 | tokens: toks_id, 700 | req_max_seq_len: req_max_seq_len, 701 | req_max_new_tokens: req_max_new_tokens, 702 | new_tokens_generated: 0, 703 | prev_pos: 0, 704 | no_token_sampling: no_token_sampling, 705 | stop_at_end_token: stop_at_end_token, 706 | sent_stuff_last_time: false, 707 | exit_after_one_query: state.exit_after_one_query, 708 | result: Vec::new(), 709 | }; 710 | 711 | return Ok(rocket::response::Stream::chunked(gsession, 1024)); 712 | } 713 | 714 | fn command_line_inference( 715 | cli: Cli, 716 | tr: Arc, 717 | tok: Arc, 718 | prompt: String, 719 | interactive_stop: Vec, 720 | interactive_system_prompt: String, 721 | interactive_prompt_prefix: String, 722 | interactive_prompt_postfix: String, 723 | start_interactive: bool, 724 | be_quiet: bool, 725 | max_seq_len: usize, 726 | params: ModelParams, 727 | max_threads: usize, 728 | ) -> Result<(), Box> { 729 | // Custom println-like macro that respects be_quiet 730 | macro_rules! pln { 731 | ($($arg:tt)*) => { 732 | if !be_quiet { 733 | std::println!($($arg)*); 734 | } 735 | }; 736 | } 737 | 738 | let mut prompt = prompt; 739 | 740 | if start_interactive && !prompt.is_empty() { 741 | return Err( 742 | "Cannot start interactive mode with a prompt. Use --interactive-system-prompt instead." 743 | .into(), 744 | ); 745 | } 746 | if start_interactive { 747 | prompt = interactive_system_prompt.clone(); 748 | } 749 | 750 | let mut toks_id: Vec = tok.tokenize_to_ids(prompt.clone()); 751 | let mut toks_str: String = prompt.clone(); 752 | let mut prev_pos = 0; 753 | let mut token_sampler = TokenSampler::new() 754 | .temperature(1.0) 755 | .top_p(1.0) 756 | .top_k(20) 757 | .repetition_penalty(1.0); 758 | 759 | if let Some(temperature) = cli.temperature { 760 | token_sampler = token_sampler.temperature(temperature); 761 | } 762 | if let Some(top_p) = cli.top_p { 763 | token_sampler = token_sampler.top_p(top_p); 764 | } 765 | if let Some(top_k) = cli.top_k { 766 | token_sampler = token_sampler.top_k(top_k as usize); 767 | } 768 | if let Some(repetition_penalty) = cli.repetition_penalty { 769 | token_sampler = token_sampler.repetition_penalty(repetition_penalty); 770 | } 771 | pln!("---"); 772 | pln!(" dim: {}", params.dim); 773 | pln!(" n_heads: {}", params.n_heads); 774 | pln!(" n_layers: {}", params.n_layers); 775 | pln!(" norm_eps: {}", params.norm_eps); 776 | pln!(" vocab_size: {}", params.vocab_size); 777 | pln!("---"); 778 | pln!(" maximum number of threads: {}", max_threads); 779 | pln!("---"); 780 | pln!("Max sequence length: {}", max_seq_len); 781 | pln!("Temperature: {}", token_sampler.get_temperature()); 782 | pln!("Top P: {}", token_sampler.get_top_p()); 783 | pln!("Top K: {}", token_sampler.get_top_k()); 784 | pln!( 785 | "Repetition penalty: {}", 786 | token_sampler.get_repetition_penalty() 787 | ); 788 | if start_interactive { 789 | pln!( 790 | " Interactive mode stop token sequences: {:?}", 791 | interactive_stop 792 | ); 793 | pln!("---"); 794 | pln!("System prompt:"); 795 | pln!(" {}", interactive_system_prompt); 796 | pln!("---"); 797 | pln!("Interactive prompt prefix: {}", interactive_prompt_prefix); 798 | pln!("Interactive prompt postfix: {}", interactive_prompt_postfix); 799 | } 800 | pln!("---"); 801 | pln!( 802 | "{}", 803 | " This is the color of the initial prompt".truecolor(128, 128, 255) 804 | ); 805 | pln!( 806 | "{}", 807 | " This is the color of the generated text".truecolor(128, 255, 128) 808 | ); 809 | pln!("---"); 810 | print!("{}", prompt.as_str().truecolor(128, 128, 255)); 811 | 812 | let _ = std::io::stdout().flush(); 813 | 814 | let mut first_token_time: std::time::Duration = std::time::Duration::new(0, 0); 815 | let mut times_per_token: Vec = vec![]; 816 | let mut caches = tr.make_caches(); 817 | let mut first: bool = true; 818 | let mut stop_seen: bool = false; 819 | let mut interactive = start_interactive; 820 | let mut user_token: Vec = vec![]; 821 | while toks_id.len() < max_seq_len { 822 | let now = std::time::Instant::now(); 823 | let preds = tr.forward(&toks_id[prev_pos..], prev_pos, &mut caches); 824 | if interactive { 825 | let mut newinput = String::new(); 826 | std::io::stdin().read_line(&mut newinput)?; 827 | // removing new line from input 828 | if newinput.ends_with('\n') { 829 | let _ = newinput.pop(); 830 | } 831 | newinput = interactive_prompt_prefix.clone() + &newinput; 832 | newinput += &interactive_prompt_postfix; 833 | user_token = tok.tokenize_to_ids(newinput.clone()); 834 | 835 | // removing [start token] as it is already in the prompt, and tokenize_to_ids adds it. 836 | let _ = user_token.remove(0); 837 | interactive = false; 838 | } 839 | let (highest_pred_idx, token_prob); 840 | 841 | if user_token.len() > 0 { 842 | highest_pred_idx = user_token.remove(0); 843 | token_prob = 0.0; 844 | } else { 845 | (highest_pred_idx, token_prob) = token_sampler.sample(&preds, &tok, &toks_id); 846 | } 847 | toks_id.push(highest_pred_idx as TokenId); 848 | 849 | for (tok_idx, tok_id) in toks_id[prev_pos + 1..].iter().enumerate() { 850 | if *tok_id == 1 { 851 | continue; 852 | } 853 | let mut tok_print: String = "".to_string(); 854 | let tok_str = tok.id_to_str(*tok_id); 855 | if tok_str == "" { 856 | tok_print += ""; 857 | stop_seen = true; 858 | } 859 | if tok_str == "<0x0A>" { 860 | tok_print += "\n"; 861 | } else { 862 | tok_print += tok_str.replace('▁', " ").as_str(); 863 | } 864 | toks_str += tok_print.as_str(); 865 | if first && tok_idx < toks_id.len() - 2 { 866 | // intentionally left empty, already print 867 | } else { 868 | let redness: f32 = token_prob * 255.0; 869 | let redness = if redness > 255.0 { 870 | 255 871 | } else if redness < 0.0 { 872 | 0 873 | } else { 874 | redness as u8 875 | }; 876 | print!( 877 | "{}", 878 | tok_print.truecolor(128 + redness / 2, 255 - redness / 2, 128) 879 | ); 880 | }; 881 | for stop_str in interactive_stop.iter() { 882 | if !first && toks_str.ends_with(stop_str.as_str()) { 883 | if start_interactive { 884 | interactive = true; 885 | } 886 | break; 887 | } 888 | } 889 | } 890 | if first { 891 | first_token_time = now.elapsed(); 892 | } else { 893 | times_per_token.push(now.elapsed()); 894 | } 895 | let _ = std::io::stdout().flush(); 896 | prev_pos = toks_id.len() - 1; 897 | first = false; 898 | if stop_seen { 899 | break; 900 | } 901 | } 902 | println!(); 903 | if stop_seen && !be_quiet { 904 | println!("Stop token seen. Stopping."); 905 | } 906 | if !be_quiet { 907 | println!("---"); 908 | println!( 909 | "Time taken to generate first token: {:?}ms", 910 | first_token_time.as_millis() 911 | ); 912 | if times_per_token.len() > 0 { 913 | println!( 914 | "Time taken per token (excluding first token): {:?}ms", 915 | times_per_token.iter().map(|t| t.as_millis()).sum::() 916 | / times_per_token.len() as u128 917 | ); 918 | } else { 919 | println!("No token generated"); 920 | } 921 | } 922 | Ok(()) 923 | } 924 | -------------------------------------------------------------------------------- /src/semaphore.rs: -------------------------------------------------------------------------------- 1 | // There is no semaphore in Rust standard library. wat?? 2 | // So I've made a simple one I can use out of a mutex and condition variable.. 3 | 4 | use std::sync::{Arc, Condvar, Mutex, MutexGuard}; 5 | 6 | #[derive(Clone)] 7 | pub struct Semaphore { 8 | count: Arc>, 9 | waiters: Arc, 10 | } 11 | 12 | pub struct SemaphoreGuard<'a> { 13 | mutex_guard: MutexGuard<'a, usize>, 14 | } 15 | 16 | impl<'a> Drop for SemaphoreGuard<'a> { 17 | fn drop(&mut self) { 18 | *self.mutex_guard += 1; 19 | } 20 | } 21 | 22 | impl Semaphore { 23 | pub fn new(count: usize) -> Semaphore { 24 | Semaphore { 25 | count: Arc::new(Mutex::new(count)), 26 | waiters: Arc::new(Condvar::new()), 27 | } 28 | } 29 | 30 | pub fn acquire(&self) -> SemaphoreGuard { 31 | let mut count = self.count.lock().unwrap(); 32 | while *count == 0 { 33 | count = self.waiters.wait(count).unwrap(); 34 | } 35 | *count -= 1; 36 | SemaphoreGuard { mutex_guard: count } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/simd_support.rs: -------------------------------------------------------------------------------- 1 | // This file contains platform-specific SIMD so that rest of rllama does not need to care which 2 | // platform it is on. 3 | 4 | use core::arch::x86_64::*; 5 | use half::f16; 6 | 7 | pub type I32x8 = __m256i; 8 | pub type F32x8 = __m256; 9 | pub type I16x8 = __m128i; 10 | 11 | /* ------------------ */ 12 | /* Loading and storing things */ 13 | /* ------------------ */ 14 | 15 | #[inline] 16 | pub fn load_i16x8(ptr: *const I16x8) -> I16x8 { 17 | unsafe { _mm_loadu_si128(ptr) } 18 | } 19 | 20 | #[inline] 21 | pub fn store_i16x8(ptr: *mut I16x8, a: I16x8) { 22 | unsafe { _mm_storeu_si128(ptr, a) } 23 | } 24 | 25 | #[inline] 26 | pub fn load_f32x8(ptr: *const F32x8) -> F32x8 { 27 | unsafe { _mm256_loadu_ps(ptr as *const f32) } 28 | } 29 | 30 | #[inline] 31 | pub fn store_f32x8(ptr: *mut F32x8, a: F32x8) { 32 | unsafe { _mm256_storeu_ps(ptr as *mut f32, a) } 33 | } 34 | 35 | #[inline] 36 | pub fn gather_f32x8(ptr: *const f32, indices: I32x8) -> F32x8 { 37 | unsafe { _mm256_i32gather_ps(ptr, indices, 1) } 38 | } 39 | 40 | /* ------------------ */ 41 | /* Conversions */ 42 | /* ------------------ */ 43 | 44 | #[inline] 45 | pub fn i16x8_as_f16_to_f32x8(a: I16x8) -> F32x8 { 46 | unsafe { _mm256_cvtph_ps(a) } 47 | } 48 | 49 | #[inline] 50 | pub fn f32x8_to_i16x8_as_f16(a: F32x8) -> I16x8 { 51 | unsafe { _mm256_cvtps_ph(a, 0) } 52 | } 53 | 54 | /* 55 | * Constants, creating from constants 56 | */ 57 | 58 | pub fn f32x8_zero() -> F32x8 { 59 | unsafe { _mm256_setzero_ps() } 60 | } 61 | 62 | pub fn i16x8_zero() -> I16x8 { 63 | unsafe { _mm_setzero_si128() } 64 | } 65 | 66 | pub fn f32x8_singleton(value: f32) -> F32x8 { 67 | unsafe { _mm256_set1_ps(value) } 68 | } 69 | 70 | pub fn i32x8_from_values( 71 | val0: i32, 72 | val1: i32, 73 | val2: i32, 74 | val3: i32, 75 | val4: i32, 76 | val5: i32, 77 | val6: i32, 78 | val7: i32, 79 | ) -> I32x8 { 80 | unsafe { _mm256_set_epi32(val0, val1, val2, val3, val4, val5, val6, val7) } 81 | } 82 | 83 | /* 84 | * Operations 85 | */ 86 | 87 | // FMA 88 | 89 | // a * b + c 90 | pub fn fma_f32x8(a: F32x8, b: F32x8, c: F32x8) -> F32x8 { 91 | unsafe { _mm256_fmadd_ps(a, b, c) } 92 | } 93 | 94 | // Horizontal sums 95 | 96 | #[inline] 97 | pub fn horizontal_sum_f32x8(mut ymm: __m256) -> f32 { 98 | unsafe { 99 | let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); 100 | ymm = _mm256_add_ps(ymm, ymm2); 101 | ymm = _mm256_hadd_ps(ymm, ymm); 102 | ymm = _mm256_hadd_ps(ymm, ymm); 103 | _mm256_cvtss_f32(ymm) 104 | } 105 | } 106 | 107 | #[inline] 108 | pub fn horizontal_sum_and_f32_to_f16(mut ymm: __m256) -> f16 { 109 | unsafe { 110 | let ymm2 = _mm256_permute2f128_ps(ymm, ymm, 1); 111 | ymm = _mm256_add_ps(ymm, ymm2); 112 | ymm = _mm256_hadd_ps(ymm, ymm); 113 | ymm = _mm256_hadd_ps(ymm, ymm); 114 | f16::from_f32(_mm256_cvtss_f32(ymm)) 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/tensor_opencl_support.rs: -------------------------------------------------------------------------------- 1 | /* 2 | * OpenCL stuff to run (some) of the tensor operations. 3 | */ 4 | 5 | use ocl::{ 6 | enums::DeviceInfo, enums::DeviceInfoResult, Buffer, Context, Device, DeviceType, Event, Kernel, 7 | Platform, Program, Queue, 8 | }; 9 | use std::alloc::Layout; 10 | use std::sync::{Arc, RwLock}; 11 | use thiserror::Error; 12 | 13 | #[derive(Debug)] 14 | #[allow(dead_code)] 15 | struct Programs { 16 | matrix_mul_transposed_f16_program: Program, 17 | matrix_mul_transposed_f16: Kernel, 18 | matrix_mul_transposed_one_row_f16_program: Program, 19 | matrix_mul_transposed_one_row_f16: Kernel, 20 | matrix_mul_transposed_f16_cpu_optimized_program: Program, 21 | matrix_mul_transposed_f16_cpu_optimized: Kernel, 22 | silu_f16_program: Program, 23 | silu_f16: Kernel, 24 | hadamard_product_f16_program: Program, 25 | hadamard_product_f16: Kernel, 26 | transpose_f16_program: Program, 27 | transpose_f16: Kernel, 28 | } 29 | 30 | #[derive(Debug, Clone)] 31 | #[allow(dead_code)] 32 | pub struct OpenCL { 33 | ctx: Context, 34 | queue: Queue, 35 | programs: Arc>, 36 | is_cpu_device: bool, 37 | } 38 | 39 | #[derive(Debug)] 40 | pub struct OpenCLTensor { 41 | buf: Buffer, // really is f16 42 | initial_write_event: Option, 43 | last_event: Option, 44 | data: *const u16, 45 | data_layout: Layout, 46 | nitems: usize, 47 | rows: i64, 48 | cols: i64, 49 | cols_capacity: i64, 50 | queue: Queue, 51 | cl: OpenCL, 52 | } 53 | 54 | #[derive(Debug)] 55 | pub struct OpenCLEvent { 56 | event: ocl::Event, 57 | } 58 | 59 | impl Drop for OpenCLTensor { 60 | fn drop(&mut self) { 61 | if self.initial_write_event.is_some() { 62 | self.initial_write_event 63 | .as_ref() 64 | .unwrap() 65 | .wait_for() 66 | .unwrap(); 67 | } 68 | self.initial_write_event = None; 69 | if !self.data.is_null() { 70 | unsafe { 71 | std::alloc::dealloc(self.data as *mut u8, self.data_layout); 72 | } 73 | } 74 | } 75 | } 76 | 77 | #[derive(Error, Debug)] 78 | pub enum OpenCLError { 79 | #[error("OpenCL error: {0}")] 80 | OpenCL(#[from] ocl::Error), 81 | #[error("Cannot select device")] 82 | OpenCLDeviceSelection, 83 | } 84 | 85 | impl OpenCL { 86 | pub fn new(verbose: bool, nth_device: usize) -> Result { 87 | let platforms = Platform::list(); 88 | let mut devices: Vec<(Platform, Device)> = Vec::new(); 89 | for platform in platforms { 90 | for device in Device::list_all(platform)? { 91 | devices.push((platform, device)); 92 | } 93 | } 94 | if verbose { 95 | println!("Enumerating OpenCL devices:"); 96 | } 97 | for (idx, (_plat, device)) in devices.iter().enumerate() { 98 | if verbose { 99 | println!("OpenCL {} device: {}", idx, device.name()?,); 100 | } 101 | } 102 | if nth_device > devices.len() { 103 | return Err(OpenCLError::OpenCLDeviceSelection); 104 | } 105 | if verbose { 106 | println!("---"); 107 | println!("Selected OpenCL device: {}", devices[nth_device].1.name()?); 108 | } 109 | 110 | let ctx = Context::builder() 111 | .platform(devices[nth_device].0) 112 | .devices(devices[nth_device].1) 113 | .build()?; 114 | 115 | let is_cpu_device = match devices[nth_device].1.info(DeviceInfo::Type)? { 116 | DeviceInfoResult::Type(DeviceType::CPU) => true, 117 | _ => false, 118 | }; 119 | 120 | let queue = Queue::new(&ctx, devices[nth_device].1, None)?; 121 | let programs = make_programs(&ctx, &queue)?; 122 | Ok(OpenCL { 123 | ctx: ctx, 124 | queue: queue, 125 | programs: Arc::new(RwLock::new(programs)), 126 | is_cpu_device, 127 | }) 128 | } 129 | 130 | pub fn flush(&self) { 131 | let _ = self.queue.flush(); 132 | } 133 | 134 | pub fn data_u16_to_gpu( 135 | &self, 136 | data: *const u16, 137 | data_layout: Layout, 138 | nitems: usize, 139 | rows: i64, 140 | cols: i64, 141 | cols_capacity: i64, 142 | ) -> Result { 143 | unsafe { 144 | let buf = Buffer::builder() 145 | .queue(self.queue.clone()) 146 | .len(nitems) 147 | .build()?; 148 | let mut event = Event::empty(); 149 | let data_slice: &[u16] = std::slice::from_raw_parts(data, nitems); 150 | buf.cmd() 151 | .write(data_slice) 152 | .block(false) 153 | .enew(&mut event) 154 | .enq()?; 155 | Ok(OpenCLTensor { 156 | buf, 157 | initial_write_event: Some(event), 158 | last_event: None, 159 | data, 160 | data_layout, 161 | nitems, 162 | rows, 163 | cols, 164 | cols_capacity, 165 | queue: self.queue.clone(), 166 | cl: self.clone(), 167 | }) 168 | } 169 | } 170 | } 171 | 172 | impl OpenCLTensor { 173 | pub fn cl(&self) -> OpenCL { 174 | self.cl.clone() 175 | } 176 | 177 | pub fn wait_until_ready(&mut self) { 178 | if self.last_event.is_some() { 179 | self.last_event.as_ref().unwrap().wait_for().unwrap(); 180 | self.last_event = None; 181 | } 182 | if self.initial_write_event.is_some() { 183 | self.initial_write_event 184 | .as_ref() 185 | .unwrap() 186 | .wait_for() 187 | .unwrap(); 188 | self.initial_write_event = None; 189 | } 190 | if !self.data.is_null() { 191 | unsafe { 192 | std::alloc::dealloc(self.data as *mut u8, self.data_layout); 193 | } 194 | self.data = std::ptr::null(); 195 | } 196 | } 197 | 198 | pub fn data_u16_from_gpu(&mut self, data: *mut u16) -> Result { 199 | unsafe { 200 | let mut event = Event::empty(); 201 | let data_slice: &mut [u16] = std::slice::from_raw_parts_mut(data, self.nitems); 202 | let b = self 203 | .buf 204 | .cmd() 205 | .read(data_slice) 206 | .block(false) 207 | .enew(&mut event); 208 | b.enq()?; 209 | self.last_event = Some(event.clone()); 210 | return Ok(OpenCLEvent { event }); 211 | } 212 | } 213 | 214 | /// Copies all values from another tensor 215 | pub fn copy_inplace(&mut self, other: &OpenCLTensor) -> Result { 216 | if other.rows != self.rows || other.cols != self.cols { 217 | panic!( 218 | "Cannot in-place copy tensors of different sizes: {}x{} <-- {}x{}", 219 | self.rows, self.cols, other.rows, other.cols 220 | ); 221 | } 222 | let mut event = Event::empty(); 223 | other 224 | .buf 225 | .cmd() 226 | .queue(&other.queue) 227 | .copy(&self.buf, None, None) 228 | .enew(&mut event) 229 | .enq()?; 230 | self.last_event = Some(event.clone()); 231 | Ok(OpenCLEvent { event }) 232 | } 233 | 234 | pub fn transpose_from(&mut self, other: &OpenCLTensor) -> Result { 235 | let prg = self.cl.programs.write().unwrap(); 236 | prg.transpose_f16.set_arg(0, self.buf.clone()).unwrap(); 237 | prg.transpose_f16.set_arg(1, other.buf.clone()).unwrap(); 238 | prg.transpose_f16 239 | .set_arg(2, self.cols_capacity as i32) 240 | .unwrap(); 241 | prg.transpose_f16 242 | .set_arg(3, other.cols_capacity as i32) 243 | .unwrap(); 244 | let mut event = Event::empty(); 245 | unsafe { 246 | let b = prg 247 | .transpose_f16 248 | .cmd() 249 | .queue(&self.queue) 250 | .global_work_size([self.rows as usize, self.cols as usize]) 251 | .enew(&mut event); 252 | b.enq().unwrap(); 253 | } 254 | self.last_event = Some(event.clone()); 255 | Ok(OpenCLEvent { event }) 256 | } 257 | 258 | pub fn hadamard_product_inplace( 259 | &mut self, 260 | other: &OpenCLTensor, 261 | ) -> Result { 262 | let prg = self.cl.programs.write().unwrap(); 263 | prg.hadamard_product_f16.set_arg(0, self.buf.clone())?; 264 | prg.hadamard_product_f16.set_arg(1, other.buf.clone())?; 265 | prg.hadamard_product_f16 266 | .set_arg(2, self.cols_capacity as i32)?; 267 | prg.hadamard_product_f16 268 | .set_arg(3, other.cols_capacity as i32)?; 269 | let mut event = Event::empty(); 270 | unsafe { 271 | let b = prg 272 | .hadamard_product_f16 273 | .cmd() 274 | .queue(&self.queue) 275 | .global_work_size([self.rows as usize, self.cols as usize]) 276 | .enew(&mut event); 277 | b.enq()?; 278 | } 279 | self.last_event = Some(event.clone()); 280 | Ok(OpenCLEvent { event }) 281 | } 282 | 283 | pub fn silu_inplace(&mut self) -> Result { 284 | let prg = self.cl.programs.write().unwrap(); 285 | prg.silu_f16.set_arg(0, self.buf.clone())?; 286 | prg.silu_f16.set_arg(1, self.cols_capacity as i32)?; 287 | let mut event = Event::empty(); 288 | unsafe { 289 | let b = prg 290 | .silu_f16 291 | .cmd() 292 | .queue(&self.queue) 293 | .global_work_size([self.rows as usize, self.cols as usize]) 294 | .enew(&mut event); 295 | b.enq()?; 296 | } 297 | self.last_event = Some(event.clone()); 298 | Ok(OpenCLEvent { event }) 299 | } 300 | 301 | pub fn matrix_mul_inplace_transposed( 302 | &mut self, 303 | src: &OpenCLTensor, 304 | other: &OpenCLTensor, 305 | ) -> Result { 306 | if src.cols != other.cols { 307 | panic!( 308 | "OpenCL matrix_mul_inplace_transposed: src.cols must equal other.cols: {}x{} vs {}x{}", 309 | src.rows, src.cols, other.rows, other.cols 310 | ); 311 | } 312 | if self.rows != src.rows || self.cols != other.rows { 313 | panic!( 314 | "OpenCL matrix_mul_inplace_transposed: self.rows must equal src.rows and self.cols must equal other.cols: {}x{} vs {}x{} vs {}x{}", 315 | self.rows, self.cols, src.rows, src.cols, other.rows, other.cols 316 | ); 317 | } 318 | 319 | // Clear out the target memory. 320 | unsafe { self.buf.cmd().fill(0u16, None).block(false).enq()? }; 321 | 322 | let prg = self.cl.programs.write().unwrap(); 323 | 324 | // 0 = CPU optimized 325 | // 1 = GPU optimized 326 | // 2 = GPU optimized vector multiply (other.rows == 1) 327 | const CPU: u8 = 0; 328 | const GPU: u8 = 1; 329 | const GPU2: u8 = 2; 330 | let strategy: u8 = if self.cl.is_cpu_device { 331 | CPU 332 | } else { 333 | if src.rows == 1 { 334 | GPU2 335 | } else { 336 | GPU 337 | } 338 | }; 339 | 340 | let prg = if strategy == CPU { 341 | &prg.matrix_mul_transposed_f16_cpu_optimized 342 | } else if strategy == GPU { 343 | &prg.matrix_mul_transposed_f16 344 | } else { 345 | &prg.matrix_mul_transposed_one_row_f16 346 | }; 347 | prg.set_arg(0, self.buf.clone())?; 348 | prg.set_arg(1, src.buf.clone())?; 349 | prg.set_arg(2, other.buf.clone())?; 350 | prg.set_arg(3, src.cols_capacity as i32)?; 351 | prg.set_arg(4, other.cols_capacity as i32)?; 352 | prg.set_arg(5, self.cols_capacity as i32)?; 353 | prg.set_arg(6, self.rows as i32)?; 354 | prg.set_arg(7, self.cols as i32)?; 355 | prg.set_arg(8, src.cols as i32)?; 356 | let mut event = Event::empty(); 357 | 358 | let rows16 = if self.rows % 16 == 0 { 359 | self.rows 360 | } else { 361 | self.rows + 16 - (self.rows % 16) 362 | }; 363 | let cols16 = if self.cols % 16 == 0 { 364 | self.cols 365 | } else { 366 | self.cols + 16 - (self.cols % 16) 367 | }; 368 | 369 | unsafe { 370 | if strategy == CPU { 371 | let b = prg 372 | .cmd() 373 | .queue(&self.queue) 374 | .global_work_size([self.cols as usize, self.rows as usize]) 375 | .enew(&mut event); 376 | b.enq()?; 377 | } else if strategy == GPU { 378 | let b = prg 379 | .cmd() 380 | .queue(&self.queue) 381 | .global_work_size([cols16 as usize, rows16 as usize]) 382 | .local_work_size([16, 16]) 383 | .enew(&mut event); 384 | b.enq()?; 385 | } else if strategy == GPU2 { 386 | let b = prg 387 | .cmd() 388 | .queue(&self.queue) 389 | .global_work_size([cols16 as usize, 1]) 390 | .local_work_size([16, 1]) 391 | .enew(&mut event); 392 | b.enq()?; 393 | } else { 394 | let b = prg 395 | .cmd() 396 | .queue(&self.queue) 397 | .global_work_size([self.cols as usize, self.rows as usize]) 398 | .enew(&mut event); 399 | b.enq()?; 400 | } 401 | } 402 | self.last_event = Some(event.clone()); 403 | Ok(OpenCLEvent { event }) 404 | } 405 | } 406 | 407 | impl OpenCLEvent { 408 | #[inline] 409 | pub fn wait(&self) { 410 | self.event.wait_for().unwrap(); 411 | } 412 | } 413 | 414 | fn make_programs(ctx: &Context, queue: &Queue) -> Result { 415 | fn make_program_with_src(ctx: &Context, src: &str) -> Result { 416 | let program = Program::builder().src(src).build(&ctx)?; 417 | Ok(program) 418 | } 419 | 420 | let matrix_mul_transposed_f16_program = 421 | make_program_with_src(ctx, MATRIX_MUL_TRANSPOSED_F16_SRC)?; 422 | let matrix_mul_transposed_f16 = Kernel::builder() 423 | .program(&matrix_mul_transposed_f16_program) 424 | .name("matrix_mul_transposed_f16") 425 | .arg(None::<&Buffer>) 426 | .arg(None::<&Buffer>) 427 | .arg(None::<&Buffer>) 428 | .arg(&0) 429 | .arg(&0) 430 | .arg(&0) 431 | .arg(&0) 432 | .arg(&0) 433 | .arg(&0) 434 | .queue(queue.clone()) 435 | .build()?; 436 | let matrix_mul_transposed_f16_cpu_optimized_program = 437 | make_program_with_src(ctx, MATRIX_MUL_TRANSPOSED_F16_CPU_OPTIMIZED_SRC)?; 438 | let matrix_mul_transposed_f16_cpu_optimized = Kernel::builder() 439 | .program(&matrix_mul_transposed_f16_cpu_optimized_program) 440 | .name("matrix_mul_transposed_f16_cpu_optimized") 441 | .arg(None::<&Buffer>) 442 | .arg(None::<&Buffer>) 443 | .arg(None::<&Buffer>) 444 | .arg(&0) 445 | .arg(&0) 446 | .arg(&0) 447 | .arg(&0) 448 | .arg(&0) 449 | .arg(&0) 450 | .queue(queue.clone()) 451 | .build()?; 452 | let matrix_mul_transposed_one_row_f16_program = 453 | make_program_with_src(ctx, MATRIX_MUL_TRANSPOSED_F16_ONE_ROW_SRC)?; 454 | let matrix_mul_transposed_one_row_f16 = Kernel::builder() 455 | .program(&matrix_mul_transposed_one_row_f16_program) 456 | .name("matrix_mul_transposed_one_row_f16") 457 | .arg(None::<&Buffer>) 458 | .arg(None::<&Buffer>) 459 | .arg(None::<&Buffer>) 460 | .arg(&0) 461 | .arg(&0) 462 | .arg(&0) 463 | .arg(&0) 464 | .arg(&0) 465 | .arg(&0) 466 | .queue(queue.clone()) 467 | .build()?; 468 | let silu_f16_program = make_program_with_src(ctx, SILU_F16_SRC)?; 469 | let silu_f16 = Kernel::builder() 470 | .program(&silu_f16_program) 471 | .name("silu_f16") 472 | .arg(None::<&Buffer>) 473 | .arg(&0) 474 | .queue(queue.clone()) 475 | .build()?; 476 | let hadamard_product_f16_program = make_program_with_src(ctx, HADAMARD_PRODUCT_F16_SRC)?; 477 | let hadamard_product_f16 = Kernel::builder() 478 | .program(&hadamard_product_f16_program) 479 | .name("hadamard_product_f16") 480 | .arg(None::<&Buffer>) 481 | .arg(None::<&Buffer>) 482 | .arg(&0) 483 | .arg(&0) 484 | .queue(queue.clone()) 485 | .build()?; 486 | let transpose_f16_program = make_program_with_src(ctx, TRANSPOSE_F16_SRC)?; 487 | let transpose_f16 = Kernel::builder() 488 | .program(&transpose_f16_program) 489 | .name("transpose_f16") 490 | .arg(None::<&Buffer>) 491 | .arg(None::<&Buffer>) 492 | .arg(&0) 493 | .arg(&0) 494 | .queue(queue.clone()) 495 | .build()?; 496 | Ok(Programs { 497 | matrix_mul_transposed_f16_program, 498 | matrix_mul_transposed_f16, 499 | matrix_mul_transposed_one_row_f16_program, 500 | matrix_mul_transposed_one_row_f16, 501 | matrix_mul_transposed_f16_cpu_optimized_program, 502 | matrix_mul_transposed_f16_cpu_optimized, 503 | silu_f16_program, 504 | silu_f16, 505 | hadamard_product_f16_program, 506 | hadamard_product_f16, 507 | transpose_f16_program, 508 | transpose_f16, 509 | }) 510 | } 511 | 512 | const MATRIX_MUL_TRANSPOSED_F16_SRC: &str = r#" 513 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 514 | 515 | __kernel void matrix_mul_transposed_f16( 516 | __global half *tgt, 517 | __global const half *left, 518 | __global const half *right, 519 | const int left_cols_capacity, 520 | const int right_cols_capacity, 521 | const int ncols_capacity, 522 | const int nrows, 523 | const int ncols, // size of target 524 | const int shared_sz 525 | ) { 526 | __local float lefttile[16][16]; 527 | __local float righttile[16][16]; 528 | 529 | const int global_x = get_global_id(0); 530 | const int global_y = get_global_id(1); 531 | const int local_x = get_local_id(0); 532 | const int local_y = get_local_id(1); 533 | const int num_tiles = (shared_sz + 15) / 16; 534 | 535 | float sum = 0.0f; 536 | for (int t = 0; t < num_tiles; ++t) { 537 | if (global_y < nrows) { 538 | lefttile[local_y][local_x] = vload_half(global_y * left_cols_capacity + t * 16 + local_x, left); 539 | } else { 540 | lefttile[local_y][local_x] = 0.0f; 541 | } 542 | if (global_x < ncols) { 543 | righttile[local_y][local_x] = vload_half(global_x * right_cols_capacity + t * 16 + local_y, right); 544 | } else { 545 | righttile[local_y][local_x] = 0.0f; 546 | } 547 | barrier(CLK_LOCAL_MEM_FENCE); 548 | for (int k = 0; k < 16; ++k) { 549 | sum += lefttile[local_y][k] * righttile[k][local_x]; 550 | } 551 | barrier(CLK_LOCAL_MEM_FENCE); 552 | } 553 | if (global_x < ncols && global_y < nrows) { 554 | vstore_half(sum, global_y * ncols_capacity + global_x, (__global half*) tgt); 555 | } 556 | } 557 | "#; 558 | 559 | const MATRIX_MUL_TRANSPOSED_F16_ONE_ROW_SRC: &str = r#" 560 | __kernel void matrix_mul_transposed_one_row_f16( 561 | __global half *tgt, 562 | __global const half *left, 563 | __global const half *right, 564 | const int left_cols_capacity, 565 | const int right_cols_capacity, 566 | const int ncols_capacity, 567 | const int nrows, 568 | const int ncols, // size of target 569 | const int shared_sz 570 | ) { 571 | // assertions: 572 | // nrows == 1 573 | // left_rows == 1 574 | __local float lefttile[16]; 575 | __local float righttile[16][16]; 576 | 577 | const int global_x = get_global_id(0); 578 | const int local_x = get_local_id(0); 579 | const int num_tiles = (shared_sz + 15) / 16; 580 | const int x_tile = (global_x / 16) * 16; 581 | 582 | float sum = 0.0f; 583 | if (x_tile + 15 < ncols) { 584 | for (int t = 0; t < num_tiles; ++t) { 585 | lefttile[local_x] = vload_half(t * 16 + local_x, left); 586 | for (int k = 0; k < 16; ++k) { 587 | righttile[k][local_x] = vload_half(t * 16 + local_x + (x_tile + k) * right_cols_capacity, right); 588 | } 589 | barrier(CLK_LOCAL_MEM_FENCE); 590 | for (int k = 0; k < 16; ++k) { 591 | sum += lefttile[k] * righttile[local_x][k]; 592 | } 593 | barrier(CLK_LOCAL_MEM_FENCE); 594 | } 595 | } else { 596 | for (int t = 0; t < num_tiles; ++t) { 597 | lefttile[local_x] = vload_half(t * 16 + local_x, left); 598 | for (int k = 0; k < 16; ++k) { 599 | if (x_tile + k >= ncols) { 600 | righttile[k][local_x] = 0.0f; 601 | } else { 602 | righttile[k][local_x] = vload_half(t * 16 + local_x + (x_tile + k) * right_cols_capacity, right); 603 | } 604 | } 605 | barrier(CLK_LOCAL_MEM_FENCE); 606 | for (int k = 0; k < 16; ++k) { 607 | sum += lefttile[k] * righttile[local_x][k]; 608 | } 609 | barrier(CLK_LOCAL_MEM_FENCE); 610 | } 611 | } 612 | 613 | if (global_x < ncols) { 614 | vstore_half(sum, global_x, (__global half*) tgt); 615 | } 616 | }"#; 617 | 618 | const MATRIX_MUL_TRANSPOSED_F16_CPU_OPTIMIZED_SRC: &str = r#" 619 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 620 | 621 | __kernel void matrix_mul_transposed_f16_cpu_optimized( 622 | __global half *tgt, 623 | __global const half *left, 624 | __global const half *right, 625 | const int left_cols_capacity, 626 | const int right_cols_capacity, 627 | const int ncols_capacity, 628 | const int nrows, 629 | const int ncols, // size of target 630 | const int shared_sz 631 | ) { 632 | const int tgt_col = get_global_id(0); 633 | const int tgt_row = get_global_id(1); 634 | int col_iterations = shared_sz / 16; 635 | if (shared_sz % 16 != 0) { 636 | col_iterations = col_iterations + 1; 637 | } 638 | float16 sum = 0; 639 | for (int col16 = 0; col16 < col_iterations; col16++) { 640 | const float16 left8 = vload_half16((tgt_row * left_cols_capacity)/16 + col16, (__global const half*) left); 641 | const float16 right8 = vload_half16((tgt_col * right_cols_capacity)/16 + col16, (__global const half*) right); 642 | // hadamard product FMA add it to sum 643 | // const float16 result8 = left8 * right8; 644 | // sum += result8; 645 | sum = fma(left8, right8, sum); 646 | } 647 | // Reduce as accurately as possible 648 | float sum1 = sum.s0 + sum.s1; 649 | float sum2 = sum.s2 + sum.s3; 650 | float sum3 = sum.s4 + sum.s5; 651 | float sum4 = sum.s6 + sum.s7; 652 | float sum5 = sum.s8 + sum.s9; 653 | float sum6 = sum.sa + sum.sb; 654 | float sum7 = sum.sc + sum.sd; 655 | float sum8 = sum.se + sum.sf; 656 | float sum11 = sum1 + sum2; 657 | float sum12 = sum3 + sum4; 658 | float sum13 = sum5 + sum6; 659 | float sum14 = sum7 + sum8; 660 | float sum21 = sum11 + sum12; 661 | float sum22 = sum13 + sum14; 662 | float total = sum21 + sum22; 663 | vstore_half(total, 0, (__global half*) &tgt[tgt_row * ncols_capacity + tgt_col]); 664 | } 665 | "#; 666 | 667 | /// Computes SILU for every f16 value in the tensor 668 | const SILU_F16_SRC: &str = r#" 669 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 670 | 671 | __kernel void silu_f16(__global half *tgt, 672 | const int ncols_capacity) 673 | { 674 | const int tgt_row = get_global_id(0); 675 | const int tgt_col = get_global_id(1); 676 | const float val = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt); 677 | const float result = val * (1.0 / (1.0 + exp(-val))); 678 | vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt); 679 | } 680 | "#; 681 | 682 | /// Computes hadamard product of two identially sized tensors 683 | const HADAMARD_PRODUCT_F16_SRC: &str = r#" 684 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 685 | 686 | __kernel void hadamard_product_f16(__global half *tgt, 687 | __global const half *left, 688 | const int ncols_capacity, 689 | const int left_cols_capacity) { 690 | const int tgt_row = get_global_id(0); 691 | const int tgt_col = get_global_id(1); 692 | const float tgt_value = vload_half(tgt_row * ncols_capacity + tgt_col, (__global const half*) tgt); 693 | const float left_value = vload_half(tgt_row * left_cols_capacity + tgt_col, (__global const half*) left); 694 | const float result = tgt_value * left_value; 695 | vstore_half(result, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt); 696 | } 697 | "#; 698 | 699 | /// Computes the transpose of a matrix 700 | const TRANSPOSE_F16_SRC: &str = r#" 701 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable 702 | 703 | __kernel void transpose_f16(__global half *tgt, 704 | __global const half *left, 705 | const int ncols_capacity, 706 | const int left_cols_capacity) 707 | { 708 | const int tgt_row = get_global_id(0); 709 | const int tgt_col = get_global_id(1); 710 | const int src_row = tgt_col; 711 | const int src_col = tgt_row; 712 | const float val = vload_half(src_row * left_cols_capacity + src_col, (__global const half*) left); 713 | vstore_half(val, tgt_row * ncols_capacity + tgt_col, (__global half*) tgt); 714 | } 715 | "#; 716 | -------------------------------------------------------------------------------- /src/token_sampler.rs: -------------------------------------------------------------------------------- 1 | use crate::tensor::Tensor; 2 | use crate::tokenizer::{TokenId, Tokenizer}; 3 | use rand::Rng; 4 | use std::collections::BTreeMap; 5 | 6 | pub struct TokenSampler { 7 | temperature: f32, 8 | top_p: f32, 9 | top_k: usize, 10 | repetition_penalty: f32, 11 | } 12 | 13 | impl Default for TokenSampler { 14 | fn default() -> Self { 15 | Self::new() 16 | } 17 | } 18 | 19 | impl TokenSampler { 20 | pub fn new() -> Self { 21 | Self { 22 | temperature: 0.2, 23 | top_p: 1.0, 24 | top_k: 1, // same as argmax 25 | repetition_penalty: 0.8, // 1.0 = no penalty. values above 1.0 make repetition 26 | // encouraged which can quickly devolve into repeating loop 27 | } 28 | } 29 | 30 | pub fn get_temperature(&self) -> f32 { 31 | self.temperature 32 | } 33 | 34 | pub fn get_top_p(&self) -> f32 { 35 | self.top_p 36 | } 37 | 38 | pub fn get_top_k(&self) -> usize { 39 | self.top_k 40 | } 41 | 42 | pub fn get_repetition_penalty(&self) -> f32 { 43 | self.repetition_penalty 44 | } 45 | 46 | pub fn temperature(self, temperature: f32) -> Self { 47 | Self { 48 | temperature, 49 | ..self 50 | } 51 | } 52 | 53 | pub fn top_p(self, top_p: f32) -> Self { 54 | Self { top_p, ..self } 55 | } 56 | 57 | pub fn top_k(self, top_k: usize) -> Self { 58 | Self { top_k, ..self } 59 | } 60 | 61 | pub fn repetition_penalty(self, repetition_penalty: f32) -> Self { 62 | Self { 63 | repetition_penalty, 64 | ..self 65 | } 66 | } 67 | 68 | pub fn logits_to_btreemap( 69 | &self, 70 | logits: &Tensor, 71 | tokenizer: &Tokenizer, 72 | ) -> BTreeMap { 73 | let mut result = BTreeMap::new(); 74 | for token_idx in 0..logits.rows() { 75 | result.insert( 76 | tokenizer.id_to_str(token_idx as TokenId).to_string(), 77 | logits.get_f32(token_idx, 0), 78 | ); 79 | } 80 | result 81 | } 82 | 83 | pub fn sample( 84 | &self, 85 | logits: &Tensor, 86 | _tokenizer: &Tokenizer, 87 | existing_tokens: &[TokenId], 88 | ) -> (TokenId, f32) { 89 | let mut times_used: BTreeMap = BTreeMap::new(); 90 | for token in existing_tokens { 91 | times_used 92 | .entry(*token) 93 | .and_modify(|e| *e += 1) 94 | .or_insert(1); 95 | } 96 | 97 | let nrows = logits.rows(); 98 | assert!(logits.cols() == 1); 99 | let mut logits = logits.transpose(); 100 | if self.temperature > 0.0 { 101 | logits = logits.scalar_multiply_f32(1.0 / self.temperature); 102 | } 103 | 104 | if self.repetition_penalty != 1.0 { 105 | for token_idx in 0..logits.rows() { 106 | if let Some(count) = times_used.get(&(token_idx as TokenId)) { 107 | let penalty = self.repetition_penalty.powf(*count as f32); 108 | logits.set_f32(0, token_idx, logits.get_f32(0, token_idx) * penalty); 109 | } 110 | } 111 | } 112 | let mut maxv: f32 = std::f32::NEG_INFINITY; 113 | for token_idx in 0..logits.rows() { 114 | let v = logits.get_f32(0, token_idx); 115 | if v > maxv { 116 | maxv = v; 117 | } 118 | } 119 | // To numerically stabilize, remove maxv from all logits 120 | // softmax(x + c) = softmax(x) where c is a constant, and we make use of htat 121 | for token_idx in 0..logits.rows() { 122 | logits.set_f32(0, token_idx, logits.get_f32(0, token_idx) - maxv); 123 | } 124 | logits = logits.softmax(); 125 | 126 | let mut logitsf: Vec<(TokenId, f32)> = Vec::with_capacity(nrows as usize); 127 | for i in 0..nrows { 128 | let score = logits.get_f32(0, i); 129 | logitsf.push((i as TokenId, score)); 130 | } 131 | logitsf.sort_unstable_by(|a, b| { 132 | match b.1.partial_cmp(&a.1) { 133 | Some(c) => c, 134 | None => { 135 | // Sort NaNs to bottom 136 | if b.1.is_nan() { 137 | std::cmp::Ordering::Less 138 | } else if a.1.is_nan() { 139 | return std::cmp::Ordering::Greater; 140 | } else { 141 | return std::cmp::Ordering::Equal; 142 | } 143 | } 144 | } 145 | }); 146 | 147 | logitsf.truncate(self.top_k); 148 | let mut p_accum: f32 = 0.0; 149 | for (idx, v) in logitsf.iter().enumerate() { 150 | p_accum += v.1; 151 | if p_accum >= self.top_p { 152 | logitsf.truncate(idx + 1); 153 | break; 154 | } 155 | } 156 | let mut total_p: f32 = 0.0; 157 | for v in logitsf.iter() { 158 | total_p += v.1; 159 | } 160 | let mut rng = rand::thread_rng(); 161 | let p: f32 = if total_p > 0.0 { 162 | rng.gen_range(0.0..=total_p) 163 | } else { 164 | 0.0 165 | }; 166 | p_accum = 0.0; 167 | for v in logitsf.into_iter() { 168 | p_accum += v.1; 169 | if p_accum >= p { 170 | return (v.0, v.1 / total_p); 171 | } 172 | } 173 | (0, 0.0) 174 | } 175 | } 176 | -------------------------------------------------------------------------------- /src/tokenizer.rs: -------------------------------------------------------------------------------- 1 | use crate::protomodels::sentencepiece_model::model_proto::sentence_piece; 2 | use crate::protomodels::sentencepiece_model::ModelProto; 3 | use protobuf::Message; 4 | use std::collections::BTreeMap; 5 | use std::io::Read; 6 | use std::path::Path; 7 | use thiserror::Error; 8 | 9 | pub type TokenId = i32; 10 | 11 | #[derive(Clone, Debug)] 12 | pub struct Tokenizer { 13 | pieces: BTreeMap, 14 | } 15 | 16 | #[derive(Clone, Debug, Copy, Eq, Ord, PartialEq, PartialOrd)] 17 | pub enum PieceType { 18 | Normal, 19 | Unknown, 20 | Control, 21 | UserDefined, 22 | Byte, 23 | Unused, 24 | } 25 | 26 | #[derive(Clone, Debug)] 27 | pub struct Piece { 28 | _tp: PieceType, 29 | // piece: String this is in the BTreeMap that holds the pieces 30 | _score: f32, 31 | idx: usize, 32 | } 33 | 34 | #[derive(Error, Debug)] 35 | pub enum TokenizerError { 36 | #[error("IO error")] 37 | IoError(#[from] std::io::Error), 38 | #[error("Protobuf error")] 39 | ProtobufError(#[from] protobuf::Error), 40 | #[error("Unknown piece type")] 41 | UnknownPieceType(String), 42 | } 43 | 44 | impl Tokenizer { 45 | pub fn load>(path: P) -> Result { 46 | let mut fs = std::fs::File::open(path)?; 47 | let mut buffer = Vec::new(); 48 | fs.read_to_end(&mut buffer)?; 49 | std::mem::drop(fs); 50 | let model = ModelProto::parse_from_bytes(&buffer)?; 51 | 52 | let mut pieces = BTreeMap::new(); 53 | for (idx, piece) in model.pieces.iter().enumerate() { 54 | let piece_str = piece.piece.clone(); 55 | if piece_str.is_none() { 56 | continue; 57 | } 58 | let piece_str = piece_str.unwrap(); 59 | let piece_type = match piece.type_ { 60 | None => sentence_piece::Type::NORMAL, 61 | Some(v) => match v.enum_value() { 62 | Err(_) => return Err(TokenizerError::UnknownPieceType(piece_str)), 63 | Ok(v) => v, 64 | }, 65 | }; 66 | 67 | let score = piece.score.unwrap_or(0.0); 68 | let tp = if piece_type == sentence_piece::Type::NORMAL { 69 | PieceType::Normal 70 | } else if piece_type == sentence_piece::Type::UNKNOWN { 71 | PieceType::Unknown 72 | } else if piece_type == sentence_piece::Type::CONTROL { 73 | PieceType::Control 74 | } else if piece_type == sentence_piece::Type::USER_DEFINED { 75 | PieceType::UserDefined 76 | } else if piece_type == sentence_piece::Type::BYTE { 77 | PieceType::Byte 78 | } else if piece_type == sentence_piece::Type::UNUSED { 79 | PieceType::Unused 80 | } else { 81 | return Err(TokenizerError::UnknownPieceType(piece_str)); 82 | }; 83 | pieces.insert( 84 | piece_str, 85 | Piece { 86 | _tp: tp, 87 | _score: score, 88 | idx, 89 | }, 90 | ); 91 | } 92 | 93 | Ok(Tokenizer { pieces }) 94 | } 95 | 96 | // Gives a string for a token id. 97 | // Panics if the id is out of range. 98 | pub fn id_to_str(&self, id: i32) -> &str { 99 | let id = id as usize; 100 | for (piece_str, piece_info) in self.pieces.iter() { 101 | if piece_info.idx == id { 102 | return piece_str; 103 | } 104 | } 105 | panic!("id out of range"); 106 | } 107 | 108 | // Tries to find a token from dictionary. 109 | pub fn str_to_id(&self, s: &str) -> Option { 110 | for (piece_str, piece_info) in self.pieces.iter() { 111 | if piece_str == s { 112 | return Some(piece_info.idx as i32); 113 | } 114 | } 115 | None 116 | } 117 | 118 | // Converts a string to a Vec<&str> 119 | // You may want to use tokenize_to_ids instead. 120 | // 121 | // This will not add start or end tokens; only the string is processed. 122 | // 123 | // I noticed LLaMa code adds an extra space character at the beginning of the string, this 124 | // function does not do that either. 125 | pub fn tokenize_to_pieces>(&self, s: S) -> Vec<&str> { 126 | let mut s: &str = s.as_ref(); 127 | let mut result: Vec<&str> = Vec::new(); 128 | 129 | // Very naive matching 130 | while !s.is_empty() { 131 | let mut best_candidate: &str = ""; 132 | let mut best_candidate_len: usize = 0; 133 | let mut skip_s: &str = ""; 134 | // Specially recognize newline. Otherwise it matches something we don't actually 135 | // want. 136 | if s.starts_with('\n') { 137 | if self.str_to_id("<0x0A>").is_some() { 138 | best_candidate = "<0x0A>"; 139 | best_candidate_len = best_candidate.len(); 140 | skip_s = &s[1..]; 141 | } else { 142 | best_candidate = "\\n"; 143 | } 144 | } else { 145 | for (piece_str, _piece_info) in self.pieces.iter() { 146 | if s.starts_with(piece_str) && best_candidate_len < piece_str.len() { 147 | best_candidate = piece_str; 148 | best_candidate_len = piece_str.len(); 149 | skip_s = &s[piece_str.len()..]; 150 | } 151 | } 152 | } 153 | if best_candidate_len == 0 { 154 | // Skip token. 155 | s = s.get(1..).unwrap_or(""); 156 | } else { 157 | result.push(best_candidate); 158 | s = skip_s; 159 | } 160 | } 161 | result 162 | } 163 | 164 | pub fn tokenize_to_ids>(&self, s: S) -> Vec { 165 | let mut s: String = format!("▁{}", s.as_ref()); 166 | // Replace all space characters with a special token. 167 | s = s.replace(' ', "▁"); 168 | 169 | let pieces = self.tokenize_to_pieces(s); 170 | let mut result = Vec::new(); 171 | result.push(1); // start token 172 | for piece in pieces { 173 | let piece_info = self.pieces.get(piece).unwrap(); 174 | result.push(piece_info.idx as i32); 175 | } 176 | result 177 | } 178 | } 179 | -------------------------------------------------------------------------------- /src/transformer.rs: -------------------------------------------------------------------------------- 1 | use crate::data_source::DataSource; 2 | use crate::embedding::Embedding; 3 | use crate::tensor::{FromPiecesDirection, Tensor, TensorDType}; 4 | #[cfg(feature = "opencl")] 5 | use crate::tensor_opencl_support::OpenCL; 6 | use crate::tokenizer::TokenId; 7 | 8 | use crate::unpickler::UnpicklingError; 9 | use indicatif::ProgressBar; 10 | use num_complex::Complex; 11 | use rayon::prelude::*; 12 | 13 | use std::sync::{Arc, RwLock}; 14 | 15 | type FreqsCis = Vec>>; 16 | 17 | #[allow(dead_code)] 18 | pub struct Transformer { 19 | freqs_cis: FreqsCis, 20 | emb: Embedding, 21 | dim: usize, 22 | n_layers: usize, 23 | n_heads: usize, 24 | n_local_heads: usize, 25 | max_seq_len: usize, 26 | head_dim: usize, 27 | 28 | norm: RMSNorm, 29 | output: Tensor, 30 | 31 | layers: Vec, 32 | 33 | data_settings: DataSettings, 34 | } 35 | 36 | // Clone is cheap 37 | #[derive(Clone)] 38 | pub struct DataSettings { 39 | #[cfg(feature = "opencl")] 40 | percentage_to_gpu: f32, 41 | #[cfg(feature = "opencl")] 42 | use_opencl_for_feedforward: bool, 43 | #[cfg(feature = "opencl")] 44 | use_opencl_for_attention: bool, 45 | #[cfg(feature = "opencl")] 46 | cl: Option, 47 | 48 | force_f16: bool, 49 | } 50 | 51 | // OpenCL is safe to send to threads but Rust doesn't know that 52 | unsafe impl Send for DataSettings {} 53 | unsafe impl Sync for DataSettings {} 54 | 55 | impl DataSettings { 56 | #[cfg(feature = "opencl")] 57 | pub fn new(cl: Option) -> Self { 58 | DataSettings { 59 | use_opencl_for_feedforward: false, 60 | use_opencl_for_attention: false, 61 | force_f16: false, 62 | percentage_to_gpu: 1.0, 63 | cl: cl.clone(), 64 | } 65 | } 66 | 67 | #[allow(clippy::new_without_default)] 68 | #[cfg(not(feature = "opencl"))] 69 | pub fn new() -> Self { 70 | DataSettings { force_f16: false } 71 | } 72 | 73 | #[cfg(feature = "opencl")] 74 | pub fn use_opencl(mut self) -> DataSettings { 75 | if self.cl.is_none() { 76 | panic!("OpenCL is not available, cannot call use_opencl() on DataSettings."); 77 | } 78 | self.use_opencl_for_feedforward = true; 79 | self.use_opencl_for_attention = true; 80 | self 81 | } 82 | 83 | #[cfg(feature = "opencl")] 84 | pub fn dont_use_opencl(mut self) -> DataSettings { 85 | self.use_opencl_for_feedforward = false; 86 | self.use_opencl_for_attention = false; 87 | self 88 | } 89 | 90 | #[cfg(feature = "opencl")] 91 | pub fn percentage_to_gpu(mut self, percentage: f32) -> DataSettings { 92 | self.percentage_to_gpu = percentage; 93 | if self.percentage_to_gpu >= 1.0 { 94 | self.percentage_to_gpu = 1.0; 95 | } 96 | if self.percentage_to_gpu < 0.0 { 97 | self.percentage_to_gpu = 0.0; 98 | } 99 | if self.percentage_to_gpu.is_nan() { 100 | self.percentage_to_gpu = 0.0; 101 | } 102 | self 103 | } 104 | 105 | pub fn force_f16(mut self) -> DataSettings { 106 | self.force_f16 = true; 107 | self 108 | } 109 | } 110 | 111 | pub struct TransformerCaches { 112 | layer_caches: Vec, 113 | } 114 | 115 | pub struct TransformerBlock { 116 | feed_forward: FeedForward, 117 | attn: Attention, 118 | ffn_norm: RMSNorm, 119 | attention_norm: RMSNorm, 120 | } 121 | 122 | pub struct AttentionCache { 123 | cache_k: Vec>>, 124 | cache_v: Vec>>, 125 | data_settings: DataSettings, 126 | } 127 | 128 | impl AttentionCache { 129 | fn new( 130 | max_seq_len: usize, 131 | n_local_heads: usize, 132 | head_dim: usize, 133 | data_settings: &DataSettings, 134 | ) -> Self { 135 | let mut cache_k = Vec::with_capacity(n_local_heads); 136 | let mut cache_v = Vec::with_capacity(n_local_heads); 137 | 138 | let dtype = if data_settings.force_f16 { 139 | TensorDType::Float16 140 | } else { 141 | TensorDType::Float32 142 | }; 143 | for _ in 0..n_local_heads { 144 | cache_k.push(Arc::new(RwLock::new(Tensor::zeros( 145 | head_dim as i64, 146 | max_seq_len as i64, 147 | dtype, 148 | )))); 149 | cache_v.push(Arc::new(RwLock::new(Tensor::zeros( 150 | head_dim as i64, 151 | max_seq_len as i64, 152 | dtype, 153 | )))); 154 | } 155 | AttentionCache { 156 | cache_k, 157 | cache_v, 158 | data_settings: data_settings.clone(), 159 | } 160 | } 161 | 162 | /// Cloning AttentionCache normally just makes new references to the same cache. 163 | /// This creates a true clone with copied tensors. 164 | fn true_clone(&self) -> AttentionCache { 165 | let mut cache_k = Vec::with_capacity(self.cache_k.len()); 166 | let mut cache_v = Vec::with_capacity(self.cache_v.len()); 167 | for idx in 0..self.cache_k.len() { 168 | let old_k = self.cache_k[idx].read().unwrap(); 169 | cache_k.push(Arc::new(RwLock::new(old_k.clone()))); 170 | let old_v = self.cache_v[idx].read().unwrap(); 171 | cache_v.push(Arc::new(RwLock::new(old_v.clone()))); 172 | } 173 | AttentionCache { 174 | cache_k, 175 | cache_v, 176 | data_settings: self.data_settings.clone(), 177 | } 178 | } 179 | 180 | fn shift_left(&mut self, shifts: usize) { 181 | for _ in 0..shifts { 182 | for idx in 0..self.cache_k.len() { 183 | let mut k = self.cache_k[idx].write().unwrap(); 184 | let mut v = self.cache_v[idx].write().unwrap(); 185 | let k_rows = k.rows(); 186 | let k_cols = k.cols(); 187 | for head_idx in 0..k_rows { 188 | for seq_idx in 0..k_cols - 1 { 189 | let kval = k.get_f32(head_idx, seq_idx + 1); 190 | let vval = v.get_f32(head_idx, seq_idx + 1); 191 | k.set_f32(head_idx, seq_idx, kval); 192 | v.set_f32(head_idx, seq_idx, vval); 193 | } 194 | } 195 | } 196 | } 197 | } 198 | } 199 | 200 | impl TransformerCaches { 201 | pub fn shift_left(&mut self, shifts: usize) { 202 | for layer in self.layer_caches.iter_mut() { 203 | layer.shift_left(shifts); 204 | } 205 | } 206 | 207 | pub fn true_clone(&self) -> TransformerCaches { 208 | let mut layer_caches = Vec::with_capacity(self.layer_caches.len()); 209 | for layer in self.layer_caches.iter() { 210 | layer_caches.push(layer.true_clone()); 211 | } 212 | TransformerCaches { layer_caches } 213 | } 214 | } 215 | 216 | pub struct RMSNorm { 217 | eps: f64, 218 | weight: Tensor, 219 | } 220 | 221 | #[allow(dead_code)] 222 | pub struct Attention { 223 | wq: Tensor, 224 | wk: Tensor, 225 | wv: Tensor, 226 | wo: Tensor, 227 | n_local_heads: usize, 228 | head_dim: usize, 229 | data_settings: DataSettings, 230 | } 231 | 232 | #[allow(dead_code)] 233 | pub struct FeedForward { 234 | w1: Tensor, 235 | w2: Tensor, 236 | w3: Tensor, 237 | data_settings: DataSettings, 238 | } 239 | 240 | impl Transformer { 241 | #[allow(clippy::too_many_arguments)] 242 | pub fn from_unpickled( 243 | emb: Embedding, 244 | dim: usize, 245 | n_layers: usize, 246 | n_heads: usize, 247 | max_seq_len: usize, 248 | eps: f64, 249 | data_settings: DataSettings, 250 | data_source: DataSource, 251 | ) -> Result { 252 | assert_eq!(dim % n_heads, 0); 253 | let head_dim = dim / n_heads; 254 | let n_local_heads = n_heads; // I think the local heads is an artifact of the original 255 | // implementation that used multi-GPU in the Facebook repo. 256 | // Should delete it later. 257 | 258 | let progress_bar = ProgressBar::new(n_layers as u64); 259 | let layers: Vec = (0..n_layers) 260 | .into_par_iter() 261 | .map(|layer_id| { 262 | let data_settings = { 263 | #[cfg(feature = "opencl")] 264 | { 265 | let max_layers = n_layers; 266 | let last_layer_on_gpu = (data_settings.percentage_to_gpu 267 | * (max_layers - 1) as f32) 268 | .round() as usize; 269 | if layer_id > last_layer_on_gpu { 270 | data_settings.clone().dont_use_opencl() 271 | } else { 272 | data_settings.clone() 273 | } 274 | } 275 | #[cfg(not(feature = "opencl"))] 276 | { 277 | data_settings.clone() 278 | } 279 | }; 280 | 281 | let result = TransformerBlock::from_unpickled( 282 | layer_id, 283 | eps, 284 | n_local_heads, 285 | head_dim, 286 | dim, 287 | data_settings, 288 | data_source.clone(), 289 | ); 290 | progress_bar.inc(1); 291 | result 292 | }) 293 | .collect::, UnpicklingError>>()?; 294 | std::mem::drop(progress_bar); 295 | 296 | let norm = RMSNorm::from_unpickled( 297 | "norm.weight".to_string(), 298 | "model.norm.weight".to_string(), 299 | eps, 300 | data_source.clone(), 301 | )?; 302 | let output = Tensor::from_unpickled_pieces2( 303 | "output.weight", 304 | "lm_head.weight", 305 | data_source.clone(), 306 | FromPiecesDirection::Rows, 307 | )? 308 | .to_f32(); 309 | 310 | Ok(Transformer { 311 | freqs_cis: compute_freqs_cis(dim / n_heads, max_seq_len, 10000.0), 312 | data_settings: data_settings.clone(), 313 | emb, 314 | dim, 315 | n_layers, 316 | n_heads, 317 | n_local_heads, 318 | max_seq_len, 319 | head_dim, 320 | 321 | norm, 322 | output, 323 | 324 | layers, 325 | }) 326 | } 327 | 328 | pub fn make_caches(&self) -> TransformerCaches { 329 | let mut result = vec![]; 330 | for _ in 0..self.n_layers { 331 | result.push(AttentionCache::new( 332 | self.max_seq_len, 333 | self.n_local_heads, 334 | self.head_dim, 335 | &self.data_settings, 336 | )); 337 | } 338 | TransformerCaches { 339 | layer_caches: result, 340 | } 341 | } 342 | 343 | pub fn forward( 344 | &self, 345 | tokens: &[TokenId], 346 | start_pos: usize, 347 | caches: &mut TransformerCaches, 348 | ) -> Tensor { 349 | assert!(caches.layer_caches.len() == self.n_layers); 350 | let mask: Option = if tokens.len() > 1 { 351 | Some(Tensor::full_triu( 352 | tokens.len() as i64, 353 | tokens.len() as i64, 354 | start_pos as i64 + 1, 355 | TensorDType::Float32, 356 | std::f32::NEG_INFINITY, 357 | )) 358 | } else { 359 | None 360 | }; 361 | let mut embs: Vec<&Tensor> = Vec::with_capacity(tokens.len()); 362 | for token in tokens.iter() { 363 | let emb = self.emb.get_embedding(*token as usize); 364 | embs.push(emb); 365 | } 366 | let mut emb_tensor: Tensor = Tensor::concat(&embs); 367 | std::mem::drop(embs); 368 | 369 | for (idx, layer) in self.layers.iter().enumerate() { 370 | emb_tensor = layer.forward( 371 | &emb_tensor, 372 | start_pos, 373 | &self.freqs_cis, 374 | &mask, 375 | &mut caches.layer_caches[idx], 376 | ); 377 | } 378 | let out = self.norm.forward(&emb_tensor); 379 | let out = out.row(out.rows() - 1); 380 | 381 | self.output.matrix_mul_transposed(&out) 382 | } 383 | } 384 | 385 | impl TransformerBlock { 386 | pub fn from_unpickled( 387 | layer_id: usize, 388 | eps: f64, 389 | n_local_heads: usize, 390 | head_dim: usize, 391 | dim: usize, 392 | data_settings: DataSettings, 393 | data_source: DataSource, 394 | ) -> Result { 395 | let ff = FeedForward::from_unpickled(layer_id, data_source.clone(), data_settings.clone())?; 396 | let attn = Attention::from_unpickled( 397 | layer_id, 398 | n_local_heads, 399 | head_dim, 400 | dim, 401 | data_settings, 402 | data_source.clone(), 403 | )?; 404 | let ffn_norm = RMSNorm::from_unpickled( 405 | format!("layers.{}.ffn_norm.weight", layer_id), 406 | format!("model.layers.{}.post_attention_layernorm.weight", layer_id), 407 | eps, 408 | data_source.clone(), 409 | )?; 410 | let attn_norm = RMSNorm::from_unpickled( 411 | format!("layers.{}.attention_norm.weight", layer_id), 412 | format!("model.layers.{}.input_layernorm.weight", layer_id), 413 | eps, 414 | data_source, 415 | )?; 416 | Ok(Self { 417 | feed_forward: ff, 418 | attn, 419 | ffn_norm, 420 | attention_norm: attn_norm, 421 | }) 422 | } 423 | 424 | pub fn forward( 425 | &self, 426 | x: &Tensor, 427 | start_pos: usize, 428 | freqs_cis: &FreqsCis, 429 | mask: &Option, 430 | attention_cache: &mut AttentionCache, 431 | ) -> Tensor { 432 | let mut attnorm_out = self.attention_norm.forward(x); 433 | let att_out = self.attn.forward( 434 | &mut attnorm_out, 435 | start_pos, 436 | freqs_cis, 437 | mask, 438 | attention_cache, 439 | ); 440 | std::mem::drop(attnorm_out); 441 | 442 | let h = x.add(&att_out); 443 | let mut att_out = self.ffn_norm.forward(&h); 444 | let att_out = self.feed_forward.forward(&mut att_out).transpose(); 445 | h.add(&att_out) 446 | } 447 | } 448 | 449 | impl RMSNorm { 450 | pub fn from_unpickled( 451 | name: String, 452 | name2: String, 453 | eps: f64, 454 | data_source: DataSource, 455 | ) -> Result { 456 | let weights = match Tensor::from_unpickled_pieces1( 457 | name, 458 | data_source.clone(), 459 | FromPiecesDirection::Rows, 460 | ) { 461 | Ok(w) => w, 462 | Err(_) => Tensor::from_unpickled_pieces1( 463 | name2, 464 | data_source.clone(), 465 | FromPiecesDirection::Rows, 466 | )?, 467 | }; 468 | let weights = weights.to_f32(); 469 | 470 | Ok(Self { 471 | eps, 472 | weight: weights, 473 | }) 474 | } 475 | 476 | fn forward(&self, x: &Tensor) -> Tensor { 477 | let inner = x.pow(2.0).mean_cols().add_scalar(self.eps as f32); 478 | let out1 = x.scalar_multiply_broadcast(&inner.rsqrt()); 479 | out1.hadamard_product_broadcast(&self.weight) 480 | } 481 | } 482 | 483 | impl FeedForward { 484 | pub fn from_unpickled( 485 | layer_id: usize, 486 | data_source: DataSource, 487 | data_settings: DataSettings, 488 | ) -> Result { 489 | let mut w1 = Tensor::from_unpickled_pieces2( 490 | format!("layers.{}.feed_forward.w1.weight", layer_id), 491 | format!("model.layers.{}.mlp.gate_proj.weight", layer_id), 492 | data_source.clone(), 493 | FromPiecesDirection::Rows, 494 | )?; 495 | let mut w2 = Tensor::from_unpickled_pieces2( 496 | format!("layers.{}.feed_forward.w2.weight", layer_id), 497 | format!("model.layers.{}.mlp.down_proj.weight", layer_id), 498 | data_source.clone(), 499 | FromPiecesDirection::Cols, 500 | )?; 501 | let mut w3 = Tensor::from_unpickled_pieces2( 502 | format!("layers.{}.feed_forward.w3.weight", layer_id), 503 | format!("model.layers.{}.mlp.up_proj.weight", layer_id), 504 | data_source.clone(), 505 | FromPiecesDirection::Rows, 506 | )?; 507 | 508 | if data_settings.force_f16 { 509 | w1 = w1.to_f16(); 510 | w2 = w2.to_f16(); 511 | w3 = w3.to_f16(); 512 | } 513 | 514 | #[cfg(feature = "opencl")] 515 | { 516 | if data_settings.use_opencl_for_feedforward { 517 | w1 = w1.to_f16(); 518 | w2 = w2.to_f16(); 519 | w3 = w3.to_f16(); 520 | let ds = data_settings.clone(); 521 | w1.to_gpu_inplace(&ds.cl.as_ref().unwrap().clone()).unwrap(); 522 | w2.to_gpu_inplace(&ds.cl.as_ref().unwrap().clone()).unwrap(); 523 | w3.to_gpu_inplace(&ds.cl.unwrap()).unwrap(); 524 | } 525 | } 526 | // w1, w2, w3 maybe be f32 or f16 depending on source data. 527 | 528 | Ok(Self { 529 | w1, 530 | w2, 531 | w3, 532 | data_settings, 533 | }) 534 | } 535 | 536 | pub fn forward(&self, x: &mut Tensor) -> Tensor { 537 | let original_x_dtype = x.dtype(); 538 | if x.dtype() != self.w1.dtype() { 539 | *x = x.to_same_type(&self.w1); 540 | } 541 | #[cfg(feature = "opencl")] 542 | let x_was_on_cpu: bool; 543 | #[cfg(feature = "opencl")] 544 | { 545 | x_was_on_cpu = x.is_on_cpu(); 546 | if self.data_settings.use_opencl_for_feedforward { 547 | x.to_gpu_inplace(self.data_settings.cl.as_ref().unwrap()) 548 | .unwrap(); 549 | } 550 | } 551 | let (mut w1_out, mut w3_out) = rayon::join( 552 | || self.w1.matrix_mul_transposed(x), 553 | || self.w3.matrix_mul_transposed(x), 554 | ); 555 | 556 | // Float16 not supported for some of these ops on CPU. 557 | if w1_out.is_on_cpu() && w1_out.dtype() == TensorDType::Float16 { 558 | w1_out = w1_out.to_f32(); 559 | w3_out = w3_out.to_f32(); 560 | } 561 | let w1_out = w1_out.silu(); 562 | let mut w1w3_out = w1_out.hadamard_product(&w3_out).transpose(); 563 | if w1w3_out.dtype() != self.w2.dtype() { 564 | w1w3_out = w1w3_out.to_same_type(&self.w2); 565 | } 566 | #[cfg(not(feature = "opencl"))] 567 | { 568 | self.w2 569 | .matrix_mul_transposed(&w1w3_out) 570 | .into_dtype(original_x_dtype) 571 | } 572 | #[cfg(feature = "opencl")] 573 | { 574 | let mut result = self.w2.matrix_mul_transposed(&w1w3_out); 575 | if x_was_on_cpu { 576 | result.to_cpu_inplace().unwrap(); 577 | result 578 | } else { 579 | result 580 | } 581 | } 582 | } 583 | } 584 | 585 | impl Attention { 586 | pub fn from_unpickled( 587 | layer_id: usize, 588 | n_local_heads: usize, 589 | head_dim: usize, 590 | dim: usize, 591 | data_settings: DataSettings, 592 | data_source: DataSource, 593 | ) -> Result { 594 | let mut wq = Tensor::from_unpickled_pieces2( 595 | format!("layers.{}.attention.wq.weight", layer_id), 596 | format!("model.layers.{}.self_attn.q_proj.weight", layer_id), 597 | data_source.clone(), 598 | FromPiecesDirection::Rows, 599 | )?; 600 | let mut wk = Tensor::from_unpickled_pieces2( 601 | format!("layers.{}.attention.wk.weight", layer_id), 602 | format!("model.layers.{}.self_attn.k_proj.weight", layer_id), 603 | data_source.clone(), 604 | FromPiecesDirection::Rows, 605 | )?; 606 | let mut wv = Tensor::from_unpickled_pieces2( 607 | format!("layers.{}.attention.wv.weight", layer_id), 608 | format!("model.layers.{}.self_attn.v_proj.weight", layer_id), 609 | data_source.clone(), 610 | FromPiecesDirection::Rows, 611 | )?; 612 | let mut wo = Tensor::from_unpickled_pieces2( 613 | format!("layers.{}.attention.wo.weight", layer_id), 614 | format!("model.layers.{}.self_attn.o_proj.weight", layer_id), 615 | data_source.clone(), 616 | FromPiecesDirection::Cols, 617 | )?; 618 | 619 | if data_source.need_to_do_antitranspose() { 620 | wq = wq.huggingface_llama_model_antitranspose(n_local_heads, dim); 621 | wk = wk.huggingface_llama_model_antitranspose(n_local_heads, dim); 622 | } 623 | 624 | if data_settings.force_f16 { 625 | wq = wq.to_f16(); 626 | wk = wk.to_f16(); 627 | wv = wv.to_f16(); 628 | wo = wo.to_f16(); 629 | } 630 | 631 | #[cfg(feature = "opencl")] 632 | { 633 | if data_settings.use_opencl_for_attention { 634 | wq = wq.to_f16(); 635 | wk = wk.to_f16(); 636 | wv = wv.to_f16(); 637 | wo = wo.to_f16(); 638 | let ds = data_settings.clone(); 639 | wq.to_gpu_inplace(&ds.cl.as_ref().unwrap().clone()).unwrap(); 640 | wk.to_gpu_inplace(&ds.cl.as_ref().unwrap().clone()).unwrap(); 641 | wv.to_gpu_inplace(&ds.cl.as_ref().unwrap().clone()).unwrap(); 642 | wo.to_gpu_inplace(&ds.cl.unwrap()).unwrap(); 643 | } 644 | } 645 | 646 | Ok(Self { 647 | wq, 648 | wk, 649 | wv, 650 | wo, 651 | n_local_heads, 652 | head_dim, 653 | data_settings, 654 | }) 655 | } 656 | 657 | fn forward( 658 | &self, 659 | x: &mut Tensor, 660 | start_pos: usize, 661 | freqs_cis: &FreqsCis, 662 | mask: &Option, 663 | attention_cache: &mut AttentionCache, 664 | ) -> Tensor { 665 | let original_x_dtype = x.dtype(); 666 | if x.dtype() != self.wq.dtype() { 667 | *x = x.to_same_type(&self.wq); 668 | } 669 | 670 | #[cfg(feature = "opencl")] 671 | let x_was_on_cpu: bool; 672 | #[cfg(feature = "opencl")] 673 | { 674 | x_was_on_cpu = x.is_on_cpu(); 675 | if self.data_settings.use_opencl_for_attention { 676 | x.to_gpu_inplace(self.data_settings.cl.as_ref().unwrap()) 677 | .unwrap(); 678 | } 679 | } 680 | 681 | let seq_len = x.rows(); 682 | #[cfg(feature = "opencl")] 683 | let (xq_out, xk_out, xv_out) = { 684 | let mut xq_out = x.matrix_mul_transposed(&self.wq); 685 | let mut xk_out = x.matrix_mul_transposed(&self.wk); 686 | let mut xv_out = x.matrix_mul_transposed(&self.wv); 687 | xq_out.to_cpu_inplace().unwrap(); 688 | xk_out.to_cpu_inplace().unwrap(); 689 | xv_out.to_cpu_inplace().unwrap(); 690 | (xq_out.to_f32(), xk_out.to_f32(), xv_out.to_f32()) 691 | }; 692 | 693 | #[cfg(not(feature = "opencl"))] 694 | let (xq_out, (xk_out, xv_out)) = rayon::join( 695 | || x.matrix_mul_transposed(&self.wq).to_f32(), 696 | || { 697 | rayon::join( 698 | || x.matrix_mul_transposed(&self.wk).to_f32(), 699 | || x.matrix_mul_transposed(&self.wv).to_f32(), 700 | ) 701 | }, 702 | ); 703 | 704 | let mut xq_views: Vec = Vec::with_capacity(seq_len as usize); 705 | let mut xk_views: Vec = Vec::with_capacity(seq_len as usize); 706 | let mut xv_views: Vec = Vec::with_capacity(seq_len as usize); 707 | 708 | for idx in 0..seq_len { 709 | let xq_row = xq_out 710 | .row(idx) 711 | .view(self.n_local_heads as i64, self.head_dim as i64); 712 | let xk_row = xk_out 713 | .row(idx) 714 | .view(self.n_local_heads as i64, self.head_dim as i64); 715 | let xv_row = xv_out 716 | .row(idx) 717 | .view(self.n_local_heads as i64, self.head_dim as i64); 718 | 719 | let (xq_row, xk_row) = 720 | apply_rotary_emb(&xq_row, &xk_row, freqs_cis, idx as usize, start_pos); 721 | 722 | xq_views.push(xq_row); 723 | xk_views.push(xk_row); 724 | xv_views.push(xv_row); 725 | } 726 | 727 | let output: Vec = (0..self.n_local_heads) 728 | .into_par_iter() 729 | .map(|idx| { 730 | let mut concat_vec: Vec = vec![]; 731 | for idx2 in 0..seq_len { 732 | concat_vec.push(xq_views[idx2 as usize].row(idx as i64)); 733 | } 734 | let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); 735 | let xq_row = Tensor::concat(&concat_vec2); 736 | 737 | concat_vec.truncate(0); 738 | for idx2 in 0..seq_len { 739 | concat_vec.push(xk_views[idx2 as usize].row(idx as i64)); 740 | } 741 | let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); 742 | let xk_row = Tensor::concat(&concat_vec2).transpose(); 743 | 744 | concat_vec.truncate(0); 745 | for idx2 in 0..seq_len { 746 | concat_vec.push(xv_views[idx2 as usize].row(idx as i64)); 747 | } 748 | let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); 749 | let xv_row = Tensor::concat(&concat_vec2); 750 | 751 | let mut cache_k = attention_cache.cache_k[idx].write().unwrap(); 752 | let mut cache_v = attention_cache.cache_v[idx].write().unwrap(); 753 | 754 | for pos in start_pos..start_pos + seq_len as usize { 755 | for dim in 0..self.head_dim { 756 | let k = xk_row.get_f32(dim as i64, (pos - start_pos) as i64); 757 | cache_k.set_f32(dim as i64, pos as i64, k); 758 | let v = xv_row.get_f32((pos - start_pos) as i64, dim as i64); 759 | cache_v.set_f32(dim as i64, pos as i64, v); 760 | } 761 | } 762 | let keys = cache_k.clip_cols(start_pos + seq_len as usize); 763 | let values = cache_v.clip_cols(start_pos + seq_len as usize); 764 | 765 | let keys = keys.into_same_type(&xq_row); 766 | let values = values.into_same_type(&xq_row); 767 | 768 | let m = xq_row 769 | .matrix_mul(&keys) 770 | .scalar_multiply_f32(1.0 / (self.head_dim as f32).sqrt()); 771 | 772 | match mask { 773 | Some(ref mask) => m 774 | .add(mask) 775 | .to_f32() 776 | .softmax() 777 | .matrix_mul_transposed(&values), 778 | None => m.softmax().matrix_mul_transposed(&values), 779 | } 780 | }) 781 | .collect(); 782 | 783 | let output2: Vec = (0..seq_len) 784 | .into_par_iter() 785 | .map(|idx| { 786 | let mut concat_vec: Vec = vec![]; 787 | for output in &output { 788 | concat_vec.push(output.row(idx)); 789 | } 790 | let concat_vec2: Vec<&Tensor> = concat_vec.iter().collect(); 791 | #[cfg(not(feature = "opencl"))] 792 | { 793 | let xq_row = Tensor::concat(&concat_vec2).view(1, self.wo.rows()); 794 | xq_row 795 | .into_same_type(&self.wo) 796 | .matrix_mul_transposed(&self.wo) 797 | } 798 | #[cfg(feature = "opencl")] 799 | { 800 | let mut xq_row = Tensor::concat(&concat_vec2) 801 | .view(1, self.wo.rows()) 802 | .to_f16(); 803 | if self.wo.is_on_gpu() { 804 | xq_row 805 | .to_gpu_inplace(&self.data_settings.cl.as_ref().unwrap()) 806 | .unwrap(); 807 | let mut result = xq_row.matrix_mul_transposed(&self.wo); 808 | result.to_cpu_inplace().unwrap(); 809 | result.to_f32() 810 | } else { 811 | xq_row.matrix_mul_transposed(&self.wo) 812 | } 813 | } 814 | }) 815 | .collect(); 816 | 817 | let output3: Vec<&Tensor> = output2.iter().collect(); 818 | let output2: Tensor = Tensor::concat(&output3); 819 | output2.into_dtype(original_x_dtype) 820 | } 821 | } 822 | 823 | fn apply_rotary_emb( 824 | xq: &Tensor, 825 | xk: &Tensor, 826 | freqs_cis: &FreqsCis, 827 | seq_idx: usize, 828 | start_pos: usize, 829 | ) -> (Tensor, Tensor) { 830 | assert!(xq.cols() % 2 == 0); 831 | assert!(xk.cols() % 2 == 0); 832 | let mut xq_out: Tensor = xq.clone(); 833 | let mut xk_out: Tensor = xk.clone(); 834 | for row in 0..xq.rows() { 835 | for col in 0..xq.cols() / 2 { 836 | let f_real = freqs_cis[seq_idx + start_pos][col as usize].re as f32; 837 | let f_imag = freqs_cis[seq_idx + start_pos][col as usize].im as f32; 838 | let xq_real = xq.get_f32(row, col * 2); 839 | let xq_imag = xq.get_f32(row, col * 2 + 1); 840 | let xk_real = xk.get_f32(row, col * 2); 841 | let xk_imag = xk.get_f32(row, col * 2 + 1); 842 | 843 | // multiply with freqs_cis 844 | let xq_realpart = xq_real * f_real - xq_imag * f_imag; 845 | let xq_imagpart = xq_real * f_imag + xq_imag * f_real; 846 | let xk_realpart = xk_real * f_real - xk_imag * f_imag; 847 | let xk_imagpart = xk_real * f_imag + xk_imag * f_real; 848 | 849 | xq_out.set_f32(row, col * 2, xq_realpart); 850 | xq_out.set_f32(row, col * 2 + 1, xq_imagpart); 851 | xk_out.set_f32(row, col * 2, xk_realpart); 852 | xk_out.set_f32(row, col * 2 + 1, xk_imagpart); 853 | } 854 | } 855 | (xq_out, xk_out) 856 | } 857 | 858 | fn compute_freqs_cis(dim: usize, end: usize, theta: f64) -> FreqsCis { 859 | let mut freqs = Vec::new(); 860 | for idx in 0..(dim / 2) { 861 | let freq = 1.0 / (theta.powf(idx as f64 * 2.0 / dim as f64)); 862 | freqs.push(freq); 863 | } 864 | 865 | let mut result: Vec> = Vec::new(); 866 | for x in 0..end { 867 | let mut row = Vec::new(); 868 | for freq in freqs.iter() { 869 | let freq = freq * (x as f64); 870 | row.push(freq); 871 | } 872 | result.push(row); 873 | } 874 | 875 | let mut resultc: Vec>> = Vec::new(); 876 | for row in result.into_iter() { 877 | let mut rowc = Vec::new(); 878 | for freq in row { 879 | let cis = Complex::from_polar(1.0, freq); 880 | rowc.push(cis); 881 | } 882 | resultc.push(rowc); 883 | } 884 | resultc 885 | } 886 | -------------------------------------------------------------------------------- /src/unpickler.rs: -------------------------------------------------------------------------------- 1 | use std::collections::{BTreeMap, BTreeSet}; 2 | use std::path::PathBuf; 3 | 4 | pub struct Unpickler {} 5 | 6 | use crate::tensor::{TensorBuilder, TensorDType, TensorError}; 7 | use thiserror::Error; 8 | 9 | #[derive(Error, Debug)] 10 | pub enum UnpicklingError { 11 | #[error("Unpickling error: {0}")] 12 | UnpicklingError(String), 13 | #[error("UTF-8 decoding error")] 14 | Utf8Error(#[from] std::str::Utf8Error), 15 | #[error("Missing field")] 16 | MissingField(String), 17 | #[error("Tensor conversion operation failed")] 18 | TensorError(#[from] TensorError), 19 | #[error("Data has incorrect format to be converted to a tensor")] 20 | InvalidTensorData, 21 | } 22 | 23 | #[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)] 24 | pub enum Value { 25 | Mark(usize), 26 | String(String), 27 | Global(String, String), // module name, attribute name 28 | Integer64(i64), 29 | Tuple(Vec), 30 | PersistentId(Box), 31 | Bool(bool), 32 | Reduce(Box, Box), 33 | Dict(BTreeMap), 34 | } 35 | 36 | impl Value { 37 | // Gets a value from a dictionary, assuming Value is a dictionary. 38 | // 39 | // Returns None if the key is not found, or the value is not a dictionary. 40 | pub fn get(&self, key: &Value) -> Option<&Value> { 41 | match self { 42 | Value::Dict(d) => d.get(key), 43 | _ => None, 44 | } 45 | } 46 | 47 | // Same as get() but uses a string as key. 48 | pub fn get_str_key>(&self, key: S) -> Option<&Value> { 49 | self.get(&Value::String(key.as_ref().to_string())) 50 | } 51 | 52 | // Same as get_str_key but tries two keys, returning the first one that is found. 53 | pub fn get_str_key2, S2: AsRef>( 54 | &self, 55 | key: S, 56 | key2: S2, 57 | ) -> Option<(String, &Value)> { 58 | let key = key.as_ref(); 59 | let key2 = key2.as_ref(); 60 | match self.get_str_key(key) { 61 | Some(v) => Some((key.to_string(), v)), 62 | None => match self.get_str_key(key2) { 63 | Some(v) => Some((key2.to_string(), v)), 64 | None => None, 65 | }, 66 | } 67 | } 68 | 69 | // Returns all keys as a set of strings, if the value is a dictionary. Otherwise returns empty set. 70 | pub fn keys(&self) -> BTreeSet { 71 | match self { 72 | Value::Dict(d) => { 73 | let mut result = BTreeSet::new(); 74 | for (k, _v) in d.iter() { 75 | match k { 76 | Value::String(s) => { 77 | result.insert(s.clone()); 78 | } 79 | _ => {} 80 | } 81 | } 82 | result 83 | } 84 | _ => BTreeSet::new(), 85 | } 86 | } 87 | 88 | // Merges value dictionaries together 89 | // 90 | // Panics if there are duplicate keys. 91 | pub fn merge_dicts(dicts: &[Self]) -> Self { 92 | if dicts.is_empty() { 93 | return Value::Dict(BTreeMap::new()); 94 | } 95 | let mut result = dicts[0].clone(); 96 | for dict in dicts.iter().skip(1) { 97 | match (result, dict) { 98 | (Value::Dict(mut d1), Value::Dict(d2)) => { 99 | for (k, v) in d2 { 100 | d1.insert(k.clone(), v.clone()); 101 | } 102 | result = Value::Dict(d1); 103 | } 104 | _ => panic!("Can only merge dictionaries"), 105 | } 106 | } 107 | result 108 | } 109 | 110 | pub fn get_global(&self) -> Option<(&str, &str)> { 111 | match self { 112 | Value::Global(module_name, attribute_name) => Some((module_name, attribute_name)), 113 | _ => None, 114 | } 115 | } 116 | 117 | pub fn get_str(&self) -> Option<&str> { 118 | match self { 119 | Value::String(s) => Some(s), 120 | _ => None, 121 | } 122 | } 123 | 124 | pub fn get_int64(&self) -> Option { 125 | match self { 126 | Value::Integer64(i) => Some(*i), 127 | _ => None, 128 | } 129 | } 130 | 131 | pub fn get_persistent_id(&self) -> Option<&Value> { 132 | match self { 133 | Value::PersistentId(v) => Some(v), 134 | _ => None, 135 | } 136 | } 137 | 138 | pub fn get_tuple(&self) -> Option<&[Value]> { 139 | match self { 140 | Value::Tuple(v) => Some(v), 141 | _ => None, 142 | } 143 | } 144 | 145 | // Assume that the value represents a tensor in PyTorch and return instructions how to actually 146 | // load the values. 147 | pub fn to_tensor_builder(&self, tensor_name: String) -> Option { 148 | match self { 149 | Value::Reduce(call, args) => match **call { 150 | Value::Global(ref module_name, ref attribute_name) => { 151 | if module_name == "torch._utils" && attribute_name == "_rebuild_tensor_v2" { 152 | match **args { 153 | Value::Tuple(ref args) => self.to_tensor_builder2(tensor_name, args), 154 | _ => None, 155 | } 156 | } else { 157 | None 158 | } 159 | } 160 | _ => None, 161 | }, 162 | _ => None, 163 | } 164 | } 165 | 166 | fn to_tensor_builder2(&self, tensor_name: String, args: &[Value]) -> Option { 167 | if args.len() == 6 { 168 | Self::to_tensor_builder2_6items(tensor_name, args) 169 | } else { 170 | None 171 | } 172 | } 173 | 174 | fn to_tensor_builder2_6items(tensor_name: String, args: &[Value]) -> Option { 175 | let storagev: &Value = args[0].get_persistent_id()?; 176 | let storage_args: &[Value] = storagev.get_tuple()?; 177 | let storage_mark: &str = storage_args[0].get_str()?; 178 | if storage_mark != "storage" { 179 | return None; 180 | } 181 | 182 | let (storage_module, storage_type) = storage_args[1].get_global()?; 183 | if storage_module != "torch" { 184 | return None; 185 | } 186 | let dtype: TensorDType = match storage_type { 187 | "HalfStorage" => TensorDType::Float16, 188 | _ => { 189 | return None; 190 | } 191 | }; 192 | let storage_filename: &str = storage_args[2].get_str()?; 193 | let nitems: i64 = storage_args[4].get_int64()?; 194 | 195 | let offset: i64 = args[1].get_int64()?; 196 | 197 | let shape: &[Value] = args[2].get_tuple()?; 198 | let stride: &[Value] = args[3].get_tuple()?; 199 | 200 | if shape.len() != 2 && shape.len() != 1 { 201 | return None; 202 | } 203 | if stride.len() != 2 && stride.len() != 1 { 204 | return None; 205 | } 206 | 207 | let (rows, cols) = if shape.len() == 2 { 208 | (shape[0].get_int64()?, shape[1].get_int64()?) 209 | } else { 210 | let cols = shape[0].get_int64()?; 211 | (1, cols) 212 | }; 213 | 214 | let (row_stride, col_stride) = if stride.len() == 1 { 215 | let (r, c) = (stride[0].get_int64()?, 1); 216 | if r != 1 { 217 | return None; 218 | } 219 | (r, c) 220 | } else { 221 | (stride[0].get_int64()?, stride[1].get_int64()?) 222 | }; 223 | 224 | if col_stride != 1 { 225 | return None; 226 | } 227 | if row_stride != cols && stride.len() == 2 { 228 | return None; 229 | } 230 | 231 | Some(TensorBuilder { 232 | src_path: PathBuf::from(storage_filename), 233 | tensor_name, 234 | dtype, 235 | stride: row_stride, 236 | rows, 237 | cols, 238 | nitems, 239 | offset, 240 | }) 241 | 242 | /* Args should look like this (took random example from debug print) : 243 | 0 PERSISTENT_ID 244 | TUPLE 245 | STRING "storage" 246 | GLOBAL "torch" "HalfStorage" 247 | STRING "0" (filename) 248 | STRING "cpu" 249 | INTEGER 131072000 (number of items) 250 | 1 INTEGER 0 251 | 2 TUPLE 252 | INTEGER 32000 253 | INTEGER 4096 254 | 3 TUPLE 255 | INTEGER 4096 256 | INTEGER 1 257 | 4 BOOL false (this is about gradient) 258 | 5 REDUCE (no idea why this is here) 259 | GLOBAL "collections" "OrderedDict" 260 | TUPLE 261 | 262 | Sometimes arguments 2 and 3 are missing. 263 | */ 264 | } 265 | 266 | // Print a nice representation of the value to stdout. Used for good old printf debugging. 267 | pub fn debug_print(&self) { 268 | self.debug_print_go(0); 269 | } 270 | 271 | fn debug_print_go(&self, indent: usize) { 272 | if indent > 0 { 273 | print!("{:indent$}", "", indent = indent); 274 | } 275 | match self { 276 | Value::Mark(_) => { 277 | println!("MARK"); 278 | } 279 | Value::String(s) => { 280 | println!("STRING {:?}", s); 281 | } 282 | Value::Global(module_name, attribute_name) => { 283 | println!("GLOBAL {:?} {:?}", module_name, attribute_name); 284 | } 285 | Value::Integer64(i) => { 286 | println!("INTEGER {:?}", i); 287 | } 288 | Value::Tuple(v) => { 289 | println!("TUPLE"); 290 | for i in v { 291 | i.debug_print_go(indent + 2); 292 | } 293 | } 294 | Value::PersistentId(v) => { 295 | println!("PERSISTENT_ID"); 296 | v.debug_print_go(indent + 2); 297 | } 298 | Value::Bool(b) => { 299 | println!("BOOL {:?}", b); 300 | } 301 | Value::Reduce(v1, v2) => { 302 | println!("REDUCE"); 303 | v1.debug_print_go(indent + 2); 304 | v2.debug_print_go(indent + 2); 305 | } 306 | Value::Dict(d) => { 307 | println!("DICT"); 308 | for (k, v) in d { 309 | k.debug_print_go(indent + 2); 310 | v.debug_print_go(indent + 2); 311 | } 312 | } 313 | } 314 | } 315 | } 316 | 317 | pub fn unpickle(bytes: &[u8]) -> Result { 318 | // The LLaMA file is in pickle 2 format, check that header is there 319 | if bytes.len() < 2 { 320 | return Err(UnpicklingError::UnpicklingError( 321 | "Data is too short to be a pickle".to_string(), 322 | )); 323 | } 324 | 325 | if bytes[0] != 128 || bytes[1] != 2 { 326 | return Err(UnpicklingError::UnpicklingError( 327 | "No magic header using Pickle 2 protocol".to_string(), 328 | )); 329 | } 330 | 331 | let mut memo: BTreeMap = BTreeMap::new(); 332 | let mut stack: Vec = vec![]; 333 | 334 | // Decode frames 335 | let mut bytes: &[u8] = &bytes[2..]; 336 | while !bytes.is_empty() { 337 | let frame_opcode = bytes[0]; 338 | if frame_opcode == 125 { 339 | // empty dict 340 | stack.push(Value::Dict(BTreeMap::new())); 341 | bytes = &bytes[1..]; 342 | continue; 343 | } 344 | if frame_opcode == 113 { 345 | // binput 346 | if bytes.len() < 2 { 347 | return Err(UnpicklingError::UnpicklingError( 348 | "Unexpected end of data while handling BINPUT".to_string(), 349 | )); 350 | } 351 | if stack.is_empty() { 352 | return Err(UnpicklingError::UnpicklingError( 353 | "Stack is empty while handling BINPUT".to_string(), 354 | )); 355 | } 356 | let key = bytes[1]; 357 | memo.insert(key as u32, stack.last().unwrap().clone()); 358 | bytes = &bytes[2..]; 359 | continue; 360 | } 361 | if frame_opcode == 40 { 362 | // mark 363 | stack.push(Value::Mark(stack.len())); 364 | bytes = &bytes[1..]; 365 | continue; 366 | } 367 | if frame_opcode == 88 { 368 | // binunicode 369 | if bytes.len() < 5 { 370 | return Err(UnpicklingError::UnpicklingError( 371 | "Unexpected end of data while handling BINUNICODE".to_string(), 372 | )); 373 | } 374 | let len = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); 375 | if bytes.len() < 5 + len as usize { 376 | return Err(UnpicklingError::UnpicklingError( 377 | "Unexpected end of data while handling BINUNICODE".to_string(), 378 | )); 379 | } 380 | let string = std::str::from_utf8(&bytes[5..5 + len as usize])?; 381 | stack.push(Value::String(string.to_string())); 382 | bytes = &bytes[5 + len as usize..]; 383 | continue; 384 | } 385 | if frame_opcode == 99 { 386 | // global 387 | // followed by newline terminated module name and attribute name 388 | bytes = &bytes[1..]; 389 | let mut module_name = String::new(); 390 | while !bytes.is_empty() && bytes[0] != 10 { 391 | module_name.push(bytes[0] as char); 392 | bytes = &bytes[1..]; 393 | if bytes.is_empty() { 394 | return Err(UnpicklingError::UnpicklingError( 395 | "Unexpected end of data while handling GLOBAL".to_string(), 396 | )); 397 | } 398 | } 399 | bytes = &bytes[1..]; 400 | let mut attribute_name = String::new(); 401 | while !bytes.is_empty() && bytes[0] != 10 { 402 | attribute_name.push(bytes[0] as char); 403 | bytes = &bytes[1..]; 404 | if bytes.is_empty() { 405 | return Err(UnpicklingError::UnpicklingError( 406 | "Unexpected end of data while handling GLOBAL".to_string(), 407 | )); 408 | } 409 | } 410 | bytes = &bytes[1..]; 411 | stack.push(Value::Global(module_name, attribute_name)); 412 | continue; 413 | } 414 | if frame_opcode == 74 { 415 | // binint 416 | if bytes.len() < 5 { 417 | return Err(UnpicklingError::UnpicklingError( 418 | "Unexpected end of data while handling BININT".to_string(), 419 | )); 420 | } 421 | let value = i32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); 422 | stack.push(Value::Integer64(value as i64)); 423 | bytes = &bytes[5..]; 424 | continue; 425 | } 426 | if frame_opcode == 116 { 427 | // tuple 428 | let mut tuple = vec![]; 429 | if stack.is_empty() { 430 | return Err(UnpicklingError::UnpicklingError( 431 | "Stack is empty while handling TUPLE".to_string(), 432 | )); 433 | } 434 | let mut ok = false; 435 | while !stack.is_empty() { 436 | let top = stack.pop().unwrap(); 437 | if let Value::Mark(_mark) = top { 438 | tuple.reverse(); 439 | stack.push(Value::Tuple(tuple)); 440 | ok = true; 441 | break; 442 | } 443 | tuple.push(top); 444 | } 445 | if !ok { 446 | return Err(UnpicklingError::UnpicklingError( 447 | "No mark while handling TUPLE".to_string(), 448 | )); 449 | } 450 | bytes = &bytes[1..]; 451 | continue; 452 | } 453 | if frame_opcode == 81 { 454 | // binpersid 455 | if stack.is_empty() { 456 | return Err(UnpicklingError::UnpicklingError( 457 | "Stack is empty while handling BINPERSID".to_string(), 458 | )); 459 | } 460 | let top = stack.pop().unwrap(); 461 | stack.push(Value::PersistentId(Box::new(top))); 462 | bytes = &bytes[1..]; 463 | continue; 464 | } 465 | if frame_opcode == 75 { 466 | // binint1 467 | if bytes.len() < 2 { 468 | return Err(UnpicklingError::UnpicklingError( 469 | "Unexpected end of data while handling BININT1".to_string(), 470 | )); 471 | } 472 | let value = bytes[1]; 473 | stack.push(Value::Integer64(value as i64)); 474 | bytes = &bytes[2..]; 475 | continue; 476 | } 477 | if frame_opcode == 77 { 478 | // binint2 479 | if bytes.len() < 3 { 480 | return Err(UnpicklingError::UnpicklingError( 481 | "Unexpected end of data while handling BININT2".to_string(), 482 | )); 483 | } 484 | let value = i16::from_le_bytes([bytes[1], bytes[2]]); 485 | stack.push(Value::Integer64(value as i64)); 486 | bytes = &bytes[3..]; 487 | continue; 488 | } 489 | if frame_opcode == 134 { 490 | // tuple2 491 | let mut tuple = vec![]; 492 | if stack.len() < 2 { 493 | return Err(UnpicklingError::UnpicklingError( 494 | "Stack does not have enough items while handling TUPLE2".to_string(), 495 | )); 496 | } 497 | tuple.push(stack.pop().unwrap()); 498 | tuple.push(stack.pop().unwrap()); 499 | tuple.reverse(); 500 | stack.push(Value::Tuple(tuple)); 501 | bytes = &bytes[1..]; 502 | continue; 503 | } 504 | if frame_opcode == 137 { 505 | // newfalse 506 | stack.push(Value::Bool(false)); 507 | bytes = &bytes[1..]; 508 | continue; 509 | } 510 | if frame_opcode == 41 { 511 | // empty tuple 512 | stack.push(Value::Tuple(vec![])); 513 | bytes = &bytes[1..]; 514 | continue; 515 | } 516 | if frame_opcode == 82 { 517 | // reduce 518 | if stack.len() < 2 { 519 | return Err(UnpicklingError::UnpicklingError( 520 | "Stack does not have enough items while handling REDUCE".to_string(), 521 | )); 522 | } 523 | let arg_tuple = stack.pop().unwrap(); 524 | let callable = stack.pop().unwrap(); 525 | stack.push(Value::Reduce(Box::new(callable), Box::new(arg_tuple))); 526 | bytes = &bytes[1..]; 527 | continue; 528 | } 529 | if frame_opcode == 104 { 530 | // binget 531 | if bytes.len() < 2 { 532 | return Err(UnpicklingError::UnpicklingError( 533 | "Unexpected end of data while handling BINGET".to_string(), 534 | )); 535 | } 536 | let idx = bytes[1]; 537 | match memo.get(&(idx as u32)) { 538 | None => { 539 | return Err(UnpicklingError::UnpicklingError( 540 | "BINGET index out of range".to_string(), 541 | )); 542 | } 543 | Some(memo_value) => { 544 | stack.push(memo_value.clone()); 545 | } 546 | } 547 | bytes = &bytes[2..]; 548 | continue; 549 | } 550 | if frame_opcode == 133 { 551 | // tuple1 552 | let mut tuple = vec![]; 553 | if stack.is_empty() { 554 | return Err(UnpicklingError::UnpicklingError( 555 | "Stack is empty while handling TUPLE1".to_string(), 556 | )); 557 | } 558 | tuple.push(stack.pop().unwrap()); 559 | stack.push(Value::Tuple(tuple)); 560 | bytes = &bytes[1..]; 561 | continue; 562 | } 563 | if frame_opcode == 114 { 564 | // long binput 565 | if bytes.len() < 5 { 566 | return Err(UnpicklingError::UnpicklingError( 567 | "Unexpected end of data while handling LONG_BINPUT".to_string(), 568 | )); 569 | } 570 | let key = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); 571 | if stack.is_empty() { 572 | return Err(UnpicklingError::UnpicklingError( 573 | "Stack is empty while handling LONG_BINPUT".to_string(), 574 | )); 575 | } 576 | memo.insert(key, stack.last().unwrap().clone()); 577 | bytes = &bytes[5..]; 578 | continue; 579 | } 580 | if frame_opcode == 117 { 581 | // setitems 582 | if stack.is_empty() { 583 | return Err(UnpicklingError::UnpicklingError( 584 | "Stack is empty while handling SETITEMS".to_string(), 585 | )); 586 | } 587 | let mut ok = false; 588 | let mut keyvalues: BTreeMap = BTreeMap::new(); 589 | while !stack.is_empty() { 590 | let value = stack.pop().unwrap(); 591 | if let Value::Mark(_mark) = value { 592 | ok = true; 593 | break; 594 | } 595 | if stack.is_empty() { 596 | return Err(UnpicklingError::UnpicklingError( 597 | "Stack is empty while handling SETITEMS".to_string(), 598 | )); 599 | } 600 | let key = stack.pop().unwrap(); 601 | if let Value::Mark(_mark) = key { 602 | return Err(UnpicklingError::UnpicklingError( 603 | "Unexpected mark while handling SETITEMS".to_string(), 604 | )); 605 | } 606 | keyvalues.insert(key, value); 607 | } 608 | if !ok { 609 | return Err(UnpicklingError::UnpicklingError( 610 | "No mark while handling SETITEMS".to_string(), 611 | )); 612 | } 613 | if stack.is_empty() { 614 | return Err(UnpicklingError::UnpicklingError( 615 | "Stack is empty while handling SETITEMS".to_string(), 616 | )); 617 | } 618 | let mut dict = stack.pop().unwrap(); 619 | match dict { 620 | Value::Dict(ref mut dict) => { 621 | for (key, value) in keyvalues { 622 | dict.insert(key, value); 623 | } 624 | } 625 | _ => { 626 | return Err(UnpicklingError::UnpicklingError( 627 | "SETITEMS on non-dict".to_string(), 628 | )); 629 | } 630 | } 631 | stack.push(dict); 632 | bytes = &bytes[1..]; 633 | continue; 634 | } 635 | if frame_opcode == 106 { 636 | // long_binget 637 | if bytes.len() < 5 { 638 | return Err(UnpicklingError::UnpicklingError( 639 | "Unexpected end of data while handling LONG_BINGET".to_string(), 640 | )); 641 | } 642 | let idx = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]); 643 | match memo.get(&{ idx }) { 644 | None => { 645 | return Err(UnpicklingError::UnpicklingError( 646 | "LONG_BINGET index out of range".to_string(), 647 | )); 648 | } 649 | Some(memo_value) => { 650 | stack.push(memo_value.clone()); 651 | } 652 | } 653 | bytes = &bytes[5..]; 654 | continue; 655 | } 656 | if frame_opcode == 46 { 657 | // stop 658 | // bytes = &bytes[1..]; 659 | break; 660 | } 661 | return Err(UnpicklingError::UnpicklingError(format!( 662 | "Unknown opcode: {}", 663 | frame_opcode 664 | ))); 665 | } 666 | 667 | // Stack should have just one item, our final value 668 | if stack.len() != 1 { 669 | return Err(UnpicklingError::UnpicklingError( 670 | "Stack does not have exactly one item after unpickling".to_string(), 671 | )); 672 | } 673 | 674 | Ok(stack.pop().unwrap()) 675 | } 676 | -------------------------------------------------------------------------------- /src/weight_compression.rs: -------------------------------------------------------------------------------- 1 | use crate::tensor::Tensor; 2 | use rand::thread_rng; 3 | 4 | pub fn quantize(tensor: &Tensor) -> Tensor { 5 | /* 6 | * This is a simplistic rounding quantizer. It splits each row in a tensor to 16 buckets and 7 | * takes the average value in said buckets as the quantized weight. 8 | */ 9 | let mut result = Tensor::zeros(tensor.rows(), tensor.cols(), tensor.dtype()); 10 | for row in 0..tensor.rows() { 11 | let mut values: Vec = Vec::with_capacity(tensor.cols() as usize); 12 | if row % 500 == 0 { 13 | println!("{}", row,); 14 | } 15 | values.truncate(0); 16 | let mut mi: f32 = std::f32::MAX; 17 | let mut ma: f32 = std::f32::MIN; 18 | 19 | for col in 0..tensor.cols() { 20 | let val = tensor.get_f32(row, col); 21 | if val < mi { 22 | mi = val; 23 | } 24 | if val > ma { 25 | ma = val; 26 | } 27 | values.push(val); 28 | } 29 | values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); 30 | let mut allowed_values: Vec = Vec::with_capacity(16); 31 | let _rng = thread_rng(); 32 | for i in 0..16 { 33 | let start_idx = i * values.len() / 16; 34 | let end_idx = (i + 1) * values.len() / 16; 35 | 36 | let mut avg = 0.0; 37 | for j in start_idx..end_idx { 38 | avg += values[j]; 39 | } 40 | avg /= (end_idx - start_idx) as f32; 41 | allowed_values.push(avg); 42 | } 43 | allowed_values[0] = mi; 44 | allowed_values[15] = ma; 45 | allowed_values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap()); 46 | 47 | for col in 0..tensor.cols() { 48 | let val = tensor.get_f32(row, col); 49 | let mut best = 0; 50 | let mut best_dist = std::f32::MAX; 51 | for i in 0..16 { 52 | let dist = (val - allowed_values[i] as f32).abs(); 53 | if dist < best_dist { 54 | best = i; 55 | best_dist = dist; 56 | } 57 | } 58 | result.set_f32(row, col, allowed_values[best] as f32); 59 | } 60 | } 61 | result 62 | } 63 | --------------------------------------------------------------------------------