├── .github └── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── .gitignore ├── .gitmodules ├── CITATION.cff ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── docs ├── ChatGPT Image Apr 27, 2025, 02_45_40 PM.png ├── GPULlama3_ROADMAP.md ├── TORNADOVM_TRANSFORMER_OPTIMIZATIONS.md ├── inter-output.gif ├── intruct-output.gif ├── java-tornado-gpu.jpg └── performance.png ├── llama-tornado ├── pom.xml ├── set_paths ├── set_paths.cmd └── src └── main └── java └── com └── example ├── LlamaApp.java ├── aot └── AOT.java ├── auxiliary ├── ChatFormat.java ├── Parallel.java ├── Timer.java └── Tuple2.java ├── core ├── model │ ├── GGMLType.java │ ├── GGUF.java │ └── tensor │ │ ├── ArrayFloatTensor.java │ │ ├── F16FloatTensor.java │ │ ├── FloatTensor.java │ │ ├── GGMLTensorEntry.java │ │ ├── Q4_0FloatTensor.java │ │ └── Q8_0FloatTensor.java └── types │ ├── Float16.java │ ├── MetadataValueType.java │ └── Pair.java ├── inference ├── CategoricalSampler.java ├── Sampler.java ├── ToppSampler.java ├── engine │ └── impl │ │ ├── Configuration.java │ │ ├── Llama.java │ │ └── Options.java └── operation │ └── RoPE.java ├── loader └── weights │ ├── ModelLoader.java │ ├── State.java │ └── Weights.java ├── tokenizer ├── impl │ └── Tokenizer.java └── vocabulary │ └── Vocabulary.java └── tornadovm ├── FloatArrayUtils.java ├── TornadoVMLayerPlanner.java ├── TornadoVMMasterPlan.java ├── TransformerComputeKernels.java └── TransformerComputeKernelsLayered.java /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Maven 2 | target/ 3 | pom.xml.tag 4 | pom.xml.releaseBackup 5 | pom.xml.versionsBackup 6 | pom.xml.next 7 | release.properties 8 | dependency-reduced-pom.xml 9 | buildNumber.properties 10 | .mvn/timing.properties 11 | .mvn/wrapper/maven-wrapper.jar 12 | 13 | # model files 14 | *.gguf 15 | *.ggml 16 | *.bin 17 | 18 | # IDE - IntelliJ IDEA 19 | .idea/ 20 | *.iws 21 | *.iml 22 | *.ipr 23 | 24 | # IDE - Eclipse 25 | .settings/ 26 | .classpath 27 | .project 28 | .factorypath 29 | bin/ 30 | .metadata/ 31 | .settings/ 32 | .loadpath 33 | 34 | # IDE - VS Code 35 | .vscode/ 36 | .code-workspace 37 | 38 | # Java 39 | *.class 40 | *.log 41 | *.jar 42 | *.war 43 | *.nar 44 | *.ear 45 | *.zip 46 | *.tar.gz 47 | *.rar 48 | hs_err_pid* 49 | replay_pid* 50 | 51 | # OS specific 52 | .DS_Store 53 | Thumbs.db 54 | Desktop.ini 55 | 56 | # Misc 57 | *.swp 58 | *.bak 59 | *.tmp 60 | *.temp 61 | log/ 62 | logs/target/ 63 | .idea/ 64 | *.iml 65 | .settings/ 66 | .project 67 | .classpath 68 | .vscode/ 69 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/tornadovm"] 2 | path = external/tornadovm 3 | url = https://github.com/beehive-lab/TornadoVM.git 4 | branch = master 5 | 6 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Papadimitriou" 5 | given-names: "Michail" 6 | - family-names: "Xekalaki" 7 | given-names: "Mary" 8 | - family-names: "Fumero" 9 | given-names: "Juan" 10 | - family-names: "Stratikopolos" 11 | given-names: "Athanasios" 12 | - family-names: "Papadakis" 13 | given-names: "Orion" 14 | - family-names: "Kotselidis" 15 | given-names: "Christos" 16 | title: "GPULlama3.java" 17 | license: MIT License 18 | version: 0.1.0-beta 19 | date-released: "2025-05-30" 20 | url: "https://github.com/beehive-lab/GPULlama3.java" 21 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute 2 | 3 | We welcome contributions! 4 | Please follow the instructions below for your Pull Requests (PRs). 5 | 6 | ## How to submit your changes 7 | 8 | 1. **Fork** the repository in GitHub. 9 | 2. **Clone** the project. 10 | 3. **Create a new branch from the `develop` branch**: 11 | ```bash 12 | $ git checkout -b fix/my/branch 13 | ``` 14 | 4. **Commit your changes**: 15 | ```bash 16 | $ git add 17 | $ git commit -a -m "My new feature/fix" 18 | ``` 19 | 5. **Push** your work to your repository: 20 | ```bash 21 | $ git push -u myRepo feat/my/branch 22 | ``` 23 | 6. Create a **Pull Request** (PR) to the `develop` branch. 24 | 7. When you open PR, there are a few GitHub actions. One of them is the checker for the **Contributor License Agreement**, [CLA](https://cla-assistant.io/beehive-lab/llama3.java-tornadovm), if you haven't signed before, you will be prompted with the link to sign the CLA. Use the same email as you commit email. 25 | 26 | Please, ensure that your changes are merged with the latest changes in the `develop` branch, and the code follows the code conventions (see below). 27 | 28 | ### What's next? 29 | 30 | We check the PR and test it internally. Be aware we are a very small team. Thus, depending on the PR, it might take some time for us to review it since we check the PR with our regression tests and benchmarks for all backends (OCL/SPIR-V/PTX platforms) as well as with different drivers and Operating Systems. 31 | 32 | ## Code of Conduct 33 | 34 | For the PR process as well as any issues and discussions we follow this [CODE_OF_CONDUCT](https://github.com/beehive-lab/llama3.java-tornadovm/blob/master/CODE_OF_CONDUCT.md). 35 | 36 | 37 | ## How is the review process? 38 | 39 | 1. We have a few GitHub actions, such as code formatter, documentation rendering and checks for the CLA (Contributor License Agreement). 40 | 2. As mentioned earlier, if you haven't signed the CLA yet, you will be redirected to the llama3.java-tornadovm CLA webpage, where you can read and review it. 41 | If you agree with the terms, then you will sign it. 42 | 3. After that, the llama3.java-tornadovm team can process your PR to be able to merge it into the llama3.java-tornadovm's codebase. 43 | 4. At least two researchers/engineers from the llama3.java-tornadovm team will review your PR. 44 | **Expect a few comments, questions and possible changes.** 45 | This is a totally normal process, and it tries not to introduce untested code for specific devices, better documentation, etc. 46 | We are proud to say that llama3.java-tornadovm is 10+ years of active development, with many different researchers and developers. Thus, we prioritize code maintainability and reproducibility, and we will work together to guarantee this as much as possible. 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Beehive lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Simple Makefile for Maven build without tests 2 | .PHONY: build clean package help 3 | 4 | # Default target 5 | all: package 6 | 7 | # Build the project (clean and package without tests) 8 | build: clean package 9 | 10 | # Clean the project 11 | clean: 12 | mvn clean 13 | 14 | # Package the project without running tests 15 | package: 16 | mvn package -DskipTests 17 | 18 | 19 | # Combined clean and package 20 | package-with-clean: 21 | mvn clean package -DskipTests 22 | 23 | # Display help 24 | help: 25 | @echo "Available targets:" 26 | @echo " all - Same as 'package' (default)" 27 | @echo " build - Clean and package (without tests)" 28 | @echo " clean - Clean the project" 29 | @echo " package - Package without running tests" 30 | @echo " package-with-clean - Clean and package in one command" 31 | @echo " help - Show this help message" 32 | -------------------------------------------------------------------------------- /docs/ChatGPT Image Apr 27, 2025, 02_45_40 PM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beehive-lab/GPULlama3.java/90a719fe59ef621d026e12a262b409f802a4417c/docs/ChatGPT Image Apr 27, 2025, 02_45_40 PM.png -------------------------------------------------------------------------------- /docs/GPULlama3_ROADMAP.md: -------------------------------------------------------------------------------- 1 | ### 🚧 Work-in-progress Features 2 | 3 | - [ ] **LangChain4j integration** 4 | - [ ] **Additional quantization formats** 5 | - [ ] Q8 6 | - [ ] Q4 7 | - [ ] INT8 native support for GPUs 8 | - [ ] **Additional architectures and model format** 9 | - [ ] Mistral/Mixtral models 10 | - [ ] Qwen 11 | - [ ] Gemma/Gemma2 models 12 | - [ ] TinyLlamas 13 | - [ ] SafeTensors format 14 | - [ ] PyTorch checkpoint loading 15 | - [ ] Automatic model conversion utilities 16 | - [ ] **Advanced inference capabilities** 17 | - [ ] Batch inference support 18 | - [ ] Speculative decoding 19 | - [ ] **Performance optimizations** 20 | - [ ] Multi-GPU support 21 | - [X] Memory-efficient attention mechanisms 22 | - [ ] More Kernel fusion improvements 23 | - [ ] **GraalVM Native Image** 24 | -------------------------------------------------------------------------------- /docs/TORNADOVM_TRANSFORMER_OPTIMIZATIONS.md: -------------------------------------------------------------------------------- 1 | ## TornadoVM Transformer Optimizations 2 | 3 | ### Core Numerical Optimizations 4 | - **Quantized Weight Support** 5 | - Optimized implementations for FP16 format 6 | - [*Experimental*] support for Q8 and Q4 with dequantize to FP16 7 | 8 | ### Memory and Caching Optimizations 9 | - **Key-Value Cache** 10 | - Efficiently stores past key-values for autoregressive generation 11 | - Organized by layer, position, and dimension for fast access 12 | - **Scale Caching** 13 | - Avoids redundant decompression of quantized weights 14 | - Caches scale factors for efficient block processing 15 | - **Optimized GPU Memory Transfers** 16 | - Minimizes host-device data movement 17 | - One-time transfer of static data (weights, caches) 18 | - Per-execution transfer of dynamic data (position, activations) 19 | - **Device-to-Device Data Consumption** 20 | - Efficient data transfer between operations 21 | - Reduces PCI-E bandwidth bottlenecks 22 | 23 | ### Algorithmic Optimizations 24 | - **Parallel Reduction RMS Normalization** 25 | - Implements two-phase reduction for efficient normalization 26 | - Work group optimization for parallel sums 27 | - **Rotary Position Embeddings (RoPE)** 28 | - Optimized implementation for positional encoding 29 | - Efficient rotation of query and key vectors 30 | - **Optimized Float16 Decoding** 31 | - Fast decoder for half-precision floating point format 32 | - Special case handling for better performance 33 | - **Parallelized Attention** 34 | - Computes attention heads in parallel 35 | - Optimized softmax with max subtraction for numerical stability 36 | - **Fused Feed-Forward Networks** 37 | - Combines operations for SwiGLU variant used in Llama models 38 | - Optimized SiLU and GELU activation functions 39 | 40 | ### GPU Execution Optimizations 41 | - **Layered Execution Planning** 42 | - Organizes computation as separate layer-based task graphs 43 | - Strategic scheduling of operations 44 | - **Work Group Optimization** 45 | - Tailored worker grid configurations for different operations 46 | - Matches GPU hardware characteristics 47 | - **Local Memory Optimization** 48 | - Strategic use of local/shared memory for reductions 49 | - Optimizes bandwidth-intensive operations 50 | -------------------------------------------------------------------------------- /docs/inter-output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beehive-lab/GPULlama3.java/90a719fe59ef621d026e12a262b409f802a4417c/docs/inter-output.gif -------------------------------------------------------------------------------- /docs/intruct-output.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beehive-lab/GPULlama3.java/90a719fe59ef621d026e12a262b409f802a4417c/docs/intruct-output.gif -------------------------------------------------------------------------------- /docs/java-tornado-gpu.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beehive-lab/GPULlama3.java/90a719fe59ef621d026e12a262b409f802a4417c/docs/java-tornado-gpu.jpg -------------------------------------------------------------------------------- /docs/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beehive-lab/GPULlama3.java/90a719fe59ef621d026e12a262b409f802a4417c/docs/performance.png -------------------------------------------------------------------------------- /llama-tornado: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | llama-tornado: GPU-accelerated LLaMA.java runner with TornadoVM 4 | Run LLaMA models using either OpenCL or PTX backends. 5 | """ 6 | 7 | import argparse 8 | import os 9 | import subprocess 10 | import sys 11 | import time 12 | import platform 13 | from pathlib import Path 14 | from typing import List, Optional, Dict, Any 15 | from enum import Enum 16 | 17 | class Backend(Enum): 18 | OPENCL = "opencl" 19 | PTX = "ptx" 20 | 21 | class LlamaRunner: 22 | """Main class for managing LLaMA model execution with GPU acceleration.""" 23 | 24 | def __init__(self): 25 | self.java_home = os.environ.get('JAVA_HOME') 26 | self.tornado_sdk = os.environ.get('TORNADO_SDK') 27 | self.llama_root = os.environ.get('LLAMA_ROOT') 28 | 29 | if not all([self.java_home, self.tornado_sdk, self.llama_root]): 30 | print("Error: Required environment variables not set") 31 | print("Please ensure JAVA_HOME, TORNADO_SDK, and LLAMA_ROOT are defined") 32 | print("Note: check set_path in root dir -> source set_path") 33 | sys.exit(1) 34 | 35 | def _validate_paths(self): 36 | """Validate that required paths exist.""" 37 | paths_to_check = { 38 | 'JAVA_HOME': self.java_home, 39 | 'TORNADO_SDK': self.tornado_sdk, 40 | 'LLAMA_ROOT': self.llama_root 41 | } 42 | 43 | for name, path in paths_to_check.items(): 44 | if not Path(path).exists(): 45 | print(f"Error: {name} path does not exist: {path}") 46 | sys.exit(1) 47 | 48 | @staticmethod 49 | def module_path_colon_sep(paths: List[str]) -> str: 50 | """Return OS-specific separator for Java module paths.""" 51 | return ";".join(paths) if platform.system() == "Windows" else ":".join(paths) 52 | 53 | def _build_base_command(self, args: argparse.Namespace) -> List[str]: 54 | """Build the base Java command with JVM options.""" 55 | cmd = [ 56 | f"{self.java_home}/bin/java", 57 | "-server", 58 | "-XX:+UnlockExperimentalVMOptions", 59 | "-XX:+EnableJVMCI", 60 | f"-Xms{args.heap_min}", 61 | f"-Xmx{args.heap_max}", 62 | "--enable-preview", 63 | f"-Djava.library.path={self.tornado_sdk}/lib", 64 | "-Djdk.module.showModuleResolution=false", 65 | "--module-path", self.module_path_colon_sep([".", f"{self.tornado_sdk}/share/java/tornado"]), 66 | ] 67 | 68 | # TornadoVM configuration 69 | tornado_config = [ 70 | "-Dtornado.load.api.implementation=uk.ac.manchester.tornado.runtime.tasks.TornadoTaskGraph", 71 | "-Dtornado.load.runtime.implementation=uk.ac.manchester.tornado.runtime.TornadoCoreRuntime", 72 | "-Dtornado.load.tornado.implementation=uk.ac.manchester.tornado.runtime.common.Tornado", 73 | "-Dtornado.load.annotation.implementation=uk.ac.manchester.tornado.annotation.ASMClassVisitor", 74 | "-Dtornado.load.annotation.parallel=uk.ac.manchester.tornado.api.annotations.Parallel", 75 | ] 76 | cmd.extend(tornado_config) 77 | 78 | # GPU options 79 | if args.use_gpu: 80 | cmd.append("-Duse.tornadovm=true") 81 | 82 | if args.verbose_init: 83 | cmd.append("-Dllama.EnableTimingForTornadoVMInit=true") 84 | 85 | # Debug options 86 | debug_config = [] 87 | 88 | if args.debug: 89 | debug_config.extend([ 90 | "-Dtornado.debug=true", 91 | "-Dtornado.threadInfo=True" if args.threads else "-Dtornado.threadInfo=false" 92 | ]) 93 | else: 94 | debug_config.extend([ 95 | "-Dtornado.threadInfo=True" if args.threads else "-Dtornado.threadInfo=false", 96 | "-Dtornado.debug=false" 97 | ]) 98 | 99 | # Additional debug options 100 | debug_config.extend([ 101 | "-Dtornado.fullDebug=True" if args.full_dump else "-Dtornado.fullDebug=false", 102 | "-Dtornado.printKernel=True" if args.print_kernel else "-Dtornado.printKernel=false", 103 | "-Dtornado.print.bytecodes=True" if args.print_bytecodes else "-Dtornado.print.bytecodes=false" 104 | ]) 105 | 106 | cmd.extend(debug_config) 107 | 108 | # Additional TornadoVM settings 109 | tornado_runtime_config = [ 110 | f"-Dtornado.device.memory={args.gpu_memory}", 111 | f"-Dtornado.profiler={str(args.profiler).lower()}", 112 | "-Dtornado.log.profiler=false", 113 | f"-Dtornado.profiler.dump.dir={args.profiler_dump_dir}", 114 | "-Dtornado.enable.fastMathOptimizations=true", 115 | "-Dtornado.enable.mathOptimizations=false", 116 | "-Dtornado.enable.nativeFunctions=true", 117 | "-Dtornado.loop.interchange=true", 118 | f"-Dtornado.eventpool.maxwaitevents={args.max_wait_events}" 119 | ] 120 | cmd.extend(tornado_runtime_config) 121 | 122 | # Backend-specific configuration 123 | if args.backend == Backend.OPENCL: 124 | # OpenCL specific flags 125 | cmd.append(f"-Dtornado.opencl.compiler.flags={args.opencl_flags}") 126 | 127 | # Module configuration - varies by backend 128 | module_config = [ 129 | f"--upgrade-module-path", f"{self.tornado_sdk}/share/java/graalJars", 130 | f"@{self.tornado_sdk}/etc/exportLists/common-exports", 131 | ] 132 | # Add backend-specific exports and modules 133 | if args.backend == Backend.OPENCL: 134 | module_config.extend([ 135 | f"@{self.tornado_sdk}/etc/exportLists/opencl-exports", 136 | "--add-modules", "ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.opencl", 137 | ]) 138 | elif args.backend == Backend.PTX: 139 | module_config.extend([ 140 | f"@{self.tornado_sdk}/etc/exportLists/ptx-exports", 141 | "--add-modules", "ALL-SYSTEM,jdk.incubator.vector,tornado.runtime,tornado.annotation,tornado.drivers.common,tornado.drivers.ptx", 142 | ]) 143 | 144 | module_config.extend([ 145 | "-cp", f"{self.llama_root}/target/gpu-llama3-1.0-SNAPSHOT.jar", 146 | "com.example.LlamaApp" 147 | ]) 148 | cmd.extend(module_config) 149 | 150 | return cmd 151 | 152 | def _add_llama_args(self, cmd: List[str], args: argparse.Namespace) -> List[str]: 153 | """Add LLaMA-specific arguments to the command.""" 154 | llama_args = [ 155 | "-m", args.model_path, 156 | "--temperature", str(args.temperature), 157 | "--top-p", str(args.top_p), 158 | "--seed", str(args.seed), 159 | "--max-tokens", str(args.max_tokens), 160 | "--stream", str(args.stream).lower(), 161 | "--echo", str(args.echo).lower() 162 | ] 163 | 164 | if args.prompt: 165 | llama_args.extend(["-p", args.prompt]) 166 | 167 | if args.system_prompt: 168 | llama_args.extend(["-sp", args.system_prompt]) 169 | 170 | if args.interactive: 171 | llama_args.append("--interactive") 172 | elif args.instruct: 173 | llama_args.append("--instruct") 174 | 175 | return cmd + llama_args 176 | 177 | def run(self, args: argparse.Namespace) -> int: 178 | """Execute the LLaMA model with the specified arguments.""" 179 | self._validate_paths() 180 | 181 | # Build the complete command 182 | cmd = self._build_base_command(args) 183 | cmd = self._add_llama_args(cmd, args) 184 | 185 | # Print command if requested (before verbose output) 186 | if args.show_command: 187 | print("Full Java command:") 188 | print("-" * 80) 189 | 190 | # Create a properly formatted command for easy copy-paste 191 | escaped_cmd = [] 192 | for arg in cmd: 193 | # Escape arguments that contain spaces or special characters 194 | if ' ' in arg or '"' in arg or "'" in arg: 195 | escaped_cmd.append(f'"{arg}"') 196 | else: 197 | escaped_cmd.append(arg) 198 | 199 | # Print as a continuous line that can be easily copied 200 | print(' '.join(escaped_cmd)) 201 | print("-" * 80) 202 | print() 203 | 204 | # If user only wants to see the command without executing 205 | if not args.execute_after_show: 206 | print("Command built successfully. Exiting without execution.") 207 | print("Use --execute-after-show to run the command after displaying it.") 208 | return 0 209 | 210 | if args.verbose: 211 | print("Executing command:") 212 | for arg in cmd: 213 | print(f" {arg}") 214 | print() 215 | 216 | # Execute the command 217 | try: 218 | result = subprocess.run(cmd, check=True) 219 | return result.returncode 220 | except subprocess.CalledProcessError as e: 221 | print(f"Error: Command failed with return code {e.returncode}") 222 | return e.returncode 223 | except KeyboardInterrupt: 224 | print("\nOperation cancelled by user") 225 | return 130 226 | except Exception as e: 227 | print(f"Error: {e}") 228 | return 1 229 | 230 | def load_env_from_script(): 231 | system = platform.system() 232 | 233 | if system == "Windows": 234 | # Call set_paths.cmd and capture output as environment 235 | result = subprocess.run( 236 | ["cmd.exe", "/c", "set_paths.cmd && set"], 237 | capture_output=True, text=True, shell=False 238 | ) 239 | if result.returncode != 0: 240 | print("Failed to run set_paths.cmd") 241 | sys.exit(1) 242 | 243 | # Parse environment variables from output 244 | for line in result.stdout.splitlines(): 245 | if '=' in line: 246 | key, value = line.strip().split('=', 1) 247 | os.environ[key] = value 248 | 249 | elif system in ("Linux", "Darwin"): 250 | # Source the set_paths file and capture env 251 | command = ['bash', '-c', 'source ./set_paths && env'] 252 | result = subprocess.run(command, capture_output=True, text=True) 253 | if result.returncode != 0: 254 | print("Failed to source set_paths") 255 | sys.exit(1) 256 | 257 | for line in result.stdout.splitlines(): 258 | if '=' in line: 259 | key, value = line.strip().split('=', 1) 260 | os.environ[key] = value 261 | else: 262 | print(f"Unsupported OS: {system}") 263 | sys.exit(1) 264 | 265 | def create_parser() -> argparse.ArgumentParser: 266 | """Create and configure the argument parser.""" 267 | parser = argparse.ArgumentParser( 268 | prog="llama-tornado", 269 | description="GPU-accelerated LLaMA.java model runner using TornadoVM", 270 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 271 | ) 272 | 273 | # Required arguments 274 | parser.add_argument("--model", dest="model_path", required=True, 275 | help="Path to the LLaMA model file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)") 276 | 277 | # LLaMA arguments 278 | llama_group = parser.add_argument_group("LLaMA Configuration") 279 | llama_group.add_argument("--prompt", help="Input prompt for the model") 280 | llama_group.add_argument("-sp", "--system-prompt", help="System prompt for the model") 281 | llama_group.add_argument("--temperature", type=float, default=0.1, 282 | help="Sampling temperature (0.0 to 2.0)") 283 | llama_group.add_argument("--top-p", type=float, default=0.95, 284 | help="Top-p sampling parameter") 285 | llama_group.add_argument("--seed", type=int, default=None, 286 | help="Random seed (default: current timestamp)") 287 | llama_group.add_argument("-n", "--max-tokens", type=int, default=512, 288 | help="Maximum number of tokens to generate") 289 | llama_group.add_argument("--stream", type=bool, default=True, 290 | help="Enable streaming output") 291 | llama_group.add_argument("--echo", type=bool, default=False, 292 | help="Echo the input prompt") 293 | 294 | # Mode selection 295 | mode_group = parser.add_argument_group("Mode Selection") 296 | mode_group.add_argument("-i", "--interactive", action="store_true", 297 | help="Run in interactive/chat mode") 298 | mode_group.add_argument("--instruct", action="store_true", default=True, 299 | help="Run in instruction mode (default)") 300 | 301 | # Hardware configuration 302 | hw_group = parser.add_argument_group("Hardware Configuration") 303 | hw_group.add_argument("--gpu", dest="use_gpu", action="store_true", 304 | help="Enable GPU acceleration") 305 | hw_group.add_argument("--opencl", dest="backend", action="store_const", const=Backend.OPENCL, 306 | help="Use OpenCL backend (default)") 307 | hw_group.add_argument("--ptx", dest="backend", action="store_const", const=Backend.PTX, 308 | help="Use PTX/CUDA backend") 309 | hw_group.add_argument("--gpu-memory", default="7GB", 310 | help="GPU memory allocation") 311 | hw_group.add_argument("--heap-min", default="20g", 312 | help="Minimum JVM heap size") 313 | hw_group.add_argument("--heap-max", default="20g", 314 | help="Maximum JVM heap size") 315 | 316 | # Debug and profiling 317 | debug_group = parser.add_argument_group("Debug and Profiling") 318 | debug_group.add_argument("--debug", action="store_true", 319 | help="Enable debug output") 320 | debug_group.add_argument("--profiler", action="store_true", 321 | help="Enable TornadoVM profiler") 322 | debug_group.add_argument("--profiler-dump-dir", 323 | default="/home/mikepapadim/repos/gpu-llama3.java/prof.json", 324 | help="Directory for profiler output") 325 | 326 | # TornadoVM Execution Verbose options 327 | verbose_group = parser.add_argument_group("TornadoVM Execution Verbose") 328 | verbose_group.add_argument("--print-bytecodes", dest="print_bytecodes", action="store_true", 329 | help="Print bytecodes (tornado.print.bytecodes=true)") 330 | verbose_group.add_argument("--print-threads", dest="threads", action="store_true", 331 | help="Print thread information (tornado.threadInfo=true)") 332 | verbose_group.add_argument("--print-kernel", dest="print_kernel", action="store_true", 333 | help="Print kernel information (tornado.printKernel=true)") 334 | verbose_group.add_argument("--full-dump", dest="full_dump", action="store_true", 335 | help="Enable full debug dump (tornado.fullDebug=true)") 336 | verbose_group.add_argument("--verbose-init", dest="verbose_init", action="store_true", 337 | help="Enable timers for TornadoVM initialization (llama.EnableTimingForTornadoVMInit=true)") 338 | 339 | 340 | # Command display options 341 | command_group = parser.add_argument_group("Command Display Options") 342 | command_group.add_argument("--show-command", action="store_true", 343 | help="Display the full Java command that will be executed") 344 | command_group.add_argument("--execute-after-show", action="store_true", 345 | help="Execute the command after showing it (use with --show-command)") 346 | 347 | # Advanced options 348 | advanced_group = parser.add_argument_group("Advanced Options") 349 | advanced_group.add_argument("--opencl-flags", 350 | default="-cl-denorms-are-zero -cl-no-signed-zeros -cl-finite-math-only", 351 | help="OpenCL compiler flags") 352 | advanced_group.add_argument("--max-wait-events", type=int, default=32000, 353 | help="Maximum wait events for TornadoVM event pool") 354 | advanced_group.add_argument("--verbose", "-v", action="store_true", 355 | help="Verbose output") 356 | 357 | return parser 358 | 359 | def main(): 360 | """Main entry point.""" 361 | load_env_from_script() 362 | parser = create_parser() 363 | args = parser.parse_args() 364 | 365 | # Set default seed if not provided 366 | if args.seed is None: 367 | args.seed = int(time.time()) 368 | 369 | # Set default backend to OpenCL if not specified 370 | if not hasattr(args, 'backend') or args.backend is None: 371 | args.backend = Backend.OPENCL 372 | 373 | # Handle mode selection logic 374 | if args.interactive: 375 | args.instruct = False 376 | 377 | # Create and run the LLaMA runner 378 | runner = LlamaRunner() 379 | return runner.run(args) 380 | 381 | if __name__ == "__main__": 382 | sys.exit(main()) 383 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | com.example 8 | gpu-llama3 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 21 13 | 21 14 | UTF-8 15 | 16 | 17 | 18 | 19 | junit 20 | junit 21 | 4.13.2 22 | test 23 | 24 | 25 | tornado 26 | tornado-api 27 | 1.1.1-dev 28 | 29 | 30 | 31 | tornado 32 | tornado-runtime 33 | 1.1.1-dev 34 | 35 | 36 | 37 | 38 | 39 | 40 | org.apache.maven.plugins 41 | maven-compiler-plugin 42 | 3.11.0 43 | 44 | 45 | --enable-preview 46 | --add-modules 47 | jdk.incubator.vector 48 | 49 | 50 | 51 | 52 | org.apache.maven.plugins 53 | maven-shade-plugin 54 | 3.5.0 55 | 56 | 57 | package 58 | 59 | shade 60 | 61 | 62 | 63 | 64 | com.example.LlamaApp 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /set_paths: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #!/bin/bash 3 | 4 | # Environment setup script for running LLaMA3 with TornadoVM GPU acceleration 5 | # This script configures all necessary environment variables for development and runtime 6 | 7 | # Resolve root of this project (LLaMA3) and TornadoVM 8 | export LLAMA_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 9 | export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" 10 | 11 | # Set the path to TornadoVM SDK binaries 12 | export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" 13 | 14 | # Add TornadoVM and LLaMA bin directories to PATH 15 | export PATH="${PATH}:${TORNADO_SDK}:${LLAMA_ROOT}/bin" 16 | 17 | # Optional: Set JAVA_HOME if required 18 | # export JAVA_HOME=/path/to/graalvm 19 | # export PATH="${JAVA_HOME}/bin:${PATH}" 20 | 21 | echo "[INFO] Environment configured for LLaMA3 with TornadoVM at: $TORNADO_ROOT" 22 | # ===== Notes ===== 23 | # After sourcing this script: 24 | # 1. TornadoVM will be available for GPU computation 25 | # 2. LLaMA3 command-line tools will be in your PATH 26 | # 3. You can run LLaMA3 with GPU acceleration using TornadoVM 27 | # 28 | # To use this script: source ./setup_environment.sh 29 | # or: . ./setup_environment.sh 30 | -------------------------------------------------------------------------------- /set_paths.cmd: -------------------------------------------------------------------------------- 1 | @echo off 2 | REM ============================================ 3 | REM Environment setup script for LLaMA3 + TornadoVM (Windows) 4 | REM ============================================ 5 | 6 | REM Resolve the absolute path to this script's directory 7 | set "LLAMA_ROOT=%~dp0" 8 | set "LLAMA_ROOT=%LLAMA_ROOT:~0,-1%" 9 | 10 | REM Set TornadoVM root and SDK paths 11 | set "TORNADO_ROOT=%LLAMA_ROOT%\external\tornadovm" 12 | set "TORNADO_SDK=%TORNADO_ROOT%\bin\sdk" 13 | 14 | REM Add TornadoVM SDK and LLaMA3 bin to PATH 15 | set "PATH=%TORNADO_SDK%;%LLAMA_ROOT%\bin;%PATH%" 16 | 17 | REM Optional: Set JAVA_HOME if needed 18 | REM set "JAVA_HOME=C:\Path\To\GraalVM" 19 | REM set "PATH=%JAVA_HOME%\bin;%PATH%" 20 | 21 | echo [INFO] Environment configured for LLaMA3 with TornadoVM at: %TORNADO_ROOT% 22 | 23 | REM ===== Notes ===== 24 | REM After running this script: 25 | REM 1. TornadoVM will be available for GPU computation 26 | REM 2. LLaMA3 command-line tools will be in your PATH 27 | REM 3. You can run LLaMA3 with GPU acceleration using TornadoVM 28 | REM 29 | REM To use this script: call set_paths.cmd 30 | -------------------------------------------------------------------------------- /src/main/java/com/example/LlamaApp.java: -------------------------------------------------------------------------------- 1 | package com.example; 2 | 3 | import com.example.aot.AOT; 4 | import com.example.auxiliary.ChatFormat; 5 | import com.example.core.model.tensor.FloatTensor; 6 | import com.example.inference.CategoricalSampler; 7 | import com.example.inference.Sampler; 8 | import com.example.inference.ToppSampler; 9 | import com.example.inference.engine.impl.Llama; 10 | import com.example.inference.engine.impl.Options; 11 | import com.example.loader.weights.ModelLoader; 12 | import com.example.loader.weights.State; 13 | import com.example.tornadovm.FloatArrayUtils; 14 | import com.example.tornadovm.TornadoVMMasterPlan; 15 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; 16 | 17 | import java.io.IOException; 18 | import java.util.ArrayList; 19 | import java.util.List; 20 | import java.util.Scanner; 21 | import java.util.Set; 22 | import java.util.function.IntConsumer; 23 | import java.util.random.RandomGenerator; 24 | import java.util.random.RandomGeneratorFactory; 25 | 26 | public class LlamaApp { 27 | // Configuration flags for hardware acceleration and optimizations 28 | public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration 29 | public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation 30 | public static final boolean USE_TORNADOVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration 31 | public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "true")); // Show performance metrics in interactive mode 32 | 33 | /** 34 | * Creates and configures a sampler for token generation based on specified parameters. 35 | * 36 | *

This method selects an appropriate sampling strategy for next-token prediction 37 | * in language model inference. It supports several sampling approaches:

38 | * 39 | * 44 | * 45 | *

The method handles both {@link FloatTensor} and {@link FloatArray} logits types 46 | * to support both CPU and GPU execution paths.

47 | * 48 | * @param vocabularySize The size of the model's vocabulary 49 | * @param temperature A value controlling randomness in sampling: 50 | * 56 | * @param topp The cumulative probability threshold for nucleus sampling (0.0-1.0). 57 | * 61 | * @param rngSeed Seed value for the random number generator to ensure reproducibility 62 | * 63 | * @return A configured {@link Sampler} that implements the selected sampling strategy 64 | * and handles both tensor and array-based logits 65 | * 66 | * @throws IllegalArgumentException if logits are of an unsupported type 67 | */ 68 | static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) { 69 | Sampler sampler; 70 | if (temperature == 0.0f) { 71 | // greedy argmax sampling: take the token with the highest probability 72 | sampler = Sampler.TENSOR_ARGMAX; // Use TENSOR_ARGMAX instead of ARGMAX 73 | } else { 74 | // we sample from this distribution to get the next token 75 | RandomGenerator rng = RandomGeneratorFactory.getDefault().create(rngSeed); 76 | Sampler innerSampler; 77 | // Determine whether to use top-p (nucleus) sampling 78 | if (topp <= 0 || topp >= 1) { 79 | // If topp is outside (0,1), use standard categorical sampling 80 | // This samples directly from the probability distribution 81 | innerSampler = new CategoricalSampler(rng); 82 | } else { 83 | // Use top-p (nucleus) sampling with the specified threshold 84 | // This restricts sampling to only the most likely tokens that 85 | // cumulatively comprise the top p probability mass 86 | innerSampler = new ToppSampler(vocabularySize, topp, rng); 87 | } 88 | 89 | // Create a sampler that: 90 | // 1. Applies temperature scaling to the logits 91 | // 2. Converts logits to probabilities using softmax 92 | // 3. Delegates the actual sampling to the appropriate inner sampler 93 | sampler = logits -> { 94 | // Handle different logits formats to support both CPU and GPU paths 95 | if (logits instanceof FloatTensor) { 96 | // For CPU path using FloatTensor 97 | FloatTensor tensorLogits = (FloatTensor) logits; 98 | // Apply temperature scaling - lower values make distribution more peaked 99 | tensorLogits.divideInPlace(0, tensorLogits.size(), temperature); 100 | // Convert logits to probabilities using softmax 101 | tensorLogits.softmaxInPlace(0, tensorLogits.size()); 102 | } else if (logits instanceof FloatArray) { 103 | // For GPU path using FloatArray 104 | FloatArray arrayLogits = (FloatArray) logits; 105 | // Apply the same operations but using FloatArray-specific methods for TornadoVM data types 106 | FloatArrayUtils.divideInPlace(arrayLogits, 0, arrayLogits.getSize(), temperature); 107 | FloatArrayUtils.softmaxInPlace(arrayLogits, 0, arrayLogits.getSize()); 108 | } else { 109 | // If logits are neither FloatTensor nor FloatArray, throw an exception 110 | throw new IllegalArgumentException("Unsupported logits type: " + (logits != null ? logits.getClass().getName() : "null")); 111 | } 112 | return innerSampler.sampleToken(logits); 113 | }; 114 | } 115 | return sampler; 116 | } 117 | 118 | static void runInteractive(Llama model, Sampler sampler, Options options) { 119 | State state = null; 120 | List conversationTokens = new ArrayList<>(); 121 | ChatFormat chatFormat = new ChatFormat(model.tokenizer()); 122 | conversationTokens.add(chatFormat.beginOfText); 123 | if (options.systemPrompt() != null) { 124 | conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); 125 | } 126 | int startPosition = 0; 127 | Scanner in = new Scanner(System.in); 128 | 129 | // Initialize TornadoVM plan once at the beginning if GPU path is enabled 130 | TornadoVMMasterPlan tornadoVMPlan = null; 131 | 132 | try { 133 | while (true) { 134 | System.out.print("> "); 135 | System.out.flush(); 136 | String userText = in.nextLine(); 137 | if (List.of("quit", "exit").contains(userText)) { 138 | break; 139 | } 140 | if (state == null) { 141 | state = model.createNewState(); 142 | } 143 | 144 | if (USE_TORNADOVM && tornadoVMPlan == null) { 145 | tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); 146 | } 147 | 148 | conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); 149 | conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); 150 | Set stopTokens = chatFormat.getStopTokens(); 151 | 152 | List responseTokens; 153 | IntConsumer tokenConsumer = token -> { 154 | if (options.stream()) { 155 | if (!model.tokenizer().isSpecialToken(token)) { 156 | System.out.print(model.tokenizer().decode(List.of(token))); 157 | } 158 | } 159 | }; 160 | 161 | // Choose between GPU and CPU path based on configuration 162 | if (USE_TORNADOVM) { 163 | // GPU path using TornadoVM 164 | responseTokens = Llama.generateTokensGPU(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), 165 | sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); 166 | } else { 167 | // CPU path 168 | responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, 169 | options.echo(), tokenConsumer); 170 | } 171 | 172 | // Include stop token in the prompt history, but not in the response displayed to the user. 173 | conversationTokens.addAll(responseTokens); 174 | startPosition = conversationTokens.size(); 175 | Integer stopToken = null; 176 | if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { 177 | stopToken = responseTokens.getLast(); 178 | responseTokens.removeLast(); 179 | } 180 | if (!options.stream()) { 181 | String responseText = model.tokenizer().decode(responseTokens); 182 | System.out.println(responseText); 183 | } 184 | if (stopToken == null) { 185 | System.err.println("\n Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX"); 186 | break; 187 | } 188 | System.out.print("\n"); 189 | 190 | // Optionally print performance metrics after each response 191 | if (SHOW_PERF_INTERACTIVE) { 192 | Llama.LastRunMetrics.printMetrics(); 193 | } 194 | } 195 | } finally { 196 | // Clean up TornadoVM resources when exiting the chat loop 197 | if (USE_TORNADOVM && tornadoVMPlan != null) { 198 | try { 199 | tornadoVMPlan.freeTornadoExecutionPlan(); 200 | } catch (Exception e) { 201 | System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage()); 202 | } 203 | } 204 | } 205 | } 206 | 207 | static void runInstructOnce(Llama model, Sampler sampler, Options options) { 208 | State state = model.createNewState(); 209 | ChatFormat chatFormat = new ChatFormat(model.tokenizer()); 210 | TornadoVMMasterPlan tornadoVMPlan = null; 211 | 212 | List promptTokens = new ArrayList<>(); 213 | promptTokens.add(chatFormat.beginOfText); 214 | if (options.systemPrompt() != null) { 215 | promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); 216 | } 217 | promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt()))); 218 | promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); 219 | List responseTokens; 220 | 221 | // Define the token consumer 222 | IntConsumer tokenConsumer = token -> { 223 | if (options.stream()) { 224 | if (!model.tokenizer().isSpecialToken(token)) { 225 | System.out.print(model.tokenizer().decode(List.of(token))); 226 | } 227 | } 228 | }; 229 | 230 | Set stopTokens = chatFormat.getStopTokens(); 231 | if (USE_TORNADOVM) { 232 | tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model); 233 | // Call generateTokensGPU without the token consumer parameter 234 | responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); 235 | } else { 236 | // CPU path still uses the token consumer 237 | responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); 238 | } 239 | 240 | if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { 241 | responseTokens.removeLast(); 242 | } 243 | if (!options.stream()) { 244 | String responseText = model.tokenizer().decode(responseTokens); 245 | System.out.println(responseText); 246 | } 247 | 248 | Llama.LastRunMetrics.printMetrics(); 249 | 250 | if (tornadoVMPlan != null) { 251 | tornadoVMPlan.freeTornadoExecutionPlan(); 252 | } 253 | } 254 | 255 | public static void main(String[] args) throws IOException { 256 | Options options = Options.parseOptions(args); 257 | Llama model; 258 | if (USE_AOT) { 259 | model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); 260 | } else { 261 | model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true); 262 | } 263 | Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), options.seed()); 264 | if (options.interactive()) { 265 | runInteractive(model, sampler, options); 266 | } else { 267 | runInstructOnce(model, sampler, options); 268 | } 269 | } 270 | } 271 | 272 | 273 | 274 | -------------------------------------------------------------------------------- /src/main/java/com/example/aot/AOT.java: -------------------------------------------------------------------------------- 1 | package com.example.aot; 2 | 3 | import com.example.auxiliary.Timer; 4 | import com.example.core.model.GGUF; 5 | import com.example.core.model.tensor.GGMLTensorEntry; 6 | import com.example.inference.engine.impl.Llama; 7 | import com.example.inference.engine.impl.Options; 8 | import com.example.loader.weights.ModelLoader; 9 | import com.example.loader.weights.Weights; 10 | 11 | import java.io.IOException; 12 | import java.nio.channels.FileChannel; 13 | import java.nio.file.Files; 14 | import java.nio.file.Path; 15 | import java.nio.file.StandardOpenOption; 16 | import java.util.Map; 17 | import java.util.Objects; 18 | 19 | /** 20 | * Support for AOT preloading of GGUF metadata with GraalVM's Native Image. 21 | * 22 | *

23 | * To preload a model at build time, pass {@code -Dllama.PreloadGGUF=/path/to/model.gguf} 24 | * to the native-image builder command. At runtime, the preloaded model will be used 25 | * iff the specified and preloaded file names (base name) match. 26 | */ 27 | public final class AOT { 28 | AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF; 29 | 30 | 31 | record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map tensorInfos) {} 32 | 33 | private static final PartialModel PRELOADED_GGUF = preLoadGGUF(System.getProperty("llama.PreloadGGUF")); 34 | 35 | private static PartialModel preLoadGGUF(String modelPath) { 36 | if (modelPath == null || modelPath.isEmpty()) { 37 | return null; 38 | } 39 | try { 40 | Path path = Path.of(modelPath); 41 | if (!Files.exists(path) || !Files.isRegularFile(path)) { 42 | throw new IllegalArgumentException("Cannot pre-load model: " + path); 43 | } 44 | GGUF gguf = GGUF.loadModel(path); 45 | try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) { 46 | return new PartialModel( 47 | path.getFileName().toString(), 48 | ModelLoader.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false), 49 | gguf.getTensorDataOffset(), 50 | gguf.getTensorInfos() 51 | ); 52 | } 53 | } catch (IOException e) { 54 | throw new RuntimeException(e); 55 | } 56 | } 57 | 58 | /** 59 | * Tries to reuse a compatible AOT preloaded model. 60 | * The file name (base name) must match with the preloaded file name. 61 | * No checksum/hash is checked for performance reasons. 62 | */ 63 | public static com.example.inference.engine.impl.Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException { 64 | AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF; 65 | if (preLoaded == null) { 66 | return null; // no pre-loaded model stored 67 | } 68 | String optionsModel = modelPath.getFileName().toString(); 69 | String preLoadedModel = preLoaded.modelFileName(); 70 | if (!Objects.equals(optionsModel, preLoadedModel)) { 71 | // Preloaded and specified model file names didn't match. 72 | return null; 73 | } 74 | Llama baseModel = preLoaded.model(); 75 | try (var timer = Timer.log("Load tensors from pre-loaded model"); 76 | var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) { 77 | // Load only the tensors (mmap slices). 78 | Map tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos()); 79 | Weights weights = ModelLoader.loadWeights(tensorEntries, baseModel.configuration()); 80 | return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights); 81 | } 82 | } 83 | } 84 | 85 | -------------------------------------------------------------------------------- /src/main/java/com/example/auxiliary/ChatFormat.java: -------------------------------------------------------------------------------- 1 | package com.example.auxiliary; 2 | 3 | import com.example.tokenizer.impl.Tokenizer; 4 | 5 | import java.util.ArrayList; 6 | import java.util.List; 7 | import java.util.Map; 8 | import java.util.Set; 9 | 10 | public class ChatFormat { 11 | 12 | final Tokenizer tokenizer; 13 | public final int beginOfText; 14 | final int endHeader; 15 | final int startHeader; 16 | final int endOfTurn; 17 | final int endOfText; 18 | final int endOfMessage; 19 | final Set stopTokens; 20 | 21 | public ChatFormat(Tokenizer tokenizer) { 22 | this.tokenizer = tokenizer; 23 | Map specialTokens = this.tokenizer.getSpecialTokens(); 24 | this.beginOfText = specialTokens.get("<|begin_of_text|>"); 25 | this.startHeader = specialTokens.get("<|start_header_id|>"); 26 | this.endHeader = specialTokens.get("<|end_header_id|>"); 27 | this.endOfTurn = specialTokens.get("<|eot_id|>"); 28 | this.endOfText = specialTokens.get("<|end_of_text|>"); 29 | this.endOfMessage = specialTokens.getOrDefault("<|eom_id|>", -1); // only in 3.1 30 | this.stopTokens = Set.of(endOfText, endOfTurn); 31 | } 32 | 33 | public Tokenizer getTokenizer() { 34 | return tokenizer; 35 | } 36 | 37 | public Set getStopTokens() { 38 | return stopTokens; 39 | } 40 | 41 | public List encodeHeader(ChatFormat.Message message) { 42 | List tokens = new ArrayList<>(); 43 | tokens.add(startHeader); 44 | tokens.addAll(this.tokenizer.encodeAsList(message.role().name())); 45 | tokens.add(endHeader); 46 | tokens.addAll(this.tokenizer.encodeAsList("\n")); 47 | return tokens; 48 | } 49 | 50 | public List encodeMessage(ChatFormat.Message message) { 51 | List tokens = this.encodeHeader(message); 52 | tokens.addAll(this.tokenizer.encodeAsList(message.content().strip())); 53 | tokens.add(endOfTurn); 54 | return tokens; 55 | } 56 | 57 | public List encodeDialogPrompt(boolean appendAssistantTurn, List dialog) { 58 | List tokens = new ArrayList<>(); 59 | tokens.add(beginOfText); 60 | for (ChatFormat.Message message : dialog) { 61 | tokens.addAll(this.encodeMessage(message)); 62 | } 63 | if (appendAssistantTurn) { 64 | // Add the start of an assistant message for the model to complete. 65 | tokens.addAll(this.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); 66 | } 67 | return tokens; 68 | } 69 | 70 | public record Message(ChatFormat.Role role, String content) { 71 | } 72 | 73 | public record Role(String name) { 74 | public static ChatFormat.Role SYSTEM = new ChatFormat.Role("system"); 75 | public static ChatFormat.Role USER = new ChatFormat.Role("user"); 76 | public static ChatFormat.Role ASSISTANT = new ChatFormat.Role("assistant"); 77 | 78 | @Override 79 | public String toString() { 80 | return name; 81 | } 82 | } 83 | } -------------------------------------------------------------------------------- /src/main/java/com/example/auxiliary/Parallel.java: -------------------------------------------------------------------------------- 1 | package com.example.auxiliary; 2 | 3 | import java.util.function.IntConsumer; 4 | import java.util.function.LongConsumer; 5 | import java.util.stream.IntStream; 6 | import java.util.stream.LongStream; 7 | 8 | public final class Parallel { 9 | public static void parallelFor(int startInclusive, int endExclusive, IntConsumer action) { 10 | IntStream.range(startInclusive, endExclusive).parallel().forEach(action); 11 | } 12 | 13 | 14 | public static void parallelForLong(long startInclusive, long endExclusive, LongConsumer action) { 15 | if (startInclusive == 0 && endExclusive == 1) { 16 | action.accept(0); 17 | return; 18 | } 19 | LongStream.range(startInclusive, endExclusive).parallel().forEach(action); 20 | } 21 | } -------------------------------------------------------------------------------- /src/main/java/com/example/auxiliary/Timer.java: -------------------------------------------------------------------------------- 1 | package com.example.auxiliary; 2 | 3 | import java.util.concurrent.TimeUnit; 4 | 5 | public interface Timer extends AutoCloseable { 6 | @Override 7 | void close(); // no Exception 8 | 9 | static Timer log(String label) { 10 | return log(label, TimeUnit.MILLISECONDS); 11 | } 12 | 13 | static Timer log(String label, TimeUnit timeUnit) { 14 | return new Timer() { 15 | final long startNanos = System.nanoTime(); 16 | 17 | @Override 18 | public void close() { 19 | long elapsedNanos = System.nanoTime() - startNanos; 20 | System.err.println(label + ": " 21 | + timeUnit.convert(elapsedNanos, TimeUnit.NANOSECONDS) + " " 22 | + timeUnit.toChronoUnit().name().toLowerCase()); 23 | } 24 | }; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/com/example/auxiliary/Tuple2.java: -------------------------------------------------------------------------------- 1 | package com.example.auxiliary; 2 | 3 | public class Tuple2 { 4 | private final T first; 5 | private final U second; 6 | 7 | public Tuple2(T first, U second) { 8 | this.first = first; 9 | this.second = second; 10 | } 11 | 12 | public T getFirst() { 13 | return first; 14 | } 15 | 16 | public U getSecond() { 17 | return second; 18 | } 19 | 20 | @Override 21 | public String toString() { 22 | return "Tuple2{" + 23 | "first=" + first + 24 | ", second=" + second + 25 | '}'; 26 | } 27 | 28 | @Override 29 | public boolean equals(Object o) { 30 | if (this == o) return true; 31 | if (o == null || getClass() != o.getClass()) return false; 32 | 33 | Tuple2 tuple2 = (Tuple2) o; 34 | 35 | if (!first.equals(tuple2.first)) return false; 36 | return second.equals(tuple2.second); 37 | } 38 | 39 | @Override 40 | public int hashCode() { 41 | int result = first.hashCode(); 42 | result = 31 * result + second.hashCode(); 43 | return result; 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/com/example/core/model/GGMLType.java: -------------------------------------------------------------------------------- 1 | package com.example.core.model; 2 | 3 | import com.example.core.types.Float16; 4 | 5 | public enum GGMLType { 6 | // Floating point types 7 | F32(Float.BYTES), 8 | F16(GGMLType.FLOAT16_BYTES), 9 | Q4_0(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32), 10 | Q4_1(2 * GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32), 11 | UNSUPPORTED_Q4_2(Integer.MAX_VALUE), // support has been removed 12 | UNSUPPORTED_Q4_3(Integer.MAX_VALUE), // support has been removed 13 | Q5_0(Integer.MAX_VALUE), 14 | Q5_1(Integer.MAX_VALUE), 15 | Q8_0(GGMLType.FLOAT16_BYTES + 32 * Byte.BYTES, 32), 16 | Q8_1(32 * Byte.BYTES + 2 * Float.BYTES, 32), 17 | // k-quantizations 18 | Q2_K(Integer.MAX_VALUE), 19 | Q3_K(Integer.MAX_VALUE), 20 | Q4_K(2 * GGMLType.FLOAT16_BYTES + ((GGMLType.QK_K / 16) / 8 * 6) + GGMLType.QK_K / 2, GGMLType.QK_K), 21 | Q5_K(2 * GGMLType.FLOAT16_BYTES + ((GGMLType.QK_K / 16) / 8 * 6) + GGMLType.QK_K / 8 + GGMLType.QK_K / 2, GGMLType.QK_K), 22 | Q6_K(GGMLType.QK_K / 2 + GGMLType.QK_K / 4 + GGMLType.QK_K / 16 + GGMLType.FLOAT16_BYTES, GGMLType.QK_K), 23 | Q8_K(Integer.MAX_VALUE), 24 | 25 | IQ2_XXS(Integer.MAX_VALUE), 26 | IQ2_XS(Integer.MAX_VALUE), 27 | IQ3_XXS(Integer.MAX_VALUE), 28 | IQ1_S(Integer.MAX_VALUE), 29 | IQ4_NL(Integer.MAX_VALUE), 30 | IQ3_S(Integer.MAX_VALUE), 31 | IQ2_S(Integer.MAX_VALUE), 32 | IQ4_XS(Integer.MAX_VALUE), 33 | 34 | I8(Byte.BYTES), 35 | I16(Short.BYTES), 36 | I32(Integer.BYTES), 37 | I64(Long.BYTES), 38 | F64(Double.BYTES), 39 | IQ1_M(Integer.MAX_VALUE), 40 | BF16(GGMLType.BFLOAT16_BYTES), 41 | Q4_0_4_4(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32), 42 | Q4_0_4_8(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32), 43 | Q4_0_8_8(GGMLType.FLOAT16_BYTES + 16 * Byte.BYTES, 32), 44 | TQ1_0(Integer.MAX_VALUE), 45 | TQ2_0(Integer.MAX_VALUE); 46 | 47 | public static final int BFLOAT16_BYTES = 2; 48 | public static final int FLOAT16_BYTES = 2; 49 | 50 | private static final GGMLType[] VALUES = values(); 51 | 52 | private final int typeSize; 53 | 54 | private final int blockSize; 55 | 56 | public int getTypeSize() { 57 | return typeSize; 58 | } 59 | 60 | public int getBlockSize() { 61 | return blockSize; 62 | } 63 | 64 | public static GGMLType fromId(int id) { 65 | return VALUES[id]; 66 | } 67 | 68 | GGMLType(int typeSize) { 69 | this(typeSize, 1); 70 | } 71 | 72 | public long byteSizeFor(int numberOfElements) { 73 | long t = numberOfElements * (long) getTypeSize(); 74 | assert t % getBlockSize() == 0; 75 | return Math.toIntExact(t / getBlockSize()); 76 | } 77 | 78 | public static final int QK_K = 256; // or 64? 79 | 80 | GGMLType(int typeSize, int blockSize) { 81 | assert blockSize > 0; 82 | assert typeSize > 0; 83 | assert isPowerOf2(blockSize); 84 | this.typeSize = typeSize; 85 | this.blockSize = blockSize; 86 | } 87 | 88 | private static boolean isPowerOf2(int n) { 89 | return n > 0 && (n & (n - 1)) == 0; 90 | } 91 | } -------------------------------------------------------------------------------- /src/main/java/com/example/core/model/GGUF.java: -------------------------------------------------------------------------------- 1 | package com.example.core.model; 2 | 3 | import com.example.auxiliary.Timer; 4 | import com.example.core.model.tensor.FloatTensor; 5 | import com.example.core.model.tensor.GGMLTensorEntry; 6 | import com.example.core.types.MetadataValueType; 7 | import com.example.core.types.Pair; 8 | 9 | import java.io.IOException; 10 | import java.lang.foreign.Arena; 11 | import java.lang.foreign.MemorySegment; 12 | import java.nio.ByteBuffer; 13 | import java.nio.ByteOrder; 14 | import java.nio.channels.FileChannel; 15 | import java.nio.charset.StandardCharsets; 16 | import java.nio.file.Path; 17 | import java.util.HashMap; 18 | import java.util.List; 19 | import java.util.Map; 20 | 21 | public final class GGUF { 22 | private static final int GGUF_MAGIC = 0x46554747; 23 | private static final int DEFAULT_ALIGNMENT = 32; // must be a power of 2 24 | private static final List SUPPORTED_GGUF_VERSIONS = List.of(2, 3); 25 | private final ByteBuffer BB_1 = ByteBuffer.allocate(Byte.BYTES).order(ByteOrder.LITTLE_ENDIAN); 26 | private final ByteBuffer BB_2 = ByteBuffer.allocate(Short.BYTES).order(ByteOrder.LITTLE_ENDIAN); 27 | private final ByteBuffer BB_4 = ByteBuffer.allocate(Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN); 28 | private final ByteBuffer BB_8 = ByteBuffer.allocate(Long.BYTES).order(ByteOrder.LITTLE_ENDIAN); 29 | private int magic; 30 | private int version; 31 | private int tensorCount; // uint64_t 32 | private int alignment; 33 | private int metadata_kv_count; // uint64_t 34 | private Map metadata; 35 | private Map tensorInfos; 36 | private long tensorDataOffset; 37 | 38 | public static GGUF loadModel(Path modelPath) throws IOException { 39 | try (FileChannel fileChannel = FileChannel.open(modelPath); var ignored = Timer.log("Parse " + modelPath)) { 40 | GGUF gguf = new GGUF(); 41 | gguf.loadModelImpl(fileChannel); 42 | return gguf; 43 | } 44 | } 45 | 46 | public static Map loadTensors(FileChannel fileChannel, long tensorDataOffset, Map tensorInfos) throws IOException { 47 | Arena arena = Arena.ofAuto(); 48 | MemorySegment tensorData = fileChannel.map(FileChannel.MapMode.READ_ONLY, tensorDataOffset, fileChannel.size() - tensorDataOffset, arena); 49 | Map tensorEntries = HashMap.newHashMap(tensorInfos.size()); 50 | for (Map.Entry entry : tensorInfos.entrySet()) { 51 | GGUFTensorInfo ti = entry.getValue(); 52 | int numberOfElements = FloatTensor.numberOfElements(ti.dimensions()); 53 | int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements)); 54 | MemorySegment memorySegment = tensorData.asSlice(ti.offset(), sizeInBytes); 55 | tensorEntries.put(ti.name(), new GGMLTensorEntry(tensorData, ti.name(), ti.ggmlType(), ti.dimensions(), memorySegment)); 56 | } 57 | return tensorEntries; 58 | } 59 | 60 | public Map getTensorInfos() { 61 | return tensorInfos; 62 | } 63 | 64 | public long getTensorDataOffset() { 65 | return tensorDataOffset; 66 | } 67 | 68 | public Map getMetadata() { 69 | return metadata; 70 | } 71 | 72 | private void loadModelImpl(FileChannel fileChannel) throws IOException { 73 | // The header of the file. 74 | readHeader(fileChannel); // gguf_header_t header; 75 | // Tensor infos, which can be used to locate the tensor data. 76 | // gguf_tensor_info_t tensor_infos[header.tensor_count]; 77 | this.tensorInfos = HashMap.newHashMap(tensorCount); 78 | for (int i = 0; i < tensorCount; ++i) { 79 | GGUF.GGUFTensorInfo ti = readTensorInfo(fileChannel); 80 | assert !tensorInfos.containsKey(ti.name); 81 | tensorInfos.put(ti.name, ti); 82 | } 83 | // Padding to the nearest multiple of `ALIGNMENT`. 84 | // uint8_t _padding[ALIGNMENT - (sizeof(header + tensor_infos) % ALIGNMENT)]; 85 | //long _padding = -fileChannel.position() & (ALIGNMENT - 1); 86 | long _padding = getAlignment() - (fileChannel.position() % getAlignment()); 87 | fileChannel.position(fileChannel.position() + _padding); 88 | // Tensor data. 89 | // 90 | // This is arbitrary binary data corresponding to the weights of the model. This data should be close 91 | // or identical to the data in the original model file, but may be different due to quantization or 92 | // other optimizations for inference. Any such deviations should be recorded in the metadata or as 93 | // part of the architecture definition. 94 | // 95 | // Each tensor's data must be stored within this array, and located through its `tensor_infos` entry. 96 | // The offset of each tensor's data must be a multiple of `ALIGNMENT`, and the space between tensors 97 | // should be padded to `ALIGNMENT` bytes. 98 | // uint8_t tensor_data[]; 99 | this.tensorDataOffset = fileChannel.position(); 100 | } 101 | 102 | private GGMLType readGGMLType(FileChannel fileChannel) throws IOException { 103 | int ggmlTypeId = readInt(fileChannel); // ggml_type type; 104 | return GGMLType.fromId(ggmlTypeId); 105 | } 106 | 107 | private GGUF.GGUFTensorInfo readTensorInfo(FileChannel fileChannel) throws IOException { 108 | // The name of the tensor. It is a standard GGUF string, with the caveat that 109 | // it must be at most 64 bytes long. 110 | String name = readString(fileChannel); // gguf_string_t name; 111 | assert name.length() <= 64; 112 | // The number of dimensions in the tensor. 113 | // Currently at most 4, but this may change in the future. 114 | int n_dimensions = readInt(fileChannel); // uint32_t n_dimensions; 115 | assert n_dimensions <= 4; 116 | // The dimensions of the tensor. 117 | int[] dimensions = new int[n_dimensions]; // uint64_t dimensions[n_dimensions]; 118 | for (int i = 0; i < n_dimensions; ++i) { 119 | dimensions[i] = Math.toIntExact(readLong(fileChannel)); 120 | } 121 | // The type of the tensor. 122 | GGMLType ggmlType = readGGMLType(fileChannel); // ggml_type type; 123 | // The offset of the tensor's data in this file in bytes. 124 | // This offset is relative to `tensor_data`, not to the start 125 | // of the file, to make it easier for writers to write the file. 126 | // Readers should consider exposing this offset relative to the 127 | // file to make it easier to read the data. 128 | // Must be a multiple of `ALIGNMENT`. 129 | long offset = readLong(fileChannel); // uint64_t offset; 130 | assert offset % getAlignment() == 0; 131 | return new GGUF.GGUFTensorInfo(name, dimensions, ggmlType, offset); 132 | } 133 | 134 | private String readString(FileChannel fileChannel) throws IOException { 135 | // A string in GGUF. 136 | // The length of the string, in bytes. 137 | int len = Math.toIntExact(readLong(fileChannel)); // uint64_t len; 138 | // The string as a UTF-8 non-null-terminated string. 139 | byte[] bytes = new byte[len]; // char string[len]; 140 | int bytesRead = fileChannel.read(ByteBuffer.wrap(bytes)); 141 | assert len == bytesRead; 142 | return new String(bytes, StandardCharsets.UTF_8); 143 | } 144 | 145 | private Pair readKeyValuePair(FileChannel fileChannel) throws IOException { 146 | // The key of the metadata. It is a standard GGUF string, with the following caveats: 147 | // - It must be a valid ASCII string. 148 | // - It must be a hierarchical key, where each segment is `lower_snake_case` and separated by a `.`. 149 | // - It must be at most 2^16-1/65535 bytes long. 150 | // Any keys that do not follow these rules are invalid. 151 | String key = readString(fileChannel); // gguf_string_t key; 152 | assert key.length() < (1 << 16); 153 | assert key.codePoints().allMatch(cp -> ('a' <= cp && cp <= 'z') || ('0' <= cp && cp <= '9') || cp == '_' || cp == '.'); 154 | Object value = readMetadataValue(fileChannel); 155 | return new Pair<>(key, value); 156 | } 157 | 158 | private Object readMetadataValue(FileChannel fileChannel) throws IOException { 159 | // The type of the value. 160 | // Must be one of the `gguf_metadata_value_type` values. 161 | MetadataValueType value_type = readMetadataValueType(fileChannel); // gguf_metadata_value_type value_type; 162 | // The value. 163 | return readMetadataValueOfType(value_type, fileChannel); // gguf_metadata_value_t value; 164 | } 165 | 166 | void readHeader(FileChannel fileChannel) throws IOException { 167 | // Magic number to announce that this is a GGUF file. 168 | // Must be `GGUF` at the byte level: `0x47` `0x47` `0x55` `0x46`. 169 | // Your executor might do little-endian byte order, so it might be 170 | // check for 0x46554747 and letting the endianness cancel out. 171 | // Consider being *very* explicit about the byte order here. 172 | this.magic = readInt(fileChannel); // uint32_t magic; 173 | if (magic != GGUF_MAGIC) { 174 | throw new IllegalArgumentException("unsupported header.magic " + magic); 175 | } 176 | // The version of the format implemented. 177 | // Must be `3` for version described in this spec. 178 | // 179 | // This version should only be increased for structural changes to the format. 180 | // Changes that do not affect the structure of the file should instead update the metadata 181 | // to signify the change. 182 | this.version = readInt(fileChannel); // uint32_t version; 183 | if (!SUPPORTED_GGUF_VERSIONS.contains(version)) { 184 | throw new IllegalArgumentException("unsupported header.version " + version); 185 | } 186 | // The number of tensors in the file. 187 | // This is explicit, instead of being included in the metadata, to ensure it is always present 188 | // for loading the tensors. 189 | this.tensorCount = Math.toIntExact(readLong(fileChannel)); // uint64_t tensor_count; 190 | // The number of metadata key-value pairs. 191 | this.metadata_kv_count = Math.toIntExact(readLong(fileChannel)); // uint64_t metadata_kv_count; 192 | // The metadata key-value pairs. 193 | // gguf_metadata_kv_t metadata_kv[metadata_kv_count]; 194 | 195 | this.metadata = HashMap.newHashMap(metadata_kv_count); 196 | for (int i = 0; i < metadata_kv_count; ++i) { 197 | Pair keyValue = readKeyValuePair(fileChannel); 198 | assert !metadata.containsKey(keyValue.first()); 199 | metadata.put(keyValue.first(), keyValue.second()); 200 | } 201 | } 202 | 203 | private Object readArray(FileChannel fileChannel) throws IOException { 204 | // Any value type is valid, including arrays. 205 | MetadataValueType value_type = readMetadataValueType(fileChannel); // gguf_metadata_value_type type; 206 | // Number of elements, not bytes 207 | int len = Math.toIntExact(readLong(fileChannel)); // uint64_t len; 208 | // The array of values. 209 | // gguf_metadata_value_t array[len]; 210 | switch (value_type) { 211 | case UINT8, INT8 -> { 212 | byte[] bytes = new byte[len]; 213 | for (int i = 0; i < len; ++i) { 214 | bytes[i] = readByte(fileChannel); 215 | } 216 | return bytes; 217 | } 218 | case UINT16, INT16 -> { 219 | short[] shorts = new short[len]; 220 | for (int i = 0; i < len; ++i) { 221 | shorts[i] = readShort(fileChannel); 222 | } 223 | return shorts; 224 | } 225 | case UINT32, INT32 -> { 226 | int[] ints = new int[len]; 227 | for (int i = 0; i < len; ++i) { 228 | ints[i] = readInt(fileChannel); 229 | } 230 | return ints; 231 | } 232 | case FLOAT32 -> { 233 | float[] floats = new float[len]; 234 | for (int i = 0; i < len; ++i) { 235 | floats[i] = readFloat(fileChannel); 236 | } 237 | return floats; 238 | } 239 | case BOOL -> { 240 | boolean[] booleans = new boolean[len]; 241 | for (int i = 0; i < len; ++i) { 242 | booleans[i] = readBoolean(fileChannel); 243 | } 244 | return booleans; 245 | } 246 | case STRING -> { 247 | String[] strings = new String[len]; 248 | for (int i = 0; i < len; ++i) { 249 | strings[i] = readString(fileChannel); 250 | } 251 | return strings; 252 | } 253 | case ARRAY -> { 254 | Object[] arrays = new Object[len]; 255 | for (int i = 0; i < len; ++i) { 256 | arrays[i] = readArray(fileChannel); 257 | } 258 | return arrays; 259 | } 260 | default -> throw new UnsupportedOperationException("read array of " + value_type); 261 | } 262 | } 263 | 264 | private Object readMetadataValueOfType(MetadataValueType valueType, FileChannel fileChannel) throws IOException { 265 | return switch (valueType) { 266 | case UINT8, INT8 -> readByte(fileChannel); 267 | case UINT16, INT16 -> readShort(fileChannel); 268 | case UINT32, INT32 -> readInt(fileChannel); 269 | case FLOAT32 -> readFloat(fileChannel); 270 | case UINT64, INT64 -> readLong(fileChannel); 271 | case FLOAT64 -> readDouble(fileChannel); 272 | case BOOL -> readBoolean(fileChannel); 273 | case STRING -> readString(fileChannel); 274 | case ARRAY -> readArray(fileChannel); 275 | }; 276 | } 277 | 278 | private byte readByte(FileChannel fileChannel) throws IOException { 279 | int bytesRead = fileChannel.read(BB_1); 280 | assert bytesRead == 1; 281 | return BB_1.clear().get(0); 282 | } 283 | 284 | private boolean readBoolean(FileChannel fileChannel) throws IOException { 285 | return readByte(fileChannel) != 0; 286 | } 287 | 288 | private short readShort(FileChannel fileChannel) throws IOException { 289 | int bytesRead = fileChannel.read(BB_2); 290 | assert bytesRead == 2; 291 | return BB_2.clear().getShort(0); 292 | } 293 | 294 | private int readInt(FileChannel fileChannel) throws IOException { 295 | int bytesRead = fileChannel.read(BB_4); 296 | assert bytesRead == 4; 297 | return BB_4.clear().getInt(0); 298 | } 299 | 300 | private long readLong(FileChannel fileChannel) throws IOException { 301 | int bytesRead = fileChannel.read(BB_8); 302 | assert bytesRead == 8; 303 | return BB_8.clear().getLong(0); 304 | } 305 | 306 | private float readFloat(FileChannel fileChannel) throws IOException { 307 | return Float.intBitsToFloat(readInt(fileChannel)); 308 | } 309 | 310 | private double readDouble(FileChannel fileChannel) throws IOException { 311 | return Double.longBitsToDouble(readLong(fileChannel)); 312 | } 313 | 314 | private MetadataValueType readMetadataValueType(FileChannel fileChannel) throws IOException { 315 | int index = readInt(fileChannel); 316 | return MetadataValueType.fromIndex(index); 317 | } 318 | 319 | public int getAlignment() { 320 | if (alignment != 0) { 321 | return alignment; 322 | } 323 | alignment = (int) metadata.getOrDefault("general.alignment", DEFAULT_ALIGNMENT); 324 | assert Integer.bitCount(alignment) == 1 : "alignment must be a power of two"; 325 | return alignment; 326 | } 327 | 328 | public record GGUFTensorInfo(String name, int[] dimensions, GGMLType ggmlType, long offset) { 329 | } 330 | } -------------------------------------------------------------------------------- /src/main/java/com/example/core/model/tensor/ArrayFloatTensor.java: -------------------------------------------------------------------------------- 1 | package com.example.core.model.tensor; 2 | 3 | import com.example.core.model.GGMLType; 4 | import jdk.incubator.vector.FloatVector; 5 | import jdk.incubator.vector.VectorSpecies; 6 | 7 | import java.lang.foreign.MemorySegment; 8 | import java.util.Arrays; 9 | 10 | import static com.example.LlamaApp.USE_VECTOR_API; 11 | 12 | public final class ArrayFloatTensor extends FloatTensor { 13 | 14 | final float[] values; 15 | 16 | ArrayFloatTensor(float[] values) { 17 | this.values = values; 18 | } 19 | 20 | public static FloatTensor allocate(int... dims) { 21 | int numberOfElements = FloatTensor.numberOfElements(dims); 22 | return new ArrayFloatTensor(new float[numberOfElements]); 23 | } 24 | 25 | @Override 26 | public int size() { 27 | return values.length; 28 | } 29 | 30 | @Override 31 | public float getFloat(int index) { 32 | return values[index]; 33 | } 34 | 35 | @Override 36 | public void setFloat(int index, float value) { 37 | values[index] = value; 38 | } 39 | 40 | @Override 41 | public GGMLType type() { 42 | return GGMLType.F32; 43 | } 44 | 45 | @Override 46 | public MemorySegment asMemorySegment() { 47 | return MemorySegment.ofArray(values); 48 | } 49 | 50 | @Override 51 | public FloatTensor fillInPlace(int thisOffset, int size, float value) { 52 | Arrays.fill(values, thisOffset, thisOffset + size, value); 53 | return this; 54 | } 55 | 56 | @Override 57 | public FloatVector getFloatVector(VectorSpecies species, int index) { 58 | if (!USE_VECTOR_API) { 59 | throw new UnsupportedOperationException(); 60 | } 61 | return FloatVector.fromArray(species, values, index); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/main/java/com/example/core/model/tensor/F16FloatTensor.java: -------------------------------------------------------------------------------- 1 | package com.example.core.model.tensor; 2 | 3 | import com.example.core.model.GGMLType; 4 | import jdk.incubator.vector.FloatVector; 5 | import jdk.incubator.vector.ShortVector; 6 | import jdk.incubator.vector.VectorOperators; 7 | import jdk.incubator.vector.VectorSpecies; 8 | 9 | import java.lang.foreign.MemorySegment; 10 | import java.nio.ByteOrder; 11 | 12 | public final class F16FloatTensor extends FloatTensor { 13 | 14 | final int size; 15 | final MemorySegment memorySegment; 16 | 17 | public F16FloatTensor(int size, MemorySegment memorySegment) { 18 | this.size = size; 19 | this.memorySegment = memorySegment; 20 | } 21 | 22 | @Override 23 | public int size() { 24 | return size; 25 | } 26 | 27 | @Override 28 | public void setFloat(int index, float value) { 29 | throw new UnsupportedOperationException("setFloat"); 30 | } 31 | 32 | @Override 33 | public FloatVector getFloatVector(VectorSpecies species, int index) { 34 | throw new UnsupportedOperationException("getFloatVector"); 35 | } 36 | 37 | @Override 38 | public GGMLType type() { 39 | return GGMLType.F16; 40 | } 41 | 42 | @Override 43 | public MemorySegment asMemorySegment() { 44 | return null; 45 | } 46 | 47 | @Override 48 | public float getFloat(int index) { 49 | assert 0 <= index && index < size; 50 | return Float.float16ToFloat(readShort(memorySegment, index * GGMLType.FLOAT16_BYTES)); 51 | } 52 | 53 | @Override 54 | public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { 55 | if (FloatTensor.USE_VECTOR_API) { 56 | return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size); 57 | } else { 58 | return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); 59 | } 60 | } 61 | 62 | private static float vectorDot(F16FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) { 63 | assert S_SPECIES_HALF.length() == F_SPECIES.length(); 64 | FloatVector val = FloatVector.zero(F_SPECIES); 65 | int upperBound = F_SPECIES.loopBound(size); 66 | for (int i = 0; i < upperBound; i += F_SPECIES.length()) { 67 | FloatVector thatVector = that.getFloatVector(F_SPECIES, thatOffset + i); 68 | ShortVector bits16 = ShortVector.fromMemorySegment(S_SPECIES_HALF, thiz.memorySegment, (thisOffset + i) * (long) GGMLType.FLOAT16_BYTES, ByteOrder.LITTLE_ENDIAN); 69 | 70 | var bits32 = bits16.castShape(I_SPECIES, 0).reinterpretAsInts(); // (int) bits16 71 | // Does not support infinities nor NaNs, preserves sign, emulate DAZ (denormals-are-zero). 72 | // Expects well-formed float16 values only (e.g. model weights). 73 | // Fast Float16 to Float32 Conversion: 74 | // 75 | // ┌─[15]─┬─[14]───···───[10]─┬─[9]────····────[0]─┐ 76 | // │ Sign │ Exponent (5 bits) │ Mantissa (10 bits) │ Float16 Layout (16 bits) 77 | // └──────┴───────────────────┴────────────────────┘ 78 | // │ │ │ 79 | // ▼ ▼ ▼ 80 | // ┌─[31]─┬─[30]───···───[23]─┬─[22]────···────[0]─┐ 81 | // │ Sign │ Exponent (8 bits) │ Mantissa (23 bits) │ Float32 Layout (32 bits) 82 | // └──────┴───────────────────┴────────────────────┘ 83 | // 84 | // Shifts and adjustments: 85 | // - Sign: float16[15] -> float32[31] (shift 16 bits up) 86 | // - Exponent: float16[10-14] -> float32[23-30] (+ bias adjustment) 87 | // - Mantissa: float16[0-9] -> float32[13-22] (shift 13 bits up) 88 | // 89 | // exp = bits32 & 0x7C00 90 | // zeroExponentMask = exp == 0 ? 0 : ~0 91 | var zeroExponentMask = bits32.and(0x7C00).neg().lanewise(VectorOperators.ASHR, 31); // = (-exp) >> 31 92 | bits32 = bits32.and(0x8000).lanewise(VectorOperators.LSHL, 16) // sign 93 | .or( 94 | // exponent and mantissa combined 95 | bits32.and(0x7FFF).add(0x1C000).lanewise(VectorOperators.LSHL, 13) 96 | .and(zeroExponentMask) // -0, +0 and DAZ (denormals-are-zero) 97 | 98 | ); 99 | 100 | FloatVector thizVector = bits32.reinterpretAsFloats(); // Float.intBitsToFloat(vi) 101 | val = thizVector.fma(thatVector, val); 102 | } 103 | float result = val.reduceLanes(VectorOperators.ADD); 104 | // Remaining entries. 105 | if (upperBound < size) { 106 | result += scalarDot(thiz, thisOffset + upperBound, that, thatOffset + upperBound, size - upperBound); 107 | } 108 | 109 | return result; 110 | } 111 | } -------------------------------------------------------------------------------- /src/main/java/com/example/core/model/tensor/FloatTensor.java: -------------------------------------------------------------------------------- 1 | package com.example.core.model.tensor; 2 | 3 | import com.example.auxiliary.Parallel; 4 | import com.example.core.model.GGMLType; 5 | import jdk.incubator.vector.FloatVector; 6 | import jdk.incubator.vector.VectorShape; 7 | import jdk.incubator.vector.VectorSpecies; 8 | import sun.misc.Unsafe; 9 | 10 | import java.lang.foreign.MemorySegment; 11 | import java.lang.foreign.ValueLayout; 12 | import java.lang.reflect.Field; 13 | import java.util.Arrays; 14 | 15 | /** 16 | * Over-simplified, shapeless, float tensor. 17 | *

18 | * Not a strict tensor, but rather just a sequence of floats, not required to be backed by memory 19 | * e.g. can represent a sequence of quantized floats. 20 | */ 21 | public abstract class FloatTensor { 22 | static final int VECTOR_BIT_SIZE = Integer.getInteger("llama.VectorBitSize", VectorShape.preferredShape().vectorBitSize()); 23 | static final boolean USE_VECTOR_API = VECTOR_BIT_SIZE != 0; 24 | 25 | // The use of Unsafe in this file is a temporary workaround to support native-image. 26 | static final Unsafe UNSAFE; 27 | 28 | static { 29 | try { 30 | Field f = Unsafe.class.getDeclaredField("theUnsafe"); 31 | f.setAccessible(true); 32 | UNSAFE = (Unsafe) f.get(null); 33 | } catch (NoSuchFieldException | IllegalAccessException e) { 34 | throw new RuntimeException(e); 35 | } 36 | } 37 | 38 | // Preferred vector size for the fast multiplication routines. 39 | // (Apple Silicon) NEON only supports up-to 128bit vectors. 40 | static final VectorSpecies F_SPECIES; 41 | static final VectorSpecies I_SPECIES; 42 | static final VectorSpecies S_SPECIES_HALF; 43 | 44 | static { 45 | if (USE_VECTOR_API) { 46 | F_SPECIES = VectorShape.forBitSize(VECTOR_BIT_SIZE).withLanes(float.class); 47 | I_SPECIES = F_SPECIES.withLanes(int.class); 48 | S_SPECIES_HALF = VectorShape.forBitSize(F_SPECIES.vectorBitSize() / 2).withLanes(short.class); 49 | assert F_SPECIES.length() == S_SPECIES_HALF.length(); 50 | } else { 51 | F_SPECIES = null; 52 | I_SPECIES = null; 53 | S_SPECIES_HALF = null; 54 | } 55 | } 56 | 57 | public static short readShort(MemorySegment memorySegment, long offset) { 58 | // The MemorySegment.get* methods should be used instead. 59 | return UNSAFE.getShort(memorySegment.address() + offset); 60 | } 61 | 62 | public static byte readByte(MemorySegment memorySegment, long offset) { 63 | // The MemorySegment.get* methods should be used instead. 64 | return UNSAFE.getByte(memorySegment.address() + offset); 65 | } 66 | 67 | // Preferred vector size for the fast multiplication routines. 68 | // (Apple Silicon) NEON only supports up-to 128bit vectors. 69 | 70 | public abstract int size(); 71 | 72 | public abstract float getFloat(int index); 73 | 74 | public abstract void setFloat(int index, float value); 75 | 76 | protected abstract FloatVector getFloatVector(VectorSpecies species, int offset); 77 | 78 | protected abstract GGMLType type(); 79 | 80 | public abstract MemorySegment asMemorySegment(); 81 | 82 | public static int numberOfElements(int... dimensions) { 83 | assert Arrays.stream(dimensions).allMatch(i -> i > 0); 84 | return Arrays.stream(dimensions).reduce(Math::multiplyExact).orElseThrow(); 85 | } 86 | 87 | static float scalarDot(FloatTensor thiz, int thisOffset, FloatTensor that, int thatOffset, int size) { 88 | float result = 0f; 89 | for (int j = 0; j < size; j++) { 90 | result += thiz.getFloat(thisOffset + j) * that.getFloat(thatOffset + j); 91 | } 92 | return result; 93 | } 94 | 95 | public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { 96 | return scalarDot(this, thisOffset, that, thatOffset, size); 97 | } 98 | 99 | public void matmul(FloatTensor that, FloatTensor out, int dim0, int dim1) { 100 | Parallel.parallelFor(0, dim0, i -> out.setFloat(i, dot(i * dim1, that, 0, dim1))); 101 | } 102 | 103 | public void matmul(int context, FloatTensor[] that, FloatTensor[] out, int dim0, int dim1) { 104 | if (that.length != out.length) { 105 | throw new IllegalArgumentException(String.format("that.len=%d, out.len=%d", that.length, out.length)); 106 | } 107 | Parallel.parallelForLong(0, dim0 * context, ti -> { 108 | int idxArr = (int) (ti / dim0); 109 | int i = (int) (ti % dim0); 110 | out[idxArr].setFloat(i, dot(i * dim1, that[idxArr], 0, dim1)); 111 | }); 112 | } 113 | 114 | @FunctionalInterface 115 | public interface AggregateFunction { 116 | float apply(float acc, float value); 117 | } 118 | 119 | public float reduce(int thisOffset, int size, float seed, AggregateFunction reduce) { 120 | float result = seed; 121 | for (int i = 0; i < size; ++i) { 122 | result = reduce.apply(result, getFloat(thisOffset + i)); 123 | } 124 | return result; 125 | } 126 | 127 | float sum(int thisOffset, int size) { 128 | return reduce(thisOffset, size, 0f, Float::sum); 129 | } 130 | 131 | float max(int thisOffset, int size) { 132 | return reduce(thisOffset, size, Float.NEGATIVE_INFINITY, Float::max); 133 | } 134 | 135 | public void copyTo(int thisOffset, FloatTensor that, int thatOffset, int size) { 136 | that.mapWithIndexInPlace(thatOffset, size, (value, index) -> this.getFloat(index - thatOffset + thisOffset)); 137 | } 138 | 139 | int argmax(int thisOffset, int size) { 140 | assert size > 0; 141 | int maxIndex = thisOffset; 142 | float maxValue = this.getFloat(maxIndex); 143 | int endIndex = thisOffset + size; 144 | for (int i = thisOffset; i < endIndex; ++i) { 145 | float f = this.getFloat(i); 146 | if (f > maxValue) { 147 | maxValue = f; 148 | maxIndex = i; 149 | } 150 | } 151 | return maxIndex; 152 | } 153 | 154 | public int argmax() { 155 | return argmax(0, size()); 156 | } 157 | 158 | @FunctionalInterface 159 | public interface MapFunction { 160 | float apply(float value); 161 | } 162 | 163 | @FunctionalInterface 164 | public interface MapWithIndexFunction { 165 | float apply(float value, int index); 166 | } 167 | 168 | FloatTensor mapInPlace(int thisOffset, int size, MapFunction mapFunction) { 169 | int endIndex = thisOffset + size; 170 | for (int i = thisOffset; i < endIndex; ++i) { 171 | setFloat(i, mapFunction.apply(getFloat(i))); 172 | } 173 | return this; 174 | } 175 | 176 | public FloatTensor mapInPlace(MapFunction mapFunction) { 177 | return mapInPlace(0, size(), mapFunction); 178 | } 179 | 180 | public FloatTensor mapWithIndexInPlace(int thisOffset, int size, FloatTensor.MapWithIndexFunction mapWithIndexFunction) { 181 | int endOffset = thisOffset + size; 182 | for (int i = thisOffset; i < endOffset; ++i) { 183 | setFloat(i, mapWithIndexFunction.apply(getFloat(i), i)); 184 | } 185 | return this; 186 | } 187 | 188 | FloatTensor addInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) { 189 | return mapWithIndexInPlace(thisOffset, size, (value, index) -> value + that.getFloat(index - thisOffset + thatOffset)); 190 | } 191 | 192 | public FloatTensor addInPlace(FloatTensor that) { 193 | return addInPlace(0, that, 0, size()); 194 | } 195 | 196 | FloatTensor multiplyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size) { 197 | return mapWithIndexInPlace(thisOffset, size, (value, index) -> value * that.getFloat(index - thisOffset + thatOffset)); 198 | } 199 | 200 | public FloatTensor multiplyInPlace(FloatTensor that) { 201 | return multiplyInPlace(0, that, 0, size()); 202 | } 203 | 204 | public FloatTensor divideInPlace(int thisOffset, int size, float value) { 205 | return mapInPlace(thisOffset, size, f -> f / value); 206 | } 207 | 208 | public FloatTensor fillInPlace(int thisOffset, int size, float value) { 209 | return mapInPlace(thisOffset, size, unused -> value); 210 | } 211 | 212 | public FloatTensor softmaxInPlace(int thisOffset, int size) { 213 | // find max value (for numerical stability) 214 | float maxVal = max(thisOffset, size); 215 | // exp and sum 216 | mapInPlace(thisOffset, size, f -> (float) Math.exp(f - maxVal)); 217 | float sum = sum(thisOffset, size); 218 | // normalize 219 | return divideInPlace(thisOffset, size, sum); 220 | } 221 | 222 | public FloatTensor saxpyInPlace(int thisOffset, FloatTensor that, int thatOffset, int size, float a) { 223 | // this[thatOffset ... thatOffset + size) = a * that[thatOffset ... thatOffset + size) + this[thisOffset ... thisOffset + size) 224 | for (int i = 0; i < size; ++i) { 225 | setFloat(thisOffset + i, a * that.getFloat(thatOffset + i) + this.getFloat(thisOffset + i)); 226 | } 227 | return this; 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /src/main/java/com/example/core/model/tensor/GGMLTensorEntry.java: -------------------------------------------------------------------------------- 1 | package com.example.core.model.tensor; 2 | 3 | import com.example.core.model.GGMLType; 4 | 5 | import java.lang.foreign.MemorySegment; 6 | 7 | public record GGMLTensorEntry(MemorySegment mappedFile, String name, GGMLType ggmlType, int[] shape, 8 | MemorySegment memorySegment) { 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/com/example/core/model/tensor/Q4_0FloatTensor.java: -------------------------------------------------------------------------------- 1 | package com.example.core.model.tensor; 2 | 3 | import com.example.LlamaApp; 4 | import com.example.core.model.GGMLType; 5 | import com.example.core.types.Float16; 6 | import jdk.incubator.vector.ByteVector; 7 | import jdk.incubator.vector.FloatVector; 8 | import jdk.incubator.vector.VectorOperators; 9 | import jdk.incubator.vector.VectorSpecies; 10 | 11 | import java.lang.foreign.MemorySegment; 12 | import java.nio.ByteOrder; 13 | 14 | /** 15 | * {@link FloatTensor} quantized in the {@link GGMLType#Q4_0} format. 16 | *

17 | * This tensor implementation is not compatible with {@link FloatTensor}, but 18 | * {@link #dot(int, FloatTensor, int, int)} has a vectorized implementation that is used when 19 | * the second argument implements {@link FloatTensor}. 20 | */ 21 | public final class Q4_0FloatTensor extends FloatTensor { 22 | 23 | final int size; 24 | final MemorySegment memorySegment; 25 | 26 | public Q4_0FloatTensor(int size, MemorySegment memorySegment) { 27 | this.size = size; 28 | this.memorySegment = memorySegment; 29 | } 30 | 31 | @Override 32 | public int size() { 33 | return size; 34 | } 35 | 36 | @Override 37 | public void setFloat(int index, float value) { 38 | throw new UnsupportedOperationException("setFloat"); 39 | } 40 | 41 | @Override 42 | protected FloatVector getFloatVector(VectorSpecies species, int index) { 43 | throw new UnsupportedOperationException("getFloatVector"); 44 | } 45 | 46 | @Override 47 | public GGMLType type() { 48 | return GGMLType.Q4_0; 49 | } 50 | 51 | @Override 52 | public MemorySegment asMemorySegment() { 53 | return memorySegment; 54 | } 55 | 56 | @Override 57 | public float getFloat(int index) { 58 | assert 0 <= index && index < size; 59 | int blockIndex = index / GGMLType.Q4_0.getBlockSize(); 60 | int blockOffset = blockIndex * GGMLType.Q4_0.getTypeSize(); 61 | float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset)); 62 | byte quant; 63 | int modIndex = index % GGMLType.Q4_0.getBlockSize(); 64 | if (modIndex < GGMLType.Q4_0.getBlockSize() / 2) { 65 | quant = (byte) (readByte(memorySegment, blockOffset + Float16.BYTES + modIndex) & 0x0F); 66 | } else { 67 | quant = (byte) ((readByte(memorySegment, blockOffset + Float16.BYTES + modIndex - GGMLType.Q4_0.getBlockSize() / 2) >>> 4) & 0x0F); 68 | } 69 | quant -= 8; 70 | return quant * scale; 71 | } 72 | 73 | @Override 74 | public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { 75 | if (LlamaApp.USE_VECTOR_API) { 76 | return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size); 77 | } else { 78 | return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); 79 | } 80 | } 81 | 82 | private static float vectorDot(Q4_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) { 83 | float result = 0f; 84 | int j = 0; 85 | 86 | // Align thisOffset + j to type().getBlockSize(). 87 | assert Integer.bitCount(GGMLType.Q4_0.getBlockSize()) == 1 : "power of 2"; 88 | int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q4_0.getBlockSize() - 1)); 89 | if (alignmentBound > 0) { 90 | result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound); 91 | j += alignmentBound; 92 | } 93 | assert (thisOffset + j) % GGMLType.Q4_0.getBlockSize() == 0; 94 | 95 | FloatVector val = FloatVector.zero(F_SPECIES); 96 | int blockOffset = (thisOffset + j) / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getTypeSize(); 97 | int upperBound = size / GGMLType.Q4_0.getBlockSize() * GGMLType.Q4_0.getBlockSize(); 98 | for (; j < upperBound; j += GGMLType.Q4_0.getBlockSize(), blockOffset += GGMLType.Q4_0.getTypeSize()) { 99 | float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset)); 100 | var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); 101 | var B_SPECIES = ByteVector.SPECIES_128; 102 | var wBytes = ByteVector.fromMemorySegment(B_SPECIES, thiz.memorySegment, blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); 103 | var loBytes = wBytes.and((byte) 0xF).sub((byte) 8); 104 | var hiBytes = wBytes.lanewise(VectorOperators.LSHR, 4).sub((byte) 8); 105 | if (F_SPECIES.vectorBitSize() == 256) { 106 | var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()).mul(loBytes.castShape(F_SPECIES, 0)); 107 | var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()).mul(loBytes.castShape(F_SPECIES, 1)); 108 | var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()).mul(hiBytes.castShape(F_SPECIES, 0)); 109 | var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()).mul(hiBytes.castShape(F_SPECIES, 1)); 110 | val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); 111 | } else if (F_SPECIES.vectorBitSize() == 128) { 112 | // This loop cannot be unrolled, why? 113 | for (int i = 0; i < 2; ++i) { 114 | var tmp = i == 0 ? loBytes : hiBytes; 115 | var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 0) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 0)); 116 | var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 1) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 1)); 117 | var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 2) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 2)); 118 | var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + (i * 4 + 3) * F_SPECIES.length()).mul(tmp.castShape(F_SPECIES, 3)); 119 | val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); 120 | } 121 | } else { 122 | throw new UnsupportedOperationException(F_SPECIES.toString()); 123 | } 124 | } 125 | result += val.reduceLanes(VectorOperators.ADD); 126 | 127 | // Remaining entries. 128 | if (j < size) { 129 | result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j); 130 | } 131 | 132 | return result; 133 | } 134 | } -------------------------------------------------------------------------------- /src/main/java/com/example/core/model/tensor/Q8_0FloatTensor.java: -------------------------------------------------------------------------------- 1 | package com.example.core.model.tensor; 2 | 3 | 4 | import com.example.core.model.GGMLType; 5 | import com.example.core.types.Float16; 6 | import jdk.incubator.vector.ByteVector; 7 | import jdk.incubator.vector.FloatVector; 8 | import jdk.incubator.vector.VectorOperators; 9 | import jdk.incubator.vector.VectorSpecies; 10 | 11 | import java.lang.foreign.MemorySegment; 12 | import java.lang.foreign.ValueLayout; 13 | import java.nio.ByteOrder; 14 | 15 | import static com.example.LlamaApp.USE_VECTOR_API; 16 | 17 | public final class Q8_0FloatTensor extends FloatTensor { 18 | 19 | final int size; 20 | final MemorySegment memorySegment; 21 | 22 | public Q8_0FloatTensor(int size, MemorySegment memorySegment) { 23 | this.size = size; 24 | this.memorySegment = memorySegment; 25 | } 26 | 27 | @Override 28 | public int size() { 29 | return size; 30 | } 31 | 32 | public MemorySegment getMemorySegment() { 33 | return memorySegment; 34 | } 35 | 36 | @Override 37 | public void setFloat(int index, float value) { 38 | throw new UnsupportedOperationException("setFloat"); 39 | } 40 | 41 | @Override 42 | protected FloatVector getFloatVector(VectorSpecies species, int index) { 43 | throw new UnsupportedOperationException("getFloatVector"); 44 | } 45 | 46 | @Override 47 | public GGMLType type() { 48 | return GGMLType.Q8_0; 49 | } 50 | 51 | @Override 52 | public MemorySegment asMemorySegment() { 53 | return memorySegment; 54 | } 55 | 56 | @Override 57 | public float getFloat(int index) { 58 | assert 0 <= index && index < size; 59 | int blockIndex = index / GGMLType.Q8_0.getBlockSize(); 60 | int withinBlockIndex = index % GGMLType.Q8_0.getBlockSize(); 61 | int blockOffset = blockIndex * GGMLType.Q8_0.getTypeSize(); 62 | byte quant = readByte(memorySegment, blockOffset + Float16.BYTES + withinBlockIndex); 63 | float scale = Float.float16ToFloat(readShort(memorySegment, blockOffset)); 64 | return quant * scale; 65 | } 66 | 67 | 68 | public static final ValueLayout.OfShort JAVA_SHORT_LE = ValueLayout.JAVA_SHORT.withOrder(ByteOrder.LITTLE_ENDIAN); 69 | 70 | @Override 71 | public float dot(int thisOffset, FloatTensor that, int thatOffset, int size) { 72 | if (USE_VECTOR_API) { 73 | return vectorDot(this, thisOffset, (ArrayFloatTensor) that, thatOffset, size); 74 | } else { 75 | return FloatTensor.scalarDot(this, thisOffset, that, thatOffset, size); 76 | } 77 | } 78 | 79 | private static float vectorDot(Q8_0FloatTensor thiz, int thisOffset, ArrayFloatTensor that, int thatOffset, int size) { 80 | float result = 0f; 81 | int j = 0; 82 | 83 | // Align thisOffset + startIndex to type().getBlockSize(). 84 | assert Integer.bitCount(GGMLType.Q8_0.getBlockSize()) == 1 : "power of 2"; 85 | int alignmentBound = Math.min(size, -thisOffset & (GGMLType.Q8_0.getBlockSize() - 1)); 86 | if (alignmentBound > 0) { 87 | result += FloatTensor.scalarDot(thiz, thisOffset, that, thatOffset, alignmentBound); 88 | j += alignmentBound; 89 | } 90 | assert (thisOffset + j) % GGMLType.Q8_0.getBlockSize() == 0; 91 | 92 | FloatVector val = FloatVector.zero(F_SPECIES); 93 | int blockOffset = (thisOffset + j) / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getTypeSize(); 94 | int upperBound = size / GGMLType.Q8_0.getBlockSize() * GGMLType.Q8_0.getBlockSize(); 95 | for (; j < upperBound; j += GGMLType.Q8_0.getBlockSize(), blockOffset += GGMLType.Q8_0.getTypeSize()) { 96 | float wScaleValue = Float.float16ToFloat(readShort(thiz.memorySegment, blockOffset)); 97 | var wScale = FloatVector.broadcast(F_SPECIES, wScaleValue); 98 | if (F_SPECIES.vectorBitSize() == 256) { 99 | var wBytes = ByteVector.fromMemorySegment(ByteVector.SPECIES_256, thiz.memorySegment, blockOffset + Float16.BYTES, ByteOrder.LITTLE_ENDIAN); 100 | var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + 0 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 0)); 101 | var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + 1 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 1)); 102 | var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + 2 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 2)); 103 | var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + 3 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 3)); 104 | val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); 105 | } 106 | else if (F_SPECIES.vectorBitSize() == 128) { 107 | VectorSpecies B_128 = ByteVector.SPECIES_128; 108 | // This loop cannot be unrolled, why? 109 | for (int i = 0; i < 2; ++i) { 110 | var wBytes = ByteVector.fromMemorySegment(B_128, thiz.memorySegment, blockOffset + Float16.BYTES + i * B_128.vectorByteSize(), ByteOrder.LITTLE_ENDIAN); 111 | var sum0 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 0 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 0)); 112 | var sum1 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 1 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 1)); 113 | var sum2 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 2 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 2)); 114 | var sum3 = that.getFloatVector(F_SPECIES, thatOffset + j + i * 16 + 3 * F_SPECIES.length()).mul(wBytes.castShape(F_SPECIES, 3)); 115 | val = sum0.add(sum1).add(sum2).add(sum3).fma(wScale, val); 116 | } 117 | } else { 118 | throw new UnsupportedOperationException(F_SPECIES.toString()); 119 | } 120 | } 121 | result += val.reduceLanes(VectorOperators.ADD); 122 | 123 | // Remaining entries. 124 | if (j < size) { 125 | result += FloatTensor.scalarDot(thiz, thisOffset + j, that, thatOffset + j, size - j); 126 | } 127 | 128 | return result; 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /src/main/java/com/example/core/types/Float16.java: -------------------------------------------------------------------------------- 1 | package com.example.core.types; 2 | 3 | public final class Float16 { 4 | public static final int BYTES = 2; 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/com/example/core/types/MetadataValueType.java: -------------------------------------------------------------------------------- 1 | package com.example.core.types; 2 | 3 | public enum MetadataValueType { 4 | // The value is a 8-bit unsigned integer. 5 | UINT8(1), 6 | // The value is a 8-bit signed integer. 7 | INT8(1), 8 | // The value is a 16-bit unsigned little-endian integer. 9 | UINT16(2), 10 | // The value is a 16-bit signed little-endian integer. 11 | INT16(2), 12 | // The value is a 32-bit unsigned little-endian integer. 13 | UINT32(4), 14 | // The value is a 32-bit signed little-endian integer. 15 | INT32(4), 16 | // The value is a 32-bit IEEE754 floating point number. 17 | FLOAT32(4), 18 | // The value is a boolean. 19 | // 1-byte value where 0 is false and 1 is true. 20 | // Anything else is invalid, and should be treated as either the model being invalid or the reader being buggy. 21 | BOOL(1), 22 | // The value is a UTF-8 non-null-terminated string, with length prepended. 23 | STRING(-8), 24 | // The value is an array of other values, with the length and type prepended. 25 | // Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes. 26 | ARRAY(-8), 27 | // The value is a 64-bit unsigned little-endian integer. 28 | UINT64(8), 29 | // The value is a 64-bit signed little-endian integer. 30 | INT64(8), 31 | // The value is a 64-bit IEEE754 floating point number. 32 | FLOAT64(8); 33 | private final int byteSize; 34 | 35 | MetadataValueType(int byteSize) { 36 | this.byteSize = byteSize; 37 | } 38 | 39 | private static final MetadataValueType[] VALUES = values(); 40 | 41 | public static MetadataValueType fromIndex(int index) { 42 | return VALUES[index]; 43 | } 44 | 45 | public int byteSize() { 46 | return byteSize; 47 | } 48 | } -------------------------------------------------------------------------------- /src/main/java/com/example/core/types/Pair.java: -------------------------------------------------------------------------------- 1 | package com.example.core.types; 2 | 3 | public record Pair(First first, Second second) { 4 | } 5 | -------------------------------------------------------------------------------- /src/main/java/com/example/inference/CategoricalSampler.java: -------------------------------------------------------------------------------- 1 | package com.example.inference; 2 | 3 | import com.example.core.model.tensor.FloatTensor; 4 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; 5 | 6 | import java.util.random.RandomGenerator; 7 | 8 | /** 9 | * A sampler that samples from a categorical distribution. 10 | * Supports both FloatTensor and FloatArray implementations. 11 | */ 12 | public record CategoricalSampler(RandomGenerator rng) implements Sampler { 13 | 14 | @Override 15 | public int sampleToken(Object tensor) { 16 | if (tensor instanceof FloatTensor) { 17 | return sampleFromFloatTensor((FloatTensor) tensor); 18 | } else if (tensor instanceof FloatArray) { 19 | return sampleFromFloatArray((FloatArray) tensor); 20 | } 21 | throw new IllegalArgumentException("Unsupported tensor type: " + 22 | (tensor != null ? tensor.getClass().getName() : "null")); 23 | } 24 | 25 | /** 26 | * Sample from a FloatTensor probability distribution. 27 | * 28 | * @param logits The FloatTensor containing probabilities 29 | * @return The sampled token index 30 | */ 31 | private int sampleFromFloatTensor(FloatTensor logits) { 32 | // sample index from probabilities (they must sum to 1!) 33 | float random0to1 = rng.nextFloat(1f); 34 | float cdf = 0.0f; 35 | for (int i = 0; i < logits.size(); i++) { 36 | cdf += logits.getFloat(i); 37 | if (random0to1 < cdf) { 38 | return i; 39 | } 40 | } 41 | return logits.size() - 1; // in case of rounding errors 42 | } 43 | 44 | /** 45 | * Sample from a FloatArray probability distribution. 46 | * 47 | * @param logits The FloatArray containing probabilities 48 | * @return The sampled token index 49 | */ 50 | private int sampleFromFloatArray(FloatArray logits) { 51 | // sample index from probabilities (they must sum to 1!) 52 | float random0to1 = rng.nextFloat(1f); 53 | float cdf = 0.0f; 54 | for (int i = 0; i < logits.getSize(); i++) { 55 | cdf += logits.get(i); 56 | if (random0to1 < cdf) { 57 | return i; 58 | } 59 | } 60 | return logits.getSize() - 1; // in case of rounding errors 61 | } 62 | } -------------------------------------------------------------------------------- /src/main/java/com/example/inference/Sampler.java: -------------------------------------------------------------------------------- 1 | package com.example.inference; 2 | 3 | import com.example.core.model.tensor.FloatTensor; 4 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; 5 | 6 | /** 7 | * Generic interface for sampling tokens from probability distributions. 8 | * Supports both FloatTensor and FloatArray tensor implementations. 9 | */ 10 | @FunctionalInterface 11 | public interface Sampler { 12 | /** 13 | * Sample a token from the provided tensor. 14 | * 15 | * @param tensor The tensor containing probabilities/logits 16 | * @return The selected token index 17 | */ 18 | int sampleToken(Object tensor); 19 | 20 | /** 21 | * Argmax implementation for FloatTensor. 22 | */ 23 | Sampler TENSOR_ARGMAX = tensor -> { 24 | if (tensor instanceof FloatTensor) { 25 | return ((FloatTensor) tensor).argmax(); 26 | } else if (tensor instanceof FloatArray) { 27 | return argmaxFloatArray((FloatArray) tensor); 28 | } 29 | throw new IllegalArgumentException("Unsupported tensor type: " + 30 | (tensor != null ? tensor.getClass().getName() : "null")); 31 | }; 32 | 33 | /** 34 | * Legacy ARGMAX for backward compatibility. 35 | * @deprecated Use TENSOR_ARGMAX instead 36 | */ 37 | @Deprecated 38 | Sampler ARGMAX = TENSOR_ARGMAX; 39 | 40 | /** 41 | * Find the index of the maximum value in a FloatArray. 42 | * 43 | * @param array The FloatArray to find the maximum value in 44 | * @return The index of the maximum value 45 | */ 46 | static int argmaxFloatArray(FloatArray array) { 47 | float maxValue = Float.NEGATIVE_INFINITY; 48 | int maxIndex = 0; 49 | 50 | for (int i = 0; i < array.getSize(); i++) { 51 | float value = array.get(i); 52 | if (value > maxValue) { 53 | maxValue = value; 54 | maxIndex = i; 55 | } 56 | } 57 | 58 | return maxIndex; 59 | } 60 | } -------------------------------------------------------------------------------- /src/main/java/com/example/inference/ToppSampler.java: -------------------------------------------------------------------------------- 1 | package com.example.inference; 2 | 3 | import com.example.core.model.tensor.FloatTensor; 4 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; 5 | 6 | import java.util.Comparator; 7 | import java.util.random.RandomGenerator; 8 | 9 | /** 10 | * Top-p sampling (nucleus sampling) implementation supporting both FloatTensor and FloatArray. 11 | * Samples from the smallest set of tokens that exceed probability topp. 12 | */ 13 | public final class ToppSampler implements Sampler { 14 | 15 | final int[] indices; 16 | final float topp; 17 | final RandomGenerator rng; 18 | 19 | public ToppSampler(int maxNumberOfElements, float topp, RandomGenerator rng) { 20 | this.indices = new int[maxNumberOfElements]; 21 | this.topp = topp; 22 | this.rng = rng; 23 | } 24 | 25 | static void swap(int[] array, int from, int to) { 26 | int tmp = array[from]; 27 | array[from] = array[to]; 28 | array[to] = tmp; 29 | } 30 | 31 | static void siftDown(int[] array, int from, int n, Comparator comparator) { 32 | int prev = from, next; 33 | while ((next = 2 * prev + 1) < n) { 34 | int r = 2 * prev + 2; 35 | if (r < n && comparator.compare(array[r], array[next]) < 0) { 36 | next = r; 37 | } 38 | if (comparator.compare(array[next], array[prev]) < 0) { 39 | swap(array, prev, next); 40 | prev = next; 41 | } else { 42 | break; 43 | } 44 | } 45 | } 46 | 47 | @Override 48 | public int sampleToken(Object tensor) { 49 | if (tensor instanceof FloatTensor) { 50 | return sampleFromFloatTensor((FloatTensor) tensor); 51 | } else if (tensor instanceof FloatArray) { 52 | return sampleFromFloatArray((FloatArray) tensor); 53 | } 54 | throw new IllegalArgumentException("Unsupported tensor type: " + 55 | (tensor != null ? tensor.getClass().getName() : "null")); 56 | } 57 | 58 | /** 59 | * Implementation of top-p sampling for FloatTensor. 60 | */ 61 | private int sampleFromFloatTensor(FloatTensor logits) { 62 | // Create a comparator that compares indices based on their values in the tensor 63 | Comparator comparator = Comparator.comparingDouble(logits::getFloat).reversed(); 64 | 65 | int n = logits.size(); 66 | int head = 0; 67 | int tail = n - 1; 68 | // values smaller than (1 - topp) / (n - 1) cannot be part of the result 69 | // so for efficiency we crop these out as candidates before sorting 70 | float cutoff = (1.0f - topp) / (n - 1); 71 | for (int i = 0; i < indices.length; i++) { 72 | if (logits.getFloat(i) >= cutoff) { 73 | indices[head++] = i; 74 | } else { 75 | indices[tail--] = i; 76 | } 77 | } 78 | 79 | return processTopP(logits, comparator, head); 80 | } 81 | 82 | /** 83 | * Implementation of top-p sampling for FloatArray. 84 | */ 85 | private int sampleFromFloatArray(FloatArray logits) { 86 | // Create a comparator that compares indices based on their values in the array 87 | Comparator comparator = (a, b) -> Float.compare(logits.get(b), logits.get(a)); // reversed order 88 | 89 | int n = logits.getSize(); 90 | int head = 0; 91 | int tail = n - 1; 92 | // values smaller than (1 - topp) / (n - 1) cannot be part of the result 93 | // so for efficiency we crop these out as candidates before sorting 94 | float cutoff = (1.0f - topp) / (n - 1); 95 | for (int i = 0; i < indices.length; i++) { 96 | if (logits.get(i) >= cutoff) { 97 | indices[head++] = i; 98 | } else { 99 | indices[tail--] = i; 100 | } 101 | } 102 | 103 | return processTopP(logits, comparator, head); 104 | } 105 | 106 | /** 107 | * Common implementation for processing top-p sampling once indices are prepared. 108 | * Uses a type-specific value getter function to access tensor values. 109 | */ 110 | private int processTopP(Object logits, Comparator comparator, int n0) { 111 | // build heap O(n0) 112 | for (int i = n0 / 2 - 1; i >= 0; --i) { 113 | siftDown(indices, i, n0, comparator); 114 | } 115 | 116 | // truncate the list where cumulative probability of the largest k elements exceeds topp 117 | // O(k lg n0) 118 | float cumulativeProb = 0.0f; 119 | int lastIndex = 0; 120 | for (int i = n0 - 1; i >= 0; i--) { 121 | swap(indices, 0, i); 122 | 123 | float value; 124 | if (logits instanceof FloatTensor) { 125 | value = ((FloatTensor) logits).getFloat(indices[i]); 126 | } else { 127 | value = ((FloatArray) logits).get(indices[i]); 128 | } 129 | 130 | cumulativeProb += value; 131 | if (cumulativeProb > topp) { 132 | lastIndex = i; 133 | break; // we've exceeded topp by including lastIndex 134 | } 135 | siftDown(indices, 0, i - 1, comparator); 136 | } 137 | 138 | // sample from the truncated list 139 | float r = rng.nextFloat(1f) * cumulativeProb; 140 | float cdf = 0.0f; 141 | for (int i = n0 - 1; i >= lastIndex; i--) { 142 | float value; 143 | if (logits instanceof FloatTensor) { 144 | value = ((FloatTensor) logits).getFloat(indices[i]); 145 | } else { 146 | value = ((FloatArray) logits).get(indices[i]); 147 | } 148 | 149 | cdf += value; 150 | if (r < cdf) { 151 | return indices[i]; 152 | } 153 | } 154 | 155 | return indices[lastIndex]; // in case of rounding errors 156 | } 157 | } -------------------------------------------------------------------------------- /src/main/java/com/example/inference/engine/impl/Configuration.java: -------------------------------------------------------------------------------- 1 | package com.example.inference.engine.impl; 2 | 3 | public final class Configuration { 4 | /** Transformer embedding dimension */ 5 | public final int dim; 6 | 7 | /** Hidden dimension size for feed-forward network layers */ 8 | public final int hiddenDim; 9 | 10 | /** Number of transformer layers in the model */ 11 | public final int numberOfLayers; 12 | 13 | /** Number of attention heads for queries */ 14 | public final int numberOfHeads; 15 | 16 | /** Number of key/value heads (can be fewer than query heads in multi-query attention) */ 17 | public final int numberOfKeyValueHeads; 18 | 19 | /** Size of the vocabulary (token set) */ 20 | public final int vocabularySize; 21 | 22 | /** Maximum sequence length the model can process */ 23 | public final int contextLength; 24 | 25 | /** Epsilon value for RMSNorm layers (stabilizes normalization) */ 26 | public final float rmsNormEps; 27 | 28 | /** Base value for RoPE (Rotary Position Embedding) calculations */ 29 | public final float ropeTheta; 30 | 31 | /** Size of each attention head (derived from dim / numberOfHeads) */ 32 | public final int headSize; 33 | 34 | /** Key/value dimension (derived from dim * numberOfKeyValueHeads / numberOfHeads) */ 35 | public final int kvDim; 36 | 37 | /** Multiplier for key/value sharing in multi-query attention */ 38 | public final int kvMul; 39 | 40 | /** 41 | 42 | /** 43 | * Constructs a new Configuration with the specified parameters. 44 | * 45 | * @param dim Transformer embedding dimension 46 | * @param hiddenDim Hidden dimension for feed-forward layers 47 | * @param numberOfLayers Number of transformer layers 48 | * @param numberOfHeads Number of attention heads 49 | * @param numberOfKeyValueHeads Number of key/value heads 50 | * @param vocabularySize Size of the vocabulary 51 | * @param contextLength Maximum sequence length 52 | * @param rmsNormEps Epsilon for RMSNorm 53 | * @param ropeTheta Base value for RoPE calculations 54 | */ 55 | public Configuration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, float rmsNormEps, float ropeTheta) { 56 | this.dim = dim; 57 | this.hiddenDim = hiddenDim; 58 | this.numberOfLayers = numberOfLayers; 59 | this.numberOfHeads = numberOfHeads; 60 | this.numberOfKeyValueHeads = numberOfKeyValueHeads; 61 | this.vocabularySize = vocabularySize; 62 | this.contextLength = contextLength; 63 | this.rmsNormEps = rmsNormEps; 64 | this.ropeTheta = ropeTheta; 65 | this.headSize = dim / numberOfHeads; 66 | this.kvDim = dim * numberOfKeyValueHeads / numberOfHeads; 67 | this.kvMul = numberOfHeads / numberOfKeyValueHeads; 68 | } 69 | 70 | /** 71 | * Creates a new Configuration with a different context length. 72 | * 73 | * @param newContextLength The new context length to use 74 | * @return A new Configuration instance with updated context length, 75 | * or the current instance if newContextLength is negative 76 | */ 77 | public Configuration withContextLength(int newContextLength) { 78 | if (newContextLength < 0) { 79 | return this; // no change 80 | } 81 | return new Configuration( 82 | this.dim, 83 | this.hiddenDim, 84 | this.numberOfLayers, 85 | this.numberOfHeads, 86 | this.numberOfKeyValueHeads, 87 | this.vocabularySize, 88 | newContextLength, 89 | this.rmsNormEps, 90 | this.ropeTheta 91 | ); 92 | } 93 | } 94 | 95 | -------------------------------------------------------------------------------- /src/main/java/com/example/inference/engine/impl/Llama.java: -------------------------------------------------------------------------------- 1 | package com.example.inference.engine.impl; 2 | 3 | import com.example.auxiliary.Parallel; 4 | import com.example.core.model.tensor.FloatTensor; 5 | import com.example.inference.Sampler; 6 | import com.example.loader.weights.State; 7 | import com.example.loader.weights.Weights; 8 | import com.example.tokenizer.impl.Tokenizer; 9 | import com.example.tornadovm.TornadoVMMasterPlan; 10 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; 11 | 12 | import java.lang.foreign.MemorySegment; 13 | import java.nio.FloatBuffer; 14 | import java.util.ArrayList; 15 | import java.util.List; 16 | import java.util.Set; 17 | import java.util.function.IntConsumer; 18 | 19 | public record Llama(Configuration configuration, Tokenizer tokenizer, Weights weights) { 20 | private static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16); 21 | 22 | public static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) { 23 | // calculate sum of squares 24 | float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi); 25 | ss /= size; 26 | ss += rmsNormEps; 27 | ss = (float) (1.0 / Math.sqrt(ss)); 28 | // normalize and scale 29 | final float finalss = ss; // for the lambda 30 | out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index))); 31 | } 32 | 33 | public static FloatTensor forwardJava(Llama model, State state, int token, int position) { 34 | // a few convenience variables 35 | Configuration config = model.configuration(); 36 | Weights weights = model.weights(); 37 | int dim = config.dim; 38 | int headSize = config.headSize; 39 | int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; 40 | int kvMul = config.numberOfHeads / config.numberOfKeyValueHeads; // integer multiplier of the kv sharing in multiquery 41 | float sqrtHeadSize = (float) Math.sqrt(headSize); 42 | 43 | // copy the token embedding into x 44 | weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); 45 | 46 | // forward all the layers 47 | for (int l = 0; l < config.numberOfLayers; l++) { 48 | // attention rmsnorm 49 | rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps); 50 | 51 | // qkv matmuls for this position 52 | 53 | weights.wq[l].matmul(state.xb, state.q, dim, dim); 54 | weights.wk[l].matmul(state.xb, state.k, kvDim, dim); 55 | weights.wv[l].matmul(state.xb, state.v, kvDim, dim); 56 | 57 | // RoPE relative positional encoding: complex-valued rotate q and k in each head 58 | for (int i = 0; i < dim; i += 2) { 59 | int head_dim = i % headSize; 60 | float fcr = weights.freq_cis_real.get(position * (headSize / 2) + (head_dim / 2)); 61 | float fci = weights.freq_cis_imag.get(position * (headSize / 2) + (head_dim / 2)); 62 | int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only 63 | for (int v = 0; v < rotn; v++) { 64 | FloatTensor vec = v == 0 ? state.q : state.k; // the vector to rotate (query or key) 65 | float v0 = vec.getFloat(i); 66 | float v1 = vec.getFloat(i + 1); 67 | vec.setFloat(i, v0 * fcr - v1 * fci); 68 | vec.setFloat(i + 1, v0 * fci + v1 * fcr); 69 | } 70 | } 71 | 72 | // save key,value at this time step (position) to our kv cache 73 | //int loff = l * config.seq_len * kvDim; 74 | // kv cache layer offset for convenience 75 | state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); 76 | state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); 77 | 78 | int curLayer = l; 79 | 80 | // multihead attention. iterate over all heads 81 | Parallel.parallelFor(0, config.numberOfHeads, h -> { 82 | // get the query vector for this head 83 | // float* q = s.q + h * headSize; 84 | int qOffset = h * headSize; 85 | 86 | // attention scores for this head 87 | // float* att = s.att + h * config.seq_len; 88 | int attOffset = h * config.contextLength; 89 | 90 | // iterate over all timesteps, including the current one 91 | for (int t = 0; t <= position; t++) { 92 | // get the key vector for this head and at this timestep 93 | // float* k = s.key_cache + loff + t * dim + h * headSize; 94 | int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; 95 | // calculate the attention score as the dot product of q and k 96 | float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); 97 | score /= sqrtHeadSize; 98 | // save the score to the attention buffer 99 | state.att.setFloat(attOffset + t, score); 100 | } 101 | 102 | // softmax the scores to get attention weights, from 0..position inclusively 103 | state.att.softmaxInPlace(attOffset, position + 1); 104 | 105 | // weighted sum of the values, store back into xb 106 | // float* xb = s.xb + h * headSize; 107 | int xbOffset = h * headSize; 108 | // memset(xb, 0, headSize * sizeof(float)); 109 | state.xb.fillInPlace(xbOffset, headSize, 0f); 110 | 111 | for (int t = 0; t <= position; t++) { 112 | // get the value vector for this head and at this timestep 113 | // float* v = s.value_cache + loff + t * dim + h * headSize; 114 | int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; 115 | // get the attention weight for this timestep 116 | float a = state.att.getFloat(attOffset + t); 117 | // accumulate the weighted value into xb 118 | state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); 119 | } 120 | }); 121 | 122 | // final matmul to get the output of the attention 123 | weights.wo[l].matmul(state.xb, state.xb2, dim, dim); 124 | 125 | // residual connection back into x 126 | state.x.addInPlace(state.xb2); 127 | 128 | // ffn rmsnorm 129 | rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], dim, config.rmsNormEps); 130 | 131 | // System.out.println("x " + weights.w1.toString() + " " + weights.w2.toString() + " " + weights.w3.toString()); 132 | // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) 133 | // first calculate self.w1(x) and self.w3(x) 134 | weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim, dim); 135 | weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim, dim); 136 | 137 | // SwiGLU non-linearity 138 | // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid 139 | state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); 140 | 141 | // elementwise multiply with w3(x) 142 | state.hb.multiplyInPlace(state.hb2); 143 | 144 | // final matmul to get the output of the ffn 145 | weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim); 146 | 147 | // residual connection 148 | state.x.addInPlace(state.xb); 149 | } 150 | 151 | rmsnorm(state.x, state.x, weights.rms_final_weight, dim, config.rmsNormEps); 152 | 153 | weights.wcls.matmul(state.x, state.logits, config.vocabularySize, dim); 154 | 155 | return state.logits; 156 | } 157 | 158 | /** 159 | * Performs the initial embedding lookup and triggers the TornadoVM accelerated forward pass for an LLM token. 160 | * 161 | *

This method handles the first phase of processing a token through the transformer model: 162 | *

    163 | *
  1. Copies the token embedding from the model's embedding table to the state's buffer
  2. 164 | *
  3. Delegates the transformer layer processing to TornadoVM through the master plan
  4. 165 | *
166 | * 167 | *

The token embedding lookup happens on the CPU using {@link MemorySegment} operations, 168 | * while the subsequent transformer layers processing is offloaded to the accelerator through 169 | * TornadoVM for improved performance. 170 | * 171 | * @param model 172 | * The Llama model containing weights and configuration parameters 173 | * @param state 174 | * The current execution state holding input/output tensors and temporary buffers 175 | * @param token 176 | * The input token ID to process 177 | * @param position 178 | * The position of this token in the sequence context window 179 | * @param tornadoVMMasterPlan 180 | * The execution plan for TornadoVM acceleration 181 | * @return FloatTensor containing the output logits for token prediction 182 | */ 183 | public static FloatArray forwardTornadoVM( // 184 | Llama model, // 185 | State state, // 186 | int token, // 187 | int position, // 188 | TornadoVMMasterPlan tornadoVMMasterPlan) { // 189 | 190 | MemorySegment.copy(model.weights.tokenEmbeddingTable.getSegment(), token * model.configuration.dim * Float.BYTES, state.wrapX.getSegment(), 0, model.configuration.dim * Float.BYTES); 191 | 192 | return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position); 193 | } 194 | 195 | public static List generateTokensGPU(Llama model, State state, 196 | int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, 197 | TornadoVMMasterPlan tornadoVMPlan) { 198 | // === Setup and Initialization === 199 | long startNanos = System.nanoTime(); 200 | long inferenceStartNanos = 0; 201 | 202 | // Pre-validate the max tokens to avoid checking in the loop 203 | int actualMaxTokens = Math.min(maxTokens > 0 ? maxTokens : model.configuration().contextLength, model.configuration().contextLength); 204 | 205 | // Preallocate with expected capacity to avoid resizing 206 | List generatedTokens = new ArrayList<>(Math.min(256, actualMaxTokens - promptTokens.size())); // Conservative estimate 207 | 208 | // === Token Generation Loop === 209 | int currentToken = state.latestToken; 210 | int nextToken; 211 | int promptIndex = 0; 212 | int pos = startPosition; 213 | 214 | // Use more efficient direct array access for prompt tokens if possible 215 | int[] promptTokenArray = null; 216 | if (promptTokens instanceof ArrayList) { 217 | // Try to extract the underlying array for faster access 218 | try { 219 | // This is a performance optimization that may not work on all JVMs 220 | promptTokenArray = promptTokens.stream().mapToInt(Integer::intValue).toArray(); 221 | } catch (Exception e) { 222 | // Fall back to list access 223 | } 224 | } 225 | 226 | // Main generation loop 227 | while (pos < actualMaxTokens) { 228 | // GPU Forward Pass - No conditional check since we know we're using GPU 229 | FloatArray logits = forwardTornadoVM(model, state, currentToken, pos, tornadoVMPlan); 230 | 231 | // Process prompt tokens if still remaining 232 | if (promptIndex < promptTokens.size()) { 233 | // Get next prompt token (using array access if available) 234 | nextToken = promptTokenArray != null ? promptTokenArray[promptIndex++] : promptTokens.get(promptIndex++); 235 | 236 | if (echo) { 237 | // Decode and output token 238 | System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); 239 | } 240 | } else { 241 | // Mark first inference token 242 | if (inferenceStartNanos == 0) { 243 | inferenceStartNanos = System.nanoTime(); 244 | } 245 | 246 | // Sample next token - use GPU sampling if available 247 | nextToken = sampler.sampleToken(logits); 248 | 249 | // Add token consumer support 250 | if (onTokenGenerated != null) { 251 | onTokenGenerated.accept(nextToken); 252 | } 253 | 254 | // Output if needed 255 | if (echo && onTokenGenerated == null) { 256 | System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); 257 | } 258 | 259 | // Store token 260 | generatedTokens.add(nextToken); 261 | 262 | // Check stop condition 263 | if (stopTokens.contains(nextToken)) { 264 | break; 265 | } 266 | } 267 | 268 | // Update for next iteration 269 | currentToken = nextToken; 270 | state.latestToken = currentToken; 271 | pos++; 272 | } 273 | 274 | // === Performance Metrics === 275 | long endNanos = System.nanoTime(); 276 | double totalSeconds = (endNanos - startNanos) / 1_000_000_000.0; 277 | int totalTokens = promptIndex + generatedTokens.size(); 278 | 279 | // Set metrics for tokens achieved 280 | LastRunMetrics.setMetrics(totalTokens, totalSeconds); 281 | 282 | return generatedTokens; 283 | } 284 | 285 | public static List generateTokens(Llama model, State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, 286 | IntConsumer onTokenGenerated) { 287 | // Initialize TornadoVM plan if enabled 288 | 289 | // Start timing the whole process 290 | long startNanos = System.nanoTime(); 291 | long inferenceStartNanos = 0; 292 | 293 | Object logits; 294 | // Validate and adjust maxTokens if necessary 295 | if (maxTokens < 0 || model.configuration().contextLength < maxTokens) { 296 | maxTokens = model.configuration().contextLength; 297 | } 298 | 299 | // Storage for generated tokens 300 | List generatedTokens = new ArrayList<>(); 301 | 302 | // Initialize token variables 303 | int currentToken = state.latestToken; 304 | int nextToken; 305 | int promptIndex = 0; 306 | int pos = startPosition; 307 | 308 | while (pos < maxTokens) { 309 | 310 | logits = forwardJava(model, state, currentToken, pos); 311 | 312 | // Handle token processing 313 | if (promptIndex < promptTokens.size()) { 314 | // We're still processing the prompt tokens 315 | nextToken = promptTokens.get(promptIndex++); 316 | if (echo) { 317 | System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); 318 | } 319 | } else { 320 | // Mark the start of actual generation (after prompt processing) 321 | if (inferenceStartNanos == 0) { 322 | inferenceStartNanos = System.nanoTime(); 323 | } 324 | 325 | // Sample the next token 326 | nextToken = sampler.sampleToken(logits); 327 | 328 | // Output the token if echo is enabled 329 | if (echo) { 330 | System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken)))); 331 | } 332 | 333 | // Track the generated token 334 | generatedTokens.add(nextToken); 335 | 336 | // Notify via callback if provided 337 | if (onTokenGenerated != null) { 338 | onTokenGenerated.accept(nextToken); 339 | } 340 | 341 | // Check for stop condition 342 | if (stopTokens.contains(nextToken)) { 343 | break; 344 | } 345 | } 346 | 347 | // Update for next iteration 348 | currentToken = nextToken; 349 | state.latestToken = currentToken; 350 | pos++; 351 | } 352 | 353 | // Calculate and print performance metrics 354 | long endNanos = System.nanoTime(); 355 | double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0; 356 | int totalTokens = promptIndex + generatedTokens.size(); 357 | 358 | LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds); 359 | 360 | return generatedTokens; 361 | } 362 | 363 | 364 | public State createNewState() { 365 | State state = new State(configuration(), -1); 366 | state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>"); 367 | return state; 368 | } 369 | 370 | public State createNewState(int batchsize) { 371 | State state = new State(configuration(), batchsize); 372 | state.latestToken = tokenizer.getSpecialTokens().get("<|begin_of_text|>"); 373 | return state; 374 | } 375 | 376 | /** 377 | * Record to store metrics from the last model run. 378 | * @param totalTokens The total number of tokens processed 379 | * @param totalSeconds The total time in seconds 380 | */ 381 | public record LastRunMetrics(int totalTokens, double totalSeconds) { 382 | /** 383 | * Singleton instance to store the latest metrics 384 | */ 385 | private static LastRunMetrics latestMetrics; 386 | 387 | /** 388 | * Sets the metrics for the latest run 389 | * 390 | * @param tokens The total number of tokens processed 391 | * @param seconds The total time in seconds 392 | */ 393 | public static void setMetrics(int tokens, double seconds) { 394 | latestMetrics = new LastRunMetrics(tokens, seconds); 395 | } 396 | 397 | /** 398 | * Prints the metrics from the latest run to stderr 399 | */ 400 | public static void printMetrics() { 401 | if (latestMetrics != null) { 402 | double tokensPerSecond = latestMetrics.totalTokens() / latestMetrics.totalSeconds(); 403 | System.err.printf("\n\nachieved tok/s: %.2f. Tokens: %d, seconds: %.2f\n", tokensPerSecond, latestMetrics.totalTokens(), latestMetrics.totalSeconds()); 404 | } 405 | } 406 | } 407 | 408 | } 409 | 410 | -------------------------------------------------------------------------------- /src/main/java/com/example/inference/engine/impl/Options.java: -------------------------------------------------------------------------------- 1 | package com.example.inference.engine.impl; 2 | 3 | import java.io.PrintStream; 4 | import java.nio.file.Path; 5 | import java.nio.file.Paths; 6 | 7 | public record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive, 8 | float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) { 9 | 10 | public static final int DEFAULT_MAX_TOKENS = 1024; 11 | 12 | public Options { 13 | require(modelPath != null, "Missing argument: --model is required"); 14 | require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"" ); 15 | require(0 <= temperature, "Invalid argument: --temperature must be non-negative"); 16 | require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]"); 17 | } 18 | 19 | static void require(boolean condition, String messageFormat, Object... args) { 20 | if (!condition) { 21 | System.out.println("ERROR " + messageFormat.formatted(args)); 22 | System.out.println(); 23 | printUsage(System.out); 24 | System.exit(-1); 25 | } 26 | } 27 | 28 | static void printUsage(PrintStream out) { 29 | out.println("Usage: jbang Llama3.java [options]"); 30 | out.println(); 31 | out.println("Options:"); 32 | out.println(" --model, -m required, path to .gguf file"); 33 | out.println(" --interactive, --chat, -i run in chat mode"); 34 | out.println(" --instruct run in instruct (once) mode, default mode"); 35 | out.println(" --prompt, -p input prompt"); 36 | out.println(" --system-prompt, -sp (optional) system prompt"); 37 | out.println(" --temperature, -temp temperature in [0,inf], default 0.1"); 38 | out.println(" --top-p p value in top-p (nucleus) sampling in [0,1] default 0.95"); 39 | out.println(" --seed random seed, default System.nanoTime()"); 40 | out.println(" --max-tokens, -n number of steps to run for < 0 = limited by context length, default " + DEFAULT_MAX_TOKENS); 41 | out.println(" --stream print tokens during generation; may cause encoding artifacts for non ASCII text, default true"); 42 | out.println(" --echo print ALL tokens to stderr, if true, recommended to set --stream=false, default false"); 43 | out.println(); 44 | } 45 | 46 | public static Options parseOptions(String[] args) { 47 | String prompt = "Tell me a story with Java"; // Hardcoded for testing 48 | String systemPrompt = null; 49 | float temperature = 0.1f; 50 | float topp = 0.95f; 51 | Path modelPath = null; 52 | long seed = System.nanoTime(); 53 | // Keep max context length small for low-memory devices. 54 | int maxTokens = DEFAULT_MAX_TOKENS; 55 | boolean interactive = false; 56 | boolean stream = true; 57 | boolean echo = false; 58 | 59 | for (int i = 0; i < args.length; i++) { 60 | String optionName = args[i]; 61 | require(optionName.startsWith("-"), "Invalid option %s", optionName); 62 | switch (optionName) { 63 | case "--interactive", "--chat", "-i" -> interactive = true; 64 | case "--instruct" -> interactive = false; 65 | case "--help", "-h" -> { 66 | printUsage(System.out); 67 | System.exit(0); 68 | } 69 | default -> { 70 | String nextArg; 71 | if (optionName.contains("=")) { 72 | String[] parts = optionName.split("=", 2); 73 | optionName = parts[0]; 74 | nextArg = parts[1]; 75 | } else { 76 | require(i + 1 < args.length, "Missing argument for option %s", optionName); 77 | nextArg = args[i + 1]; 78 | i += 1; // skip arg 79 | } 80 | switch (optionName) { 81 | case "--prompt", "-p" -> prompt = nextArg; 82 | case "--system-prompt", "-sp" -> systemPrompt = nextArg; 83 | case "--temperature", "--temp" -> temperature = Float.parseFloat(nextArg); 84 | case "--top-p" -> topp = Float.parseFloat(nextArg); 85 | case "--model", "-m" -> modelPath = Paths.get(nextArg); 86 | case "--seed", "-s" -> seed = Long.parseLong(nextArg); 87 | case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg); 88 | case "--stream" -> stream = Boolean.parseBoolean(nextArg); 89 | case "--echo" -> echo = Boolean.parseBoolean(nextArg); 90 | default -> require(false, "Unknown option: %s", optionName); 91 | } 92 | } 93 | } 94 | } 95 | return new Options(modelPath, prompt, systemPrompt, interactive, temperature, topp, seed, maxTokens, stream, echo); 96 | } 97 | } 98 | -------------------------------------------------------------------------------- /src/main/java/com/example/inference/operation/RoPE.java: -------------------------------------------------------------------------------- 1 | package com.example.inference.operation; 2 | 3 | import com.example.core.types.Pair; 4 | 5 | public final class RoPE { 6 | public static Pair precomputeFreqsCis(int contextLength, int headSize, double theta, 7 | boolean ropeScaling, float scaleFactor, float loFreqFactor, float hiFreqFactor, float oldContextLength) { 8 | assert headSize % 2 == 0; 9 | float[] cr = new float[contextLength * (headSize / 2)]; 10 | float[] ci = new float[contextLength * (headSize / 2)]; 11 | int n = 0; 12 | for (int pos = 0; pos < contextLength; ++pos) { 13 | for (int i = 0; i < headSize; i += 2) { 14 | float freq = (float) (1.0 / Math.pow(theta, i / (double) headSize)); 15 | if (ropeScaling) { 16 | // Llama 3.1 scaling 17 | float loFreqWavelen = oldContextLength / loFreqFactor; 18 | float hiFreqWavelen = oldContextLength / hiFreqFactor; 19 | float wavelen = (float) (2.0 * Math.PI / freq); 20 | if (wavelen < hiFreqWavelen) { 21 | freq = freq; 22 | } else if (wavelen > loFreqWavelen) { 23 | freq = freq / scaleFactor; 24 | } else { 25 | float smooth = (oldContextLength / wavelen - loFreqFactor) / (hiFreqFactor - loFreqFactor); 26 | freq = (1.0f - smooth) * freq / scaleFactor + smooth * freq; 27 | } 28 | } 29 | float val = pos * freq; 30 | cr[n] = (float) Math.cos(val); 31 | ci[n] = (float) Math.sin(val); 32 | n++; 33 | } 34 | } 35 | assert contextLength * (headSize / 2) == n; 36 | return new Pair<>(cr, ci); 37 | } 38 | } -------------------------------------------------------------------------------- /src/main/java/com/example/loader/weights/ModelLoader.java: -------------------------------------------------------------------------------- 1 | package com.example.loader.weights; 2 | 3 | import com.example.LlamaApp; 4 | import com.example.auxiliary.Timer; 5 | import com.example.core.model.GGMLType; 6 | import com.example.core.model.GGUF; 7 | import com.example.core.model.tensor.F16FloatTensor; 8 | import com.example.core.model.tensor.FloatTensor; 9 | import com.example.core.model.tensor.GGMLTensorEntry; 10 | import com.example.core.model.tensor.Q4_0FloatTensor; 11 | import com.example.core.model.tensor.Q8_0FloatTensor; 12 | import com.example.core.types.Pair; 13 | import com.example.inference.engine.impl.Configuration; 14 | import com.example.inference.engine.impl.Llama; 15 | import com.example.inference.operation.RoPE; 16 | import com.example.tokenizer.impl.Tokenizer; 17 | import com.example.tokenizer.vocabulary.Vocabulary; 18 | import uk.ac.manchester.tornado.api.types.HalfFloat; 19 | import uk.ac.manchester.tornado.api.types.arrays.ByteArray; 20 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; 21 | import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; 22 | 23 | import java.io.IOException; 24 | import java.nio.ByteOrder; 25 | import java.nio.FloatBuffer; 26 | import java.nio.channels.FileChannel; 27 | import java.nio.file.Path; 28 | import java.nio.file.StandardOpenOption; 29 | import java.util.Arrays; 30 | import java.util.List; 31 | import java.util.Map; 32 | import java.util.function.IntFunction; 33 | import java.util.stream.Collectors; 34 | import java.util.stream.IntStream; 35 | 36 | public final class ModelLoader { 37 | private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2"; 38 | 39 | private static final String LLAMA_3_PATTERN = "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; 40 | 41 | public static Llama loadModel(Path ggufPath, int contextLength, boolean loadWeights) throws IOException { 42 | GGUF gguf = GGUF.loadModel(ggufPath); 43 | FileChannel fileChannel = FileChannel.open(ggufPath, StandardOpenOption.READ); 44 | return loadModel(fileChannel, gguf, contextLength, loadWeights); 45 | } 46 | 47 | public static Llama loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights) throws IOException { 48 | try (var ignored = Timer.log("Load LlaMa model")) { 49 | Map metadata = gguf.getMetadata(); 50 | Vocabulary vocabulary = Vocabulary.loadVocabulary(metadata); 51 | Tokenizer tokenizer = createTokenizer(metadata, vocabulary); 52 | 53 | Configuration config = new Configuration((int) metadata.get("llama.embedding_length"), (int) metadata.get("llama.feed_forward_length"), (int) metadata.get("llama.block_count"), 54 | (int) metadata.get("llama.attention.head_count"), 55 | 56 | metadata.containsKey("llama.attention.head_count_kv") ? (int) metadata.get("llama.attention.head_count_kv") : (int) metadata.get("llama.attention.head_count"), 57 | 58 | vocabulary.size(), (int) metadata.get("llama.context_length"), (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), 59 | (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)).withContextLength(contextLength); 60 | 61 | Weights weights = null; 62 | if (loadWeights) { 63 | Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); 64 | weights = loadWeights(tensorEntries, config); 65 | } 66 | return new Llama(config, tokenizer, weights); 67 | } 68 | } 69 | 70 | public static Weights loadWeights(Map tensorEntries, Configuration config) { 71 | boolean ropeScaling = tensorEntries.containsKey("rope_freqs"); 72 | RopeConfig ropeConfig = new RopeConfig(8.0f, // scaleFactor 73 | 1.0f, // loFreqFactor 74 | 3.0f, // hiFreqFactor 75 | 8192 // oldContextLength 76 | ); 77 | 78 | Pair ropeFreqs = RoPE.precomputeFreqsCis(config.contextLength, // Maximum sequence length the model can process 79 | config.headSize, // Dimension of each attention head 80 | config.ropeTheta, // Base frequency parameter (typically 10000.0) 81 | ropeScaling, // Whether to apply frequency scaling (determined by model type) 82 | ropeConfig.scaleFactor, // Scale factor for extending context length (NTK-aware scaling) 83 | ropeConfig.loFreqFactor, // Low frequency scaling factor for better long-range dependencies 84 | ropeConfig.hiFreqFactor, // High frequency scaling factor for preserving local precision 85 | ropeConfig.oldContextLength // Original context length the model was trained with 86 | ); 87 | 88 | GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); 89 | GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings); 90 | 91 | if (LlamaApp.USE_TORNADOVM) { 92 | System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); 93 | return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); 94 | } else { 95 | return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); 96 | } 97 | } 98 | 99 | private static Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, 100 | GGMLTensorEntry outputWeight) { 101 | return new Weights( 102 | // Load directly to TornadoVM format 103 | loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), 104 | loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), 105 | loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), 106 | loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), 107 | loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), 108 | loadArrayAsFloatArrayFromBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), 109 | loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), 110 | loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), 111 | loadArrayAsHalfFloatArray(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), 112 | FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType()); 113 | } 114 | 115 | /** 116 | * Creates weights in standard format only 117 | */ 118 | private static Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, 119 | GGMLTensorEntry outputWeight) { 120 | return new Weights(loadQuantized(tokenEmbeddings), loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), 121 | loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), 122 | loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), 123 | loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), 124 | loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), 125 | loadArrayOfFloatBuffer(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), 126 | loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), 127 | loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), 128 | loadArrayOfQuantized(config.numberOfLayers, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), toFloatBuffer(tensorEntries.get("output_norm.weight")), 129 | FloatBuffer.wrap(ropeFreqs.first()), FloatBuffer.wrap(ropeFreqs.second()), loadQuantized(outputWeight), outputWeight.ggmlType()); 130 | } 131 | 132 | private static Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { 133 | String[] mergeLines = (String[]) metadata.get("tokenizer.ggml.merges"); 134 | List> merges = Arrays.stream(mergeLines).map(line -> line.split(" ")) 135 | .map(parts -> new Pair<>(vocabulary.getIndex(parts[0]).orElseThrow(), vocabulary.getIndex(parts[1]).orElseThrow())).toList(); 136 | 137 | int allTokens = vocabulary.size(); 138 | int baseTokens = 128000; // assume all tokens after the base ones are special. 139 | int reservedSpecialTokens = allTokens - baseTokens; 140 | List specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList(); 141 | 142 | assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent()); 143 | 144 | Map specialTokens = IntStream.range(0, specialTokensList.size()).boxed().collect(Collectors.toMap(i -> specialTokensList.get(i), i -> baseTokens + i)); 145 | 146 | return new Tokenizer(vocabulary, merges, LLAMA_3_PATTERN, specialTokens); 147 | } 148 | 149 | public static FloatTensor loadQuantized(GGMLTensorEntry entry) { 150 | GGMLType ggmlType = entry.ggmlType(); 151 | return switch (ggmlType) { 152 | // case F32 -> new F32FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); 153 | case Q8_0 -> new Q8_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); 154 | case Q4_0 -> new Q4_0FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); 155 | case F16 -> new F16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); 156 | default -> throw new UnsupportedOperationException("Quantization format " + ggmlType); 157 | }; 158 | } 159 | 160 | public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction getTensorEntry) { 161 | FloatArray[] array = new FloatArray[size]; 162 | for (int i = 0; i < size; i++) { 163 | array[i] = loadTensorAsFloatArray(getTensorEntry.apply(i)); 164 | } 165 | return array; 166 | } 167 | 168 | public static HalfFloatArray[] loadArrayAsHalfFloatArray(int size, IntFunction getTensorEntry) { 169 | HalfFloatArray[] array = new HalfFloatArray[size]; 170 | for (int i = 0; i < size; i++) { 171 | array[i] = loadTensorAsHalfFloatArray(getTensorEntry.apply(i)); 172 | } 173 | return array; 174 | } 175 | 176 | public static FloatArray floatBufferToFloatArray(GGMLTensorEntry tensorEntry) { 177 | if (tensorEntry.ggmlType() == GGMLType.F32) { 178 | FloatBuffer buffer = tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); 179 | return FloatArray.fromFloatBuffer(buffer); 180 | } else { 181 | throw new UnsupportedOperationException("Conversion to FloatArray from " + tensorEntry.ggmlType()); 182 | } 183 | } 184 | 185 | public static FloatArray[] loadArrayAsFloatArrayFromBuffer(int size, IntFunction getTensorEntry) { 186 | FloatArray[] array = new FloatArray[size]; 187 | for (int i = 0; i < size; i++) { 188 | array[i] = floatBufferToFloatArray(getTensorEntry.apply(i)); 189 | } 190 | return array; 191 | } 192 | 193 | public static ByteArray createByteArrayFromTensor(GGMLTensorEntry entry) { 194 | FloatTensor tensor = loadQuantized(entry); 195 | return ByteArray.fromSegment(tensor.asMemorySegment()); 196 | } 197 | 198 | public static FloatArray loadTensorAsFloatArray(GGMLTensorEntry entry) { 199 | if (entry.ggmlType() == GGMLType.F32) { 200 | // For F32, we can directly create FloatArray from memory 201 | FloatBuffer buffer = entry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); 202 | FloatArray array = new FloatArray(buffer.remaining()); 203 | for (int i = 0; i < buffer.remaining(); i++) { 204 | array.set(i, buffer.get()); 205 | } 206 | return array; 207 | } else { 208 | // For quantized formats, we need to load through FloatTensor 209 | FloatTensor tensor = loadQuantized(entry); 210 | FloatArray array = new FloatArray(tensor.size()); 211 | for (int i = 0; i < tensor.size(); i++) { 212 | array.set(i, tensor.getFloat(i)); 213 | } 214 | return array; 215 | } 216 | } 217 | 218 | public static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) { 219 | if (entry.ggmlType() == GGMLType.F32) { 220 | System.out.println("Loading F32 tensor as HalfFloatArray"); 221 | return null; 222 | } else { 223 | // For quantized formats, we need to load through FloatTensor 224 | FloatTensor tensor = loadQuantized(entry); 225 | HalfFloatArray array = new HalfFloatArray(tensor.size()); 226 | for (int i = 0; i < tensor.size(); i++) { 227 | HalfFloat x = new HalfFloat(tensor.getFloat(i)); 228 | array.set(i, x); 229 | } 230 | return array; 231 | } 232 | } 233 | 234 | public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction getTensorEntry) { 235 | FloatTensor[] array = new FloatTensor[size]; 236 | for (int i = 0; i < size; i++) { 237 | array[i] = loadQuantized(getTensorEntry.apply(i)); 238 | } 239 | return array; 240 | } 241 | 242 | public static FloatBuffer[] loadArrayOfFloatBuffer(int size, IntFunction getTensorEntry) { 243 | FloatBuffer[] array = new FloatBuffer[size]; 244 | for (int i = 0; i < size; i++) { 245 | array[i] = toFloatBuffer(getTensorEntry.apply(i)); 246 | } 247 | return array; 248 | } 249 | 250 | public static FloatBuffer toFloatBuffer(GGMLTensorEntry tensorEntry) { 251 | GGMLType ggmlType = tensorEntry.ggmlType(); 252 | return switch (ggmlType) { 253 | case F32 -> tensorEntry.memorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer(); 254 | default -> throw new UnsupportedOperationException("Conversion to " + ggmlType); 255 | }; 256 | } 257 | 258 | // Helper class to encapsulate RoPE configuration parameters 259 | private static class RopeConfig { 260 | final float scaleFactor; 261 | final float loFreqFactor; 262 | final float hiFreqFactor; 263 | final int oldContextLength; 264 | 265 | RopeConfig(float scaleFactor, float loFreqFactor, float hiFreqFactor, int oldContextLength) { 266 | this.scaleFactor = scaleFactor; 267 | this.loFreqFactor = loFreqFactor; 268 | this.hiFreqFactor = hiFreqFactor; 269 | this.oldContextLength = oldContextLength; 270 | } 271 | } 272 | 273 | } 274 | -------------------------------------------------------------------------------- /src/main/java/com/example/loader/weights/State.java: -------------------------------------------------------------------------------- 1 | package com.example.loader.weights; 2 | 3 | import com.example.core.model.tensor.ArrayFloatTensor; 4 | import com.example.core.model.tensor.FloatTensor; 5 | import com.example.inference.engine.impl.Configuration; 6 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; 7 | import uk.ac.manchester.tornado.api.types.arrays.IntArray; 8 | 9 | import java.util.stream.Stream; 10 | 11 | public final class State { 12 | 13 | // current wave of activations 14 | public final FloatTensor x; // activation at current time stamp (dim,) 15 | public final FloatTensor xb; // same, but inside a residual branch (dim,) 16 | public final FloatTensor xb2; // an additional buffer just for convenience (dim,) 17 | public final FloatTensor hb; // buffer for hidden dimension in the ffn (hidden_dim,) 18 | public final FloatTensor hb2; // buffer for hidden dimension in the ffn (hidden_dim,) 19 | public final FloatTensor q; // query (dim,) 20 | public final FloatTensor k; // key (dim,) 21 | public final FloatTensor v; // value (dim,) 22 | public final FloatTensor att; // buffer for scores/attention values (n_heads, seq_len) 23 | public final FloatTensor logits; // output logits 24 | public final int batchsize; 25 | 26 | // kv cache 27 | public final FloatTensor[] keyCache; // (n_layer, seq_len, kv_dim) 28 | public final FloatTensor[] valueCache; // (n_layer, seq_len, kv_dim) 29 | 30 | // Wrappers for TornadoVM compatibility (FloatArray data structure for TornadoVM acceleration) 31 | // TornadoVM uses FloatArray for more efficient handling of data, particularly when running on GPU or other accelerators. 32 | public final FloatArray wrapLogits; // FloatArray wrapper for the logits tensor, compatible with TornadoVM for GPU execution. 33 | public final FloatArray wrapXb; // FloatArray wrapper for xb (residual branch activation), optimized for TornadoVM usage. 34 | public final FloatArray wrapXb2; // FloatArray wrapper for xb2, another residual buffer to aid in computations with TornadoVM. 35 | public final FloatArray wrapHb; // FloatArray wrapper for hb (hidden dimension buffer for FFN), optimized for TornadoVM. 36 | public final FloatArray wrapHb2; // FloatArray wrapper for hb2, additional hidden buffer for FFN, for compatibility with TornadoVM. 37 | public final FloatArray wrapX; // FloatArray wrapper for the current activation tensor, optimized for TornadoVM. 38 | 39 | public final FloatArray wrapQ; // FloatArray wrapper for the query tensor, optimized for TornadoVM. 40 | public final FloatArray wrapK; // FloatArray wrapper for the key tensor, optimized for TornadoVM. 41 | public final FloatArray wrapV; // FloatArray wrapper for the value tensor, optimized for TornadoVM. 42 | public final FloatArray wrapAtt; // FloatArray wrapper for the attention scores, optimized for TornadoVM. 43 | public final FloatArray wrapKeyCache;// FloatArray wrapper for the key cache, optimized for TornadoVM. 44 | public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM. 45 | public final IntArray positionHolder; 46 | 47 | // store inter 48 | // 49 | public int localSize; 50 | public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size. 51 | public FloatArray tempFFN; // Temporary buffer for feed-forward network calculations, size adjusted for local workgroup size. 52 | public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size. 53 | 54 | public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models. 55 | 56 | /** last index in previous block */ 57 | 58 | public State(Configuration config, int batchsize) { 59 | this.batchsize = -1; 60 | 61 | this.x = ArrayFloatTensor.allocate(config.dim); 62 | this.xb = ArrayFloatTensor.allocate(config.dim); 63 | this.xb2 = ArrayFloatTensor.allocate(config.dim); 64 | this.hb = ArrayFloatTensor.allocate(config.hiddenDim); 65 | this.hb2 = ArrayFloatTensor.allocate(config.hiddenDim); 66 | this.q = ArrayFloatTensor.allocate(config.dim); 67 | this.k = ArrayFloatTensor.allocate(config.dim); 68 | this.v = ArrayFloatTensor.allocate(config.dim); 69 | this.att = ArrayFloatTensor.allocate(config.numberOfHeads, config.contextLength); 70 | this.logits = ArrayFloatTensor.allocate(config.vocabularySize); 71 | int kvDim = (config.dim * config.numberOfKeyValueHeads) / config.numberOfHeads; 72 | this.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new); 73 | this.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength, kvDim)).limit(config.numberOfLayers).toArray(FloatTensor[]::new); 74 | 75 | this.wrapX = new FloatArray(config.dim); 76 | this.wrapXb = new FloatArray(config.dim); 77 | this.wrapXb2 = new FloatArray(config.dim); 78 | this.wrapHb = new FloatArray(config.hiddenDim); 79 | this.wrapHb2 = new FloatArray(config.hiddenDim); 80 | 81 | this.wrapLogits = new FloatArray(config.vocabularySize); 82 | this.wrapQ = new FloatArray(config.dim); 83 | this.wrapK = new FloatArray(config.dim); 84 | this.wrapV = new FloatArray(config.dim); 85 | 86 | // dim vs kvdim 87 | this.wrapKeyCache = new FloatArray(config.contextLength * kvDim * config.numberOfLayers); 88 | this.wrapValueCache = new FloatArray(config.contextLength * kvDim * config.numberOfLayers); 89 | this.wrapValueCache.init(0.f); 90 | this.wrapKeyCache.init(0.f); 91 | this.wrapAtt = new FloatArray(config.numberOfHeads * config.contextLength); 92 | this.positionHolder = new IntArray(1); 93 | this.latestToken = -1; 94 | 95 | // 96 | this.localSize = 256; 97 | // You need at least 9 elements: 1 for the final result + 8 for the workgroup partial sums 98 | this.temp = new FloatArray(1 + ((config.dim + localSize-1) / localSize)); 99 | this.tempFFN = new FloatArray(1 + ((config.dim + localSize-1) / localSize)); 100 | this.tempLogits = new FloatArray(1 + ((config.dim + localSize-1) / localSize)); 101 | } 102 | 103 | @Override 104 | public State clone() throws CloneNotSupportedException { 105 | return (State) super.clone(); 106 | } 107 | } -------------------------------------------------------------------------------- /src/main/java/com/example/loader/weights/Weights.java: -------------------------------------------------------------------------------- 1 | package com.example.loader.weights; 2 | 3 | import com.example.core.model.GGMLType; 4 | import com.example.core.model.tensor.FloatTensor; 5 | import com.example.core.model.tensor.GGMLTensorEntry; 6 | import com.example.core.types.Float16; 7 | import uk.ac.manchester.tornado.api.types.HalfFloat; 8 | import uk.ac.manchester.tornado.api.types.arrays.ByteArray; 9 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; 10 | import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; 11 | 12 | import java.lang.foreign.MemorySegment; 13 | import java.nio.ByteOrder; 14 | import java.nio.FloatBuffer; 15 | import java.util.function.IntFunction; 16 | 17 | import static com.example.core.model.tensor.FloatTensor.readByte; 18 | import static com.example.core.model.tensor.FloatTensor.readShort; 19 | 20 | public class Weights { 21 | // token embedding table 22 | public final FloatTensor token_embedding_table; // (vocab_size, dim) 23 | // weights for rmsnorms 24 | public final FloatBuffer[] rms_att_weight; // (layer, dim) rmsnorm weights 25 | // weights for matmuls 26 | public final FloatTensor[] wq; // (layer, n_heads * head_size) 27 | public final FloatTensor[] wk; // (layer, n_kv_heads, head_size) 28 | public final FloatTensor[] wv; // (layer, n_kv_heads * head_size) 29 | public final FloatTensor[] wo; // (layer, n_heads * head_size, dim) 30 | public final FloatBuffer[] rms_ffn_weight; // (layer, dim) 31 | 32 | // weights for ffn 33 | public final FloatTensor[] w1; // (layer, hidden_dim, dim) 34 | public final FloatTensor[] w2; // (layer, dim, hidden_dim) 35 | public final FloatTensor[] w3; // (layer, hidden_dim, dim) 36 | // 37 | public final FloatTensor wcls; // (vocab_size, dim) 38 | public final HalfFloatArray wclsHalfFloat; 39 | // public final rmsnorm 40 | public final FloatBuffer rms_final_weight; // (dim,) 41 | // freq_cis for RoPE relatively positional embeddings 42 | public final FloatBuffer freq_cis_real; // (seq_len, head_size/2) 43 | public final FloatBuffer freq_cis_imag; // (seq_len, head_size/2) 44 | // // Layered Data structures 45 | public FloatArray[] rms_att_weightLayered; // (layer, dim) rmsnorm weights 46 | public HalfFloatArray[] wqLayered; // (layer, n_heads * head_size) 47 | public HalfFloatArray[] wkLayered; // (layer, n_kv_heads, head_size) 48 | public HalfFloatArray[] wvLayered; // (layer, n_kv_heads * head_size) 49 | public HalfFloatArray[] woLayered; // (layer, n_heads * head_size, dim) 50 | public FloatArray[] rms_ffn_weightLayered; // (layer, dim) 51 | public HalfFloatArray[] w1Layered; // (layer, hidden_dim, dim) 52 | public HalfFloatArray[] w2Layered; // (layer, dim, hidden_dim) 53 | // 54 | public HalfFloatArray[] w3Layered; // (layer, hidden_dim, dim) 55 | public FloatArray rms_final_weight_as_floatArray; 56 | public FloatArray tokenEmbeddingTable; // (vocab_size, dim) 57 | public FloatArray freq_cis_realFlat; // (seq_len, head_size/2) 58 | public FloatArray freq_cis_imagFlat; // (seq_len, head_size/2) 59 | // (optional) classifier weights for the logits, on the last layer 60 | public GGMLType weightType; 61 | 62 | /** 63 | * Constructor to initialize all weight tensors for the model. Automatically creates TornadoVM-compatible versions when needed. 64 | * 65 | * @param token_embedding_table 66 | * Token embeddings matrix 67 | * @param rms_att_weight 68 | * RMSNorm weights for attention layers 69 | * @param wq 70 | * Query weight matrices 71 | * @param wk 72 | * Key weight matrices 73 | * @param wv 74 | * Value weight matrices 75 | * @param wo 76 | * Output projection matrices 77 | * @param rms_ffn_weight 78 | * RMSNorm weights for FFN layers 79 | * @param w1 80 | * First FFN weight matrices 81 | * @param w2 82 | * Second FFN weight matrices 83 | * @param w3 84 | * Third FFN weight matrices (gate) 85 | * @param rms_final_weight 86 | * Final layer normalization weights 87 | * @param freq_cis_real 88 | * RoPE cosine components 89 | * @param freq_cis_imag 90 | * RoPE sine components 91 | * @param wcls 92 | * Classifier weights for output logits 93 | * 94 | /** 95 | * Constructor for standard (non-TornadoVM) mode 96 | */ 97 | public Weights(FloatTensor token_embedding_table, FloatBuffer[] rms_att_weight, FloatTensor[] wq, FloatTensor[] wk, FloatTensor[] wv, FloatTensor[] wo, FloatBuffer[] rms_ffn_weight, 98 | FloatTensor[] w1, FloatTensor[] w2, FloatTensor[] w3, FloatBuffer rms_final_weight, FloatBuffer freq_cis_real, FloatBuffer freq_cis_imag, FloatTensor wcls, GGMLType weightType) { 99 | // Standard format 100 | this.token_embedding_table = token_embedding_table; 101 | this.rms_att_weight = rms_att_weight; 102 | this.wq = wq; 103 | this.wk = wk; 104 | this.wv = wv; 105 | this.wo = wo; 106 | this.rms_ffn_weight = rms_ffn_weight; 107 | this.w1 = w1; 108 | this.w2 = w2; 109 | this.w3 = w3; 110 | this.wcls = wcls; 111 | this.rms_final_weight = rms_final_weight; 112 | this.freq_cis_real = freq_cis_real; 113 | this.freq_cis_imag = freq_cis_imag; 114 | this.weightType = weightType; 115 | 116 | // TornadoVM format (null when not using TornadoVM) 117 | this.tokenEmbeddingTable = null; 118 | this.rms_att_weightLayered = null; 119 | this.wqLayered = null; 120 | this.wkLayered = null; 121 | this.wvLayered = null; 122 | this.woLayered = null; 123 | this.rms_ffn_weightLayered = null; 124 | this.w1Layered = null; 125 | this.w2Layered = null; 126 | this.w3Layered = null; 127 | this.rms_final_weight_as_floatArray = null; 128 | this.freq_cis_realFlat = null; 129 | this.freq_cis_imagFlat = null; 130 | this.wclsHalfFloat = null; 131 | } 132 | 133 | /** 134 | * Constructor for TornadoVM mode 135 | */ 136 | public Weights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered, HalfFloatArray[] woLayered, 137 | FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered, HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered, FloatArray rms_final_weight_as_floatArray, 138 | FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray, GGMLType weightType) { 139 | // Standard format (null when using TornadoVM) 140 | this.token_embedding_table = null; 141 | this.rms_att_weight = null; 142 | this.wq = null; 143 | this.wk = null; 144 | this.wv = null; 145 | this.wo = null; 146 | this.rms_ffn_weight = null; 147 | this.w1 = null; 148 | this.w2 = null; 149 | this.w3 = null; 150 | this.wcls = null; 151 | this.rms_final_weight = null; 152 | this.freq_cis_real = null; 153 | this.freq_cis_imag = null; 154 | 155 | // TornadoVM format 156 | this.tokenEmbeddingTable = tokenEmbeddingTable; 157 | this.rms_att_weightLayered = rms_att_weightLayered; 158 | this.wqLayered = wqLayered; 159 | this.wkLayered = wkLayered; 160 | this.wvLayered = wvLayered; 161 | this.woLayered = woLayered; 162 | this.rms_ffn_weightLayered = rms_ffn_weightLayered; 163 | this.w1Layered = w1Layered; 164 | this.w2Layered = w2Layered; 165 | this.w3Layered = w3Layered; 166 | this.rms_final_weight_as_floatArray = rms_final_weight_as_floatArray; 167 | this.freq_cis_realFlat = freq_cis_realFlat; 168 | this.freq_cis_imagFlat = freq_cis_imagFlat; 169 | this.wclsHalfFloat = wclsByteArray; 170 | this.weightType = weightType; 171 | } 172 | 173 | } -------------------------------------------------------------------------------- /src/main/java/com/example/tokenizer/impl/Tokenizer.java: -------------------------------------------------------------------------------- 1 | package com.example.tokenizer.impl; 2 | 3 | import com.example.core.types.Pair; 4 | import com.example.tokenizer.vocabulary.Vocabulary; 5 | 6 | import java.nio.charset.StandardCharsets; 7 | import java.util.ArrayList; 8 | import java.util.Arrays; 9 | import java.util.Comparator; 10 | import java.util.HashMap; 11 | import java.util.HexFormat; 12 | import java.util.List; 13 | import java.util.Map; 14 | import java.util.Set; 15 | import java.util.regex.Matcher; 16 | import java.util.regex.Pattern; 17 | import java.util.stream.Collectors; 18 | import java.util.stream.IntStream; 19 | 20 | public class Tokenizer { 21 | private final Pattern compiledPattern; 22 | private final Vocabulary vocabulary; 23 | private final Map, Integer> merges; 24 | private final Map specialTokens; 25 | 26 | public String regexPattern() { 27 | if (compiledPattern == null) { 28 | return null; 29 | } 30 | return compiledPattern.pattern(); 31 | } 32 | 33 | public Map getSpecialTokens() { 34 | return specialTokens; 35 | } 36 | 37 | public boolean isSpecialToken(int tokenIndex) { 38 | return specialTokens.containsValue(tokenIndex); 39 | } 40 | 41 | public Tokenizer(Vocabulary vocabulary, List> merges, String regexPattern, Map specialTokens) { 42 | this.vocabulary = vocabulary; 43 | this.compiledPattern = regexPattern != null ? Pattern.compile(regexPattern) : null; 44 | this.specialTokens = new HashMap<>(specialTokens); 45 | this.merges = new HashMap<>(); 46 | for (Pair pair : merges) { 47 | int firstIndex = pair.first(); 48 | int secondIndex = pair.second(); 49 | int mergeIndex = vocabulary.getIndex(vocabulary.get(firstIndex) + vocabulary.get(secondIndex)).orElseThrow(); 50 | this.merges.put(pair, mergeIndex); 51 | } 52 | } 53 | 54 | private int[] encodeImpl(String text) { 55 | return encode(text, Set.of()).stream().mapToInt(i -> i).toArray(); 56 | } 57 | 58 | /** 59 | * Unlike {@link #encodeOrdinary(String)}, this function handles special tokens. 60 | * allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens 61 | * if none_raise, then an error is raised if any special token is encountered in text 62 | * this is the default tiktoken behavior right now as well 63 | * any other behavior is either annoying, or a major footgun. 64 | */ 65 | List encode(String text, Set allowedSpecial) { 66 | // decode the user desire w.r.t. handling of special tokens 67 | Set special = allowedSpecial; 68 | assert getSpecialTokens().keySet().containsAll(special); 69 | if (special.isEmpty()) { 70 | // shortcut: if no special tokens, just use the ordinary encoding 71 | return encodeOrdinary(text); 72 | } 73 | 74 | // otherwise, we have to be careful with potential special tokens in text 75 | // we handle special tokens by splitting the text 76 | // based on the occurrence of any exact match with any of the special tokens 77 | // we can use re.split for this. note that surrounding the pattern with () 78 | // makes it into a capturing group, so the special tokens will be included 79 | String specialPattern = special 80 | .stream() 81 | .map(Pattern::quote) 82 | .collect(Collectors.joining("|", "(", ")")); 83 | 84 | String[] specialChunks = text.split(specialPattern); 85 | // now all the special characters are separated from the rest of the text 86 | // all chunks of text are encoded separately, then results are joined 87 | List ids = new ArrayList<>(); 88 | for (String part : specialChunks) { 89 | if (special.contains(part)) { 90 | // this is a special token, encode it separately as a special case 91 | ids.add(getSpecialTokens().get(part)); 92 | } else { 93 | // this is an ordinary sequence, encode it normally 94 | ids.addAll(encodeOrdinary(part)); 95 | } 96 | } 97 | return ids; 98 | } 99 | 100 | private static List findAll(Pattern pattern, String text) { 101 | List allMatches = new ArrayList<>(); 102 | Matcher matcher = pattern.matcher(text); 103 | while (matcher.find()) { 104 | allMatches.add(matcher.group()); 105 | } 106 | return allMatches; 107 | } 108 | 109 | /** 110 | * Encoding that ignores any special tokens. 111 | */ 112 | public List encodeOrdinary(String text) { 113 | // split text into chunks of text by categories defined in regex pattern 114 | List textChunks = findAll(compiledPattern, text); 115 | // all chunks of text are encoded separately, then results are joined 116 | List ids = new ArrayList<>(); 117 | for (String chunk : textChunks) { 118 | List chunkIds = encodeChunk(chunk); 119 | ids.addAll(chunkIds); 120 | } 121 | return ids; 122 | } 123 | 124 | private Map, Integer> getStats(List ids) { 125 | Map, Integer> map = new HashMap<>(); 126 | for (int i = 0; i + 1 < ids.size(); i++) { 127 | Pair key = new Pair<>(ids.get(i), ids.get(i + 1)); 128 | map.put(key, map.getOrDefault(key, 0) + 1); 129 | } 130 | return map; 131 | } 132 | 133 | private List encodeChunk(String chunk) { 134 | // return the token ids 135 | // let's begin. first, convert all bytes to integers in range 0..255 136 | List ids = new ArrayList<>(); 137 | for (int b : chunk.toCharArray()) { 138 | int tokenIndex = this.vocabulary.getIndex(String.valueOf((char) b)).orElseThrow(); 139 | ids.add(tokenIndex); 140 | } 141 | 142 | while (ids.size() >= 2) { 143 | // find the pair with the lowest merge index 144 | Map, Integer> stats = getStats(ids); 145 | Pair pair = stats.keySet().stream().min(Comparator.comparingInt(key -> this.merges.getOrDefault(key, Integer.MAX_VALUE))).orElseThrow(); 146 | // subtle: if there are no more merges available, the key will 147 | // result in an inf for every single pair, and the min will be 148 | // just the first pair in the list, arbitrarily 149 | // we can detect this terminating case by a membership check 150 | if (!this.merges.containsKey(pair)) { 151 | break; // nothing else can be merged anymore 152 | } 153 | // otherwise let's merge the best pair (lowest merge index) 154 | int idx = this.merges.get(pair); 155 | ids = merge(ids, pair, idx); 156 | } 157 | return ids; 158 | } 159 | 160 | private static List merge(List ids, Pair pair, int idx) { 161 | List newids = new ArrayList<>(); 162 | int i = 0; 163 | while (i < ids.size()) { 164 | // if not at the very last position AND the pair matches, replace it 165 | if (ids.get(i).equals(pair.first()) && i < ids.size() - 1 && ids.get(i + 1).equals(pair.second())) { 166 | newids.add(idx); 167 | i += 2; 168 | } else { 169 | newids.add(ids.get(i)); 170 | i += 1; 171 | } 172 | } 173 | return newids; 174 | } 175 | 176 | public String decodeImpl(List tokens) { 177 | StringBuilder sb = new StringBuilder(); 178 | for (int token : tokens) { 179 | String tokenString = vocabulary.get(token); 180 | sb.append(tokenString); 181 | } 182 | return sb.toString(); 183 | } 184 | 185 | /** 186 | * Returns list of utf-8 byte and a corresponding list of unicode strings. 187 | * The reversible bpe codes work on unicode strings. 188 | * This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 189 | * When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 190 | * This is a significant percentage of your normal, say, 32K bpe vocab. 191 | * To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 192 | * And avoids mapping to whitespace/control characters the bpe code barfs on. 193 | */ 194 | private static Map bytesToUnicode() { 195 | List bs = new ArrayList<>(); 196 | IntStream.rangeClosed('!', '~').forEach(bs::add); 197 | IntStream.rangeClosed('¡', '¬').forEach(bs::add); 198 | IntStream.rangeClosed('®', 'ÿ').forEach(bs::add); 199 | 200 | List cs = new ArrayList<>(bs); 201 | int n = 0; 202 | for (int b = 0; b < 256; ++b) { 203 | if (!bs.contains(b)) { 204 | bs.add(b); 205 | cs.add(256 + n); 206 | n += 1; 207 | } 208 | } 209 | 210 | // return dict(zip(bs, cs)) 211 | return IntStream.range(0, bs.size()) 212 | .boxed() 213 | .collect(Collectors.toMap(bs::get, cs::get)); 214 | } 215 | 216 | static final Map BYTE_ENCODER = bytesToUnicode(); 217 | static final Map BYTE_DECODER = BYTE_ENCODER.entrySet() 218 | .stream() 219 | .collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); 220 | 221 | public int[] encode(String text) { 222 | StringBuilder sb = new StringBuilder(); 223 | byte[] bytes = text.getBytes(StandardCharsets.UTF_8); 224 | for (byte b : bytes) { 225 | sb.appendCodePoint(BYTE_ENCODER.get(Byte.toUnsignedInt(b))); 226 | } 227 | return encodeImpl(sb.toString()); 228 | } 229 | 230 | public static String replaceControlCharacters(int[] codePoints) { 231 | // we don't want to print control characters 232 | // which distort the output (e.g. \n or much worse) 233 | // https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python/19016117#19016117 234 | // http://www.unicode.org/reports/tr44/#GC_Values_Table\ 235 | StringBuilder chars = new StringBuilder(); 236 | for (int cp : codePoints) { 237 | if (Character.getType(cp) == Character.CONTROL && cp != '\n') { 238 | chars.append("\\u").append(HexFormat.of().toHexDigits(cp, 4)); // escape 239 | } else { 240 | chars.appendCodePoint(cp); // this character is ok 241 | } 242 | } 243 | return chars.toString(); 244 | } 245 | 246 | public static String replaceControlCharacters(String str) { 247 | return replaceControlCharacters(str.codePoints().toArray()); 248 | } 249 | 250 | public List encodeAsList(String text) { 251 | return Arrays.stream(encode(text)).boxed().toList(); 252 | } 253 | 254 | public String decode(List tokens) { 255 | String decoded = decodeImpl(tokens); 256 | int[] decodedBytesAsInts = decoded.codePoints().map(BYTE_DECODER::get).toArray(); 257 | byte[] rawBytes = new byte[decodedBytesAsInts.length]; 258 | for (int i = 0; i < decoded.length(); i++) { 259 | rawBytes[i] = (byte) decodedBytesAsInts[i]; 260 | } 261 | return new String(rawBytes, StandardCharsets.UTF_8); 262 | } 263 | } 264 | 265 | -------------------------------------------------------------------------------- /src/main/java/com/example/tokenizer/vocabulary/Vocabulary.java: -------------------------------------------------------------------------------- 1 | package com.example.tokenizer.vocabulary; 2 | 3 | import java.util.Map; 4 | import java.util.OptionalInt; 5 | import java.util.stream.Collectors; 6 | import java.util.stream.IntStream; 7 | 8 | 9 | public record Vocabulary(String[] tokens, float[] scores, Map tokenToIndex) { 10 | private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2"; 11 | 12 | public Vocabulary(String[] vocabulary, float[] scores) { 13 | this(vocabulary, scores, 14 | IntStream.range(0, vocabulary.length) 15 | .boxed() 16 | .collect(Collectors.toMap(i -> vocabulary[i], i -> i)) 17 | ); 18 | } 19 | 20 | public String get(int tokenIndex) { 21 | return tokens[tokenIndex]; 22 | } 23 | 24 | public OptionalInt getIndex(String token) { 25 | Integer value = tokenToIndex.get(token); 26 | return value != null ? OptionalInt.of(value) : OptionalInt.empty(); 27 | } 28 | 29 | public static Vocabulary loadVocabulary(Map metadata) { 30 | String model = (String) metadata.get("tokenizer.ggml.model"); 31 | if (!TOKENIZER_LLAMA_3_MODEL.equals(model)) { 32 | throw new IllegalArgumentException("expected " + TOKENIZER_LLAMA_3_MODEL + " but found " + model); 33 | } 34 | String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens"); 35 | return new Vocabulary(tokens, null); 36 | } 37 | 38 | public int size() { 39 | return tokens.length; 40 | } 41 | } -------------------------------------------------------------------------------- /src/main/java/com/example/tornadovm/FloatArrayUtils.java: -------------------------------------------------------------------------------- 1 | package com.example.tornadovm; 2 | 3 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; 4 | import uk.ac.manchester.tornado.api.math.TornadoMath; 5 | 6 | /** 7 | * Helper class for FloatArray operations that mirror the functionality of FloatTensor methods. 8 | * Provides utility methods for common tensor operations on TornadoVM's FloatArray type. 9 | */ 10 | public final class FloatArrayUtils { 11 | 12 | private FloatArrayUtils() { 13 | // Utility class, not meant to be instantiated 14 | } 15 | 16 | /** 17 | * Divides all elements in the specified range of a FloatArray by a value in-place. 18 | * Mirrors the functionality of FloatTensor.divideInPlace(). 19 | * 20 | * @param array The FloatArray to modify 21 | * @param start The starting index (inclusive) 22 | * @param end The ending index (exclusive) 23 | * @param value The value to divide by 24 | * @return The modified FloatArray for method chaining 25 | */ 26 | public static FloatArray divideInPlace(FloatArray array, int start, int end, float value) { 27 | for (int i = start; i < end; i++) { 28 | array.set(i, array.get(i) / value); 29 | } 30 | return array; 31 | } 32 | 33 | /** 34 | * Divides all elements in a FloatArray by a value in-place. 35 | * 36 | * @param array The FloatArray to modify 37 | * @param value The value to divide by 38 | * @return The modified FloatArray for method chaining 39 | */ 40 | public static FloatArray divideInPlace(FloatArray array, float value) { 41 | return divideInPlace(array, 0, array.getSize(), value); 42 | } 43 | 44 | /** 45 | * Applies the softmax function to a range of elements in a FloatArray in-place. 46 | * Mirrors the functionality of FloatTensor.softmaxInPlace(). 47 | * 48 | * @param array The FloatArray to modify 49 | * @param start The starting index (inclusive) 50 | * @param end The ending index (exclusive) 51 | * @return The modified FloatArray for method chaining 52 | */ 53 | public static FloatArray softmaxInPlace(FloatArray array, int start, int end) { 54 | // Find max value for numerical stability 55 | float maxVal = Float.NEGATIVE_INFINITY; 56 | for (int i = start; i < end; i++) { 57 | float val = array.get(i); 58 | if (val > maxVal) { 59 | maxVal = val; 60 | } 61 | } 62 | 63 | // Apply exp(x-max) to each element and calculate sum 64 | float sum = 0.0f; 65 | for (int i = start; i < end; i++) { 66 | float exp; 67 | if (TornadoVMSupport.isTornadoVMEnabled()) { 68 | // Use TornadoMath for GPU execution if possible 69 | exp = TornadoMath.exp(array.get(i) - maxVal); 70 | } else { 71 | // Fallback to standard Math 72 | exp = (float) Math.exp(array.get(i) - maxVal); 73 | } 74 | array.set(i, exp); 75 | sum += exp; 76 | } 77 | 78 | // Normalize by sum 79 | if (sum == 0.0f) { 80 | // Handle edge case, divide evenly 81 | float value = 1.0f / (end - start); 82 | for (int i = start; i < end; i++) { 83 | array.set(i, value); 84 | } 85 | } else { 86 | // Normal case, divide by sum 87 | for (int i = start; i < end; i++) { 88 | array.set(i, array.get(i) / sum); 89 | } 90 | } 91 | 92 | return array; 93 | } 94 | 95 | /** 96 | * Applies the softmax function to all elements in a FloatArray in-place. 97 | * 98 | * @param array The FloatArray to modify 99 | * @return The modified FloatArray for method chaining 100 | */ 101 | public static FloatArray softmaxInPlace(FloatArray array) { 102 | return softmaxInPlace(array, 0, array.getSize()); 103 | } 104 | 105 | /** 106 | * Finds the index of the maximum value in a FloatArray. 107 | * Mirrors the functionality of FloatTensor.argmax(). 108 | * 109 | * @param array The FloatArray to search 110 | * @param start The starting index (inclusive) 111 | * @param end The ending index (exclusive) 112 | * @return The index of the maximum value 113 | */ 114 | public static int argmax(FloatArray array, int start, int end) { 115 | float maxValue = Float.NEGATIVE_INFINITY; 116 | int maxIndex = start; 117 | 118 | for (int i = start; i < end; i++) { 119 | float value = array.get(i); 120 | if (value > maxValue) { 121 | maxValue = value; 122 | maxIndex = i; 123 | } 124 | } 125 | 126 | return maxIndex; 127 | } 128 | 129 | /** 130 | * Finds the index of the maximum value in a FloatArray. 131 | * 132 | * @param array The FloatArray to search 133 | * @return The index of the maximum value 134 | */ 135 | public static int argmax(FloatArray array) { 136 | return argmax(array, 0, array.getSize()); 137 | } 138 | 139 | /** 140 | * Helper class to check if TornadoVM is enabled. 141 | * This allows us to decide whether to use TornadoMath or standard Math. 142 | */ 143 | private static class TornadoVMSupport { 144 | private static final boolean TORNADO_VM_ENABLED; 145 | 146 | static { 147 | boolean enabled; 148 | try { 149 | // Try to access a TornadoVM-specific class 150 | Class.forName("uk.ac.manchester.tornado.api.math.TornadoMath"); 151 | // Check for system property 152 | enabled = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); 153 | } catch (ClassNotFoundException e) { 154 | enabled = false; 155 | } 156 | TORNADO_VM_ENABLED = enabled; 157 | } 158 | 159 | static boolean isTornadoVMEnabled() { 160 | return TORNADO_VM_ENABLED; 161 | } 162 | } 163 | } -------------------------------------------------------------------------------- /src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java: -------------------------------------------------------------------------------- 1 | package com.example.tornadovm; 2 | 3 | import com.example.auxiliary.Tuple2; 4 | import com.example.inference.engine.impl.Configuration; 5 | import com.example.inference.engine.impl.Llama; 6 | import com.example.loader.weights.State; 7 | import uk.ac.manchester.tornado.api.GridScheduler; 8 | import uk.ac.manchester.tornado.api.ImmutableTaskGraph; 9 | import uk.ac.manchester.tornado.api.TornadoExecutionPlan; 10 | import uk.ac.manchester.tornado.api.TornadoRuntime; 11 | import uk.ac.manchester.tornado.api.runtime.TornadoRuntimeProvider; 12 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; 13 | 14 | import java.util.List; 15 | 16 | public class TornadoVMMasterPlan { 17 | private static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False")); 18 | 19 | private final State state; 20 | private final Configuration config; 21 | public GridScheduler scheduler; 22 | public TornadoExecutionPlan executionPlan; 23 | List taskGraphs; 24 | 25 | public TornadoVMMasterPlan(State state, Llama model, boolean isNvidia) { 26 | TornadoVMLayerPlanner tornadoVMLayerPlanner = new TornadoVMLayerPlanner(state, model); 27 | Tuple2, GridScheduler> tornadoVMPlan = isNvidia ? tornadoVMLayerPlanner.setupTornadoForwardPlanLayered() : tornadoVMLayerPlanner.setupTornadoForwardPlanLayeredNonNvidia(); 28 | this.taskGraphs = tornadoVMPlan.getFirst(); 29 | this.scheduler = tornadoVMPlan.getSecond(); 30 | this.state = state; 31 | this.config = model.configuration(); 32 | this.executionPlan = new TornadoExecutionPlan(taskGraphs.toArray(new ImmutableTaskGraph[taskGraphs.size()])); 33 | } 34 | 35 | /** 36 | * Initializes the TornadoVM plan for GPU acceleration with optional timing. 37 | * This method handles: 38 | * 1. Creation of the TornadoVM master plan 39 | * 2. Warming up the JIT compiler for better performance 40 | * 3. Copying read-only model weights to the GPU 41 | * 42 | * @param state The model state containing KV cache 43 | * @param model The Llama model instance 44 | * @return The initialized TornadoVMMasterPlan ready for inference 45 | */ 46 | public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Llama model) { 47 | // Initialize timing variables outside conditional blocks to avoid scope issues 48 | long startTime = System.nanoTime(); 49 | long planCreationTime = 0; 50 | long warmupTime = 0; 51 | 52 | // Start a timing message if enabled 53 | if (ENABLE_TORNADOVM_INIT_TIME) { 54 | System.err.println("\nStarting TornadoVM initialization..."); 55 | } 56 | 57 | // 1. Pre-allocate the TornadoVM plan 58 | TornadoRuntime coreRuntime = TornadoRuntimeProvider.getTornadoRuntime(); 59 | boolean isNvidia = coreRuntime.getBackend(0).getDefaultDevice().getPlatformName().toLowerCase().contains("nvidia"); 60 | TornadoVMMasterPlan tornadoVMPlan = new TornadoVMMasterPlan(state, model, isNvidia); 61 | 62 | // Record time after plan creation 63 | if (ENABLE_TORNADOVM_INIT_TIME) { 64 | planCreationTime = System.nanoTime(); 65 | System.err.printf("TornadoVM GPU execution plan creation: %.2f ms\n", (planCreationTime - startTime) / 1_000_000.0); 66 | } 67 | 68 | // 2. Perform warmup with extra iterations to ensure JIT compilation is complete 69 | tornadoVMPlan.executionPlan.withPreCompilation(); // Force JIT compilation from Java to GPU code 70 | 71 | // Record time after warmup 72 | if (ENABLE_TORNADOVM_INIT_TIME) { 73 | warmupTime = System.nanoTime(); 74 | System.err.printf("Java to GPU JIT compiler warmup: %.2f ms\n", (warmupTime - planCreationTime) / 1_000_000.0); 75 | } 76 | 77 | // 3. Perform copy-in of read-only weights and objects 78 | tornadoVMPlan.forceCopyInReadOnlyDataLayered(); // Force copy-in read-only weights 79 | 80 | // Record final timing information 81 | if (ENABLE_TORNADOVM_INIT_TIME) { 82 | long copyTime = System.nanoTime(); 83 | System.err.printf("Transfer read-only weights to GPU: %.2f ms\n", (copyTime - warmupTime) / 1_000_000.0); 84 | System.err.printf("Finished TornadoVM initialization...\n \n"); 85 | } 86 | 87 | return tornadoVMPlan; 88 | } 89 | 90 | /** 91 | * Executes the forward pass of a LLaMA transformer model using TornadoVM acceleration. 92 | *This method processes the transformer layers in sequence for a particular token position in the context 93 | * window. 94 | * 95 | *

The execution happens in three phases: 96 | *

    97 | *
  1. Initial token embedding lookup (already done before calling this method)
  2. 98 | *
  3. Sequential processing through each transformer layer using TornadoVM
  4. 99 | *
  5. Final projection to logits using TornadoVM
  6. 100 | *
101 | * 102 | * 103 | * @param position 104 | * The current position in the sequence being processed 105 | * @return FloatTensor containing the output logits for token prediction 106 | */ 107 | 108 | public FloatArray tornadoVMForwardExecuteLayered(int position) { 109 | // @formatter:off 110 | // 1. Execute the preprocessing graph (e.g., input preparation, memory initialization) 111 | executionPlan.withGraph(getPreprocessingGraphIndex()) 112 | .withGridScheduler(scheduler) 113 | .execute(); 114 | 115 | // Set the position in the state object (used by attention layers) 116 | state.positionHolder.set(0, position); 117 | 118 | // 2. Execute each transformer layer graph sequentially 119 | // Each graph computes attention and feed-forward transformations for one layer 120 | for (int layer = 0; layer < config.numberOfLayers; layer++) { 121 | executionPlan.withGraph(getLayerGraphIndex(layer)) 122 | .withGridScheduler(scheduler) 123 | .execute(); 124 | } 125 | 126 | // 3. Execute the final graph that projects the last hidden state to output logits 127 | executionPlan.withGraph(getFinalLogitsGraphIndex()) 128 | .withGridScheduler(scheduler) 129 | .execute(); 130 | 131 | // @formatter:on 132 | // Return the logits (used for token prediction) 133 | return state.wrapLogits; 134 | } 135 | 136 | /** 137 | * Returns the graph index for the pre-processing step (e.g., token embedding). 138 | */ 139 | private int getPreprocessingGraphIndex() { 140 | return 0; 141 | } 142 | 143 | /** 144 | * Returns the graph index for the given transformer layer. 145 | * @param layerIndex Index of the transformer layer (0-based) 146 | */ 147 | private int getLayerGraphIndex(int layerIndex) { 148 | return 1 + layerIndex; 149 | } 150 | 151 | /** 152 | * Returns the graph index for the final projection to logits. 153 | */ 154 | private int getFinalLogitsGraphIndex() { 155 | return taskGraphs.size() - 1; 156 | } 157 | 158 | /// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration 159 | /// just once to copy the data into the read-only data layer. 160 | public void forceCopyInReadOnlyDataLayered() { 161 | // Execute all TornadoVM graphs 162 | state.wrapX.init(0.0f); 163 | state.positionHolder.init(0); 164 | 165 | // Execute activation update graph 166 | executionPlan.withGraph(0).withGridScheduler(scheduler).execute(); 167 | 168 | // Execute layer processing graphs 169 | for (int layer = 0; layer < config.numberOfLayers; layer++) { 170 | executionPlan.withGraph(layer + 1).withGridScheduler(scheduler).execute(); 171 | } 172 | 173 | // Execute logits graph 174 | executionPlan.withGraph(config.numberOfLayers + 1).withGridScheduler(scheduler).execute(); 175 | } 176 | 177 | /** 178 | * Frees the device memory allocated for the TornadoVM execution plan. 179 | * This method should be called when the execution plan is no longer needed 180 | * to release resources and avoid memory leaks. 181 | */ 182 | public void freeTornadoExecutionPlan() { 183 | executionPlan.freeDeviceMemory(); 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /src/main/java/com/example/tornadovm/TransformerComputeKernels.java: -------------------------------------------------------------------------------- 1 | package com.example.tornadovm; 2 | 3 | import uk.ac.manchester.tornado.api.KernelContext; 4 | import uk.ac.manchester.tornado.api.math.TornadoMath; 5 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray; 6 | 7 | public class TransformerComputeKernels { 8 | 9 | /** 10 | * Default constructor for the TransformerComputeKernels class. 11 | */ 12 | public TransformerComputeKernels() { 13 | } 14 | 15 | public static void emptyTaskToForceCopyIn(FloatArray buffer) { 16 | float dummy = buffer.get(0); 17 | if (dummy > Float.MAX_VALUE) { 18 | buffer.set(0, dummy); 19 | } 20 | } 21 | 22 | /** 23 | * Performs RMS (Root Mean Square) normalization using parallel reduction. 24 | * This is a two-phase reduction: first within work groups, then across work groups. 25 | * 26 | * Phase 1: Each work group computes a partial sum of squares 27 | * Phase 2: First thread combines all partial sums and computes normalization factor 28 | * 29 | * @param context Kernel execution context 30 | * @param output Array to store partial sums and final normalization factor 31 | * @param x Input array to normalize 32 | * @param size Number of elements to process 33 | * @param ermsNorm Epsilon value for numerical stability (epsilon * epsilon) 34 | * @param localMemSize Size of local memory allocation (work group size) 35 | */ 36 | public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) { 37 | int gid = context.globalIdx; 38 | int lid = context.localIdx; 39 | int groupId = context.groupIdx; 40 | int groupSize = context.localGroupSizeX; 41 | 42 | // Allocate local memory with the provided size 43 | float[] localX = context.allocateFloatLocalArray(localMemSize); 44 | 45 | // Load input value and compute square 46 | if (gid < size) { 47 | localX[lid] = x.get(gid); 48 | localX[lid] = localX[lid] * localX[lid]; 49 | } else { 50 | localX[lid] = 0.0f; 51 | } 52 | 53 | // Perform parallel reduction within the work group 54 | for (int stride = (groupSize / 2); stride > 0; stride /= 2) { 55 | context.localBarrier(); 56 | if (lid < stride) { 57 | localX[lid] += localX[lid + stride]; 58 | } 59 | } 60 | 61 | // Each workgroup stores its partial sum in a different location 62 | if (lid == 0) { 63 | // Store the partial sum from each workgroup 64 | output.set(groupId + 1, localX[0]); 65 | } 66 | 67 | // Only the first thread in the first workgroup computes the final normalization factor 68 | if (gid == 0) { 69 | // Combine partial sums from all workgroups 70 | float ss = 0.0f; 71 | for (int i = 1; i <= (size / localMemSize); i++) { // Assuming 8 workgroups 72 | ss += output.get(i); 73 | } 74 | 75 | ss /= size; 76 | ss += ermsNorm; 77 | ss = 1.0f / TornadoMath.sqrt(ss); 78 | output.set(0, ss); // Store the final scale factor 79 | } 80 | } 81 | 82 | /** 83 | * Applies the computed normalization factor to scale weights. 84 | * This is the second phase of RMS normalization. 85 | * 86 | * @param context Kernel execution context 87 | * @param output Array for normalized output 88 | * @param weights Weight values to normalize 89 | * @param temp Temporary array containing a normalization factor at index 0 90 | */ 91 | public static void reductionOneBlock2WithLogits(KernelContext context, FloatArray output, FloatArray weights, FloatArray temp) { 92 | int gid = context.globalIdx; 93 | float ss = temp.get(0); 94 | output.set(gid, weights.get(gid) * (ss * output.get(gid))); 95 | } 96 | 97 | } 98 | --------------------------------------------------------------------------------