├── .github ├── FUNDING.yml ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── feature_request.yaml ├── pull_request_template.md └── workflows │ ├── build.yml │ └── warmup.yml ├── .gitignore ├── BUILDING.md ├── LICENSE ├── README.md ├── api ├── OAI │ ├── router.ts │ ├── types │ │ ├── chatCompletions.ts │ │ └── completions.ts │ └── utils │ │ ├── chatCompletion.ts │ │ ├── completion.ts │ │ └── generation.ts ├── core │ ├── router.ts │ └── types │ │ ├── auth.ts │ │ ├── health.ts │ │ ├── model.ts │ │ ├── template.ts │ │ └── token.ts ├── middleware │ ├── authMiddleware.ts │ ├── checkModelMiddleware.ts │ └── requestLogMiddleware.ts └── server.ts ├── assets ├── icon.ico └── icon.png ├── bindings ├── CMakeLists.txt ├── bindings.ps1 ├── bindings.sh ├── bindings.ts ├── generationResources.ts ├── grammar.ts ├── job.ts ├── lib.ts ├── minimal_cpp_test.cpp ├── readbackBuffer.ts ├── samplers.ts ├── server │ ├── c_library.cpp │ ├── c_library.h │ ├── generation_resources.hpp │ ├── inference_args.hpp │ ├── json_status.hpp │ ├── presampler.hpp │ ├── processor.hpp │ ├── readback_buffer.hpp │ ├── request.hpp │ ├── rule_stream.hpp │ ├── samplers.hpp │ ├── sequence_stream.hpp │ ├── server_basic_example.cpp │ ├── slot.hpp │ ├── tokenization.hpp │ └── trie.hpp ├── symbols.ts ├── types.ts └── utils.ts ├── common ├── args.ts ├── auth.ts ├── config.ts ├── configModels.ts ├── errors.ts ├── logging.ts ├── modelContainer.ts ├── myZod.ts ├── networking.ts ├── samplerOverrides.ts ├── sampling.ts ├── templating.ts └── utils.ts ├── config_sample.yml ├── deno.json ├── deno.lock ├── generateGitSha.ts ├── lib └── place_libs_here.txt ├── main.ts ├── minimal_test_setup.ts ├── models └── place_your_models_here.txt ├── sampler_overrides └── sample_preset.yml ├── templates ├── alpaca.jinja ├── chatml.jinja └── place_your_templates_here.txt └── types ├── jinja.d.ts └── utils.ts /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | ko_fi: kingbri 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: Bug report 2 | description: Report code related issues 3 | title: "[BUG]" 4 | labels: bug 5 | body: 6 | 7 | - type: markdown 8 | attributes: 9 | value: | 10 | ### Disclaimer: 11 | Github Issues are **only** for code related bugs. 12 | If you do not understand how to startup or use TabbyAPI, please ask in the [Discord Server](https://discord.gg/sYQxnuD7Fj) 13 | 14 | - type: dropdown 15 | attributes: 16 | label: OS 17 | options: 18 | - Windows 19 | - macOS 20 | - Linux 21 | validations: 22 | required: true 23 | 24 | - type: dropdown 25 | attributes: 26 | label: GPU Library 27 | description: Ex. CUDA, ROCm 28 | options: 29 | - CUDA 30 | - AMD ROCm 31 | - Metal 32 | - CPU 33 | validations: 34 | required: true 35 | 36 | - type: input 37 | attributes: 38 | label: YALS commit sha 39 | description: Enter the commit SHA you're using (found on startup) 40 | placeholder: "ex. a1b4da3" 41 | validations: 42 | required: true 43 | 44 | - type: textarea 45 | attributes: 46 | label: Describe the bug 47 | description: A clear and concise description of what the bug is. 48 | validations: 49 | required: true 50 | 51 | - type: textarea 52 | attributes: 53 | label: Reproduction steps 54 | description: Walk us through how the bug occurred and how to make it happen. 55 | validations: 56 | required: true 57 | 58 | - type: textarea 59 | attributes: 60 | label: Expected behavior 61 | description: What was expected to happen? 62 | validations: 63 | required: true 64 | 65 | - type: textarea 66 | attributes: 67 | label: Logs 68 | description: If applicable, add logs and call stacks to help explain your problem. 69 | validations: 70 | required: false 71 | 72 | - type: textarea 73 | attributes: 74 | label: Additional context 75 | description: Add any other context about the problem here. 76 | validations: 77 | required: false 78 | 79 | - type: checkboxes 80 | attributes: 81 | label: Acknowledgements 82 | description: Before submitting this issue, please make sure you have completed the following checklist. 83 | options: 84 | - label: I have looked for similar issues before submitting this one. 85 | required: true 86 | - label: I have read the disclaimer, and this issue is related to a code bug. If I have a question, I will use the Discord server. 87 | required: true 88 | - label: I understand that the developers have lives and my issue will be answered when possible. 89 | required: true 90 | - label: I understand the developers of this program are human, and I will ask my questions politely. 91 | required: true 92 | 93 | - type: markdown 94 | attributes: 95 | value: | 96 | ## Thanks! 97 | Well-formatted issues improve YALS and make the development process smoother. 98 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yaml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest a new idea 3 | title: "[REQUEST]" 4 | body: 5 | 6 | - type: textarea 7 | attributes: 8 | label: Problem 9 | description: Is the feature request related to a problem? If so, please describe. 10 | placeholder: A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 11 | validations: 12 | required: false 13 | 14 | - type: textarea 15 | attributes: 16 | label: Solution 17 | description: Describe the solution you'd like. 18 | placeholder: A clear and concise description of what you want to happen. 19 | validations: 20 | required: true 21 | 22 | - type: textarea 23 | attributes: 24 | label: Alternatives 25 | description: What alternative options did you consider? 26 | validations: 27 | required: false 28 | 29 | - type: textarea 30 | attributes: 31 | label: Explanation 32 | description: Why should this feature be added? 33 | validations: 34 | required: true 35 | 36 | - type: textarea 37 | attributes: 38 | label: Examples 39 | description: | 40 | Examples of the feature in action and its significance. 41 | 42 | Not required, but will make your request easier to understand. 43 | validations: 44 | required: false 45 | 46 | - type: textarea 47 | attributes: 48 | label: Additional context 49 | description: Anything else to add? 50 | validations: 51 | required: false 52 | 53 | - type: checkboxes 54 | attributes: 55 | label: Acknowledgements 56 | description: Before submitting this issue, please make sure you have completed the following checklist. 57 | options: 58 | - label: I have looked for similar requests before submitting this one. 59 | required: true 60 | - label: I understand that the developers have lives and my issue will be answered when possible. 61 | required: true 62 | - label: I understand the developers of this program are human, and I will make my requests politely. 63 | required: true 64 | 65 | - type: markdown 66 | attributes: 67 | value: | 68 | ## Thanks! 69 | Well-formatted issues improve YALS and make the development process smoother. 70 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | **Warning** 2 | Make all Pull Requests to the dev branch. Main is reserved for stability and building actions. 3 | 4 | **Is your pull request related to a problem? Please describe.** 5 | A clear and concise description of what the problem is. You can also link to an existing issue. 6 | 7 | **Why should this feature be added?** 8 | An explanation of why the feature should be added. Please be as specific as possible to help us understand the reasoning. 9 | 10 | **Examples** 11 | Examples of the feature in action and its significance compared to not having the feature. 12 | 13 | **Additional context** 14 | Add any other context or screenshots about the pull request here. 15 | -------------------------------------------------------------------------------- /.github/workflows/warmup.yml: -------------------------------------------------------------------------------- 1 | name: Cache Management 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 */2 * *' 6 | workflow_dispatch: 7 | inputs: 8 | force_warmup: 9 | description: 'Force cache warmup regardless of age' 10 | required: false 11 | default: false 12 | type: boolean 13 | 14 | jobs: 15 | cleanup-cache: 16 | runs-on: ubuntu-latest 17 | permissions: 18 | actions: write 19 | steps: 20 | - name: Cleanup old caches 21 | uses: actions/github-script@v7 22 | with: 23 | script: | 24 | const retentionCount = 2; // Keep the 2 most recent caches per prefix 25 | const forceWarmupDays = 5; // Force warmup if newest cache is older than this 26 | const caches = await github.rest.actions.getActionsCacheList({ 27 | owner: context.repo.owner, 28 | repo: context.repo.repo 29 | }); 30 | 31 | // Track if any cache needs warming up 32 | let needsWarmup = false; 33 | 34 | // Group caches by their prefix pattern (e.g., sccache-Windows-cuda-main) 35 | const cacheGroups = {}; 36 | for (const cache of caches.data.actions_caches) { 37 | if (cache.key.startsWith('sccache-')) { 38 | // Extract the prefix pattern (everything before the last hyphen + timestamp) 39 | const prefixPattern = cache.key.replace(/-\d+$/, ''); 40 | 41 | if (!cacheGroups[prefixPattern]) { 42 | cacheGroups[prefixPattern] = []; 43 | } 44 | cacheGroups[prefixPattern].push(cache); 45 | } 46 | } 47 | 48 | const now = new Date(); 49 | // Process each group of caches 50 | for (const prefix in cacheGroups) { 51 | // Sort caches by creation date (newest first) 52 | const sortedCaches = cacheGroups[prefix].sort((a, b) => 53 | new Date(b.created_at) - new Date(a.created_at)); 54 | 55 | // Check if most recent cache is older than forceWarmupDays 56 | if (sortedCaches.length > 0) { 57 | const newestCache = sortedCaches[0]; 58 | const createdAt = new Date(newestCache.created_at); 59 | const ageInDays = (now - createdAt) / (1000 * 60 * 60 * 24); 60 | 61 | if (ageInDays > forceWarmupDays) { 62 | console.log(`Cache ${prefix} is stale (${ageInDays.toFixed(1)} days old). Will force warmup.`); 63 | needsWarmup = true; 64 | } 65 | 66 | // Log the kept most recent cache 67 | console.log(`Keeping most recent cache: ${newestCache.key}, created ${ageInDays.toFixed(1)} days ago`); 68 | 69 | // Keep second most recent cache if it exists 70 | if (sortedCaches.length > 1) { 71 | const secondCache = sortedCaches[1]; 72 | const secondCreatedAt = new Date(secondCache.created_at); 73 | const secondAgeInDays = (now - secondCreatedAt) / (1000 * 60 * 60 * 24); 74 | console.log(`Keeping second most recent cache: ${secondCache.key}, created ${secondAgeInDays.toFixed(1)} days ago`); 75 | } 76 | } 77 | 78 | // Delete all caches beyond the retention count 79 | for (let i = retentionCount; i < sortedCaches.length; i++) { 80 | const cache = sortedCaches[i]; 81 | const createdAt = new Date(cache.created_at); 82 | const ageInDays = (now - createdAt) / (1000 * 60 * 60 * 24); 83 | 84 | console.log(`Deleting old cache: ${cache.key}, created ${ageInDays.toFixed(1)} days ago`); 85 | await github.rest.actions.deleteActionsCacheByKey({ 86 | owner: context.repo.owner, 87 | repo: context.repo.repo, 88 | key: cache.key 89 | }); 90 | } 91 | } 92 | 93 | // Set output to control whether to run warmup jobs 94 | core.setOutput('needs_warmup', needsWarmup.toString()); 95 | 96 | warmup-unix: 97 | needs: cleanup-cache 98 | if: ${{ needs.cleanup-cache.outputs.needs_warmup == 'true' || (github.event_name == 'workflow_dispatch' && github.event.inputs.force_warmup == 'true') }} 99 | runs-on: ${{ matrix.os }} 100 | strategy: 101 | matrix: 102 | os: [ubuntu-22.04, macos-15] 103 | device: [cpu, metal, cuda] 104 | exclude: 105 | - os: macos-15 106 | device: cpu 107 | - os: macos-15 108 | device: cuda 109 | - os: ubuntu-22.04 110 | device: metal 111 | 112 | container: ${{ matrix.device == 'cuda' && 'nvidia/cuda:12.8.0-devel-ubuntu22.04' || '' }} 113 | steps: 114 | - uses: actions/checkout@v4 115 | - name: Run sccache-cache 116 | uses: mozilla-actions/sccache-action@v0.0.7 117 | - name: Configure sccache 118 | id: sccache 119 | run: | 120 | mkdir -p "$PWD/bindings/.sccache" 121 | export SCCACHE_DIR="$PWD/bindings/.sccache" 122 | echo "SCCACHE_DIR=$SCCACHE_DIR" >> $GITHUB_ENV 123 | - name: Cache sccache storage 124 | uses: actions/cache@v4 125 | with: 126 | path: ${{ env.SCCACHE_DIR }} 127 | key: sccache-${{ runner.os }}-${{ matrix.device }}-${{ github.ref_name }}-${{ github.run_id }} 128 | restore-keys: | 129 | sccache-${{ runner.os }}-${{ matrix.device }}-${{ github.ref_name }}- 130 | sccache-${{ runner.os }}-${{ matrix.device }}- 131 | 132 | warmup-win: 133 | needs: cleanup-cache 134 | if: ${{ needs.cleanup-cache.outputs.needs_warmup == 'true' || (github.event_name == 'workflow_dispatch' && github.event.inputs.force_warmup == 'true') }} 135 | runs-on: ${{ matrix.os }} 136 | strategy: 137 | matrix: 138 | os: [windows-2022] 139 | device: [cpu, cuda] 140 | 141 | steps: 142 | - uses: actions/checkout@v4 143 | - name: Run sccache-cache 144 | uses: mozilla-actions/sccache-action@v0.0.7 145 | - name: Configure sccache 146 | run: | 147 | New-Item -ItemType Directory -Force -Path "$PWD/bindings/.sccache" 148 | $env:SCCACHE_DIR="$PWD/bindings/.sccache" 149 | echo "SCCACHE_DIR=$env:SCCACHE_DIR" >> $env:GITHUB_ENV 150 | - name: Cache sccache storage 151 | uses: actions/cache@v4 152 | with: 153 | path: ${{ env.SCCACHE_DIR }} 154 | key: sccache-${{ runner.os }}-${{ matrix.device }}-${{ github.ref_name }}-${{ github.run_id }} 155 | restore-keys: | 156 | sccache-${{ runner.os }}-${{ matrix.device }}-${{ github.ref_name }}- 157 | sccache-${{ runner.os }}-${{ matrix.device }}- 158 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Deno ### 2 | /.idea 3 | */.idea/ 4 | /.vscode/ 5 | 6 | /node_modules 7 | /bindings/home/ 8 | /bindings/gguf 9 | bindings/.sccache 10 | 11 | .env 12 | *.orig 13 | *.pyc 14 | *.swp 15 | *.gguf 16 | 17 | # End of https://www.toptal.com/developers/gitignore/api/deno 18 | 19 | # C++ ignores 20 | cmake-build-debug* 21 | */cmake-build-debug* 22 | cmake-build-release* 23 | */cmake-build-release* 24 | bindings/build 25 | 26 | # Lib 27 | lib/* 28 | !lib/place_libs_here.txt 29 | 30 | # build 31 | build/* 32 | venv/* 33 | .venv/* 34 | 35 | # User configuration 36 | config.yml 37 | api_tokens.yml 38 | 39 | # Models folder 40 | models/* 41 | !models/place_your_models_here.txt 42 | 43 | # Templates folder 44 | templates/* 45 | !templates/place_your_templates_here.txt 46 | !templates/alpaca.jinja 47 | !templates/chatml.jinja 48 | 49 | # Sampler overrides folder 50 | sampler_overrides/* 51 | !sampler_overrides/sample_preset.yml 52 | 53 | # Compiled binaries and embedded assets 54 | gitSha.txt 55 | YALS.exe 56 | YALS 57 | 58 | # Markdown 59 | .obsidian/ 60 | 61 | # macOS 62 | *.DS_Store 63 | 64 | # Exclude all .yml except config sample 65 | *.yml 66 | !config_sample.yml -------------------------------------------------------------------------------- /BUILDING.md: -------------------------------------------------------------------------------- 1 | # Build Instructions 2 | 3 | YALS contains two components: 4 | 1. TypeScript code: Universally buildable on any OS 5 | 2. C++ bindings: Requires an OS-specific C++ compiler and additional setup 6 | 7 | The C++ bindings need to be built to integrate the `llama.cpp` library and provide the necessary "glue" required by YALS. 8 | 9 | ## Prerequisites 10 | 11 | To get started, install the following prerequisites: 12 | - [Deno](https://deno.com) 13 | - A C++ compiler: 14 | - Windows: Visual Studio 2022 build tools 15 | - macOS: Xcode command-line tools (`xcode-select --install`) 16 | - Linux: GCC (`sudo apt install build-essential`) 17 | - CMake: 18 | - Windows: Installed with Visual Studio build tools 19 | - macOS (homebrew): `brew install cmake` 20 | - Linux: `sudo apt install cmake` (For Ubuntu 22.04, follow this [askubuntu](https://askubuntu.com/a/865294) answer to install the latest version) 21 | - Ninja (Makes builds faster) 22 | - Windows: `winget install -e --id Ninja-build.Ninja` 23 | - macOS (homebrew): `brew install ninja` 24 | - Linux: `sudo apt install ninja-build` 25 | - [sccache](https://github.com/mozilla/sccache) (optional, but speeds up subsequent builds) 26 | - [Rust](https://rustup.rs)(Used for improved grammar parsing via LLGuidance) 27 | 28 | ## Building 29 | 30 | Clone the repository and navigate to the project folder: 31 | ```sh 32 | git clone https://github.com/theroyallab/YALS.git 33 | cd YALS 34 | ``` 35 | 36 | All build commands are encapsulated in Deno tasks, similar to npm scripts in NodeJS. 37 | 38 | > [!NOTE] 39 | > Unlike llama.cpp and its derivatives, YALS uses an extremely fast grammar tool called llguidance for JSON schemas, Regex, and lark grammars. 40 | > 41 | > Due to an extra dependency being required for users systems, llguidance is off by default, but it is **highly recommended** to turn it on at build time for improved grammar handling. 42 | 43 | To enable it, set `LLGUIDANCE=1` in your shell before invoking the deno task. 44 | 45 | To build the C++bindings: 46 | 47 | - Windows: `deno task bindings-win` 48 | - macOS/Linux: `deno task bindings` 49 | 50 | This will invoke CMake to build the bindings and copy the resulting shared libraries to the `lib` folder. 51 | 52 | Optionally, environment variables can be set for certain architectures when building (ex. CUDA): 53 | - `MAX_JOBS`: Number of parallel jobs (defaults to the number of CPU cores) 54 | - `LLAMACPP_REPO`: Point to a custom repository for llama.cpp (Here be dragons!) 55 | - `LLAMACPP_TAG`: Set a specific tag for llama.cpp (Here be dragons!) 56 | - `GGML_CUDA=1`: Enables CUDA support 57 | - `CMAKE_CUDA_ARCHITECTURES`: Specifies CUDA compute capabilities (defaults to `native` if using CMake > 3.24) 58 | - `GGML_VULKAN=1`: Enables Vulkan Support 59 | - `GGML_HIP=1`: Enables HIP ROCM Support (Requires specifying DAMDGPU_TARGETS, Linux only) 60 | - `AMDGPU_TARGETS`: Specify ROCM target (example: `gfx1030`) 61 | - `LLGUIDANCE=1`: (Recommended) Enable llguidance for grammars. Requires Rust on the system. (default `0`) 62 | 63 | ## Running 64 | 65 | To start the server with necessary permissions: 66 | ```sh 67 | deno task start 68 | ``` 69 | 70 | With full permissions (useful for testing new features): 71 | ```sh 72 | deno run -A main.ts 73 | ``` 74 | 75 | ## Packaging 76 | 77 | > [!NOTE] 78 | > **Note:** All YALS commits are built via GitHub Actions, so manual packaging is typically unnecessary unless you need to distribute builds with a custom build configuration. 79 | 80 | To create a distributable binary: 81 | 82 | 1. Run: `deno task build` to package all TypeScript code into a standalone binary 83 | 2. Zip the following files and directories: 84 | - `YALS(.exe)` 85 | - `lib/` 86 | - `models/` 87 | - `templates/` 88 | - `config_sample.yml` 89 | 3. Distribute the archive, and the recipient can simply extract and run it. 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | # YALS 6 | 7 |

8 | Python 3.10, 3.11, and 3.12 9 | 10 | License: AGPL v3 11 | 12 | 13 | Discord Server 14 | 15 |

16 | 17 |

18 | 19 | Support on Ko-Fi 20 | 21 |

22 | 23 | > [!NOTE] 24 | > 25 | > Need help? Join the [Discord Server](https://discord.gg/sYQxnuD7Fj) and get the `Tabby` role. Please be nice when asking questions. 26 | 27 | Welcome to YALS, also known as **Y**et **A**nother **L**lamacpp **S**erver. 28 | 29 | YALS is a friendly OAI compatible API server built with Deno, Hono, and Zod, designed to facilitate LLM text generation via the [llama.cpp backend](https://github.com/ggml-org/llama.cpp) 30 | 31 | ## Disclaimer 32 | 33 | This project is in an alpha state. There may be bugs, possibly even ones that could cause thermonuclear war. Please note that commits happen frequently, and builds are distributed via CI. 34 | 35 | YALS is a hobby project made for a small amount of users. It is not meant to run on production servers. For that, please look at other solutions that support those workloads. 36 | 37 | ## Why? 38 | 39 | The AI space is full of backend projects that wrap llama.cpp, but I felt that something was missing. This led me to create my own backend, one which is extensible, speedy, and as elegant as TabbyAPI, but specifically for llama.cpp and GGUF. 40 | 41 | ## What about TabbyAPI? 42 | 43 | Here are the reasons why I decided to create a separate project instead of integrating llamacpp support into TabbyAPI: 44 | 45 | 1. **Separation of concerns**: I want TabbyAPI to stay focused on ExLlama, not become a monolithic backend. 46 | 2. **Distribution patterns**: Unlike TabbyAPI, llama.cpp backends are often distributed as binaries. Deno’s compile command is vastly superior to PyInstaller, making binary distribution easier. 47 | 3. **Dependency hell**: Python’s dependency system is a mess. Adding another layer of abstractions would confuse users further. 48 | 4. **New technologies**: Since C++ (via C bindings) is universally compatible via an FFI interface, I wanted to try something new instead of struggling with Python. The main reason for using Deno is because it augments an easy to learn language (TypeScript) with inbuilt tooling and a robust FFI system. 49 | ## Getting Started 50 | 51 | To get started, download the latest zip from [releases](https://github.com/theroyallab/YALS/releases/latest) that corresponds to your setup. 52 | 53 | The currently supported builds via CI are: 54 | 55 | - **macOS**: Metal 56 | - **Windows/Linux**: CPU 57 | - **Windows/Linux**: CUDA (built for Pascal and newer consumer architectures) 58 | 59 | > [!NOTE] 60 | > 61 | > If your specific setup is not available via CI, you can build locally via the [building guide](https://github.com/theroyallab/YALS/blob/main/BUILDING.md), or request a certain architecture in issues. 62 | 63 | Then follow these steps: 64 | 65 | 1. Extract the zip file 66 | 2. Copy `config_sample.yml` to a file called `config.yml` 67 | 3. Edit `config.yml` to configure model loading, networking, and other parameters. 68 | 1. All options are commented: **if you're unsure about an option, it's best to leave it unchanged**. 69 | 2. You can also use CLI arguments, similar to TabbyAPI (ex. `--flash-attention true`). 70 | 4. Download a `.gguf` model into the `models` directory (or whatever you set your directory to) 71 | 1. If the model is split into multiple parts (`00001-of-0000x.gguf`), set `model_name` in `config.yml` to the **first part** (ending in `00001`). Other parts will load automatically. 72 | 5. Start YALS: 73 | 1. Windows: Double click `YALS.exe` or run `.\YALS.exe` from the terminal (recommended) 74 | 2. macOS/Linux: Open a terminal and run `./YALS` 75 | 6. Navigate to `http:///docs` (ex. `http://localhost:5000/docs`) to view the YALS Scalar API documentation. 76 | ## Features 77 | 78 | - OpenAI compatible API 79 | - Loading/unloading models 80 | - Flexible Jinja2 template engine for chat completions that conforms to HuggingFace 81 | - Fast JSON schema + Regex + EBNF support via llguidance 82 | - String banning 83 | - Concurrent inference with Hono + async TypeScript 84 | - Robust validation with Zod 85 | - Utilizes modern TS paradigms and the Deno runtime 86 | - Inbuilt proxy to override client request parameters/samplers 87 | - Continuous slot-based batching engine with improved KV cache assignment 88 | 89 | More features will be added as the project matures. If something is missing here, PR it in! 90 | 91 | ## Supported Model Types 92 | 93 | Since YALS uses llama.cpp for inference, the only supported model format is GGUF. 94 | 95 | If you want to use other model formats such as Exl2, try [tabbyAPI](https://github.com/theroyallab/TabbyAPI) 96 | 97 | ## Contributing 98 | 99 | Use the template when creating issues or pull requests, otherwise the developers may not look at your post. 100 | 101 | If you have issues with the project: 102 | 103 | - Describe the issue in detail 104 | - If you have a feature request, please indicate it as such. 105 | 106 | If you have a Pull Request: 107 | 108 | - Describe the pull request in detail, what, and why you are changing something 109 | 110 | ## Developers and Permissions 111 | 112 | Creators/Developers: 113 | 114 | - [kingbri](https://github.com/kingbri1) - TypeScript, Deno, and some C++ 115 | - [CoffeeVampire](https://github.com/CoffeeVampir3) - Main C++ developer 116 | 117 | ## Acknowledgements 118 | 119 | YALS would not exist without the work of other contributors and FOSS projects: 120 | 121 | - [llama.cpp](https://github.com/ggml-org/llama.cpp) 122 | - [Deno](https://deno.com) 123 | - [Hono](https://hono.dev) 124 | - [Zod](https://zod.dev) 125 | - [llguidance](https://github.com/guidance-ai/llguidance) 126 | - [KoboldCpp](https://github.com/lostruins/koboldcpp) 127 | - [SillyTavern](https://github.com/SillyTavern/SillyTavern) 128 | -------------------------------------------------------------------------------- /api/OAI/router.ts: -------------------------------------------------------------------------------- 1 | import { Hono } from "hono"; 2 | import { HTTPException } from "hono/http-exception"; 3 | import { streamSSE } from "hono/streaming"; 4 | import { describeRoute } from "hono-openapi"; 5 | import { validator as sValidator } from "hono-openapi"; 6 | import { 7 | ChatCompletionRequest, 8 | ChatCompletionResponse, 9 | } from "@/api/OAI/types/chatCompletions.ts"; 10 | import { 11 | generateChatCompletion, 12 | streamChatCompletion, 13 | } from "@/api/OAI/utils/chatCompletion.ts"; 14 | import { AuthKeyPermission } from "@/common/auth.ts"; 15 | import { jsonContent } from "@/common/networking.ts"; 16 | import { PromptTemplate } from "@/common/templating.ts"; 17 | 18 | import authMiddleware from "../middleware/authMiddleware.ts"; 19 | import checkModelMiddleware from "../middleware/checkModelMiddleware.ts"; 20 | import { CompletionRequest, CompletionResponse } from "./types/completions.ts"; 21 | import { generateCompletion, streamCompletion } from "./utils/completion.ts"; 22 | 23 | const router = new Hono(); 24 | 25 | const completionsRoute = describeRoute({ 26 | responses: { 27 | 200: jsonContent(CompletionResponse, "Response to completions"), 28 | }, 29 | }); 30 | 31 | router.post( 32 | "/v1/completions", 33 | completionsRoute, 34 | authMiddleware(AuthKeyPermission.API), 35 | checkModelMiddleware, 36 | sValidator("json", CompletionRequest), 37 | async (c) => { 38 | const params = c.req.valid("json"); 39 | 40 | if (params.stream) { 41 | return streamSSE(c, async (stream) => { 42 | await streamCompletion( 43 | c.var.requestId, 44 | stream, 45 | params, 46 | c.var.model, 47 | c.req.raw.signal, 48 | ); 49 | }); 50 | } else { 51 | const completionResult = await generateCompletion( 52 | c.var.requestId, 53 | params, 54 | c.var.model, 55 | c.req.raw.signal, 56 | ); 57 | 58 | return c.json(completionResult); 59 | } 60 | }, 61 | ); 62 | 63 | const chatCompletionsRoute = describeRoute({ 64 | responses: { 65 | 200: jsonContent( 66 | ChatCompletionResponse, 67 | "Response to chat completions", 68 | ), 69 | }, 70 | }); 71 | 72 | router.post( 73 | "/v1/chat/completions", 74 | chatCompletionsRoute, 75 | authMiddleware(AuthKeyPermission.API), 76 | checkModelMiddleware, 77 | sValidator("json", ChatCompletionRequest), 78 | async (c) => { 79 | const params = c.req.valid("json"); 80 | 81 | let promptTemplate: PromptTemplate; 82 | if (c.var.model.promptTemplate) { 83 | promptTemplate = c.var.model.promptTemplate; 84 | } else { 85 | throw new HTTPException(422, { 86 | message: 87 | "Chat completions are disabled because a prompt template isn't set.", 88 | }); 89 | } 90 | 91 | if (params.stream) { 92 | return streamSSE(c, async (stream) => { 93 | await streamChatCompletion( 94 | c.var.requestId, 95 | stream, 96 | params, 97 | c.var.model, 98 | promptTemplate, 99 | c.req.raw.signal, 100 | ); 101 | }); 102 | } else { 103 | const chatCompletionResult = await generateChatCompletion( 104 | c.var.requestId, 105 | params, 106 | c.var.model, 107 | promptTemplate, 108 | c.req.raw.signal, 109 | ); 110 | return c.json(chatCompletionResult); 111 | } 112 | }, 113 | ); 114 | 115 | export default router; 116 | -------------------------------------------------------------------------------- /api/OAI/types/chatCompletions.ts: -------------------------------------------------------------------------------- 1 | import * as z from "@/common/myZod.ts"; 2 | import { 3 | CommonCompletionRequest, 4 | UsageStats, 5 | } from "@/api/OAI/types/completions.ts"; 6 | import { BaseSamplerRequest } from "@/common/sampling.ts"; 7 | 8 | const ChatCompletionImageUrl = z.object({ 9 | url: z.string(), 10 | }); 11 | 12 | const ChatCompletionMessagePart = z.object({ 13 | type: z.string().nullish().coalesce("text"), 14 | text: z.string().nullish(), 15 | image_url: ChatCompletionImageUrl.nullish(), 16 | }); 17 | 18 | export type ChatCompletionMessagePart = z.infer< 19 | typeof ChatCompletionMessagePart 20 | >; 21 | 22 | export const ChatCompletionMessage = z.object({ 23 | role: z.string().default("user"), 24 | content: z.union([z.string(), z.array(ChatCompletionMessagePart)]), 25 | }); 26 | 27 | export type ChatCompletionMessage = z.infer; 28 | 29 | const ChatCompletionStreamOptions = z.object({ 30 | include_usage: z.boolean().nullish().coalesce(false), 31 | }); 32 | 33 | export const ChatCompletionRequest = z.aliasedObject( 34 | z.object({ 35 | messages: z.array(ChatCompletionMessage).nullish().coalesce([]), 36 | stream_options: ChatCompletionStreamOptions.nullish(), 37 | add_generation_prompt: z.boolean().nullish().coalesce(true), 38 | prompt_template: z.string().nullish(), 39 | template_vars: z.record(z.string(), z.unknown()).nullish().coalesce({}), 40 | }), 41 | [ 42 | { field: "template_vars", aliases: ["chat_template_kwargs"] }, 43 | ], 44 | ) 45 | .and(CommonCompletionRequest) 46 | .and(BaseSamplerRequest) 47 | .transform((obj) => { 48 | // Always unset add_bos_token 49 | obj.add_bos_token = undefined; 50 | return obj; 51 | }); 52 | 53 | export type ChatCompletionRequest = z.infer; 54 | 55 | export const ChatCompletionRespChoice = z.object({ 56 | index: z.number().default(0), 57 | finish_reason: z.string().optional(), 58 | message: ChatCompletionMessage, 59 | }); 60 | 61 | export const ChatCompletionResponse = z.object({ 62 | id: z.string().default( 63 | `chatcmpl-${crypto.randomUUID().replaceAll("-", "")}`, 64 | ), 65 | choices: z.array(ChatCompletionRespChoice), 66 | created: z.number().default(Math.floor(Date.now() / 1000)), 67 | model: z.string(), 68 | object: z.string().default("chat.completion"), 69 | usage: UsageStats.optional(), 70 | }); 71 | 72 | export const ChatCompletionStreamChoice = z.object({ 73 | index: z.number().default(0), 74 | finish_reason: z.string().optional(), 75 | delta: z.union([ChatCompletionMessage, z.record(z.string(), z.unknown())]), 76 | }); 77 | 78 | export const ChatCompletionStreamChunk = z.object({ 79 | id: z.string().default( 80 | `chatcmpl-${crypto.randomUUID().replaceAll("-", "")}`, 81 | ), 82 | choices: z.array(ChatCompletionStreamChoice).default([]), 83 | created: z.number().default(Math.floor(Date.now() / 1000)), 84 | model: z.string(), 85 | object: z.string().default("chat.completion.chunk"), 86 | usage: UsageStats.optional(), 87 | }); 88 | -------------------------------------------------------------------------------- /api/OAI/types/completions.ts: -------------------------------------------------------------------------------- 1 | import * as z from "@/common/myZod.ts"; 2 | import { BaseSamplerRequest } from "@/common/sampling.ts"; 3 | 4 | export const CompletionResponseFormat = z.object({ 5 | type: z.string().default("text"), 6 | }); 7 | 8 | export const UsageStats = z.object({ 9 | prompt_tokens: z.number(), 10 | completion_tokens: z.number(), 11 | total_tokens: z.number(), 12 | }); 13 | 14 | export type UsageStats = z.infer; 15 | 16 | export const CommonCompletionRequest = z.object({ 17 | model: z.string().nullish(), 18 | stream: z.boolean().nullish().coalesce(false), 19 | logprobs: z.number().gte(0).nullish().coalesce(0), 20 | response_format: CompletionResponseFormat.nullish().coalesce( 21 | CompletionResponseFormat.parse({}), 22 | ), 23 | n: z.number().gte(1).nullish().coalesce(1), 24 | best_of: z.number().nullish(), 25 | echo: z.boolean().nullish().coalesce(false), 26 | suffix: z.string().nullish(), 27 | user: z.string().nullish(), 28 | }) 29 | .and(BaseSamplerRequest) 30 | .transform((obj) => { 31 | if (obj.response_format.type === "json") { 32 | obj.json_schema = { 33 | "type": "object", 34 | }; 35 | } 36 | 37 | return obj; 38 | }); 39 | 40 | export const CompletionRequest = z.object({ 41 | prompt: z.union([ 42 | z.string(), 43 | z.array(z.string()).transform((arr) => arr.join("\n")), 44 | ]), 45 | }) 46 | .and(CommonCompletionRequest) 47 | .describe("Completion Request parameters"); 48 | 49 | export type CompletionRequest = z.infer; 50 | 51 | export const CompletionRespChoice = z.object({ 52 | index: z.number().default(0), 53 | finish_reason: z.string().optional(), 54 | text: z.string(), 55 | }); 56 | 57 | export const CompletionResponse = z.object({ 58 | id: z.string().default(`cmpl-${crypto.randomUUID().replaceAll("-", "")}`), 59 | choices: z.array(CompletionRespChoice), 60 | created: z.number().default(Math.floor(Date.now() / 1000)), 61 | model: z.string(), 62 | object: z.string().default("text_completion"), 63 | usage: UsageStats.optional(), 64 | }); 65 | -------------------------------------------------------------------------------- /api/OAI/utils/chatCompletion.ts: -------------------------------------------------------------------------------- 1 | import { SSEStreamingApi } from "hono/streaming"; 2 | 3 | import { 4 | convertFinishReason, 5 | createUsageStats, 6 | GenerationType, 7 | staticGenerate, 8 | } from "@/api/OAI/utils/generation.ts"; 9 | import { Model } from "@/bindings/bindings.ts"; 10 | import { FinishChunk, GenerationChunk } from "@/bindings/types.ts"; 11 | import { toGeneratorError } from "@/common/networking.ts"; 12 | import { PromptTemplate } from "@/common/templating.ts"; 13 | 14 | import { 15 | ChatCompletionMessage, 16 | ChatCompletionMessagePart, 17 | ChatCompletionRequest, 18 | ChatCompletionRespChoice, 19 | ChatCompletionResponse, 20 | ChatCompletionStreamChoice, 21 | ChatCompletionStreamChunk, 22 | } from "../types/chatCompletions.ts"; 23 | import { CancellationError } from "@/common/errors.ts"; 24 | import { logger } from "@/common/logging.ts"; 25 | 26 | interface TemplateFormatOptions { 27 | addBosToken?: boolean; 28 | banEosToken?: boolean; 29 | addGenerationPrompt?: boolean; 30 | templateVars?: Record; 31 | } 32 | 33 | function createResponse(chunk: FinishChunk, modelName: string) { 34 | const message = ChatCompletionMessage.parse({ 35 | role: "assistant", 36 | content: chunk.text, 37 | }); 38 | 39 | const choice = ChatCompletionRespChoice.parse({ 40 | message: message, 41 | finish_reason: convertFinishReason(chunk), 42 | }); 43 | 44 | const usage = createUsageStats(chunk); 45 | 46 | const response = ChatCompletionResponse.parse({ 47 | choices: [choice], 48 | model: modelName, 49 | usage, 50 | }); 51 | 52 | return response; 53 | } 54 | 55 | function createStreamChunk( 56 | chunk: GenerationChunk, 57 | modelName: string, 58 | cmplId: string, 59 | ) { 60 | const message = ChatCompletionMessage.parse({ 61 | role: "assistant", 62 | content: chunk.text, 63 | }); 64 | 65 | const choice = ChatCompletionStreamChoice.parse({ 66 | delta: message, 67 | }); 68 | 69 | const response = ChatCompletionStreamChunk.parse({ 70 | id: cmplId, 71 | choices: [choice], 72 | model: modelName, 73 | }); 74 | 75 | return response; 76 | } 77 | 78 | function createUsageChunk( 79 | chunk: FinishChunk, 80 | modelName: string, 81 | cmplId: string, 82 | ) { 83 | const response = ChatCompletionStreamChunk.parse({ 84 | id: cmplId, 85 | model: modelName, 86 | usage: createUsageStats(chunk), 87 | }); 88 | 89 | return response; 90 | } 91 | 92 | export function applyChatTemplate( 93 | model: Model, 94 | promptTemplate: PromptTemplate, 95 | messages: ChatCompletionMessage[], 96 | options: TemplateFormatOptions = {}, 97 | ): string { 98 | const { 99 | addGenerationPrompt = true, 100 | templateVars = {}, 101 | } = options; 102 | 103 | messages.forEach((message) => { 104 | if (Array.isArray(message.content)) { 105 | const messageParts = message.content as ChatCompletionMessagePart[]; 106 | message.content = messageParts.find((part) => 107 | part.type === "text" 108 | )?.text ?? ""; 109 | } 110 | }); 111 | 112 | const bosToken = model.tokenizer.bosToken; 113 | let prompt = promptTemplate.render({ 114 | ...templateVars, 115 | messages: messages, 116 | bos_token: bosToken?.piece ?? "", 117 | eos_token: model.tokenizer.eosToken?.piece ?? "", 118 | add_generation_prompt: addGenerationPrompt, 119 | }); 120 | 121 | // Remove extra BOS token at start of prompt if present 122 | // Some model templates don't respect their own add_bos_token setting 123 | // Better to do this since a template can add BOS anywhere 124 | if ( 125 | bosToken && model.tokenizer.addBosToken && 126 | prompt.startsWith(bosToken.piece) 127 | ) { 128 | prompt = prompt.slice(bosToken.piece.length); 129 | } 130 | 131 | return prompt; 132 | } 133 | 134 | function addTemplateMetadata( 135 | promptTemplate: PromptTemplate, 136 | params: ChatCompletionRequest, 137 | ) { 138 | const metadata = promptTemplate.metadata; 139 | 140 | if (metadata.stop_strings) { 141 | params.stop.push(...metadata.stop_strings); 142 | } 143 | } 144 | 145 | // TODO: Possibly rewrite this to unify with completions 146 | export async function streamChatCompletion( 147 | requestId: string, 148 | stream: SSEStreamingApi, 149 | params: ChatCompletionRequest, 150 | model: Model, 151 | promptTemplate: PromptTemplate, 152 | requestSignal: AbortSignal, 153 | ) { 154 | logger.info(`Received streaming chat completion request ${requestId}`); 155 | 156 | const cmplId = `chatcmpl-${crypto.randomUUID().replaceAll("-", "")}`; 157 | const abortController = new AbortController(); 158 | let finished = false; 159 | 160 | // If an abort happens before streaming starts 161 | requestSignal.addEventListener("abort", () => { 162 | if (!finished) { 163 | abortController.abort( 164 | new CancellationError( 165 | `Streaming chat completion ${requestId} cancelled by user.`, 166 | ), 167 | ); 168 | finished = true; 169 | } 170 | }); 171 | 172 | const prompt = applyChatTemplate( 173 | model, 174 | promptTemplate, 175 | params.messages, 176 | { 177 | addGenerationPrompt: params.add_generation_prompt, 178 | templateVars: params.template_vars, 179 | }, 180 | ); 181 | 182 | addTemplateMetadata(promptTemplate, params); 183 | 184 | try { 185 | const generator = model.generateGen( 186 | requestId, 187 | prompt, 188 | params, 189 | abortController.signal, 190 | ); 191 | 192 | for await (const chunk of generator) { 193 | const streamChunk = createStreamChunk( 194 | chunk, 195 | model.path.name, 196 | cmplId, 197 | ); 198 | 199 | await stream.writeSSE({ 200 | data: JSON.stringify(streamChunk), 201 | }); 202 | 203 | // Write usage stats if user requests it 204 | if ( 205 | params.stream_options?.include_usage && chunk.kind === "finish" 206 | ) { 207 | const usageChunk = createUsageChunk( 208 | chunk, 209 | model.path.name, 210 | cmplId, 211 | ); 212 | 213 | await stream.writeSSE({ 214 | data: JSON.stringify(usageChunk), 215 | }); 216 | } 217 | } 218 | 219 | logger.info(`Finished streaming chat completion request ${requestId}`); 220 | } catch (error) { 221 | await stream.writeSSE({ 222 | data: JSON.stringify(toGeneratorError(error)), 223 | }); 224 | } 225 | 226 | finished = true; 227 | } 228 | 229 | export async function generateChatCompletion( 230 | requestId: string, 231 | params: ChatCompletionRequest, 232 | model: Model, 233 | promptTemplate: PromptTemplate, 234 | requestSignal: AbortSignal, 235 | ) { 236 | logger.info(`Received chat completion request ${requestId}`); 237 | 238 | const prompt = applyChatTemplate( 239 | model, 240 | promptTemplate, 241 | params.messages, 242 | { 243 | addGenerationPrompt: params.add_generation_prompt, 244 | templateVars: params.template_vars, 245 | }, 246 | ); 247 | 248 | addTemplateMetadata(promptTemplate, params); 249 | 250 | // Handle generation in the common function 251 | const gen = await staticGenerate( 252 | requestId, 253 | GenerationType.ChatCompletion, 254 | prompt, 255 | params, 256 | model, 257 | requestSignal, 258 | ); 259 | const response = createResponse(gen, model.path.name); 260 | 261 | return response; 262 | } 263 | -------------------------------------------------------------------------------- /api/OAI/utils/completion.ts: -------------------------------------------------------------------------------- 1 | import { SSEStreamingApi } from "hono/streaming"; 2 | 3 | import { 4 | convertFinishReason, 5 | createUsageStats, 6 | GenerationType, 7 | staticGenerate, 8 | } from "@/api/OAI/utils/generation.ts"; 9 | import { Model } from "@/bindings/bindings.ts"; 10 | import { GenerationChunk } from "@/bindings/types.ts"; 11 | import { CancellationError } from "@/common/errors.ts"; 12 | import { toGeneratorError } from "@/common/networking.ts"; 13 | import { logger } from "@/common/logging.ts"; 14 | import { 15 | CompletionRequest, 16 | CompletionRespChoice, 17 | CompletionResponse, 18 | } from "../types/completions.ts"; 19 | 20 | function createResponse(chunk: GenerationChunk, modelName: string) { 21 | const finishReason = chunk.kind === "finish" 22 | ? convertFinishReason(chunk) 23 | : undefined; 24 | const choice = CompletionRespChoice.parse({ 25 | text: chunk.text, 26 | finish_reason: finishReason, 27 | }); 28 | 29 | const usage = chunk.kind === "finish" ? createUsageStats(chunk) : undefined; 30 | 31 | const response = CompletionResponse.parse({ 32 | choices: [choice], 33 | model: modelName, 34 | usage, 35 | }); 36 | 37 | return response; 38 | } 39 | 40 | export async function streamCompletion( 41 | requestId: string, 42 | stream: SSEStreamingApi, 43 | params: CompletionRequest, 44 | model: Model, 45 | requestSignal: AbortSignal, 46 | ) { 47 | logger.info(`Received streaming completion request ${requestId}`); 48 | 49 | const abortController = new AbortController(); 50 | let finished = false; 51 | 52 | // If an abort happens before streaming starts 53 | requestSignal.addEventListener("abort", () => { 54 | if (!finished) { 55 | abortController.abort( 56 | new CancellationError( 57 | `Streaming completion ${requestId} cancelled by user.`, 58 | ), 59 | ); 60 | finished = true; 61 | } 62 | }); 63 | 64 | try { 65 | const generator = model.generateGen( 66 | requestId, 67 | params.prompt, 68 | params, 69 | abortController.signal, 70 | ); 71 | 72 | for await (const chunk of generator) { 73 | const streamChunk = createResponse(chunk, model.path.name); 74 | 75 | await stream.writeSSE({ 76 | data: JSON.stringify(streamChunk), 77 | }); 78 | } 79 | 80 | logger.info(`Finished streaming completion request ${requestId}`); 81 | } catch (error) { 82 | await stream.writeSSE({ 83 | data: JSON.stringify(toGeneratorError(error)), 84 | }); 85 | } 86 | 87 | finished = true; 88 | } 89 | 90 | export async function generateCompletion( 91 | requestId: string, 92 | params: CompletionRequest, 93 | model: Model, 94 | requestSignal: AbortSignal, 95 | ) { 96 | logger.info(`Received completion request ${requestId}`); 97 | 98 | // Handle generation in the common function 99 | const gen = await staticGenerate( 100 | requestId, 101 | GenerationType.Completion, 102 | params.prompt, 103 | params, 104 | model, 105 | requestSignal, 106 | ); 107 | 108 | const response = createResponse(gen, model.path.name); 109 | 110 | return response; 111 | } 112 | -------------------------------------------------------------------------------- /api/OAI/utils/generation.ts: -------------------------------------------------------------------------------- 1 | import { UsageStats } from "@/api/OAI/types/completions.ts"; 2 | import { Model } from "@/bindings/bindings.ts"; 3 | import { FinishChunk, ReadbackFinishReason } from "@/bindings/types.ts"; 4 | import { logger } from "@/common/logging.ts"; 5 | import { BaseSamplerRequest } from "@/common/sampling.ts"; 6 | import { toHttpException } from "@/common/networking.ts"; 7 | import { CancellationError } from "@/common/errors.ts"; 8 | 9 | export enum GenerationType { 10 | Completion = "Completion", 11 | ChatCompletion = "Chat completion", 12 | } 13 | 14 | export function createUsageStats(chunk: FinishChunk) { 15 | const usage = UsageStats.parse({ 16 | prompt_tokens: chunk.promptTokens, 17 | completion_tokens: chunk.genTokens, 18 | total_tokens: chunk.promptTokens + chunk.genTokens, 19 | }); 20 | 21 | return usage; 22 | } 23 | 24 | export function convertFinishReason(chunk: FinishChunk) { 25 | return chunk.finishReason === ReadbackFinishReason.MaxNewTokens 26 | ? "length" 27 | : "stop"; 28 | } 29 | 30 | export async function staticGenerate( 31 | requestId: string, 32 | genType: GenerationType, 33 | prompt: string, 34 | params: BaseSamplerRequest, 35 | model: Model, 36 | requestSignal: AbortSignal, 37 | ) { 38 | const abortController = new AbortController(); 39 | let finished = false; 40 | 41 | requestSignal.addEventListener("abort", () => { 42 | if (!finished) { 43 | abortController.abort( 44 | new CancellationError( 45 | `${genType} ${requestId} cancelled by user.`, 46 | ), 47 | ); 48 | finished = true; 49 | } 50 | }); 51 | 52 | try { 53 | const result = await model.generate( 54 | requestId, 55 | prompt, 56 | params, 57 | abortController.signal, 58 | ); 59 | 60 | logger.info(`Finished ${genType.toLowerCase()} request ${requestId}`); 61 | 62 | finished = true; 63 | return result; 64 | } catch (error) { 65 | throw toHttpException(error); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /api/core/router.ts: -------------------------------------------------------------------------------- 1 | import { Hono } from "hono"; 2 | import { HTTPException } from "hono/http-exception"; 3 | import { describeRoute } from "hono-openapi"; 4 | import { validator as sValidator } from "hono-openapi"; 5 | import { AuthPermissionResponse } from "@/api/core/types/auth.ts"; 6 | import { HealthSchema } from "@/api/core/types/health.ts"; 7 | import { applyChatTemplate } from "@/api/OAI/utils/chatCompletion.ts"; 8 | import { 9 | ModelCard, 10 | ModelList, 11 | ModelLoadRequest, 12 | } from "@/api/core/types/model.ts"; 13 | import { 14 | TemplateList, 15 | TemplateSwitchRequest, 16 | } from "@/api/core/types/template.ts"; 17 | import { 18 | TokenDecodeRequest, 19 | TokenDecodeResponse, 20 | TokenEncodeRequest, 21 | TokenEncodeResponse, 22 | } from "@/api/core/types/token.ts"; 23 | import { AuthKeyPermission, getAuthPermission } from "@/common/auth.ts"; 24 | import { ModelConfig } from "@/common/configModels.ts"; 25 | import { config } from "@/common/config.ts"; 26 | import { logger } from "@/common/logging.ts"; 27 | import * as modelContainer from "@/common/modelContainer.ts"; 28 | import { jsonContent, toHttpException } from "@/common/networking.ts"; 29 | import { PromptTemplate } from "@/common/templating.ts"; 30 | 31 | import authMiddleware from "../middleware/authMiddleware.ts"; 32 | import checkModelMiddleware from "../middleware/checkModelMiddleware.ts"; 33 | 34 | const router = new Hono(); 35 | 36 | const healthRoute = describeRoute({ 37 | responses: { 38 | 200: jsonContent(HealthSchema, "Health status of server"), 39 | }, 40 | }); 41 | 42 | router.get( 43 | "/health", 44 | healthRoute, 45 | checkModelMiddleware, 46 | (c) => { 47 | return c.json(HealthSchema.parse({ health: "ok" })); 48 | }, 49 | ); 50 | 51 | const modelsRoute = describeRoute({ 52 | responses: { 53 | 200: jsonContent(ModelList, "List of models in directory"), 54 | }, 55 | }); 56 | 57 | router.on( 58 | "GET", 59 | ["/v1/models", "/v1/model/list"], 60 | modelsRoute, 61 | authMiddleware(AuthKeyPermission.API), 62 | async (c) => { 63 | const modelCards: ModelCard[] = []; 64 | for await (const file of Deno.readDir(config.model.model_dir)) { 65 | if (!file.name.endsWith(".gguf")) { 66 | continue; 67 | } 68 | 69 | const modelCard = ModelCard.parse({ 70 | id: file.name.replace(".gguf", ""), 71 | }); 72 | 73 | modelCards.push(modelCard); 74 | } 75 | 76 | const modelList = ModelList.parse({ 77 | data: modelCards, 78 | }); 79 | 80 | return c.json(modelList); 81 | }, 82 | ); 83 | 84 | const currentModelRoute = describeRoute({ 85 | responses: { 86 | 200: jsonContent( 87 | ModelCard, 88 | "The currently loaded model (if it exists)", 89 | ), 90 | }, 91 | }); 92 | 93 | router.get( 94 | "/v1/model", 95 | currentModelRoute, 96 | authMiddleware(AuthKeyPermission.API), 97 | checkModelMiddleware, 98 | (c) => { 99 | const modelCard = ModelCard.parse({ 100 | id: c.var.model.path.base, 101 | }); 102 | 103 | return c.json(modelCard); 104 | }, 105 | ); 106 | 107 | const loadModelRoute = describeRoute({ 108 | responses: { 109 | 200: { 110 | description: "Model successfully loaded", 111 | }, 112 | }, 113 | }); 114 | 115 | // TODO: Make this a streaming response if necessary 116 | router.post( 117 | "/v1/model/load", 118 | loadModelRoute, 119 | authMiddleware(AuthKeyPermission.Admin), 120 | sValidator("json", ModelLoadRequest), 121 | async (c) => { 122 | const params = c.req.valid("json"); 123 | const loadParams = ModelConfig.parse({ 124 | ...params, 125 | model_dir: config.model.model_dir, 126 | }); 127 | 128 | // Makes sure the event doesn't fire multiple times 129 | let finished = false; 130 | 131 | // Abort handler 132 | const progressAbort = new AbortController(); 133 | c.req.raw.signal.addEventListener("abort", () => { 134 | if (!finished) { 135 | progressAbort.abort(); 136 | } 137 | }); 138 | 139 | const progressCallback = (_progress: number): boolean => { 140 | if (progressAbort.signal.aborted) { 141 | logger.error("Load request cancelled"); 142 | return false; 143 | } 144 | 145 | return true; 146 | }; 147 | 148 | // Load the model and re-raise errors 149 | try { 150 | await modelContainer.loadModel(loadParams, progressCallback); 151 | } catch (error) { 152 | if (error instanceof Error) { 153 | throw new HTTPException(422, error); 154 | } 155 | } 156 | 157 | finished = true; 158 | 159 | c.status(200); 160 | return c.body(null); 161 | }, 162 | ); 163 | 164 | const unloadRoute = describeRoute({ 165 | responses: { 166 | 200: { 167 | description: "Model successfully unloaded", 168 | }, 169 | }, 170 | }); 171 | 172 | router.post( 173 | "/v1/model/unload", 174 | unloadRoute, 175 | authMiddleware(AuthKeyPermission.Admin), 176 | checkModelMiddleware, 177 | async (c) => { 178 | await modelContainer.unloadModel(true); 179 | 180 | c.status(200); 181 | return c.body(null); 182 | }, 183 | ); 184 | 185 | const templatesRoute = describeRoute({ 186 | responses: { 187 | 200: jsonContent(TemplateList, "List of prompt templates"), 188 | }, 189 | }); 190 | 191 | router.on( 192 | "GET", 193 | ["/v1/templates", "/v1/template/list"], 194 | templatesRoute, 195 | authMiddleware(AuthKeyPermission.API), 196 | async (c) => { 197 | const templates: string[] = []; 198 | for await (const file of Deno.readDir("templates")) { 199 | if (!file.name.endsWith(".jinja")) { 200 | continue; 201 | } 202 | 203 | templates.push(file.name.replace(".jinja", "")); 204 | } 205 | 206 | const templateList = TemplateList.parse({ 207 | data: templates, 208 | }); 209 | 210 | return c.json(templateList); 211 | }, 212 | ); 213 | 214 | const templateSwitchRoute = describeRoute({ 215 | responses: { 216 | 200: { 217 | description: "Prompt template switched", 218 | }, 219 | }, 220 | }); 221 | 222 | router.post( 223 | "/v1/template/switch", 224 | templateSwitchRoute, 225 | authMiddleware(AuthKeyPermission.API), 226 | checkModelMiddleware, 227 | sValidator("json", TemplateSwitchRequest), 228 | async (c) => { 229 | const params = c.req.valid("json"); 230 | 231 | const templatePath = `templates/${params.prompt_template_name}`; 232 | c.var.model.promptTemplate = await PromptTemplate.fromFile( 233 | templatePath, 234 | ); 235 | }, 236 | ); 237 | 238 | const authPermissionRoute = describeRoute({ 239 | responses: { 240 | 200: jsonContent( 241 | AuthPermissionResponse, 242 | "Returns permissions of a given auth key", 243 | ), 244 | }, 245 | }); 246 | 247 | router.get( 248 | "/v1/auth/permission", 249 | authPermissionRoute, 250 | authMiddleware(AuthKeyPermission.API), 251 | (c) => { 252 | try { 253 | const permission = getAuthPermission(c.req.header()); 254 | const response = AuthPermissionResponse.parse({ 255 | permission, 256 | }); 257 | 258 | return c.json(response); 259 | } catch (error) { 260 | throw toHttpException(error, 400); 261 | } 262 | }, 263 | ); 264 | 265 | const tokenEncodeRoute = describeRoute({ 266 | responses: { 267 | 200: jsonContent(TokenEncodeResponse, "Encode token response"), 268 | }, 269 | }); 270 | 271 | router.post( 272 | "/v1/token/encode", 273 | tokenEncodeRoute, 274 | authMiddleware(AuthKeyPermission.API), 275 | checkModelMiddleware, 276 | sValidator("json", TokenEncodeRequest), 277 | async (c) => { 278 | const params = c.req.valid("json"); 279 | 280 | let text: string; 281 | if (typeof params.text === "string") { 282 | text = params.text; 283 | } else if (Array.isArray(params.text)) { 284 | if (!c.var.model.promptTemplate) { 285 | throw new HTTPException(422, { 286 | message: "Cannot tokenize chat completion " + 287 | "because a prompt template is not set", 288 | }); 289 | } 290 | 291 | text = applyChatTemplate( 292 | c.var.model, 293 | c.var.model.promptTemplate, 294 | params.text, 295 | { 296 | addBosToken: params.add_bos_token, 297 | addGenerationPrompt: false, 298 | }, 299 | ); 300 | } else { 301 | throw new HTTPException(422, { 302 | message: "Unable to tokenize the provided text. " + 303 | "Check your formatting?", 304 | }); 305 | } 306 | 307 | const tokens = await c.var.model.tokenizer.tokenize( 308 | text, 309 | params.add_bos_token, 310 | params.encode_special_tokens, 311 | ); 312 | 313 | const resp = TokenEncodeResponse.parse({ 314 | tokens, 315 | length: tokens.length, 316 | }); 317 | 318 | return c.json(resp); 319 | }, 320 | ); 321 | 322 | const tokenDecodeRoute = describeRoute({ 323 | responses: { 324 | 200: jsonContent(TokenDecodeResponse, "Decode token response"), 325 | }, 326 | }); 327 | 328 | router.post( 329 | "/v1/token/decode", 330 | tokenDecodeRoute, 331 | authMiddleware(AuthKeyPermission.API), 332 | checkModelMiddleware, 333 | sValidator("json", TokenDecodeRequest), 334 | async (c) => { 335 | const params = c.req.valid("json"); 336 | 337 | const text = await c.var.model.tokenizer.detokenize( 338 | params.tokens, 339 | undefined, 340 | params.add_bos_token, 341 | params.decode_special_tokens, 342 | ); 343 | 344 | const resp = TokenDecodeResponse.parse({ 345 | text, 346 | }); 347 | 348 | return c.json(resp); 349 | }, 350 | ); 351 | 352 | export default router; 353 | -------------------------------------------------------------------------------- /api/core/types/auth.ts: -------------------------------------------------------------------------------- 1 | import * as z from "@/common/myZod.ts"; 2 | 3 | export const AuthPermissionResponse = z.object({ 4 | permission: z.string(), 5 | }); 6 | -------------------------------------------------------------------------------- /api/core/types/health.ts: -------------------------------------------------------------------------------- 1 | import * as z from "@/common/myZod.ts"; 2 | 3 | export const HealthSchema = z.object({ 4 | health: z.enum(["ok", "unhealthy"]), 5 | }); 6 | -------------------------------------------------------------------------------- /api/core/types/model.ts: -------------------------------------------------------------------------------- 1 | import * as z from "@/common/myZod.ts"; 2 | import { ModelConfig } from "@/common/configModels.ts"; 3 | import { applyLoadDefaults } from "@/common/modelContainer.ts"; 4 | 5 | export const ModelLoadRequest = z.preprocess( 6 | (data: unknown) => applyLoadDefaults(data), 7 | ModelConfig.extend({ 8 | model_name: z.string(), 9 | }).omit({ 10 | model_dir: true, 11 | use_as_default: true, 12 | }), 13 | ); 14 | 15 | export const ModelCard = z.object({ 16 | id: z.string().default("test"), 17 | object: z.string().default("model"), 18 | created: z.number().default(Date.now()), 19 | owned_by: z.string().default("YALS"), 20 | }); 21 | 22 | export type ModelCard = z.infer; 23 | 24 | export const ModelList = z.object({ 25 | object: z.string().default("list"), 26 | data: z.array(ModelCard).default([]), 27 | }); 28 | 29 | export type ModelList = z.infer; 30 | -------------------------------------------------------------------------------- /api/core/types/template.ts: -------------------------------------------------------------------------------- 1 | import * as z from "@/common/myZod.ts"; 2 | 3 | export const TemplateList = z.object({ 4 | object: z.string().default("list"), 5 | data: z.array(z.string()).default([]), 6 | }); 7 | 8 | export const TemplateSwitchRequest = z.aliasedObject( 9 | z.object({ 10 | prompt_template_name: z.string(), 11 | }), 12 | [{ field: "prompt_template_name", aliases: ["name"] }], 13 | ); 14 | -------------------------------------------------------------------------------- /api/core/types/token.ts: -------------------------------------------------------------------------------- 1 | import * as z from "@/common/myZod.ts"; 2 | import { ChatCompletionMessage } from "@/api/OAI/types/chatCompletions.ts"; 3 | 4 | const CommonTokenRequest = z.object({ 5 | add_bos_token: z.boolean().nullish().coalesce(true), 6 | encode_special_tokens: z.boolean().nullish().coalesce(true), 7 | decode_special_tokens: z.boolean().nullish().coalesce(true), 8 | }); 9 | 10 | export const TokenEncodeRequest = z.object({ 11 | text: z.union([z.string(), z.array(ChatCompletionMessage)]), 12 | }) 13 | .merge(CommonTokenRequest); 14 | 15 | export const TokenEncodeResponse = z.object({ 16 | tokens: z.array(z.number()), 17 | length: z.number(), 18 | }); 19 | 20 | export const TokenDecodeRequest = z.object({ 21 | tokens: z.array(z.number()), 22 | }) 23 | .merge(CommonTokenRequest); 24 | 25 | export const TokenDecodeResponse = z.object({ 26 | text: z.string(), 27 | }); 28 | -------------------------------------------------------------------------------- /api/middleware/authMiddleware.ts: -------------------------------------------------------------------------------- 1 | import { HTTPException } from "hono/http-exception"; 2 | import { createMiddleware } from "hono/factory"; 3 | import { AuthKeyPermission, authKeys } from "@/common/auth.ts"; 4 | import { config } from "@/common/config.ts"; 5 | 6 | // Middleware for checking if the model exists 7 | // Sends a validated version of the model via Hono's ctx 8 | const authMiddleware = (permission: AuthKeyPermission) => { 9 | return createMiddleware(async (c, next) => { 10 | if (config.network.disable_auth) { 11 | await next(); 12 | return; 13 | } 14 | 15 | const headers = c.req.header(); 16 | const xHeader = `x-${permission.toLowerCase()}-key`; 17 | 18 | // TODO: Possibly refactor error throws 19 | if (xHeader in headers) { 20 | const valid = authKeys?.verifyKey(headers[xHeader], permission); 21 | if (!valid) { 22 | throw new HTTPException(401, { 23 | message: `Invalid ${permission} key`, 24 | }); 25 | } 26 | } else if ("authorization" in headers) { 27 | const splitKey = headers["authorization"].split(" "); 28 | if (splitKey.length < 2) { 29 | throw new HTTPException(401, { 30 | message: `Invalid ${permission} key`, 31 | }); 32 | } 33 | 34 | const valid = splitKey[0].toLowerCase() === "bearer" && 35 | authKeys?.verifyKey(splitKey[1], permission); 36 | 37 | if (!valid) { 38 | throw new HTTPException(401, { 39 | message: `Invalid ${permission} key`, 40 | }); 41 | } 42 | } else { 43 | throw new HTTPException(401, { message: "Key not provided" }); 44 | } 45 | 46 | await next(); 47 | }); 48 | }; 49 | 50 | export default authMiddleware; 51 | -------------------------------------------------------------------------------- /api/middleware/checkModelMiddleware.ts: -------------------------------------------------------------------------------- 1 | import { createMiddleware } from "hono/factory"; 2 | 3 | import { Model } from "@/bindings/bindings.ts"; 4 | import { ModelNotLoadedError } from "@/common/errors.ts"; 5 | import { model } from "@/common/modelContainer.ts"; 6 | 7 | // Extra vars for context 8 | interface CtxOptions { 9 | Variables: { 10 | model: Model; 11 | }; 12 | } 13 | 14 | // Middleware for checking if the model exists 15 | // Sends a validated version of the model via Hono's ctx 16 | const checkModelMiddleware = createMiddleware( 17 | async (c, next) => { 18 | if (!model) { 19 | throw new ModelNotLoadedError(); 20 | } 21 | 22 | // Validated reference 23 | c.set("model", model); 24 | 25 | await next(); 26 | }, 27 | ); 28 | 29 | export default checkModelMiddleware; 30 | -------------------------------------------------------------------------------- /api/middleware/requestLogMiddleware.ts: -------------------------------------------------------------------------------- 1 | import { createMiddleware } from "hono/factory"; 2 | import { logger } from "../../common/logging.ts"; 3 | 4 | // Middleware for logging parts of a request 5 | const requestLogMiddleware = createMiddleware( 6 | async (c, next) => { 7 | const logMessage = [ 8 | `Information for ${c.req.method} request ${c.var.requestId}`, 9 | ]; 10 | 11 | logMessage.push(`URL: ${c.req.url}`); 12 | 13 | const headers = Object.fromEntries(c.req.raw.headers); 14 | logMessage.push(`Headers: ${JSON.stringify(headers, null, 2)}`); 15 | 16 | if (c.req.method !== "GET") { 17 | const clonedReq = c.req.raw.clone(); 18 | const textBody = await clonedReq.text(); 19 | 20 | if (textBody) { 21 | logMessage.push(`Body: ${textBody}`); 22 | } 23 | } 24 | 25 | logger.info(logMessage.join("\n")); 26 | 27 | await next(); 28 | }, 29 | ); 30 | 31 | export default requestLogMiddleware; 32 | -------------------------------------------------------------------------------- /api/server.ts: -------------------------------------------------------------------------------- 1 | import { Hono } from "hono"; 2 | import { cors } from "hono/cors"; 3 | import { requestId } from "hono/request-id"; 4 | import { logger as loggerMiddleware } from "hono/logger"; 5 | import { ContentfulStatusCode } from "hono/utils/http-status"; 6 | import { openAPISpecs } from "hono-openapi"; 7 | import { apiReference } from "@scalar/hono-api-reference"; 8 | 9 | import { config } from "@/common/config.ts"; 10 | import { logger } from "@/common/logging.ts"; 11 | import core from "./core/router.ts"; 12 | import oai from "./OAI/router.ts"; 13 | import { generateUuidHex } from "@/common/utils.ts"; 14 | import { ModelNotLoadedError } from "@/common/errors.ts"; 15 | import requestLogMiddleware from "./middleware/requestLogMiddleware.ts"; 16 | 17 | export function createApi() { 18 | const app = new Hono(); 19 | 20 | // TODO: Use a custom middleware instead of overriding Hono's logger 21 | const printToLogger = (message: string, ...rest: string[]) => { 22 | logger.info(message, { rest }); 23 | }; 24 | 25 | // Middleware 26 | app.use(loggerMiddleware(printToLogger)); 27 | app.use("*", cors()); 28 | app.use(requestId({ limitLength: 16, generator: generateUuidHex })); 29 | 30 | if (config.logging.log_requests) { 31 | app.use(requestLogMiddleware); 32 | } 33 | 34 | // Add routers 35 | app.route("/", core); 36 | app.route("/", oai); 37 | 38 | // OpenAPI documentation 39 | app.get( 40 | "/openapi.json", 41 | openAPISpecs(app, { 42 | documentation: { 43 | openapi: "3.0.0", 44 | info: { 45 | version: "0.0.1", 46 | title: "YALS", 47 | }, 48 | }, 49 | }), 50 | ); 51 | 52 | app.get( 53 | "/docs", 54 | apiReference({ 55 | spec: { 56 | url: "/openapi.json", 57 | }, 58 | }), 59 | ); 60 | 61 | // Error handling 62 | // Originally from the Stoker package 63 | app.onError((err, c) => { 64 | const currentStatus = "status" in err 65 | ? err.status 66 | : c.newResponse(null).status; 67 | const statusCode = currentStatus != 200 68 | ? (currentStatus as ContentfulStatusCode) 69 | : 500; 70 | 71 | const logError = !( 72 | statusCode === 401 73 | ); 74 | 75 | // Only log in console if the error allows it 76 | if (logError) { 77 | const messageOnly = statusCode === 408 || 78 | err instanceof ModelNotLoadedError; 79 | 80 | if (messageOnly) { 81 | logger.error(`Sent to request: ${err.message}`); 82 | } else { 83 | logger.error(`Sent to request: ${err.stack || err.message}`); 84 | } 85 | } 86 | 87 | // Always send error + message to client 88 | return c.json({ 89 | detail: err.message, 90 | }, statusCode); 91 | }); 92 | 93 | app.notFound((c) => { 94 | return c.json({ 95 | message: `Method or path not found - ${c.req.method} ${c.req.path}`, 96 | }, 404); 97 | }); 98 | 99 | // Serve 100 | Deno.serve({ 101 | hostname: config.network.host, 102 | port: config.network.port, 103 | handler: app.fetch, 104 | onListen: ({ hostname, port }) => { 105 | logger.info(`Server running on http://${hostname}:${port}`); 106 | }, 107 | }); 108 | } 109 | -------------------------------------------------------------------------------- /assets/icon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theroyallab/YALS/60286959be95d577e05efdf33ba6733395d60020/assets/icon.ico -------------------------------------------------------------------------------- /assets/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theroyallab/YALS/60286959be95d577e05efdf33ba6733395d60020/assets/icon.png -------------------------------------------------------------------------------- /bindings/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14.0) 2 | project(LlamaMultiUserInference) 3 | set(CMAKE_CXX_STANDARD 17) 4 | 5 | option(LLGUIDANCE "Enable LLGuidance support (requires Rust)" OFF) 6 | 7 | # Set RPath for Apple and Unix systems 8 | if (APPLE) 9 | set(CMAKE_INSTALL_RPATH "@loader_path") 10 | set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) 11 | set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) 12 | elseif (UNIX) 13 | set(CMAKE_INSTALL_RPATH "$ORIGIN") 14 | set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) 15 | set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) 16 | endif() 17 | 18 | if (DEFINED LLAMACPP_REPO OR DEFINED LLAMACPP_COMMIT) 19 | message(STATUS "Using a custom commit or repo for llama.cpp. Build might not work as expected. Here be dragons!") 20 | endif() 21 | 22 | # Do not cache these variables with subsequent builds 23 | set(LLAMACPP_REPO "https://github.com/ggerganov/llama.cpp.git") 24 | message(STATUS "Using llama.cpp repo ${LLAMACPP_REPO}") 25 | 26 | # Stable llama.cpp commit for bindings 27 | set(LLAMACPP_COMMIT "7675c555a13c9f473249e59a54db35032ce8e0fc") 28 | message(STATUS "Using llama.cpp tag ${LLAMACPP_COMMIT}") 29 | 30 | # Optional: You can also enable mixed FP16/FP32 computation for faster processing 31 | # set(LLAMA_CUDA_F16 ON CACHE BOOL "llama.cpp: use float16 for GPU operations" FORCE) 32 | # set(GGML_CUDA ON CACHE BOOL "llama.cpp: use float16 for GPU operations" FORCE) 33 | 34 | # Disable unused components to speed up build 35 | set(LLAMA_BUILD_EXAMPLES OFF CACHE BOOL "llama.cpp: build examples" FORCE) 36 | set(LLAMA_BUILD_TESTS OFF CACHE BOOL "llama.cpp: build tests" FORCE) 37 | set(LLAMA_BUILD_SERVER OFF CACHE BOOL "llama.cpp: build server" FORCE) 38 | set(LLAMA_CURL OFF CACHE BOOL "llama.cpp: use libcurl" FORCE) 39 | 40 | # Enable common 41 | set(LLAMA_BUILD_COMMON ON CACHE BOOL "llama.cpp: build common utils library" FORCE) 42 | 43 | if(LLGUIDANCE) 44 | find_program(CARGO cargo) 45 | if(CARGO) 46 | message(STATUS "Including LLGuidance in build") 47 | set(LLAMA_LLGUIDANCE ON CACHE BOOL "llama.cpp: enable LLGuidance support" FORCE) 48 | else() 49 | message(FATAL_ERROR "LLGuidance is enabled, but requires Rust for compilation. Get it at https://rustup.rs") 50 | endif() 51 | else() 52 | message(STATUS "LLGuidance support is disabled. Enable with -DLLGUIDANCE=ON for grammar, JSON schema, and regex support.") 53 | set(LLAMA_LLGUIDANCE OFF CACHE BOOL "llama.cpp: disable LLGuidance support" FORCE) 54 | endif() 55 | 56 | # Fetch llama.cpp latest 57 | # FIXME: Maybe use a vendored llama.cpp build for stability 58 | include(FetchContent) 59 | FetchContent_Declare( 60 | llama 61 | GIT_REPOSITORY ${LLAMACPP_REPO} 62 | GIT_TAG ${LLAMACPP_COMMIT} 63 | ) 64 | 65 | # Set build type to Release for performance 66 | set(CMAKE_BUILD_TYPE Release) 67 | 68 | # Build all libs to bin 69 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 70 | 71 | # Make llama.cpp available 72 | FetchContent_MakeAvailable(llama) 73 | 74 | message(STATUS "llama source dir: ${llama_SOURCE_DIR}") 75 | 76 | # Apple build changes 77 | # From llama-cpp-python 78 | if (APPLE AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm64") 79 | # Need to disable these llama.cpp flags on Apple x86_64, 80 | # otherwise users may encounter invalid instruction errors 81 | set(GGML_AVX "Off" CACHE BOOL "ggml: enable AVX" FORCE) 82 | set(GGML_AVX2 "Off" CACHE BOOL "ggml: enable AVX2" FORCE) 83 | set(GGML_FMA "Off" CACHE BOOL "gml: enable FMA" FORCE) 84 | set(GGML_F16C "Off" CACHE BOOL "gml: enable F16C" FORCE) 85 | endif() 86 | 87 | if (APPLE) 88 | set(GGML_METAL_EMBED_LIBRARY ON CACHE BOOL "llama: embed metal library" FORCE) 89 | endif() 90 | 91 | # Create a library from c_library.cpp 92 | add_library(c_library SHARED 93 | server/c_library.cpp 94 | ) 95 | 96 | # Set include directories for the library 97 | target_include_directories(c_library PUBLIC 98 | ${CMAKE_CURRENT_SOURCE_DIR} 99 | ${CMAKE_CURRENT_SOURCE_DIR}/server 100 | ${llama_SOURCE_DIR}/src 101 | ) 102 | 103 | # Link llama libraries to our c_library 104 | target_link_libraries(c_library PUBLIC llama common) 105 | 106 | # Create our main executable 107 | add_executable(multi_user_inference 108 | server/server_basic_example.cpp 109 | ) 110 | 111 | # set_target_properties(multi_user_inference PROPERTIES 112 | # INSTALL_RPATH "${CMAKE_BINARY_DIR}/bin" 113 | # ) 114 | 115 | # Include directories for main executable 116 | target_include_directories(multi_user_inference PRIVATE 117 | ${CMAKE_CURRENT_SOURCE_DIR} 118 | ${CMAKE_CURRENT_SOURCE_DIR}/server 119 | ) 120 | 121 | # Link our c_library to the main executable 122 | target_link_libraries(multi_user_inference PRIVATE 123 | c_library 124 | ) 125 | 126 | if(LLGUIDANCE) 127 | target_compile_definitions(c_library PUBLIC LLGUIDANCE_BUILT=1) 128 | endif() 129 | 130 | # Windows options 131 | if(WIN32) 132 | set_target_properties(c_library PROPERTIES 133 | WINDOWS_EXPORT_ALL_SYMBOLS TRUE 134 | ) 135 | endif() -------------------------------------------------------------------------------- /bindings/bindings.ps1: -------------------------------------------------------------------------------- 1 | if (Get-Command cmake -ErrorAction SilentlyContinue) { 2 | Write-Host "Found CMake: $(cmake --version)" 3 | } else { 4 | Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\Common7\Tools\Microsoft.VisualStudio.DevShell.dll' 5 | Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools' -DevCmdArguments '-arch=x64 -host_arch=x64' 6 | } 7 | 8 | $jobs = if ($env:MAX_JOBS) { 9 | $env:MAX_JOBS 10 | } else { 11 | $env:NUMBER_OF_PROCESSORS 12 | } 13 | 14 | $extraCmakeArgs = @() 15 | 16 | # llama.cpp dev options 17 | if ($env:LLAMACPP_REPO) { 18 | $extraCmakeArgs += "-DLLAMACPP_REPO=$env:LLAMACPP_REPO" 19 | Write-Host "Using custom llama.cpp repo: $env:LLAMACPP_REPO" 20 | } 21 | 22 | if ($env:LLAMACPP_COMMIT) { 23 | $extraCmakeArgs += "-DLLAMACPP_COMMIT=$env:LLAMACPP_COMMIT" 24 | Write-Host "Using custom llama.cpp commit: $env:LLAMACPP_COMMIT" 25 | } 26 | 27 | if ($env:LLGUIDANCE -eq 1) { 28 | $env:RUSTC_WRAPPER="sccache" 29 | Write-Host "LLGuidance enabled, including in build" 30 | $extraCmakeArgs += "-DLLGUIDANCE=ON" 31 | } 32 | 33 | if ($env:GGML_CUDA -eq 1) { 34 | Write-Host "CUDA enabled, including in build" 35 | 36 | $extraCmakeArgs += "-DGGML_CUDA=ON" 37 | 38 | if ($env:CMAKE_CUDA_ARCHITECTURES) { 39 | $extraCmakeArgs += @( 40 | "-DCMAKE_CUDA_ARCHITECTURES=$env:CMAKE_CUDA_ARCHITECTURES", 41 | "-DGGML_NATIVE=OFF" 42 | ) 43 | } 44 | } 45 | 46 | if ($env:GGML_VULKAN -eq 1) { 47 | Write-Host "Vulkan enabled, including in build" 48 | 49 | $extraCmakeArgs += "-DGGML_VULKAN=ON" 50 | } 51 | 52 | cmake . -B build -G "Ninja" -DCMAKE_BUILD_TYPE=Release $extraCmakeArgs 53 | cmake --build build --config Release --target c_library -j $jobs 54 | Copy-Item build/*.dll ../lib 55 | Copy-Item build/bin/*.dll ../lib -------------------------------------------------------------------------------- /bindings/bindings.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=$(uname -s) 4 | 5 | # Set number of jobs for parallel build 6 | if [ -n "$MAX_JOBS" ]; then 7 | JOBS=$MAX_JOBS 8 | elif [ "$OS" = "Darwin" ]; then 9 | JOBS=$(sysctl -n hw.physicalcpu) 10 | else 11 | JOBS=$(nproc --all) 12 | fi 13 | 14 | # Initialize as empty array 15 | EXTRA_CMAKE_ARGS=() 16 | 17 | # llama.cpp dev options 18 | if [ -n "$LLAMACPP_REPO" ]; then 19 | EXTRA_CMAKE_ARGS+=("-DLLAMACPP_REPO=$LLAMACPP_REPO") 20 | echo "Using custom llama.cpp repo: ${LLAMACPP_REPO}" 21 | fi 22 | 23 | if [ -n "$LLAMACPP_COMMIT" ]; then 24 | EXTRA_CMAKE_ARGS+=("-DLLAMACPP_COMMIT=$LLAMACPP_COMMIT") 25 | echo "Using custom llama.cpp commit: ${LLAMACPP_COMMIT}" 26 | fi 27 | 28 | if [ "$LLGUIDANCE" = "1" ]; then 29 | export RUSTC_WRAPPER="sccache" 30 | EXTRA_CMAKE_ARGS+=("-DLLGUIDANCE=ON") 31 | echo "LLGuidance enabled, including in build" 32 | fi 33 | 34 | if [ "$GGML_CUDA" = "1" ]; then 35 | EXTRA_CMAKE_ARGS+=("-DGGML_CUDA=ON") 36 | echo "CUDA enabled, including in build" 37 | 38 | if [ -n "$CMAKE_CUDA_ARCHITECTURES" ]; then 39 | EXTRA_CMAKE_ARGS+=( 40 | "-DGGML_NATIVE=OFF" "-DCMAKE_CUDA_ARCHITECTURES=$CMAKE_CUDA_ARCHITECTURES" 41 | ) 42 | fi 43 | fi 44 | 45 | if [ "$GGML_VULKAN" = "1" ]; then 46 | EXTRA_CMAKE_ARGS+=("-DGGML_VULKAN=ON") 47 | echo "Vulkan enabled, including in build" 48 | fi 49 | 50 | if [ "$GGML_HIP" = "1" ]; then 51 | EXTRA_CMAKE_ARGS+=("-DGGML_HIP=ON") 52 | echo "HIP enabled, including in build" 53 | 54 | if [ -n "$AMDGPU_TARGETS" ]; then 55 | EXTRA_CMAKE_ARGS+=( 56 | "-DAMDGPU_TARGETS=$AMDGPU_TARGETS" 57 | ) 58 | fi 59 | fi 60 | 61 | # Join array elements with spaces 62 | CMAKE_ARGS="${EXTRA_CMAKE_ARGS[*]}" 63 | 64 | cmake . -B build -G "Ninja" -DCMAKE_BUILD_TYPE=Release ${CMAKE_ARGS} 65 | cmake --build build --config Release --target c_library -j ${JOBS} 66 | 67 | if [ "$OS" = "Darwin" ]; then 68 | echo "Copying .dylib files" 69 | cp build/bin/*.dylib ../lib 70 | elif [ "$OS" = "Linux" ]; then 71 | echo "Copying .so files" 72 | cp build/bin/*.so ../lib 73 | fi -------------------------------------------------------------------------------- /bindings/generationResources.ts: -------------------------------------------------------------------------------- 1 | import { lib } from "./lib.ts"; 2 | import { ReadbackBuffer } from "./readbackBuffer.ts"; 3 | 4 | export class GenerationResources { 5 | private readbackBufferPtr: Deno.PointerValue; 6 | 7 | rawPtr: Deno.PointerValue; 8 | samplerPtr: Deno.PointerValue; 9 | readbackBuffer: ReadbackBuffer; 10 | 11 | constructor() { 12 | this.rawPtr = lib.symbols.generation_resources_make(); 13 | if (!this.rawPtr) { 14 | throw new Error("Could not allocate shared resource bundle."); 15 | } 16 | 17 | const view = new Deno.UnsafePointerView(this.rawPtr); 18 | this.readbackBufferPtr = Deno.UnsafePointer.create( 19 | view.getBigUint64(0), 20 | ); 21 | this.readbackBuffer = new ReadbackBuffer(this.readbackBufferPtr); 22 | 23 | this.samplerPtr = Deno.UnsafePointer.create(view.getBigUint64(8)); 24 | } 25 | 26 | close() { 27 | lib.symbols.generation_resources_release(this.rawPtr); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /bindings/grammar.ts: -------------------------------------------------------------------------------- 1 | import { hasLlguidance } from "@/bindings/lib.ts"; 2 | import { SamplerBuilder } from "@/bindings/samplers.ts"; 3 | import { logger } from "@/common/logging.ts"; 4 | 5 | export class YALSGrammar { 6 | private sampler: SamplerBuilder; 7 | 8 | constructor(sampler: SamplerBuilder) { 9 | this.sampler = sampler; 10 | } 11 | 12 | BNF(grammar: string) { 13 | if (hasLlguidance) { 14 | this.sampler.llguidance(grammar); 15 | } else { 16 | logger.warn( 17 | "YALS was not built with LLGuidance. Using GBNF.", 18 | ); 19 | 20 | this.sampler.grammar(grammar); 21 | } 22 | } 23 | 24 | jsonSchema(schema: Record) { 25 | if (!hasLlguidance) { 26 | logger.warn( 27 | "YALS was not built with LLGuidance. Skipping JSON schema.", 28 | ); 29 | 30 | return; 31 | } 32 | 33 | const grammarArray = ["start: json_object"]; 34 | const schemaString = JSON.stringify( 35 | schema, 36 | null, 37 | 2, 38 | ); 39 | grammarArray.push(`json_object: %json ${schemaString}`); 40 | 41 | this.sampler.llguidance(grammarArray.join("\n")); 42 | } 43 | 44 | regex(regex: string) { 45 | if (!hasLlguidance) { 46 | logger.warn( 47 | "YALS was not built with LLGuidance. Skipping Regex parsing.", 48 | ); 49 | 50 | return; 51 | } 52 | 53 | const grammarArray = ["start: text"]; 54 | grammarArray.push(`text: ${regex}`); 55 | 56 | this.sampler.llguidance(grammarArray.join("\n")); 57 | } 58 | } 59 | -------------------------------------------------------------------------------- /bindings/job.ts: -------------------------------------------------------------------------------- 1 | import { lib } from "@/bindings/lib.ts"; 2 | import { ReadbackBuffer } from "./readbackBuffer.ts"; 3 | import { GenerationChunk } from "./types.ts"; 4 | 5 | export class Job { 6 | // Private references 7 | private readbackBuffer: ReadbackBuffer; 8 | private processor: Deno.PointerValue; 9 | 10 | isComplete = false; 11 | id: number; 12 | 13 | constructor( 14 | id: number, 15 | readbackBuffer: ReadbackBuffer, 16 | processor: Deno.PointerValue, 17 | ) { 18 | this.id = id; 19 | this.readbackBuffer = readbackBuffer; 20 | this.processor = processor; 21 | } 22 | 23 | async *stream(): AsyncGenerator { 24 | for await (const { text, token } of this.readbackBuffer.read()) { 25 | yield { kind: "data", text, token }; 26 | } 27 | 28 | const status = await this.readbackBuffer.readStatus(); 29 | if (status) { 30 | yield status; 31 | } 32 | } 33 | 34 | cancel() { 35 | if (this.isComplete) { 36 | return; 37 | } 38 | 39 | this.isComplete = true; 40 | 41 | lib.symbols.processor_cancel_work( 42 | this.processor, 43 | this.id, 44 | ); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /bindings/lib.ts: -------------------------------------------------------------------------------- 1 | import libraryInterface from "./symbols.ts"; 2 | 3 | export let lib: Deno.DynamicLibrary; 4 | export let hasLlguidance: boolean = false; 5 | 6 | export function loadYalsBindings() { 7 | const libName = "c_library"; 8 | const libDir = `${Deno.cwd()}/lib/`; 9 | let libPath = libDir; 10 | 11 | switch (Deno.build.os) { 12 | case "windows": 13 | Deno.env.set("PATH", `${Deno.env.get("PATH")};${libDir}`); 14 | libPath += `${libName}.dll`; 15 | break; 16 | case "linux": 17 | libPath += `lib${libName}.so`; 18 | break; 19 | case "darwin": 20 | libPath += `lib${libName}.dylib`; 21 | break; 22 | default: 23 | throw new Error(`Unsupported operating system: ${Deno.build.os}`); 24 | } 25 | 26 | try { 27 | lib = Deno.dlopen(libPath, libraryInterface); 28 | hasLlguidance = lib.symbols.has_llguidance(); 29 | } catch (error: unknown) { 30 | console.error( 31 | `Failed to load YALS library: ${ 32 | error instanceof Error ? error.message : String(error) 33 | }`, 34 | ); 35 | throw error; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /bindings/minimal_cpp_test.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "common.h" 3 | #include "c_library.h" 4 | #include "llama.h" 5 | 6 | int main() { 7 | const auto idk = new float(0.0); 8 | const auto model = model_load( 9 | "/home/blackroot/Desktop/YALS/YALS/models/PocketDoc_Dans-PersonalityEngine-V1.2.0-24b-Q6_K_L.gguf", 10 | 999, 11 | idk, 12 | nullptr 13 | ); 14 | 15 | const auto ctx = ctx_make(model, 1024, 999, 512, false, -1, false, 0, 0, 0.0f); 16 | if (!model || !ctx) { 17 | std::cerr << "Failed to load model" << std::endl; 18 | return 1; 19 | } 20 | 21 | std::cout << "Model and context loaded successfully" << std::endl; 22 | 23 | auto sampler = sampler_make(); 24 | sampler = sampler_temp(sampler, 2); 25 | sampler = sampler_dist(sampler, 1337); 26 | 27 | const auto processor = processor_make(model, ctx, 4); 28 | 29 | const auto readback_buffer = readback_create_buffer(); 30 | 31 | const auto prompt = R"(<|im_start|>system 32 | Respond with *actions* *words* *thoughts* in a json format, with 33 | { 34 | "action" : ["first, second]", 35 | "mood" : "current mood from 20 mood choices", 36 | "magazine capacity" : "a number" 37 | } 38 | <|im_end|> 39 | <|im_start|>user 40 | Hi how are you? 41 | <|im_end|> 42 | <|im_start|>assistant 43 | )"; 44 | 45 | auto lark_grammar = R"( 46 | // Define the start rule 47 | start: json_string 48 | 49 | // The exact JSON string with fixed format 50 | json_string: "{\n \"action\" : [\"" ACTION_CONTENT "\"],\n \"mood\" : \"" EMOTION "\",\n \"magazine capacity\" : \"" CAPACITY_CONTENT "\"\n}" 51 | 52 | // Content restrictions 53 | ACTION_CONTENT: /[a-zA-Z0-9 ,]{1,15}/ 54 | CAPACITY_CONTENT: /[0-9]+( rounds| bullets| shots)?/ 55 | EMOTION: "happy" | "sad" | "angry" | "excited" | "bored" | "anxious" | "calm" | "confused" 56 | | "curious" | "depressed" | "ecstatic" | "fearful" | "grateful" | "hopeful" 57 | | "irritated" | "jealous" | "peaceful" | "proud" | "surprised" | "tired" 58 | )"; 59 | 60 | const char* seq[] = {"*"}; 61 | 62 | processor_submit_work( 63 | processor, 64 | prompt, 65 | sampler, 66 | readback_buffer, 67 | 100, 68 | 0, 69 | 1337, 70 | nullptr, 71 | 0, 72 | nullptr, 73 | 0, 74 | nullptr, 75 | 0, 76 | lark_grammar); 77 | 78 | std::cout << "Starting model:" << std::endl; 79 | while (!readback_is_buffer_finished(readback_buffer)) { 80 | char* char_out; 81 | llama_token token; 82 | if (readback_read_next(readback_buffer, &char_out, &token)) { 83 | std::cout << char_out; 84 | std::cout.flush(); 85 | } 86 | } 87 | 88 | const char* status = readback_read_status(readback_buffer); 89 | std::cout << status << std::endl; 90 | 91 | return 0; 92 | } 93 | -------------------------------------------------------------------------------- /bindings/readbackBuffer.ts: -------------------------------------------------------------------------------- 1 | import { delay } from "@std/async/delay"; 2 | 3 | import { logger } from "@/common/logging.ts"; 4 | import { lib } from "./lib.ts"; 5 | import { FinishChunk } from "@/bindings/types.ts"; 6 | 7 | /** 8 | * ReadbackBuffer provides an interface to read generated tokens and text 9 | * from the LLM generation process. 10 | */ 11 | export class ReadbackBuffer { 12 | private rawPtr: Deno.PointerValue; 13 | 14 | constructor(readbackPtr: Deno.PointerValue) { 15 | this.rawPtr = readbackPtr; 16 | } 17 | 18 | async *read() { 19 | while (!lib.symbols.readback_is_buffer_finished(this.rawPtr)) { 20 | const charBuf = new Uint8Array(8); 21 | const tokenBuf = new Int32Array(1); 22 | 23 | if ( 24 | !await lib.symbols.readback_read_next( 25 | this.rawPtr, 26 | Deno.UnsafePointer.of(charBuf), 27 | Deno.UnsafePointer.of(tokenBuf), 28 | ) 29 | ) { 30 | await delay(2); 31 | continue; 32 | } 33 | 34 | const ptrVal = new BigUint64Array(charBuf.buffer)[0]; 35 | if (ptrVal === 0n) continue; 36 | 37 | const charPtr = Deno.UnsafePointer.create(ptrVal); 38 | if (!charPtr) continue; 39 | 40 | yield { 41 | text: new Deno.UnsafePointerView(charPtr).getCString(), 42 | token: tokenBuf[0], 43 | }; 44 | } 45 | } 46 | 47 | /** 48 | * Reads the status information from the buffer 49 | * @returns A ReadbackFinish object or null if status couldn't be read 50 | */ 51 | async readStatus(): Promise { 52 | const statusPtr = await lib.symbols.readback_read_status( 53 | this.rawPtr, 54 | ); 55 | if (!statusPtr) { 56 | return null; 57 | } 58 | 59 | const view = new Deno.UnsafePointerView(statusPtr); 60 | const statusStr = view.getCString(); 61 | 62 | try { 63 | const status = JSON.parse(statusStr); 64 | return { 65 | ...status, 66 | kind: "finish", 67 | text: "", 68 | }; 69 | } catch (e) { 70 | logger.error("Failed to parse status JSON:", e); 71 | return null; 72 | } 73 | } 74 | } 75 | -------------------------------------------------------------------------------- /bindings/samplers.ts: -------------------------------------------------------------------------------- 1 | import { lib } from "./lib.ts"; 2 | import { GenerationResources } from "./generationResources.ts"; 3 | 4 | export interface LogitBias { 5 | token: number; 6 | bias: number; 7 | } 8 | 9 | export class SamplerBuilder { 10 | private sampler: Deno.PointerValue; 11 | private readonly model: Deno.PointerValue; 12 | 13 | constructor( 14 | model: Deno.PointerValue, 15 | resourceBundle: GenerationResources, 16 | ) { 17 | this.sampler = resourceBundle.samplerPtr; 18 | if (!this.sampler) { 19 | throw new Error("Failed to create sampler"); 20 | } 21 | this.model = model; 22 | } 23 | 24 | /** 25 | * Adds distribution sampling with the specified seed 26 | * @param seed Random seed for sampling 27 | * @returns This builder instance for chaining 28 | */ 29 | dist(seed: number): SamplerBuilder { 30 | this.sampler = lib.symbols.sampler_dist(this.sampler, seed); 31 | return this; 32 | } 33 | 34 | /** 35 | * Adds grammar-based sampling constraints 36 | * @param grammar Grammar definition as a string 37 | * @param root Root rule name in the grammar 38 | * @returns This builder instance for chaining 39 | */ 40 | grammar(grammar: string): SamplerBuilder { 41 | const grammarPtr = new TextEncoder().encode(grammar + "\0"); 42 | 43 | this.sampler = lib.symbols.sampler_grammar( 44 | this.sampler, 45 | this.model, 46 | grammarPtr, 47 | ); 48 | 49 | return this; 50 | } 51 | 52 | /** 53 | * Adds llguidance sampler 54 | * @param grammar Grammar definition as a string 55 | */ 56 | llguidance(grammar: string): SamplerBuilder { 57 | const grammarPtr = new TextEncoder().encode(grammar + "\0"); 58 | 59 | this.sampler = lib.symbols.sampler_llguidance( 60 | this.sampler, 61 | this.model, 62 | grammarPtr, 63 | ); 64 | 65 | return this; 66 | } 67 | 68 | /** 69 | * Configures the sampler to always choose the most likely token (greedy sampling) 70 | * @returns This builder instance for chaining 71 | */ 72 | greedy(): SamplerBuilder { 73 | this.sampler = lib.symbols.sampler_greedy(this.sampler); 74 | return this; 75 | } 76 | 77 | /** 78 | * Configures the sampler for infill generation 79 | * @returns This builder instance for chaining 80 | */ 81 | infill(): SamplerBuilder { 82 | this.sampler = lib.symbols.sampler_infill(this.sampler, this.model); 83 | return this; 84 | } 85 | 86 | /** 87 | * Applies token biases to influence generation probabilities 88 | * @param logitBias Array of token biases to apply 89 | * @returns This builder instance for chaining 90 | */ 91 | logitBias(logitBias: LogitBias[]): SamplerBuilder { 92 | const nBias = logitBias.length; 93 | 94 | const bufferSize = nBias * 8; // 4 bytes for token (int32) + 4 bytes for bias (float) 95 | const buffer = new ArrayBuffer(bufferSize); 96 | const view = new DataView(buffer); 97 | 98 | logitBias.forEach((bias, index) => { 99 | view.setInt32(index * 8, bias.token, true); 100 | view.setFloat32(index * 8 + 4, bias.bias, true); 101 | }); 102 | 103 | this.sampler = lib.symbols.sampler_logit_bias( 104 | this.sampler, 105 | this.model, 106 | nBias, 107 | Deno.UnsafePointer.of(buffer), 108 | ); 109 | 110 | return this; 111 | } 112 | 113 | /** 114 | * Configures dry run sampling with sequence breakers 115 | * @param multiplier Dry run multiplier 116 | * @param base Dry run base 117 | * @param allowedLength Maximum allowed length 118 | * @param penaltyLastN Penalty context window size 119 | * @param sequenceBreakers Array of strings that break sequences 120 | * @returns This builder instance for chaining 121 | */ 122 | dry( 123 | multiplier: number, 124 | base: number, 125 | allowedLength: number, 126 | penaltyLastN: number, 127 | sequenceBreakers: string[] = [], 128 | ): SamplerBuilder { 129 | const nullTerminatedBreakers = sequenceBreakers.map((str) => 130 | str + "\0" 131 | ); 132 | 133 | // Encode strings to Uint8Arrays 134 | const encodedBreakers = nullTerminatedBreakers.map((str) => 135 | new TextEncoder().encode(str) 136 | ); 137 | 138 | // Create pointers to encoded strings 139 | const breakerPtrs = encodedBreakers.map((encoded) => 140 | Deno.UnsafePointer.of(encoded) 141 | ); 142 | 143 | // Create an array to hold the pointers 144 | const ptrArrayBuffer = new ArrayBuffer(breakerPtrs.length * 8); 145 | const ptrArray = new BigUint64Array(ptrArrayBuffer); 146 | 147 | // Store the pointer values in the array 148 | breakerPtrs.forEach((ptr, index) => { 149 | ptrArray[index] = BigInt(Deno.UnsafePointer.value(ptr)); 150 | }); 151 | 152 | this.sampler = lib.symbols.sampler_dry( 153 | this.sampler, 154 | this.model, 155 | multiplier, 156 | base, 157 | allowedLength, 158 | penaltyLastN, 159 | Deno.UnsafePointer.of(ptrArrayBuffer), 160 | BigInt(sequenceBreakers.length), 161 | ); 162 | 163 | return this; 164 | } 165 | 166 | /** 167 | * Configures minimum-p sampling 168 | * @param minP Minimum probability threshold 169 | * @param minKeep Minimum number of tokens to keep 170 | * @returns This builder instance for chaining 171 | */ 172 | minP(minP: number, minKeep: bigint): SamplerBuilder { 173 | this.sampler = lib.symbols.sampler_min_p(this.sampler, minP, minKeep); 174 | return this; 175 | } 176 | 177 | /** 178 | * Configures mirostat sampling (adaptive temperature) 179 | * @param seed Random seed 180 | * @param tau Target entropy 181 | * @param eta Learning rate 182 | * @param m Order of the mirostat 183 | * @returns This builder instance for chaining 184 | */ 185 | mirostat( 186 | seed: number, 187 | tau: number, 188 | eta: number, 189 | m: number, 190 | ): SamplerBuilder { 191 | this.sampler = lib.symbols.sampler_mirostat( 192 | this.sampler, 193 | this.model, 194 | seed, 195 | tau, 196 | eta, 197 | m, 198 | ); 199 | return this; 200 | } 201 | 202 | /** 203 | * Configures mirostat v2 sampling (simplified adaptive temperature) 204 | * @param seed Random seed 205 | * @param tau Target entropy 206 | * @param eta Learning rate 207 | * @returns This builder instance for chaining 208 | */ 209 | mirostatV2(seed: number, tau: number, eta: number): SamplerBuilder { 210 | this.sampler = lib.symbols.sampler_mirostat_v2( 211 | this.sampler, 212 | seed, 213 | tau, 214 | eta, 215 | ); 216 | return this; 217 | } 218 | 219 | /** 220 | * Configures repetition penalties 221 | * @param penaltyLastN Number of tokens to consider for penalties 222 | * @param penaltyRepeat Repetition penalty 223 | * @param penaltyFreq Frequency penalty 224 | * @param penaltyPresent Presence penalty 225 | * @returns This builder instance for chaining 226 | */ 227 | penalties( 228 | penaltyLastN: number, 229 | penaltyRepeat: number, 230 | penaltyFreq: number, 231 | penaltyPresent: number, 232 | ): SamplerBuilder { 233 | this.sampler = lib.symbols.sampler_penalties( 234 | this.sampler, 235 | penaltyLastN, 236 | penaltyRepeat, 237 | penaltyFreq, 238 | penaltyPresent, 239 | ); 240 | return this; 241 | } 242 | 243 | /** 244 | * Sets the sampling temperature 245 | * @param temp Temperature value (higher = more random) 246 | * @returns This builder instance for chaining 247 | */ 248 | temp(temp: number): SamplerBuilder { 249 | this.sampler = lib.symbols.sampler_temp(this.sampler, temp); 250 | return this; 251 | } 252 | 253 | /** 254 | * Sets extended temperature settings 255 | * @param temp Base temperature 256 | * @param dynatempRange Dynamic temperature range 257 | * @param dynatempExponent Dynamic temperature exponent 258 | * @returns This builder instance for chaining 259 | */ 260 | tempExt( 261 | temp: number, 262 | dynatempRange: number, 263 | dynatempExponent: number, 264 | ): SamplerBuilder { 265 | this.sampler = lib.symbols.sampler_temp_ext( 266 | this.sampler, 267 | temp, 268 | dynatempRange, 269 | dynatempExponent, 270 | ); 271 | return this; 272 | } 273 | 274 | /** 275 | * Configures top-k sampling 276 | * @param k Number of most likely tokens to consider 277 | * @returns This builder instance for chaining 278 | */ 279 | topK(k: number): SamplerBuilder { 280 | this.sampler = lib.symbols.sampler_top_k(this.sampler, k); 281 | return this; 282 | } 283 | 284 | /** 285 | * Configures top-p (nucleus) sampling 286 | * @param p Cumulative probability threshold 287 | * @param minKeep Minimum number of tokens to keep 288 | * @returns This builder instance for chaining 289 | */ 290 | topP(p: number, minKeep: bigint): SamplerBuilder { 291 | this.sampler = lib.symbols.sampler_top_p(this.sampler, p, minKeep); 292 | return this; 293 | } 294 | 295 | /** 296 | * Configures typical sampling 297 | * @param typicalP Typical probability threshold 298 | * @param minKeep Minimum number of tokens to keep 299 | * @returns This builder instance for chaining 300 | */ 301 | typical(typicalP: number, minKeep: bigint): SamplerBuilder { 302 | this.sampler = lib.symbols.sampler_typical( 303 | this.sampler, 304 | typicalP, 305 | minKeep, 306 | ); 307 | return this; 308 | } 309 | 310 | /** 311 | * Configures top-n-sigma sampling 312 | * @param nSigma Number of standard deviations to consider 313 | * @returns This builder instance for chaining 314 | */ 315 | topNSigma(nSigma: number): SamplerBuilder { 316 | this.sampler = lib.symbols.sampler_top_n_sigma(this.sampler, nSigma); 317 | return this; 318 | } 319 | 320 | /** 321 | * Configures XTC (exploration time control) sampling 322 | * @param xtcProbability XTC probability 323 | * @param xtcThreshold XTC threshold 324 | * @param minKeep Minimum number of tokens to keep 325 | * @param seed Random seed 326 | * @returns This builder instance for chaining 327 | */ 328 | xtc( 329 | xtcProbability: number, 330 | xtcThreshold: number, 331 | minKeep: bigint, 332 | seed: number, 333 | ): SamplerBuilder { 334 | this.sampler = lib.symbols.sampler_xtc( 335 | this.sampler, 336 | xtcProbability, 337 | xtcThreshold, 338 | minKeep, 339 | seed, 340 | ); 341 | return this; 342 | } 343 | 344 | /** 345 | * Builds and returns the configured sampler 346 | * @returns Pointer to the configured sampler 347 | */ 348 | build(): Deno.PointerValue { 349 | return this.sampler; 350 | } 351 | } 352 | -------------------------------------------------------------------------------- /bindings/server/c_library.cpp: -------------------------------------------------------------------------------- 1 | #include "c_library.h" 2 | 3 | #include 4 | 5 | #include "processor.hpp" 6 | #include 7 | 8 | #include "log.h" 9 | 10 | // Implementation of processor interface functions 11 | int processor_submit_work( 12 | Processor* processor, 13 | const char* prompt, 14 | GenerationResources* gen_resources, 15 | const int max_tokens, 16 | const int min_tokens, 17 | const uint32_t max_slot_n_ctx, 18 | const unsigned seed, 19 | const char** rewind_strings, 20 | const unsigned num_rewind_strings, 21 | const char** stopping_strings, 22 | const unsigned num_stopping_strings, 23 | const int32_t* stopping_tokens, 24 | const unsigned num_stopping_tokens, 25 | const bool add_special) { 26 | 27 | const std::string prompt_as_string(prompt); 28 | const InferenceArgs args( 29 | gen_resources, 30 | max_tokens, 31 | min_tokens, 32 | max_slot_n_ctx, 33 | seed, 34 | rewind_strings, 35 | num_rewind_strings, 36 | stopping_strings, 37 | num_stopping_strings, 38 | stopping_tokens, 39 | num_stopping_tokens, 40 | add_special 41 | ); 42 | 43 | return processor->submit_work( 44 | prompt_as_string, 45 | args); 46 | } 47 | 48 | bool processor_cancel_work(Processor* processor, const int request_id_to_cancel) { 49 | return processor->cancel_work(request_id_to_cancel); 50 | } 51 | 52 | Processor* processor_make(llama_model* model, llama_context* ctx, const int num_processor_slots) { 53 | return new Processor(model, ctx, num_processor_slots); 54 | } 55 | 56 | void processor_free(const Processor* processor) { 57 | delete processor; 58 | } 59 | 60 | // Simplified version from common args.cpp 61 | std::vector tensor_type_split(const std::string& value, std::vector& leaked_strings) { 62 | std::vector tensor_buft_overrides; 63 | 64 | std::map buft_list; 65 | if (buft_list.empty()) { 66 | for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { 67 | auto* dev = ggml_backend_dev_get(i); 68 | if (auto* buft = ggml_backend_dev_buffer_type(dev)) { 69 | buft_list[ggml_backend_buft_name(buft)] = buft; 70 | } 71 | } 72 | } 73 | 74 | for (const auto & override : string_split(value, ',')) { 75 | const std::string::size_type pos = override.find('='); 76 | if (pos == std::string::npos) { 77 | throw std::invalid_argument("invalid value"); 78 | } 79 | std::string tensor_name = override.substr(0, pos); 80 | std::string buffer_type = override.substr(pos + 1); 81 | if (buft_list.find(buffer_type) == buft_list.end()) { 82 | printf("Available buffer types:\n"); 83 | for (const auto &[name, type] : buft_list) { 84 | printf(" %s\n", ggml_backend_buft_name(type)); 85 | } 86 | throw std::invalid_argument("Attempted to use an invalid buffer override type. Exiting. "); 87 | } 88 | 89 | leaked_strings.push_back(strdup(tensor_name.c_str())); 90 | tensor_buft_overrides.push_back({leaked_strings.back(), buft_list.at(buffer_type)}); 91 | } 92 | 93 | if (!tensor_buft_overrides.empty()) { 94 | //Yes this is some nightmare garbage where it needs a null terminator don't ask me man, this does need to be here. 95 | tensor_buft_overrides.push_back({nullptr, nullptr}); 96 | } 97 | return tensor_buft_overrides; 98 | } 99 | 100 | llama_model* model_load( 101 | const char* model_path, 102 | const int32_t num_gpu_layers, 103 | const int tensor_split_mode, 104 | const float* tensor_split, 105 | const llama_progress_callback callback, 106 | const char* tensor_type_split_regex, 107 | const bool use_mmap, 108 | const bool realtime_process_priority) 109 | { 110 | llama_model_params model_params = llama_model_default_params(); 111 | model_params.n_gpu_layers = num_gpu_layers; 112 | model_params.progress_callback = callback; 113 | 114 | model_params.split_mode = static_cast(tensor_split_mode); 115 | model_params.tensor_split = tensor_split; 116 | model_params.use_mmap = use_mmap; 117 | 118 | // Requires sudo on unix systems 119 | // Requires admin for realtime on Windows 120 | if (realtime_process_priority) { 121 | set_process_priority(GGML_SCHED_PRIO_REALTIME); 122 | } 123 | 124 | if (tensor_type_split_regex != nullptr) { 125 | std::vector leaked_c_strings; 126 | const auto overrides = tensor_type_split(std::string(tensor_type_split_regex), leaked_c_strings); 127 | 128 | if (!overrides.empty()) { 129 | model_params.tensor_buft_overrides = overrides.data(); 130 | } 131 | llama_model* model = llama_model_load_from_file(model_path, model_params); 132 | for (char* ptr : leaked_c_strings) { 133 | free(ptr); 134 | } 135 | return model; 136 | } 137 | 138 | llama_model* model = llama_model_load_from_file(model_path, model_params); 139 | return model; 140 | } 141 | 142 | float model_get_freq_base(const llama_model* model) { 143 | static auto freqBaseKey = "general.rope_freq_base"; 144 | 145 | const int32_t bufSize = llama_model_meta_val_str(model, freqBaseKey, nullptr, 0) + 1; 146 | if (bufSize <= 1) { 147 | return 10000.0f; 148 | } 149 | 150 | std::vector buffer(bufSize); 151 | const int32_t written = llama_model_meta_val_str(model, freqBaseKey, buffer.data(), bufSize); 152 | if (written <= 0) { 153 | return 10000.0f; 154 | } 155 | 156 | try { 157 | std::stringstream ss(buffer.data()); 158 | ss.imbue(std::locale::classic()); 159 | float value; 160 | ss >> value; 161 | 162 | if (ss.fail()) { 163 | return 10000.0f; 164 | } 165 | 166 | return value; 167 | } catch (...) { 168 | return 10000.0f; 169 | } 170 | } 171 | 172 | void model_free(llama_model* model) 173 | { 174 | llama_model_free(model); 175 | } 176 | 177 | llama_token model_vocab_bos(const llama_model* model) 178 | { 179 | return llama_vocab_bos(&model->vocab); 180 | } 181 | 182 | llama_token model_vocab_eos(const llama_model* model) 183 | { 184 | return llama_vocab_eos(&model->vocab); 185 | } 186 | 187 | llama_token model_vocab_eot(const llama_model* model) 188 | { 189 | return llama_vocab_eot(&model->vocab); 190 | } 191 | 192 | bool model_vocab_add_bos(const llama_model* model) 193 | { 194 | return llama_vocab_get_add_bos(&model->vocab); 195 | } 196 | 197 | const char* model_vocab_token_to_string(const llama_model* model, const llama_token token) { 198 | return llama_vocab_get_text(&model->vocab, token); 199 | } 200 | 201 | llama_context* ctx_make( 202 | llama_model* model, 203 | const unsigned context_length, 204 | const unsigned num_batches, 205 | const int32_t num_gpu_layers, 206 | const int32_t num_threads, 207 | const bool flash_attn, 208 | const float rope_freq_base, 209 | const bool use_yarn, 210 | int k_cache_quant_type, 211 | int v_cache_quant_type, 212 | const float kv_defrag_threshold 213 | ) { 214 | llama_context_params ctx_params = llama_context_default_params(); 215 | ctx_params.n_ctx = context_length; 216 | ctx_params.n_batch = num_batches; 217 | ctx_params.n_ubatch = num_batches; 218 | ctx_params.no_perf = false; 219 | ctx_params.flash_attn = flash_attn; 220 | 221 | ctx_params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; 222 | const float freqBaseTrain = model_get_freq_base(model); 223 | 224 | // Yarn, allegedly ext_factor -1 to default to model cfg, but it looks sussy. 225 | // Only set linear RoPE if freq base is greater than the trained base 226 | if (use_yarn) { 227 | ctx_params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; 228 | ctx_params.yarn_ext_factor = -1; 229 | } else if (rope_freq_base > freqBaseTrain) { 230 | ctx_params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; 231 | ctx_params.rope_freq_base = rope_freq_base; 232 | ctx_params.rope_freq_scale = 0; 233 | } 234 | 235 | // Use only one thread if model is fully offloaded on GPU 236 | if (num_gpu_layers >= llama_model_n_layer(model) || num_gpu_layers == -1) { 237 | ctx_params.n_threads = 1; 238 | ctx_params.n_threads_batch = 1; 239 | } else { 240 | ctx_params.n_threads = num_threads; 241 | ctx_params.n_threads_batch = num_threads; 242 | } 243 | 244 | ctx_params.type_k = static_cast(k_cache_quant_type); 245 | ctx_params.type_v = static_cast(v_cache_quant_type); 246 | ctx_params.defrag_thold = kv_defrag_threshold; 247 | llama_context* ctx = llama_init_from_model(model, ctx_params); 248 | 249 | return ctx; 250 | } 251 | 252 | uint32_t ctx_max_seq_len(const llama_context* ctx) 253 | { 254 | return llama_n_ctx(ctx); 255 | } 256 | 257 | void ctx_free(llama_context* ctx) 258 | { 259 | llama_free(ctx); 260 | } 261 | 262 | void ctx_clear_kv(llama_context* ctx) 263 | { 264 | llama_kv_self_clear(ctx); 265 | } 266 | 267 | int32_t* endpoint_tokenize( 268 | const llama_model* model, 269 | const char* prompt, 270 | const bool add_special, 271 | const bool parse_special) { 272 | 273 | const auto promptLength = static_cast(strlen(prompt)); 274 | const int n_prompt = -llama_tokenize(&model->vocab, prompt, promptLength, 275 | nullptr, 0, add_special, parse_special); 276 | const auto tokenArray = new int32_t[n_prompt + 1]; 277 | tokenArray[0] = n_prompt; 278 | 279 | if (llama_tokenize(&model->vocab, prompt, promptLength, 280 | tokenArray + 1, n_prompt + 1, 281 | add_special, parse_special) < 0) { 282 | return nullptr; 283 | } 284 | 285 | return tokenArray; 286 | } 287 | 288 | char* model_chat_template(const llama_model* model) { 289 | static auto tokenizerTemplateKey = "tokenizer.chat_template"; 290 | const int32_t bufSize = llama_model_meta_val_str(model, tokenizerTemplateKey, nullptr, 0) + 1; 291 | 292 | // Return null if template doesn't exist 293 | if (bufSize <= 1) { 294 | return nullptr; 295 | } 296 | 297 | const auto buffer = new char[bufSize]; 298 | llama_model_meta_val_str(model, tokenizerTemplateKey, buffer, bufSize); 299 | 300 | // Additional check to see if the buffer has data 301 | if (buffer[0] == '\0') { 302 | delete[] buffer; 303 | return nullptr; 304 | } 305 | 306 | return buffer; 307 | } 308 | 309 | char* endpoint_detokenize( 310 | const llama_model* model, 311 | const int32_t* tokens, 312 | const int32_t num_tokens, 313 | const int32_t max_text_size, 314 | const bool add_special, 315 | const bool parse_special) { 316 | const auto outText = new char[max_text_size]; 317 | llama_detokenize(&model->vocab, tokens, num_tokens, outText, max_text_size, add_special, parse_special); 318 | return outText; 319 | } 320 | 321 | void endpoint_free_string(const char* str) { 322 | delete[] str; 323 | } 324 | 325 | void endpoint_free_tokens(const int32_t* tokens) { 326 | delete[] tokens; 327 | } 328 | 329 | bool has_llguidance() { 330 | #if defined(LLGUIDANCE_BUILT) || LLGUIDANCE_BUILT != 0 331 | return true; 332 | #else 333 | return false; 334 | #endif 335 | } 336 | -------------------------------------------------------------------------------- /bindings/server/c_library.h: -------------------------------------------------------------------------------- 1 | #ifndef PROCESSOR_INTERFACE_H 2 | #define PROCESSOR_INTERFACE_H 3 | 4 | #include "llama.h" 5 | 6 | #ifdef __cplusplus 7 | 8 | extern "C" { 9 | #endif 10 | 11 | typedef struct Processor Processor; 12 | typedef struct ReadbackBuffer ReadbackBuffer; 13 | typedef struct GenerationResources GenerationResources; 14 | 15 | // ~~~ Lcpp Model ~~~ 16 | 17 | // LEAKABLE! Ensure you use model_free to clean up. 18 | llama_model* model_load( 19 | const char* model_path, 20 | int32_t num_gpu_layers, 21 | const int tensor_split_mode, 22 | const float* tensor_split, 23 | llama_progress_callback callback, 24 | const char* tensor_type_split_regex, 25 | const bool use_mmap, 26 | const bool realtime_process_priority); 27 | 28 | float model_get_freq_base( 29 | const llama_model* model); 30 | 31 | void model_free( 32 | llama_model* model); 33 | 34 | // LEAKABLE! Ensure you use endpoint_free_string to clean up. 35 | char* model_chat_template( 36 | const llama_model* model); 37 | 38 | // ~~~ Processor ~~~ 39 | 40 | int processor_submit_work( 41 | Processor* processor, 42 | const char* prompt, 43 | GenerationResources* gen_resources, 44 | const int max_tokens, 45 | const int min_tokens, 46 | const uint32_t max_slot_n_ctx, 47 | const unsigned seed, 48 | const char** rewind_strings, 49 | const unsigned num_rewind_strings, 50 | const char** stopping_strings, 51 | const unsigned num_stopping_strings, 52 | const int32_t* stopping_tokens, 53 | const unsigned num_stopping_tokens, 54 | const bool add_special); 55 | 56 | bool processor_cancel_work( 57 | Processor* processor, 58 | int request_id_to_cancel); 59 | 60 | Processor* processor_make( 61 | llama_model* model, 62 | llama_context* ctx, 63 | int num_processor_slots); 64 | 65 | void processor_free( 66 | const Processor* processor); 67 | 68 | // ~~~ Lcpp Endpoint ~~~ 69 | 70 | // LEAKABLE! Ensure you use endpoint_free_tokens to clean up. 71 | int32_t* endpoint_tokenize( 72 | const llama_model* model, 73 | const char* prompt, 74 | bool add_special, 75 | bool parse_special); 76 | 77 | // LEAKABLE! Ensure you use endpoint_free_string to clean up. 78 | char* endpoint_detokenize( 79 | const llama_model* model, 80 | const int32_t* tokens, 81 | int32_t num_tokens, 82 | int32_t max_text_size, 83 | bool add_special, 84 | bool parse_special); 85 | 86 | void endpoint_free_string( 87 | const char* str); 88 | 89 | void endpoint_free_tokens( 90 | const int32_t* tokens); 91 | 92 | // ~~~ Lcpp Vocab ~~~ 93 | 94 | llama_token model_vocab_bos( 95 | const llama_model* model); 96 | 97 | llama_token model_vocab_eos( 98 | const llama_model* model); 99 | 100 | llama_token model_vocab_eot( 101 | const llama_model* model); 102 | 103 | bool model_vocab_add_bos( 104 | const llama_model* model); 105 | 106 | // LEAKABLE! Ensure you use endpoint_free_string to clean up. 107 | const char* model_vocab_token_to_string( 108 | const llama_model* model, 109 | llama_token token); 110 | 111 | // ~~~ Lcpp Context ~~~ 112 | 113 | // LEAKABLE! Ensure you use ctx_free to clean up. 114 | llama_context* ctx_make( 115 | llama_model* model, 116 | unsigned context_length, 117 | unsigned num_batches, 118 | int32_t num_gpu_layers, 119 | int32_t num_threads, 120 | bool flash_attn, 121 | float rope_freq_base, 122 | bool use_yarn, 123 | int k_cache_quant_type, 124 | int v_cache_quant_type, 125 | float kv_defrag_threshold 126 | ); 127 | 128 | uint32_t ctx_max_seq_len( 129 | const llama_context* ctx); 130 | 131 | void ctx_free( 132 | llama_context* ctx); 133 | 134 | void ctx_clear_kv( 135 | llama_context* ctx); 136 | 137 | // ~~~ Readback Buffer ~~~ 138 | 139 | bool readback_is_buffer_finished( 140 | ReadbackBuffer* buffer); 141 | 142 | bool readback_read_next( 143 | ReadbackBuffer* buffer, 144 | char** outChar, 145 | llama_token* outToken); 146 | 147 | //TODO::@Z Validate. 148 | // Not leakable, owned by readback buffer ? 149 | char* readback_read_status( 150 | ReadbackBuffer* buffer); 151 | 152 | void readback_annihilate( 153 | ReadbackBuffer* buffer); 154 | 155 | // ~~~ Samplers ~~~ 156 | 157 | llama_sampler* sampler_dist( 158 | llama_sampler* chain, 159 | uint32_t seed); 160 | 161 | llama_sampler* sampler_greedy( 162 | llama_sampler* chain); 163 | 164 | llama_sampler* sampler_min_p( 165 | llama_sampler* chain, 166 | float min_p, 167 | size_t min_keep); 168 | 169 | llama_sampler* sampler_mirostat_v2( 170 | llama_sampler* chain, 171 | uint32_t seed, 172 | float tau, 173 | float eta); 174 | 175 | llama_sampler* sampler_penalties( 176 | llama_sampler* chain, 177 | int penalty_last_n, 178 | float penalty_repeat, 179 | float penalty_freq, 180 | float penalty_present); 181 | 182 | llama_sampler* sampler_temp( 183 | llama_sampler* chain, 184 | float temp); 185 | 186 | llama_sampler* sampler_temp_ext( 187 | llama_sampler* chain, 188 | float temp, 189 | float dynatemp_range, 190 | float dynatemp_exponent); 191 | 192 | llama_sampler* sampler_top_k( 193 | llama_sampler* chain, 194 | int top_k); 195 | 196 | llama_sampler* sampler_top_p( 197 | llama_sampler* chain, 198 | float top_p, 199 | size_t min_keep); 200 | 201 | llama_sampler* sampler_typical( 202 | llama_sampler* chain, 203 | float typical_p, 204 | size_t min_keep); 205 | 206 | llama_sampler* sampler_top_n_sigma( 207 | llama_sampler* chain, 208 | float n_sigma); 209 | 210 | llama_sampler* sampler_xtc( 211 | llama_sampler* chain, 212 | float xtc_probability, 213 | float xtc_threshold, 214 | size_t min_keep, 215 | uint32_t seed); 216 | 217 | llama_sampler* sampler_grammar( 218 | llama_sampler* chain, 219 | const llama_model* model, 220 | const char* grammar); 221 | 222 | llama_sampler* sampler_dry( 223 | llama_sampler* chain, 224 | const llama_model* model, 225 | float multiplier, 226 | float base, 227 | int32_t allowed_length, 228 | int32_t penalty_last_n, 229 | const char** sequence_breakers, 230 | size_t n_breakers); 231 | 232 | llama_sampler* sampler_infill( 233 | llama_sampler* chain, 234 | const llama_model* model); 235 | 236 | llama_sampler* sampler_logit_bias( 237 | llama_sampler* chain, 238 | const llama_model* model, 239 | int32_t n_bias, 240 | const llama_logit_bias* logit_bias); 241 | 242 | llama_sampler* sampler_mirostat( 243 | llama_sampler* chain, 244 | const llama_model* model, 245 | uint32_t seed, 246 | float tau, 247 | float eta, 248 | int m); 249 | 250 | llama_sampler* sampler_llguidance( 251 | llama_sampler* chain, 252 | const llama_model* model, 253 | const char* grammar_data); 254 | 255 | // ~~~ Generation Resources ~~~ 256 | 257 | //Leakable! Shared PTR behaviour, use release to free. 258 | GenerationResources* generation_resources_make(); 259 | 260 | void generation_resources_release( 261 | GenerationResources* resources); 262 | 263 | // ~~~ Features ~~~ 264 | 265 | bool has_llguidance(); 266 | 267 | #ifdef __cplusplus 268 | } 269 | #endif 270 | 271 | #endif // PROCESSOR_INTERFACE_H -------------------------------------------------------------------------------- /bindings/server/generation_resources.hpp: -------------------------------------------------------------------------------- 1 | #ifndef GENERATION_RESOURCES_HPP 2 | #define GENERATION_RESOURCES_HPP 3 | 4 | #include 5 | #include "readback_buffer.hpp" 6 | #include "samplers.hpp" 7 | 8 | /* 9 | * An atomic reference counted shared resource bundle for cooperative resources. 10 | */ 11 | 12 | struct GenerationResources { 13 | ReadbackBuffer* readback_buffer{nullptr}; 14 | llama_sampler* sampler{nullptr}; 15 | 16 | std::atomic ref_count{1}; 17 | }; 18 | 19 | // C API 20 | // Free with resource_bundle_release -- this is a shared ptr. 21 | GenerationResources* generation_resources_make() { 22 | const auto bundle = new GenerationResources{}; 23 | bundle->readback_buffer = new ReadbackBuffer{}; 24 | bundle->sampler = sampler_make(); 25 | return bundle; 26 | } 27 | 28 | GenerationResources* generation_resources_ref_acquire(GenerationResources* resources) { 29 | resources->ref_count.fetch_add(1, std::memory_order_relaxed); 30 | return resources; 31 | } 32 | 33 | // C API 34 | void generation_resources_release(GenerationResources* resources) { 35 | if (!resources) { 36 | return; 37 | } 38 | 39 | if ((resources->ref_count.fetch_sub(1, std::memory_order_acq_rel)) == 1) { 40 | delete resources->readback_buffer; 41 | sampler_free(resources->sampler); 42 | delete resources; 43 | } 44 | } 45 | 46 | #endif //GENERATION_RESOURCES_HPP 47 | -------------------------------------------------------------------------------- /bindings/server/inference_args.hpp: -------------------------------------------------------------------------------- 1 | #ifndef INFERENCE_ARGS_HPP 2 | #define INFERENCE_ARGS_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | /* 9 | * A lightweight data struct of inference args. 10 | * 11 | * Provides 12 | * Automatic conversion from c-style null-terminated string arrays with length and token arrays with count to vectors 13 | */ 14 | 15 | struct GenerationResources; 16 | 17 | class InferenceArgs { 18 | public: 19 | GenerationResources* gen_resources; 20 | int max_tokens_to_gen; 21 | int min_tokens_to_gen; 22 | uint32_t max_slot_n_ctx; 23 | unsigned seed; 24 | std::vector rewind_strings; 25 | std::vector stopping_strings; 26 | std::vector stopping_tokens; 27 | bool add_special; 28 | 29 | InferenceArgs(): gen_resources(nullptr), max_tokens_to_gen(0), min_tokens_to_gen(0), 30 | max_slot_n_ctx(std::numeric_limits::max()), seed(0), 31 | add_special(true) { 32 | }; 33 | 34 | explicit InferenceArgs( 35 | GenerationResources* gen_resources, 36 | const int max_tokens = 50, 37 | const int min_tokens = 10, 38 | const uint32_t max_slot_n_ctx = std::numeric_limits::max(), 39 | const unsigned seed = 1337, 40 | const char** rewind_strings = nullptr, 41 | const unsigned num_rewind_strings = 0, 42 | const char** stopping_strings = nullptr, 43 | const unsigned num_stopping_strings = 0, 44 | const int32_t* stopping_tokens = nullptr, 45 | const unsigned num_stopping_tokens = 0, 46 | const bool add_special = true) 47 | 48 | : gen_resources(gen_resources), 49 | max_tokens_to_gen(max_tokens), 50 | min_tokens_to_gen(min_tokens), 51 | seed(seed), 52 | add_special(add_special) 53 | { 54 | if (rewind_strings != nullptr && num_rewind_strings > 0) { 55 | this->rewind_strings.reserve(num_rewind_strings); 56 | for (unsigned i = 0; i < num_rewind_strings; ++i) { 57 | if (rewind_strings[i]) { 58 | this->rewind_strings.emplace_back(rewind_strings[i]); 59 | } 60 | } 61 | } 62 | 63 | if (stopping_strings != nullptr && num_stopping_strings > 0) { 64 | this->stopping_strings.reserve(num_stopping_strings); 65 | for (unsigned i = 0; i < num_stopping_strings; ++i) { 66 | if (stopping_strings[i]) { 67 | this->stopping_strings.emplace_back(stopping_strings[i]); 68 | } 69 | } 70 | } 71 | 72 | if (stopping_tokens != nullptr && num_stopping_tokens > 0) { 73 | this->stopping_tokens.assign(stopping_tokens, 74 | stopping_tokens + num_stopping_tokens); 75 | } 76 | 77 | this->max_slot_n_ctx = max_slot_n_ctx == 0 ? std::numeric_limits::max() : max_slot_n_ctx; 78 | } 79 | }; 80 | 81 | #endif //INFERENCE_ARGS_HPP 82 | -------------------------------------------------------------------------------- /bindings/server/json_status.hpp: -------------------------------------------------------------------------------- 1 | #ifndef JSON_STATUS_HPP 2 | #define JSON_STATUS_HPP 3 | 4 | #include 5 | #include 6 | #include "slot.hpp" 7 | 8 | inline std::string escape_string(const std::string& input) { 9 | std::ostringstream ss; 10 | ss << std::hex << std::setfill('0'); 11 | 12 | for (const unsigned char ch : input) { 13 | switch (ch) { 14 | case '"': ss << "\\\""; break; 15 | case '\\': ss << "\\\\"; break; 16 | case '\b': ss << "\\b"; break; 17 | case '\f': ss << "\\f"; break; 18 | case '\n': ss << "\\n"; break; 19 | case '\r': ss << "\\r"; break; 20 | case '\t': ss << "\\t"; break; 21 | default: 22 | if (ch < 0x20) { 23 | ss << "\\u" << std::setw(4) << static_cast(ch); 24 | } else { 25 | ss << ch; 26 | } 27 | } 28 | } 29 | 30 | return ss.str(); 31 | } 32 | 33 | template 34 | void add_json_value(std::ostringstream& ss, const std::string& key, const T& value, bool is_last = false) { 35 | ss << "\"" << key << "\":"; 36 | if constexpr (std::is_same_v) { 37 | ss << "\"" << escape_string(value) << "\""; 38 | } else { 39 | ss << value; 40 | } 41 | 42 | if (!is_last) { 43 | ss << ","; 44 | } 45 | } 46 | 47 | inline std::string make_empty_json_status_string(const std::string &finish_reason, 48 | const std::string &stop_token) { 49 | constexpr double prompt_sec = 0.0; 50 | constexpr double gen_sec = 0.0; 51 | constexpr double total_sec = 0.0; 52 | constexpr double prompt_tokens_per_sec = 0.0; 53 | constexpr double gen_tokens_per_sec = 0.0; 54 | constexpr int prompt_tokens = 0; 55 | constexpr int gen_tokens = 0; 56 | constexpr int slot_id = -1; 57 | constexpr int request_id = -1; 58 | constexpr int job_index = -1; 59 | 60 | std::ostringstream ss; 61 | ss << std::fixed << std::setprecision(6) << "{"; 62 | 63 | add_json_value(ss, "slotId", slot_id); 64 | add_json_value(ss, "requestId", request_id); 65 | add_json_value(ss, "jobIndex", job_index); 66 | 67 | add_json_value(ss, "promptTokens", prompt_tokens); 68 | add_json_value(ss, "genTokens", gen_tokens); 69 | 70 | add_json_value(ss, "promptSec", prompt_sec); 71 | add_json_value(ss, "genSec", gen_sec); 72 | add_json_value(ss, "totalSec", total_sec); 73 | add_json_value(ss, "genTokensPerSec", gen_tokens_per_sec); 74 | add_json_value(ss, "promptTokensPerSec", prompt_tokens_per_sec); 75 | 76 | add_json_value(ss, "finishReason", finish_reason); 77 | add_json_value(ss, "stopToken", stop_token, true); 78 | 79 | ss << "}"; 80 | 81 | return ss.str(); 82 | } 83 | 84 | inline std::string make_json_status_string(const Slot& slot, const std::string &finish_reason, 85 | const std::string &stop_token) { 86 | 87 | const double prompt_sec = (slot.prompt_end_time - slot.slot_start_time) / 1000.0; 88 | const double gen_sec = (slot.generating_end_time - slot.prompt_end_time) / 1000.0; 89 | const double total_sec = (slot.generating_end_time - slot.slot_start_time) / 1000.0; 90 | 91 | 92 | const double prompt_tokens_per_sec = prompt_sec > 0 ? 93 | static_cast(slot.prompt_tokens_processed) / prompt_sec : 0.0; 94 | const double gen_tokens_per_sec = gen_sec > 0 ? 95 | static_cast(slot.tokens_generated) / gen_sec : 0.0; 96 | 97 | std::ostringstream ss; 98 | ss << std::fixed << std::setprecision(2) << "{"; 99 | 100 | add_json_value(ss, "slotId", slot.slot_id); 101 | add_json_value(ss, "requestId", slot.request_id); 102 | add_json_value(ss, "jobIndex", slot.job_index); 103 | 104 | add_json_value(ss, "promptTokens", slot.prompt_tokens_processed); 105 | add_json_value(ss, "genTokens", slot.tokens_generated); 106 | 107 | add_json_value(ss, "promptSec", prompt_sec); 108 | add_json_value(ss, "genSec", gen_sec); 109 | add_json_value(ss, "totalSec", total_sec); 110 | add_json_value(ss, "genTokensPerSec", gen_tokens_per_sec); 111 | add_json_value(ss, "promptTokensPerSec", prompt_tokens_per_sec); 112 | 113 | add_json_value(ss, "finishReason", finish_reason); 114 | add_json_value(ss, "stopToken", stop_token, true); 115 | 116 | ss << "}"; 117 | 118 | return ss.str(); 119 | } 120 | 121 | #endif // JSON_STATUS_HPP -------------------------------------------------------------------------------- /bindings/server/presampler.hpp: -------------------------------------------------------------------------------- 1 | #ifndef PRESAMPLER_HPP 2 | #define PRESAMPLER_HPP 3 | 4 | #include 5 | #include "samplers.hpp" 6 | 7 | /* 8 | * The presampler is responsible for rewind biasing and stopping biasing. 9 | * 10 | * Provides: 11 | * Minimum token generation. (By banning the stop token) 12 | * Rewind bans: Keeps track of the rewinding ban buffer. 13 | * 14 | * Mechanism: 15 | * This is overall simply an extra sampler that is used first in the sampling chain to pre-filter banned logits. 16 | */ 17 | 18 | inline llama_sampler* build_presampler_chain( 19 | const llama_model* model, 20 | const uint32_t seed, 21 | const int32_t n_bias, 22 | const llama_logit_bias* logit_bias) { 23 | llama_sampler* sampler = sampler_make(); 24 | sampler = sampler_logit_bias(sampler, model, n_bias, logit_bias); 25 | sampler = sampler_dist(sampler, seed); 26 | 27 | return sampler; 28 | } 29 | 30 | struct Presampler { 31 | private: 32 | //Biases imposed by the rewind mechanism. 33 | std::unordered_set rewind_biases; 34 | 35 | //Biases imposed by stopping criterion. 36 | std::unordered_set eos_biases; 37 | 38 | void rebuild_presampler(const llama_model* model) { 39 | std::vector biases; 40 | for (const llama_token token : rewind_biases) { 41 | biases.push_back({token, -50000.0f}); 42 | } 43 | for (const llama_token token : eos_biases) { 44 | biases.push_back({token, -50000.0f}); 45 | } 46 | 47 | should_presample = !biases.empty(); 48 | 49 | llama_sampler_free(sampler); 50 | sampler = build_presampler_chain(model, seed, static_cast(biases.size()), biases.data()); 51 | } 52 | 53 | void add_tokens_to_bias(std::unordered_set& bias_set, 54 | const llama_model* model, 55 | const std::vector& tokens) { 56 | for (auto& token : tokens) { 57 | bias_set.insert(token); 58 | } 59 | rebuild_presampler(model); 60 | } 61 | 62 | public: 63 | llama_sampler* sampler {nullptr}; 64 | uint32_t seed = 1337; 65 | bool should_presample = false; 66 | 67 | void add_rewind_bans(const llama_model* model, const std::vector &tokens) { 68 | add_tokens_to_bias(rewind_biases, model, tokens); 69 | } 70 | 71 | void add_eos_ban(const llama_model* model, const std::vector &tokens) { 72 | add_tokens_to_bias(eos_biases, model, tokens); 73 | } 74 | 75 | void clear_rewind_bans(const llama_model* model) { 76 | if (rewind_biases.empty()) { 77 | return; 78 | } 79 | rewind_biases.clear(); 80 | rebuild_presampler(model); 81 | } 82 | 83 | void clear_eos_bans(const llama_model* model) { 84 | if (eos_biases.empty()) { 85 | return; 86 | } 87 | eos_biases.clear(); 88 | rebuild_presampler(model); 89 | } 90 | }; 91 | 92 | #endif //PRESAMPLER_HPP 93 | -------------------------------------------------------------------------------- /bindings/server/readback_buffer.hpp: -------------------------------------------------------------------------------- 1 | #ifndef READBACK_BUFFER_HPP 2 | #define READBACK_BUFFER_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | /** 10 | * Owned buffer for live token and character streaming. 11 | */ 12 | struct ReadbackBuffer { 13 | unsigned last_readback_index {0}; 14 | bool buffer_finished_write {false}; 15 | char* status_buffer = nullptr; 16 | 17 | // Owner internal char*'s. Must free all of them. (strdup) 18 | std::vector* data = new std::vector(); 19 | std::vector* ids = new std::vector(); 20 | 21 | // Two phase destruction 22 | std::mutex readback_mutex; 23 | std::atomic being_destroyed {false}; 24 | }; 25 | 26 | template 27 | bool using_readback_buffer(ReadbackBuffer* buffer, Callback&& callback) { 28 | if (!buffer || buffer->being_destroyed) 29 | return false; 30 | 31 | std::lock_guard lock(buffer->readback_mutex); 32 | if (!buffer || buffer->being_destroyed) 33 | return false; 34 | 35 | callback(); 36 | return true; 37 | } 38 | 39 | // C API 40 | bool readback_is_buffer_finished(ReadbackBuffer* buffer) { 41 | bool is_finished = true; 42 | using_readback_buffer(buffer, [&] { 43 | is_finished = buffer->buffer_finished_write && buffer->last_readback_index >= buffer->ids->size(); 44 | }); 45 | return is_finished; 46 | } 47 | 48 | // C API 49 | ReadbackBuffer* readback_create_buffer() { 50 | return new ReadbackBuffer{}; 51 | } 52 | 53 | // C API 54 | bool readback_read_next(ReadbackBuffer* buffer, char** outChar, llama_token* outToken) { 55 | bool success = false; 56 | using_readback_buffer(buffer, [&] { 57 | if (buffer->last_readback_index < buffer->ids->size() && 58 | buffer->last_readback_index < buffer->data->size()) { 59 | *outChar = buffer->data->at(buffer->last_readback_index); 60 | *outToken = buffer->ids->at(buffer->last_readback_index); 61 | buffer->last_readback_index++; 62 | success = true; 63 | } 64 | }); 65 | 66 | return success; 67 | } 68 | 69 | // C API 70 | char* readback_read_status(ReadbackBuffer* buffer) { 71 | char* status = nullptr; 72 | using_readback_buffer(buffer, [&]() { 73 | status = buffer->status_buffer; 74 | }); 75 | 76 | return status; 77 | } 78 | 79 | // C API 80 | void readback_annihilate(ReadbackBuffer* buffer) { 81 | if (!buffer) 82 | return; 83 | 84 | { 85 | std::lock_guard lock(buffer->readback_mutex); 86 | buffer->being_destroyed = true; 87 | 88 | if (buffer->data) { 89 | for (char* str : *(buffer->data)) { 90 | free(str); 91 | } 92 | delete buffer->data; 93 | } 94 | 95 | if (buffer->ids) { 96 | delete buffer->ids; 97 | } 98 | 99 | if (buffer->status_buffer) { 100 | free(buffer->status_buffer); 101 | } 102 | } 103 | delete buffer; 104 | } 105 | 106 | // Internal -- MALLOC copy -- Free all data buffers via free() 107 | void readback_write_to_buffer(ReadbackBuffer* buffer, const std::string& data, const llama_token token) { 108 | using_readback_buffer(buffer, [&]() { 109 | char* copy = strdup(data.c_str()); 110 | buffer->data->push_back(copy); 111 | buffer->ids->push_back(token); 112 | }); 113 | } 114 | 115 | // Internal -- MALLOC copy -- Free status buffer via free() 116 | void readback_finish(ReadbackBuffer* buffer, const std::string& status) { 117 | using_readback_buffer(buffer, [&]() { 118 | char* copy = strdup(status.c_str()); 119 | if (buffer->status_buffer) { 120 | free(buffer->status_buffer); 121 | } 122 | buffer->buffer_finished_write = true; 123 | buffer->status_buffer = copy; 124 | }); 125 | } 126 | 127 | #endif // READBACK_BUFFER_HPP -------------------------------------------------------------------------------- /bindings/server/request.hpp: -------------------------------------------------------------------------------- 1 | #ifndef REQUEST_HPP 2 | #define REQUEST_HPP 3 | 4 | /* 5 | * A light abstraction over a request to fill a slot. This pends in a queue until we have free slots to take 6 | * the next request. 7 | */ 8 | 9 | struct Request { 10 | int id; 11 | std::vector prompt_tokens; 12 | InferenceArgs inference_args; 13 | ReadbackBuffer* readback_buffer; 14 | }; 15 | 16 | #endif // REQUEST_HPP -------------------------------------------------------------------------------- /bindings/server/samplers.hpp: -------------------------------------------------------------------------------- 1 | #ifndef SAMPLERS_HPP 2 | #define SAMPLERS_HPP 3 | 4 | #include "llama-model.h" 5 | #include "sampling.h" 6 | #include 7 | 8 | /* 9 | * A very minimal abstraction over lcpp samplers primarily to expose to bindings. 10 | */ 11 | 12 | llama_sampler* sampler_make() { 13 | llama_sampler_chain_params params = llama_sampler_chain_default_params(); 14 | params.no_perf = false; 15 | return llama_sampler_chain_init(params); 16 | } 17 | 18 | template 19 | llama_sampler* add_sampler(llama_sampler* chain, T* sampler) { 20 | llama_sampler_chain_add(chain, sampler); 21 | return chain; 22 | } 23 | 24 | void sampler_free(llama_sampler* sampler) { 25 | llama_sampler_free(sampler); 26 | } 27 | 28 | llama_sampler* sampler_llguidance(llama_sampler* chain, const llama_model* model, const char* grammar_data) { 29 | static constexpr auto grammar_kind = "lark"; 30 | return add_sampler(chain, llama_sampler_init_llg(llama_model_get_vocab(model), grammar_kind, grammar_data)); 31 | } 32 | 33 | llama_sampler* sampler_dist(llama_sampler* chain, const uint32_t seed) { 34 | return add_sampler(chain, llama_sampler_init_dist(seed)); 35 | } 36 | 37 | llama_sampler* sampler_greedy(llama_sampler* chain) { 38 | return add_sampler(chain, llama_sampler_init_greedy()); 39 | } 40 | 41 | llama_sampler* sampler_min_p(llama_sampler* chain, const float min_p, const size_t min_keep) { 42 | return add_sampler(chain, llama_sampler_init_min_p(min_p, min_keep)); 43 | } 44 | 45 | llama_sampler* sampler_mirostat_v2(llama_sampler* chain, const uint32_t seed, const float tau, const float eta) { 46 | return add_sampler(chain, llama_sampler_init_mirostat_v2(seed, tau, eta)); 47 | } 48 | 49 | llama_sampler* sampler_penalties(llama_sampler* chain, const int penalty_last_n, const float penalty_repeat, 50 | const float penalty_freq, const float penalty_present) { 51 | return add_sampler(chain, llama_sampler_init_penalties( 52 | penalty_last_n, penalty_repeat, penalty_freq, penalty_present)); 53 | } 54 | 55 | llama_sampler* sampler_temp(llama_sampler* chain, const float temp) { 56 | return add_sampler(chain, llama_sampler_init_temp(temp)); 57 | } 58 | 59 | llama_sampler* sampler_temp_ext(llama_sampler* chain, const float temp, 60 | const float dynatemp_range, const float dynatemp_exponent) { 61 | return add_sampler(chain, llama_sampler_init_temp_ext(temp, dynatemp_range, dynatemp_exponent)); 62 | } 63 | 64 | llama_sampler* sampler_top_k(llama_sampler* chain, const int top_k) { 65 | return add_sampler(chain, llama_sampler_init_top_k(top_k)); 66 | } 67 | 68 | llama_sampler* sampler_top_p(llama_sampler* chain, const float top_p, const size_t min_keep) { 69 | return add_sampler(chain, llama_sampler_init_top_p(top_p, min_keep)); 70 | } 71 | 72 | llama_sampler* sampler_typical(llama_sampler* chain, const float typical_p, const size_t min_keep) { 73 | return add_sampler(chain, llama_sampler_init_typical(typical_p, min_keep)); 74 | } 75 | 76 | llama_sampler* sampler_top_n_sigma(llama_sampler* chain, const float n_sigma) { 77 | return add_sampler(chain, llama_sampler_init_top_n_sigma(n_sigma)); 78 | } 79 | 80 | llama_sampler* sampler_xtc(llama_sampler* chain, const float xtc_probability, const float xtc_threshold, 81 | const size_t min_keep, const uint32_t seed) { 82 | return add_sampler(chain, llama_sampler_init_xtc(xtc_probability, xtc_threshold, min_keep, seed)); 83 | } 84 | 85 | llama_sampler* sampler_grammar(llama_sampler* chain, const llama_model* model, const char* grammar) { 86 | static constexpr auto root = "root"; 87 | return add_sampler(chain, llama_sampler_init_grammar(&model->vocab, grammar, root)); 88 | } 89 | 90 | llama_sampler* sampler_dry(llama_sampler* chain, const llama_model* model, const float multiplier, 91 | const float base, const int32_t allowed_length, const int32_t penalty_last_n, 92 | const char** sequence_breakers, const size_t n_breakers) { 93 | return add_sampler(chain, llama_sampler_init_dry( 94 | &model->vocab, llama_model_n_ctx_train(model), multiplier, base, allowed_length, 95 | penalty_last_n, sequence_breakers, n_breakers)); 96 | } 97 | 98 | llama_sampler* sampler_infill(llama_sampler* chain, const llama_model* model) { 99 | return add_sampler(chain, llama_sampler_init_infill(&model->vocab)); 100 | } 101 | 102 | llama_sampler* sampler_logit_bias(llama_sampler* chain, const llama_model* model, 103 | const int32_t n_bias, const llama_logit_bias* logit_bias) { 104 | return add_sampler(chain, llama_sampler_init_logit_bias( 105 | llama_vocab_n_tokens(&model->vocab), n_bias, logit_bias)); 106 | } 107 | 108 | llama_sampler* sampler_mirostat(llama_sampler* chain, const llama_model* model, const uint32_t seed, 109 | const float tau, const float eta, const int m) { 110 | const int n_vocab = llama_vocab_n_tokens(&model->vocab); 111 | return add_sampler(chain, llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m)); 112 | } 113 | 114 | #endif // SAMPLERS_HPP -------------------------------------------------------------------------------- /bindings/server/sequence_stream.hpp: -------------------------------------------------------------------------------- 1 | #ifndef SEQUENCE_STREAM_HPP 2 | #define SEQUENCE_STREAM_HPP 3 | #include 4 | #include 5 | #include "trie.hpp" 6 | 7 | /* 8 | * The sequence stream is responsible for monitoring sequence events in the inference stream. 9 | * 10 | * Provides: 11 | * A lightweight buffer that indicates the status of the stream and how the processor should proceed. 12 | * 13 | * Mechanism 14 | * A sequence buffer and matching trie that checks for stops or rewinds, and indicates when we should buffer inputs. 15 | */ 16 | 17 | class SequenceStream { 18 | int buffered_seq_size {}; 19 | MatchTrie* match_trie = nullptr; 20 | 21 | public: 22 | std::string sequence_buffer; 23 | 24 | enum SequenceStatus { 25 | ACCEPT = 1, 26 | BUFFER = 2, 27 | STOP = 4, 28 | REWIND = 8 29 | }; 30 | 31 | // Contains the result of what was in the buffer during the status. 32 | struct SequenceContext { 33 | SequenceStatus sequence_status {}; 34 | int current_sequence_size {}; 35 | std::string current_text_piece {}; 36 | std::string current_sequence {}; 37 | std::string unmatched_sequence {}; 38 | }; 39 | 40 | SequenceStream() = default; 41 | 42 | void bind_sequences(const std::vector& stop_seq, const std::vector& rewind_seq) { 43 | // Delete nullptr is safe 44 | delete match_trie; 45 | match_trie = new MatchTrie(); 46 | match_trie->add_matchable_words(stop_seq, MatchType::STOP); 47 | match_trie->add_matchable_words(rewind_seq, MatchType::REWIND); 48 | 49 | this->sequence_buffer.clear(); 50 | } 51 | 52 | SequenceContext append(const std::string_view& next_item) { 53 | sequence_buffer += next_item; 54 | buffered_seq_size++; 55 | 56 | const auto [result, unmatched] = match_trie->check_buffer(sequence_buffer); 57 | auto status = SequenceStatus::BUFFER; 58 | switch (result) { 59 | case MatchResult::NO: 60 | status = SequenceStatus::ACCEPT; 61 | break; 62 | case MatchResult::MAYBE: 63 | status = SequenceStatus::BUFFER; 64 | break; 65 | case MatchResult::MATCHED_REWIND: 66 | status = SequenceStatus::REWIND; 67 | break; 68 | case MatchResult::MATCHED_STOP: 69 | status = SequenceStatus::STOP; 70 | break; 71 | } 72 | 73 | const auto seq_ctx = SequenceContext{ 74 | status, 75 | buffered_seq_size, 76 | std::string(next_item), 77 | std::string(sequence_buffer), 78 | std::string(unmatched)}; 79 | 80 | if (result != MatchResult::MAYBE) { 81 | buffered_seq_size = 0; 82 | sequence_buffer.clear(); 83 | } 84 | 85 | return seq_ctx; 86 | } 87 | }; 88 | 89 | #endif // SEQUENCE_STREAM_HPP -------------------------------------------------------------------------------- /bindings/server/server_basic_example.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "c_library.h" 3 | #include "json.hpp" 4 | #include "llama.h" 5 | #include "generation_resources.hpp" 6 | 7 | int main() { 8 | const auto idk = new float(0.0); 9 | const auto model = model_load( 10 | "/home/blackroot/Desktop/YALS/YALS/models/PocketDoc_Dans-PersonalityEngine-V1.2.0-24b-Q6_K_L.gguf", 11 | 999, 12 | idk, 13 | nullptr 14 | ); 15 | 16 | const auto ctx = ctx_make(model, 1024, 999, 512, false, -1, false, 0, 0, 0.0f); 17 | if (!model || !ctx) { 18 | std::cerr << "Failed to load model" << std::endl; 19 | return 1; 20 | } 21 | 22 | std::cout << "Model and context loaded successfully" << std::endl; 23 | 24 | GenerationResources* gen_resources = generation_resources_make(); 25 | auto readback_buffer = gen_resources->readback_buffer; 26 | 27 | auto sampler = gen_resources->sampler; 28 | sampler_temp(sampler, .5); 29 | sampler_dist(sampler, 1337); 30 | 31 | std::cout << "Porc s" << std::endl; 32 | 33 | Processor *processor = processor_make(model, ctx, 1); 34 | 35 | std::cout << "Porc up" << std::endl; 36 | 37 | const auto prompt = R"(<|im_start|>system 38 | Respond with *actions* *words* *thoughts* in a json format, with 39 | { 40 | "action" : ["first, second]", 41 | "mood" : "current mood from 20 mood choices", 42 | "magazine capacity" : "a number" 43 | } 44 | <|im_end|> 45 | <|im_start|>user 46 | Hi how are you? 47 | <|im_end|> 48 | <|im_start|>assistant 49 | )"; 50 | 51 | std::cout << "Inference" << std::endl; 52 | processor_submit_work( 53 | processor, 54 | prompt, 55 | gen_resources, 56 | 100, 57 | 0, 58 | 1024, 59 | 1337, 60 | nullptr, 61 | 0, 62 | nullptr, 63 | 0, 64 | nullptr, 65 | 0); 66 | 67 | std::cout << "Starting model:" << std::endl; 68 | while (!readback_is_buffer_finished(readback_buffer)) { 69 | char* char_out; 70 | llama_token token; 71 | if (readback_read_next(readback_buffer, &char_out, &token)) { 72 | std::cout << char_out; 73 | std::cout.flush(); 74 | } 75 | } 76 | 77 | const char* status = readback_read_status(readback_buffer); 78 | std::cout << status << std::endl; 79 | 80 | generation_resources_release(gen_resources); 81 | 82 | return 0; 83 | } 84 | -------------------------------------------------------------------------------- /bindings/server/slot.hpp: -------------------------------------------------------------------------------- 1 | #ifndef SLOT_HPP 2 | #define SLOT_HPP 3 | 4 | #include 5 | #include 6 | #include "llama.h" 7 | #include "tokenization.hpp" 8 | #include "sequence_stream.hpp" 9 | #include "generation_resources.hpp" 10 | #include "presampler.hpp" 11 | 12 | /* 13 | * Slots are essentially just a data container holding the current inference state for a single complete inference. 14 | * 15 | * Provides 16 | * A centralized data container for the processor to manage the inference state. 17 | */ 18 | 19 | struct Slot { 20 | enum class State { 21 | IDLE, 22 | PROMPT, 23 | GENERATING, 24 | SUSPENDED, 25 | }; 26 | 27 | struct SlotSnapshot { 28 | size_t prompt_tokens_processed{}; 29 | int tokens_generated{}; 30 | int n_past{}; 31 | int i_batch{}; 32 | llama_token last_token{}; 33 | std::string previous_seq_stream_buffer; 34 | int32_t previous_kv_pos{}; 35 | 36 | static SlotSnapshot snapshot_slot(const Slot& slot, llama_context* ctx, const bool during_prompt) { 37 | SlotSnapshot snapshot; 38 | snapshot.prompt_tokens_processed = slot.prompt_tokens_processed; 39 | snapshot.tokens_generated = slot.tokens_generated; 40 | snapshot.n_past = slot.n_past; 41 | snapshot.i_batch = slot.i_batch; 42 | snapshot.last_token = slot.last_token; 43 | snapshot.previous_seq_stream_buffer = slot.sequence_stream->sequence_buffer; 44 | 45 | // During the prompt because we do not call decode, we need a special case to update the kv pos for prompt 46 | snapshot.previous_kv_pos = during_prompt ? slot.n_past : llama_kv_self_seq_pos_max(ctx, slot.slot_id); 47 | return snapshot; 48 | } 49 | 50 | int32_t rewind_slot(Slot& slot) const { 51 | slot.prompt_tokens_processed = prompt_tokens_processed; 52 | slot.tokens_generated = tokens_generated; 53 | slot.n_past = n_past; 54 | slot.i_batch = i_batch; 55 | slot.last_token = last_token; 56 | slot.sequence_stream->sequence_buffer = previous_seq_stream_buffer; 57 | return previous_kv_pos; 58 | } 59 | }; 60 | 61 | int job_index{-1}; 62 | int request_id{-1}; 63 | int slot_id{0}; 64 | uint32_t n_ctx_max{0}; 65 | State state = State::IDLE; 66 | 67 | std::vector prompt_tokens; 68 | size_t prompt_tokens_processed{0}; 69 | int tokens_generated{0}; 70 | 71 | int n_past{0}; 72 | int i_batch{-1}; 73 | 74 | double slot_start_time{0.0}; 75 | double prompt_end_time{0.0}; 76 | double generating_end_time{0.0}; 77 | 78 | llama_token last_token{0}; 79 | std::string generated_text; 80 | 81 | TokenStreamDetokenizer* detokenizer; 82 | SequenceStream* sequence_stream; 83 | SlotSnapshot rewind_snapshot; 84 | 85 | llama_sampler* rule_chain{nullptr}; 86 | Presampler presampler; 87 | llama_sampler* sampler{nullptr}; 88 | 89 | GenerationResources* gen_resources{nullptr}; 90 | class RuleStream* rule_stream{nullptr}; 91 | 92 | explicit Slot(const llama_model* model, llama_context* ctx): presampler() { 93 | detokenizer = new TokenStreamDetokenizer(ctx); 94 | sequence_stream = new SequenceStream(); 95 | } 96 | 97 | ~Slot() { 98 | delete detokenizer; 99 | delete sequence_stream; 100 | generation_resources_release(gen_resources); 101 | } 102 | 103 | [[nodiscard]] bool is_processing() const { return state == State::PROMPT || state == State::GENERATING; } 104 | [[nodiscard]] bool is_processing_prompt() const { return state == State::PROMPT; } 105 | [[nodiscard]] bool is_generating() const { return state == State::GENERATING; } 106 | 107 | void clear() { 108 | request_id = -1; 109 | state = State::IDLE; 110 | prompt_tokens_processed = 0; 111 | tokens_generated = 0; 112 | n_past = 0; 113 | i_batch = -1; 114 | last_token = 0; 115 | slot_start_time = 0; 116 | prompt_end_time = 0.0; 117 | generating_end_time = 0.0; 118 | generated_text.clear(); 119 | detokenizer->reset(); 120 | } 121 | 122 | State previous_state{State::IDLE}; 123 | void suspend() { 124 | if (state == State::SUSPENDED) return; 125 | previous_state = state; 126 | state = State::SUSPENDED; 127 | } 128 | 129 | void resume() { 130 | if (state != State::SUSPENDED) return; 131 | state = previous_state; 132 | } 133 | 134 | void end(const int new_id, llama_context* ctx) { 135 | clear(); 136 | job_index = new_id; 137 | } 138 | }; 139 | 140 | #endif // SLOT_HPP -------------------------------------------------------------------------------- /bindings/server/tokenization.hpp: -------------------------------------------------------------------------------- 1 | #ifndef TOKENIZATION_HPP 2 | #define TOKENIZATION_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | #include "llama.h" 8 | #include "common.h" 9 | 10 | // From Llama cpp server example 11 | static size_t validate_utf8(const std::string& text) { 12 | const size_t len = text.size(); 13 | if (len == 0) return 0; 14 | 15 | for (size_t i = 1; i <= 4 && i <= len; ++i) { 16 | const unsigned char c = text[len - i]; 17 | if ((c & 0xE0) == 0xC0) { 18 | // 110xxxxx 19 | if (i < 2) return len - i; 20 | } else if ((c & 0xF0) == 0xE0) { 21 | // 1110xxxx 22 | if (i < 3) return len - i; 23 | } else if ((c & 0xF8) == 0xF0) { 24 | // 11110xxx 25 | if (i < 4) return len - i; 26 | } 27 | } 28 | 29 | return len; 30 | } 31 | 32 | class TokenStreamDetokenizer { 33 | std::string utf_buffer; 34 | llama_context* ctx; 35 | 36 | public: 37 | explicit TokenStreamDetokenizer(llama_context* ctx) 38 | : ctx(ctx) { 39 | } 40 | 41 | std::string process_token(const llama_token token, const bool parse_special) { 42 | const std::string piece = common_token_to_piece(ctx, token, parse_special); 43 | utf_buffer += piece; 44 | 45 | const size_t valid_bytes = validate_utf8(utf_buffer); 46 | 47 | if (valid_bytes == 0) { 48 | return std::string{}; 49 | } 50 | 51 | if (valid_bytes == utf_buffer.size()) { 52 | std::string result = std::move(utf_buffer); 53 | utf_buffer.clear(); 54 | return result; 55 | } 56 | 57 | std::string result = utf_buffer.substr(0, valid_bytes); 58 | utf_buffer = utf_buffer.substr(valid_bytes); 59 | return result; 60 | } 61 | 62 | std::string flush() { 63 | std::string result = std::move(utf_buffer); 64 | utf_buffer.clear(); 65 | return result; 66 | } 67 | 68 | [[nodiscard]] bool has_incomplete() const { 69 | return !utf_buffer.empty(); 70 | } 71 | 72 | void reset() { 73 | utf_buffer.clear(); 74 | } 75 | }; 76 | 77 | class Tokenizer { 78 | llama_context* ctx; 79 | const llama_vocab* vocab; 80 | 81 | public: 82 | Tokenizer(const llama_model* model, llama_context* ctx) 83 | : ctx(ctx), vocab(llama_model_get_vocab(model)) { 84 | } 85 | 86 | [[nodiscard]] bool is_end_of_generation_token(const llama_token token) const { 87 | return llama_vocab_is_eog(vocab, token); 88 | } 89 | 90 | [[nodiscard]]std::vector tokenize(const std::string_view& text, const bool add_special = true, const bool parse_special = true) const { 91 | return common_tokenize(vocab, std::string(text), add_special, parse_special); 92 | } 93 | }; 94 | 95 | #endif // TOKENIZATION_HPP -------------------------------------------------------------------------------- /bindings/server/trie.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MATCH_TRIE_HPP 2 | #define MATCH_TRIE_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | enum class MatchType { 10 | REWIND, 11 | STOP 12 | }; 13 | 14 | enum class MatchResult { 15 | NO, 16 | MAYBE, 17 | MATCHED_REWIND, 18 | MATCHED_STOP 19 | }; 20 | 21 | class TrieNode { 22 | public: 23 | std::unordered_map> children; 24 | bool is_end_of_word; 25 | MatchType match_type; 26 | 27 | TrieNode() : is_end_of_word(false), match_type() { 28 | } 29 | }; 30 | 31 | class MatchTrie { 32 | std::unique_ptr root; 33 | 34 | static char to_lower(const char c) { 35 | return static_cast(std::tolower(static_cast(c))); 36 | } 37 | 38 | public: 39 | MatchTrie() : root(std::make_unique()) {} 40 | 41 | void add_matchable_words(const std::vector& words, const MatchType type) const { 42 | for (const auto& word : words) { 43 | TrieNode* current = root.get(); 44 | 45 | for (const char c : word) { 46 | char lower_char = to_lower(c); 47 | if (current->children.find(lower_char) == current->children.end()) { 48 | current->children[lower_char] = std::make_unique(); 49 | } 50 | current = current->children[lower_char].get(); 51 | } 52 | current->is_end_of_word = true; 53 | current->match_type = type; 54 | } 55 | } 56 | 57 | struct BufferCheckResult { 58 | MatchResult result; 59 | std::string_view unmatched; 60 | }; 61 | 62 | // Does substring matches to check submatches in the buffer, which is actually needed. 63 | [[nodiscard]] BufferCheckResult check_buffer(const std::string_view& buffer) const { 64 | if (root->children.empty()) 65 | return {MatchResult::NO, buffer}; 66 | 67 | auto best_result = MatchResult::NO; 68 | std::string_view best_unmatched = buffer; 69 | size_t best_match_pos = buffer.length(); 70 | 71 | for (size_t start = 0; start < buffer.length(); ++start) { 72 | TrieNode* node = root.get(); 73 | size_t i = start; 74 | 75 | for (; i < buffer.length(); ++i) { 76 | char lower_char = to_lower(buffer[i]); 77 | 78 | auto it = node->children.find(lower_char); 79 | if (it == node->children.end()) { 80 | break; 81 | } 82 | 83 | node = it->second.get(); 84 | 85 | if (node->is_end_of_word) { 86 | if (start < best_match_pos) { 87 | best_match_pos = start; 88 | best_unmatched = buffer.substr(0, start); 89 | best_result = (node->match_type == MatchType::REWIND) ? 90 | MatchResult::MATCHED_REWIND : 91 | MatchResult::MATCHED_STOP; 92 | } 93 | } 94 | } 95 | 96 | if (i == buffer.length() && !node->children.empty() && best_result == MatchResult::NO) { 97 | best_result = MatchResult::MAYBE; 98 | } 99 | } 100 | 101 | return {best_result, best_unmatched}; 102 | } 103 | }; 104 | 105 | #endif // MATCH_TRIE_HPP -------------------------------------------------------------------------------- /bindings/types.ts: -------------------------------------------------------------------------------- 1 | // Subset for caching 2 | export enum GGMLType { 3 | f32 = 0, 4 | f16 = 1, 5 | q4_0 = 2, 6 | q4_1 = 3, 7 | // 4 and 5 were removed (Q4_2 and Q4_3) 8 | q5_0 = 6, 9 | q5_1 = 7, 10 | q8_0 = 8, 11 | } 12 | 13 | export enum GGMLTensorSplitMode { 14 | none = 0, 15 | layer = 1, 16 | row = 2, 17 | } 18 | 19 | export type GenerationChunk = StreamChunk | FinishChunk; 20 | 21 | export interface StreamChunk { 22 | kind: "data"; 23 | text: string; 24 | token: number; 25 | } 26 | 27 | export enum ReadbackFinishReason { 28 | CtxExceeded = "CtxExceeded", 29 | BatchDecode = "BatchDecode", 30 | StopToken = "StopToken", 31 | MaxNewTokens = "MaxNewTokens", 32 | StopString = "StopString", 33 | TokenEncode = "TokenEncode", 34 | Aborted = "Aborted", 35 | } 36 | 37 | export interface FinishChunk { 38 | kind: "finish"; 39 | text: string; 40 | slotId: number; 41 | requestId: number; 42 | jobIndex: number; 43 | 44 | promptTokens: number; 45 | genTokens: number; 46 | 47 | promptSec: number; 48 | genSec: number; 49 | totalSec: number; 50 | promptTokensPerSec: number; 51 | genTokensPerSec: number; 52 | 53 | finishReason: ReadbackFinishReason; 54 | stopToken: string; 55 | } 56 | -------------------------------------------------------------------------------- /bindings/utils.ts: -------------------------------------------------------------------------------- 1 | import { logger } from "@/common/logging.ts"; 2 | 3 | export function pointerArrayFromStrings(strings: string[]): { 4 | inner: BigUint64Array; 5 | // Return the buffer so it stays alive 6 | buffer: Uint8Array; 7 | } { 8 | const encoder = new TextEncoder(); 9 | 10 | // Calculate total buffer size needed including null terminators 11 | const encodedStrings = strings.map((str) => encoder.encode(str + "\0")); 12 | const totalSize = encodedStrings.reduce( 13 | (sum, encoded) => sum + encoded.length, 14 | 0, 15 | ); 16 | 17 | // Allocate single buffer for all strings 18 | const buffer = new Uint8Array(totalSize); 19 | const ptrArray = new BigUint64Array(strings.length); 20 | 21 | let offset = 0; 22 | strings.forEach((str, index) => { 23 | // Encode string with null terminator 24 | const encoded = encoder.encode(str + "\0"); 25 | buffer.set(encoded, offset); 26 | 27 | // Store pointer to current string 28 | ptrArray[index] = BigInt(Deno.UnsafePointer.value( 29 | Deno.UnsafePointer.of(buffer.subarray(offset)), 30 | )); 31 | 32 | offset += encoded.length; 33 | }); 34 | 35 | // Return both the pointer array and the buffer 36 | return { inner: ptrArray, buffer }; 37 | } 38 | 39 | export function adjustCacheSize(cacheSize: number, maxSeqLen: number) { 40 | if (cacheSize < maxSeqLen) { 41 | logger.warn( 42 | `The given cache_size (${cacheSize}) is smaller than the ` + 43 | "desired context length.\n" + 44 | "Overriding cache_size to max_seq_len. ", 45 | ); 46 | 47 | cacheSize = maxSeqLen; 48 | } 49 | 50 | const cacheRemainder = cacheSize % 256; 51 | if (cacheRemainder != 0) { 52 | const roundedCacheSize = 256 * 53 | Math.floor((cacheSize - cacheRemainder) / 256 + 1); 54 | logger.info( 55 | `Rounding cache size from ${cacheSize} to ${roundedCacheSize} ` + 56 | `tokens (multiple of 256)`, 57 | ); 58 | cacheSize = roundedCacheSize; 59 | } 60 | 61 | return cacheSize; 62 | } 63 | -------------------------------------------------------------------------------- /common/args.ts: -------------------------------------------------------------------------------- 1 | // @ts-types="npm:@types/command-line-args" 2 | import commandLineArgs from "command-line-args"; 3 | 4 | // @ts-types="npm:@types/command-line-usage"; 5 | import commandLineUsage from "command-line-usage"; 6 | import * as z from "@/common/myZod.ts"; 7 | import { ConfigSchema } from "@/common/configModels.ts"; 8 | 9 | // Replicates Python's strtobool for handling boolean values 10 | function strToBool(value: string): boolean { 11 | return z.stringbool().parse(value); 12 | } 13 | 14 | // Converts the ConfigSchema to CLI arguments 15 | function configToArgs() { 16 | const configGroups: commandLineUsage.OptionList[] = []; 17 | 18 | // Iterate and create groups from top-level arguments 19 | for (const [groupName, params] of Object.entries(ConfigSchema.shape)) { 20 | const groupOptions = createGroupOptions(groupName, params.shape); 21 | configGroups.push({ header: groupName, optionList: groupOptions }); 22 | } 23 | 24 | return configGroups; 25 | } 26 | 27 | // Creates inner arg options for argument groups 28 | function createGroupOptions(groupName: string, shape: z.ZodRawShape) { 29 | return Object.entries(shape).map(([key, value]) => { 30 | const option: commandLineUsage.OptionDefinition = { 31 | name: key.replaceAll("_", "-"), 32 | group: groupName, 33 | }; 34 | 35 | setArgType(option, value); 36 | return option; 37 | }); 38 | } 39 | 40 | // Converts a Zod schema type to a command-line-args type 41 | // Drills down recursively until primitive types are found 42 | function setArgType( 43 | option: commandLineUsage.OptionDefinition, 44 | zodType: z.core.$ZodType, 45 | ) { 46 | // Use _zod for fetching the underlying type 47 | const typeName = zodType._zod.def.type; 48 | 49 | switch (typeName) { 50 | case "string": 51 | option["type"] = String; 52 | break; 53 | case "number": 54 | option["type"] = Number; 55 | break; 56 | case "boolean": 57 | option["type"] = strToBool; 58 | break; 59 | case "optional": 60 | setArgType( 61 | option, 62 | (zodType as z.ZodOptional).unwrap(), 63 | ); 64 | break; 65 | case "nullable": 66 | setArgType( 67 | option, 68 | (zodType as z.ZodNullable).unwrap(), 69 | ); 70 | break; 71 | case "union": 72 | setArgType( 73 | option, 74 | (zodType as z.ZodUnion<[z.ZodType, ...z.ZodType[]]>).def 75 | .options[0], 76 | ); 77 | break; 78 | case "pipe": 79 | setArgType( 80 | option, 81 | (zodType as z.ZodPipe).def.in, 82 | ); 83 | break; 84 | case "array": 85 | option["multiple"] = true; 86 | setArgType(option, (zodType as z.ZodArray).element); 87 | break; 88 | } 89 | } 90 | 91 | // Parses global arguments from Deno.args 92 | export function parseArgs() { 93 | // Define option groups 94 | const helpGroup: commandLineUsage.Section = { 95 | header: "Support", 96 | optionList: [{ 97 | name: "help", 98 | type: Boolean, 99 | description: "Prints this menu", 100 | group: "support", 101 | }], 102 | }; 103 | 104 | const epilog: commandLineUsage.Section = { 105 | header: "Epilog", 106 | content: "- strtobool flags require an explicit value. " + 107 | "Example: --flash-attention true", 108 | raw: true, 109 | }; 110 | 111 | const configGroups = configToArgs(); 112 | const optionGroups = [...configGroups, helpGroup]; 113 | const usage = commandLineUsage([...optionGroups, epilog]); 114 | const cliOptions: commandLineUsage.OptionDefinition[] = optionGroups 115 | .flatMap((option) => option.optionList ?? []); 116 | 117 | // Parse the options 118 | const args = commandLineArgs(cliOptions, { argv: Deno.args }); 119 | 120 | // Replace keys with underscores for config parsing 121 | for (const groupName of Object.keys(args)) { 122 | const groupArgs = args[groupName]; 123 | 124 | if (groupArgs && typeof groupArgs === "object") { 125 | args[groupName] = Object.fromEntries( 126 | Object.entries(groupArgs).map(( 127 | [k, v], 128 | ) => [k.replaceAll("-", "_"), v]), 129 | ); 130 | } 131 | } 132 | 133 | return { args, usage }; 134 | } 135 | -------------------------------------------------------------------------------- /common/auth.ts: -------------------------------------------------------------------------------- 1 | import * as YAML from "@std/yaml"; 2 | import * as z from "@/common/myZod.ts"; 3 | import { config } from "@/common/config.ts"; 4 | import { logger } from "@/common/logging.ts"; 5 | import { generateUuidHex } from "@/common/utils.ts"; 6 | 7 | const AuthFileSchema = z.object({ 8 | api_key: z.string(), 9 | admin_key: z.string(), 10 | }); 11 | 12 | type AuthFile = z.infer; 13 | 14 | export enum AuthKeyPermission { 15 | API = "API", 16 | Admin = "Admin", 17 | } 18 | 19 | export class AuthKeys { 20 | public apiKey: string; 21 | public adminKey: string; 22 | 23 | public constructor( 24 | apiKey: string, 25 | adminKey: string, 26 | ) { 27 | this.apiKey = apiKey; 28 | this.adminKey = adminKey; 29 | } 30 | 31 | public verifyKey(testKey: string, permission: AuthKeyPermission): boolean { 32 | switch (permission) { 33 | case AuthKeyPermission.Admin: 34 | return testKey === this.adminKey; 35 | case AuthKeyPermission.API: 36 | return testKey === this.apiKey || testKey === this.adminKey; 37 | default: 38 | return false; 39 | } 40 | } 41 | } 42 | 43 | export let authKeys: AuthKeys | undefined = undefined; 44 | 45 | export async function loadAuthKeys() { 46 | const authFilePath = "api_tokens.yml"; 47 | 48 | if (config.network.disable_auth) { 49 | logger.warn( 50 | "Disabling authentication makes your instance vulnerable. \n" + 51 | "Set the `disable_auth` flag to false in config.yml " + 52 | "to share this instance with others.", 53 | ); 54 | } 55 | 56 | const fileInfo = await Deno.stat(authFilePath).catch(() => null); 57 | if (fileInfo?.isFile) { 58 | const rawKeys = await Deno.readTextFile(authFilePath); 59 | const parsedKeys = AuthFileSchema.parse(YAML.parse(rawKeys)); 60 | authKeys = new AuthKeys( 61 | parsedKeys.api_key, 62 | parsedKeys.admin_key, 63 | ); 64 | } else { 65 | const newAuthFile = AuthFileSchema.parse({ 66 | api_key: generateUuidHex(), 67 | admin_key: generateUuidHex(), 68 | }); 69 | 70 | authKeys = new AuthKeys( 71 | newAuthFile.api_key, 72 | newAuthFile.admin_key, 73 | ); 74 | 75 | await Deno.writeFile( 76 | authFilePath, 77 | new TextEncoder().encode(YAML.stringify(newAuthFile)), 78 | ); 79 | } 80 | 81 | logger.info( 82 | "\n" + 83 | `Your API key is: ${authKeys.apiKey}\n` + 84 | `Your Admin key is: ${authKeys.adminKey}\n\n` + 85 | "If these keys get compromised, make sure to delete api_tokens.yml " + 86 | "and restart the server. Have fun!", 87 | ); 88 | } 89 | 90 | export function getAuthPermission(headers: Record) { 91 | if (config.network.disable_auth) { 92 | return AuthKeyPermission.Admin; 93 | } 94 | 95 | let testKey = headers["x-admin-key"] ?? headers["x-api-key"] ?? 96 | headers["authorization"]; 97 | 98 | if (!testKey) { 99 | throw new Error("The provided authentication key is missing."); 100 | } 101 | 102 | if (testKey.toLowerCase().startsWith("bearer")) { 103 | testKey = testKey.split(" ")[1]; 104 | } 105 | 106 | if (authKeys?.verifyKey(testKey, AuthKeyPermission.Admin)) { 107 | return AuthKeyPermission.Admin.toLowerCase(); 108 | } else if (authKeys?.verifyKey(testKey, AuthKeyPermission.API)) { 109 | return AuthKeyPermission.API.toLowerCase(); 110 | } else { 111 | throw new Error("The provided authentication key is invalid."); 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /common/config.ts: -------------------------------------------------------------------------------- 1 | import * as YAML from "@std/yaml"; 2 | 3 | import { 4 | ConfigSchema, 5 | DeveloperConfig, 6 | LoggingConfig, 7 | ModelConfig, 8 | NetworkConfig, 9 | SamplingConfig, 10 | } from "./configModels.ts"; 11 | import { logger } from "./logging.ts"; 12 | import { applyLoadDefaults } from "./modelContainer.ts"; 13 | 14 | // Initialize with an empty config 15 | export let config: ConfigSchema = ConfigSchema.parse({ 16 | network: NetworkConfig.parse({}), 17 | logging: LoggingConfig.parse({}), 18 | model: ModelConfig.parse({}), 19 | sampling: SamplingConfig.parse({}), 20 | developer: DeveloperConfig.parse({}), 21 | }); 22 | 23 | export async function loadConfig(args: Record) { 24 | const configPath = "config.yml"; 25 | let parsedConfig: Record = {}; 26 | 27 | // Warn if the file doesn't exist 28 | const fileInfo = await Deno.stat(configPath).catch(() => null); 29 | if (fileInfo?.isFile) { 30 | const rawConfig = await Deno.readTextFile(configPath); 31 | parsedConfig = YAML.parse(rawConfig) as Record; 32 | } else { 33 | logger.warn("Could not find a config file. Starting anyway."); 34 | } 35 | 36 | const mergedConfig: Record = {}; 37 | 38 | // Single loop to merge default config, file config, and args 39 | for (const key of Object.keys(config) as Array) { 40 | mergedConfig[key] = { 41 | ...(parsedConfig[key] as Record || {}), 42 | ...(args[key] as Record || {}), 43 | }; 44 | } 45 | 46 | if (mergedConfig["model"]) { 47 | mergedConfig["model"] = await applyLoadDefaults(mergedConfig["model"]); 48 | } 49 | 50 | config = ConfigSchema.parse(mergedConfig); 51 | } 52 | -------------------------------------------------------------------------------- /common/configModels.ts: -------------------------------------------------------------------------------- 1 | import * as z from "@/common/myZod.ts"; 2 | import { GGMLType } from "@/bindings/types.ts"; 3 | 4 | export const NetworkConfig = z.object({ 5 | host: z.string().nullish().coalesce("127.0.0.1"), 6 | port: z.number().nullish().coalesce(5000), 7 | disable_auth: z.boolean().nullish().coalesce(false), 8 | }); 9 | 10 | export type NetworkConfig = z.infer; 11 | 12 | export const LoggingConfig = z.object({ 13 | log_prompt: z.boolean().nullish().coalesce(false), 14 | log_generation_params: z.boolean().nullish().coalesce(false), 15 | log_requests: z.boolean().nullish().coalesce(false), 16 | }); 17 | 18 | export type LoggingConfig = z.infer; 19 | 20 | export const ModelConfig = z.object({ 21 | model_dir: z.string().nullish().coalesce("models"), 22 | model_name: z.string().nullish(), 23 | use_as_default: z.array(z.string()).nullish().coalesce([]), 24 | max_seq_len: z.number().nullish() 25 | .coalesce(4096).transform((v) => v === -1 ? 0 : v), 26 | num_slots: z.number().nullish().coalesce(1), 27 | cache_size: z.number().nullish(), 28 | chunk_size: z.number().nullish().coalesce(512), 29 | num_gpu_layers: z.number().nullish().coalesce(0), 30 | gpu_split: z.array(z.number()).nullish().coalesce([]), 31 | tensor_parallel: z.boolean().nullish().coalesce(false), 32 | num_threads: z.number().nullish().coalesce(-1), 33 | prompt_template: z.string().nullish(), 34 | flash_attention: z.boolean().nullish().coalesce(true), 35 | rope_freq_base: z.number().nullish().coalesce(0), 36 | enable_yarn: z.boolean().nullish().coalesce(false), 37 | cache_mode_k: z.union([ 38 | z.string().transform((str) => 39 | GGMLType[str.toLowerCase() as keyof typeof GGMLType] 40 | ), 41 | z.number(), 42 | ]) 43 | .nullish() 44 | .coalesce(GGMLType.f16), 45 | cache_mode_v: z.union([ 46 | z.string().transform((str) => 47 | GGMLType[str.toLowerCase() as keyof typeof GGMLType] 48 | ), 49 | z.number(), 50 | ]) 51 | .nullish() 52 | .coalesce(GGMLType.f16), 53 | override_tensor: z.string().nullish(), 54 | mmap: z.boolean().nullish().coalesce(true), 55 | }); 56 | 57 | export type ModelConfig = z.infer; 58 | 59 | export const SamplingConfig = z.object({ 60 | override_preset: z.string().nullish(), 61 | }); 62 | 63 | export type SamplingConfig = z.infer; 64 | 65 | export const DeveloperConfig = z.object({ 66 | realtime_process_priority: z.boolean().nullish().coalesce(true), 67 | }); 68 | 69 | export const ConfigSchema = z.object({ 70 | network: NetworkConfig, 71 | logging: LoggingConfig, 72 | model: ModelConfig, 73 | sampling: SamplingConfig, 74 | developer: DeveloperConfig, 75 | }); 76 | 77 | export type ConfigSchema = z.infer; 78 | 79 | // Config shim for inline overrides 80 | export const InlineConfigSchema = z.object({ 81 | model: z.record(z.string(), z.unknown()), 82 | }); 83 | 84 | export type InlineConfigSchema = z.infer; 85 | -------------------------------------------------------------------------------- /common/errors.ts: -------------------------------------------------------------------------------- 1 | import { HTTPException } from "hono/http-exception"; 2 | 3 | export class CancellationError extends Error { 4 | constructor(message: string = "Operation cancelled") { 5 | super(message); 6 | this.name = "CancellationError"; 7 | } 8 | } 9 | 10 | export class ModelNotLoadedError extends HTTPException { 11 | constructor(message: string = "A model is not loaded.") { 12 | super(503, { message: message }); 13 | this.name = "ModelNotLoadedError"; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /common/logging.ts: -------------------------------------------------------------------------------- 1 | import winston from "winston"; 2 | import colors from "yoctocolors"; 3 | import { config } from "@/common/config.ts"; 4 | import { BaseSamplerRequest } from "@/common/sampling.ts"; 5 | 6 | const customFormat = winston.format.printf(({ timestamp, level, message }) => { 7 | const coloredTimestamp = colors.dim(timestamp as string); 8 | const upperLevel = level.toUpperCase(); 9 | 10 | // Set colored log level 11 | let coloredLevel = upperLevel; 12 | switch (level) { 13 | case "error": 14 | coloredLevel = colors.red(upperLevel); 15 | break; 16 | case "warn": 17 | coloredLevel = colors.yellow(upperLevel); 18 | break; 19 | case "info": 20 | coloredLevel = colors.green(upperLevel); 21 | break; 22 | case "debug": 23 | coloredLevel = colors.cyan(upperLevel); 24 | break; 25 | default: 26 | coloredLevel = colors.dim(upperLevel); 27 | } 28 | coloredLevel = colors.bold(coloredLevel); 29 | 30 | const coloredPrefix = colors.dim("YALS"); 31 | 32 | return `${coloredTimestamp} ${coloredLevel} ${coloredPrefix}: ${message}`; 33 | }); 34 | 35 | export const logger = winston.createLogger({ 36 | level: "debug", 37 | format: winston.format.combine( 38 | winston.format.splat(), 39 | winston.format.timestamp({ format: "YYYY-MM-DD HH:mm:ss.SSS" }), 40 | customFormat, 41 | ), 42 | transports: [new winston.transports.Console({ level: "info" })], 43 | }); 44 | 45 | export function logPrompt(prompt: string) { 46 | // Log prompt to console 47 | // Prompts can be very large, so make the newline log a console.log instead 48 | if (config.logging.log_prompt) { 49 | logger.info(`Prompt:`); 50 | console.log(prompt); 51 | } 52 | } 53 | 54 | export function logGenParams(requestId: string, params: BaseSamplerRequest) { 55 | if (config.logging.log_generation_params) { 56 | const samplerParams = BaseSamplerRequest.parse(params); 57 | const formattedParams = Deno.inspect(samplerParams, { 58 | depth: 2, 59 | compact: true, 60 | breakLength: Infinity, 61 | }); 62 | 63 | logger.info( 64 | `Generation Parameters (ID: ${requestId}): ${ 65 | colors.green(formattedParams) 66 | }`, 67 | ); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /common/modelContainer.ts: -------------------------------------------------------------------------------- 1 | import { Mutex } from "@core/asyncutil"; 2 | import * as Path from "@std/path"; 3 | import * as YAML from "@std/yaml"; 4 | 5 | import * as z from "./myZod.ts"; 6 | import { Model } from "@/bindings/bindings.ts"; 7 | import { config } from "./config.ts"; 8 | import { InlineConfigSchema, ModelConfig } from "./configModels.ts"; 9 | import { logger } from "./logging.ts"; 10 | 11 | export let model: Model | undefined = undefined; 12 | const loadLock = new Mutex(); 13 | 14 | export async function loadModel( 15 | params: ModelConfig, 16 | progressCallback?: (progress: number) => boolean, 17 | ) { 18 | if (loadLock.locked) { 19 | throw new Error( 20 | "Another model load operation is in progress. Please wait.", 21 | ); 22 | } 23 | 24 | using _lock = await loadLock.acquire(); 25 | 26 | if (model) { 27 | if (model?.path.name === params.model_name?.replace(".gguf", "")) { 28 | throw new Error( 29 | `Model ${params.model_name} is already loaded! Aborting.`, 30 | ); 31 | } 32 | 33 | logger.info("Unloading existing model."); 34 | await unloadModel(); 35 | } 36 | 37 | if (!params.model_name?.endsWith(".gguf")) { 38 | params.model_name = `${params.model_name}.gguf`; 39 | } 40 | 41 | model = await Model.init(params, progressCallback); 42 | } 43 | 44 | export async function unloadModel(skipQueue: boolean = false) { 45 | await model?.unload(skipQueue); 46 | model = undefined; 47 | } 48 | 49 | // Applies model load overrides. Sources are inline and model config 50 | // Agnostic due to passing of ModelLoadRequest and ModelConfig 51 | export async function applyLoadDefaults(item: unknown) { 52 | const obj = z.record(z.string(), z.unknown()).safeParse(item); 53 | 54 | // Silently return since further validation will fail 55 | if (!obj.success) { 56 | return item; 57 | } 58 | 59 | const data = { ...obj.data }; 60 | const modelOverrides: Record = {}; 61 | 62 | if (typeof data["model_name"] === "string") { 63 | const modelName = data["model_name"] as string; 64 | const modelDir = data["model_dir"] as string ?? config.model.model_dir; 65 | const inlineConfigPath = Path.join( 66 | modelDir, 67 | `${modelName.replace(".gguf", "")}.yml`, 68 | ); 69 | 70 | const fileInfo = await Deno.stat(inlineConfigPath).catch(() => null); 71 | if (fileInfo?.isFile) { 72 | const rawInlineConfig = await Deno.readTextFile(inlineConfigPath); 73 | const inlineYaml = YAML.parse(rawInlineConfig) as Record< 74 | string, 75 | unknown 76 | >; 77 | 78 | const inlineResult = InlineConfigSchema.safeParse(inlineYaml); 79 | if (inlineResult.success) { 80 | Object.assign(modelOverrides, inlineResult.data.model); 81 | } else { 82 | logger.warn( 83 | `Invalid inline config for ${modelName}: ` + 84 | inlineResult.error.message, 85 | ); 86 | } 87 | } 88 | } 89 | 90 | // Iterate through defaults 91 | for (const key of config.model.use_as_default) { 92 | if (key in config.model) { 93 | modelOverrides[key] = 94 | config.model[key as keyof typeof config.model]; 95 | } 96 | } 97 | 98 | // Apply modelOverrides first then overlay data 99 | return { ...modelOverrides, ...data }; 100 | } 101 | -------------------------------------------------------------------------------- /common/myZod.ts: -------------------------------------------------------------------------------- 1 | import * as z from "zod/v4"; 2 | 3 | // Coalesce function 4 | 5 | function coalesce>>( 6 | this: T, 7 | defaultValue: D, 8 | ) { 9 | // Needs a second assertion to satisfy arrays 10 | return this.transform((val) => 11 | val ?? (defaultValue as NonNullable>) 12 | ); 13 | } 14 | 15 | z.ZodType.prototype.coalesce = coalesce; 16 | 17 | // Sampler overrides 18 | 19 | // Store the sampler override default function to prevent circular import 20 | let samplerOverrideResolver = (_key: string): unknown | null | undefined => 21 | undefined; 22 | 23 | export function registerSamplerOverrideResolver( 24 | resolver: (key: string) => unknown | null | undefined, 25 | ) { 26 | samplerOverrideResolver = resolver; 27 | } 28 | 29 | const samplerOverride = function (this: T, key: string) { 30 | return this.transform((value, ctx) => { 31 | if (value !== undefined && value !== null) { 32 | return value; 33 | } 34 | 35 | const defaultValue = samplerOverrideResolver(key); 36 | const result = this.safeParse(defaultValue); 37 | if (result.success) { 38 | return defaultValue as z.output; 39 | } else { 40 | let expectedType = ""; 41 | 42 | const issues = result.error.issues; 43 | if (issues.length > 0 && issues[0].code === "invalid_type") { 44 | const issue = issues[0] as z.core.$ZodIssueInvalidType; 45 | expectedType = issue.expected; 46 | } 47 | 48 | ctx.addIssue({ 49 | code: "custom", 50 | message: `Sampler override for ${key} must match ` + 51 | `the input type ${expectedType}`, 52 | input: defaultValue, 53 | path: ["samplerOverride"], 54 | }); 55 | 56 | return z.NEVER; 57 | } 58 | }); 59 | }; 60 | 61 | z.ZodType.prototype.samplerOverride = samplerOverride; 62 | 63 | // Alias support 64 | interface AliasChoice { 65 | field: string; 66 | aliases: string[]; 67 | } 68 | 69 | export function aliasedObject< 70 | O extends z.ZodTypeAny, 71 | >( 72 | schema: O, 73 | aliasChoices: AliasChoice[], 74 | ) { 75 | return z.preprocess((item: unknown) => { 76 | const obj = z.record(z.string(), z.unknown()).safeParse(item); 77 | if (obj.success) { 78 | for (const choice of aliasChoices) { 79 | // If the field contains a value, skip 80 | if (obj.data[choice.field]) { 81 | continue; 82 | } 83 | 84 | // Replace with the first found alias value 85 | const foundAlias = choice.aliases.find((alias) => 86 | alias in obj.data 87 | ); 88 | if (foundAlias) { 89 | obj.data[choice.field] = obj.data[foundAlias]; 90 | } 91 | } 92 | 93 | // Reassign the object 94 | item = obj.data; 95 | } 96 | 97 | return obj.data; 98 | }, schema); 99 | } 100 | 101 | // Export all types 102 | export * from "zod/v4"; 103 | 104 | declare module "zod/v4" { 105 | interface ZodType { 106 | coalesce>>( 107 | defaultValue: D, 108 | ): ReturnType>; 109 | 110 | samplerOverride( 111 | key: string, 112 | ): ReturnType>; 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /common/networking.ts: -------------------------------------------------------------------------------- 1 | import * as z from "@/common/myZod.ts"; 2 | import { HTTPException } from "hono/http-exception"; 3 | import { ContentfulStatusCode } from "hono/utils/http-status"; 4 | import { resolver } from "hono-openapi"; 5 | 6 | import { CancellationError } from "@/common/errors.ts"; 7 | import { logger } from "@/common/logging.ts"; 8 | 9 | // Originally from Stoker adopted for hono-openapi 10 | export const jsonContent = ( 11 | schema: T, 12 | description: string, 13 | ) => { 14 | return { 15 | content: { 16 | "application/json": { 17 | schema: resolver(schema), 18 | }, 19 | }, 20 | description, 21 | }; 22 | }; 23 | 24 | // Return an HTTP exception for static request errors 25 | export function toHttpException(error: unknown, status = 422) { 26 | let message = "An unexpected error occurred"; 27 | 28 | if (error instanceof CancellationError) { 29 | status = 408; 30 | message = error.message; 31 | } else if (error instanceof Error) { 32 | message = error.message; 33 | } 34 | 35 | const statusCode = status as ContentfulStatusCode; 36 | throw new HTTPException(statusCode, { message }); 37 | } 38 | 39 | // Return an error payload for stream generators 40 | export function toGeneratorError(error: unknown) { 41 | let message = "An unexpected error occurred"; 42 | 43 | if (error instanceof CancellationError) { 44 | logger.error(error.message); 45 | message = error.message; 46 | } else if (error instanceof Error) { 47 | logger.error(error.stack || error.message); 48 | message = error.message; 49 | } 50 | 51 | return { 52 | error: { 53 | message, 54 | }, 55 | }; 56 | } 57 | -------------------------------------------------------------------------------- /common/samplerOverrides.ts: -------------------------------------------------------------------------------- 1 | import * as YAML from "@std/yaml"; 2 | import * as z from "@/common/myZod.ts"; 3 | import { logger } from "@/common/logging.ts"; 4 | import { BaseSamplerRequest } from "@/common/sampling.ts"; 5 | 6 | export const SamplerOverride = z.object({ 7 | override: z.unknown().refine((val) => val !== undefined && val !== null, { 8 | error: "Override value cannot be undefined or null", 9 | }), 10 | force: z.boolean().optional().default(false), 11 | additive: z.boolean().optional().default(false), 12 | }); 13 | 14 | // Sampler overrides 15 | export type SamplerOverride = z.infer; 16 | 17 | class SamplerOverridesContainer { 18 | selectedPreset?: string; 19 | overrides: Record = {}; 20 | forcedOverrides: Record = {}; 21 | } 22 | 23 | // No need to export this, the functions work properly 24 | const overridesContainer = new SamplerOverridesContainer(); 25 | 26 | export function overridesFromDict(newOverrides: Record) { 27 | const parsedOverrides: Record = {}; 28 | 29 | // Forced also includes additive 30 | const forcedOverrides: Record = {}; 31 | 32 | // Validate each entry as a SamplerOverride type 33 | for (const [key, value] of Object.entries(newOverrides)) { 34 | try { 35 | const parsedOverride = SamplerOverride.parse(value); 36 | parsedOverrides[key] = parsedOverride; 37 | 38 | // Add to forced object for faster lookup 39 | if (parsedOverride.force || parsedOverride.additive) { 40 | forcedOverrides[key] = parsedOverride; 41 | } 42 | } catch (error) { 43 | if (error instanceof Error) { 44 | logger.error(error.stack); 45 | logger.warn( 46 | `Skipped assignment of override with key "${key}" ` + 47 | "due to the above error.", 48 | ); 49 | } 50 | } 51 | } 52 | 53 | overridesContainer.overrides = parsedOverrides; 54 | overridesContainer.forcedOverrides = forcedOverrides; 55 | } 56 | 57 | export async function overridesFromFile(presetName: string) { 58 | const presetPath = `sampler_overrides/${presetName}.yml`; 59 | overridesContainer.selectedPreset = presetName; 60 | 61 | // Read from override preset file 62 | const fileInfo = await Deno.stat(presetPath).catch(() => null); 63 | if (fileInfo?.isFile) { 64 | const rawPreset = await Deno.readTextFile(presetPath); 65 | const presetsYaml = YAML.parse(rawPreset) as Record; 66 | 67 | overridesFromDict(presetsYaml); 68 | 69 | logger.info(`Applied sampler overrides from preset ${presetName}`); 70 | } else { 71 | throw new Error( 72 | `Sampler override file named ${presetName} was not found. ` + 73 | "Make sure it's located in the sampler_overrides folder.", 74 | ); 75 | } 76 | } 77 | 78 | export function forcedSamplerOverrides(params: BaseSamplerRequest) { 79 | const forcedKeys: string[] = []; 80 | const castParams = params as Record; 81 | 82 | for ( 83 | const [key, value] of Object.entries(overridesContainer.forcedOverrides) 84 | ) { 85 | if (value.force) { 86 | castParams[key] = value.override; 87 | forcedKeys.push(key); 88 | } else if ( 89 | value.additive && Array.isArray(value.override) && 90 | Array.isArray(castParams[key]) 91 | ) { 92 | castParams[key] = Array.from( 93 | new Set([...castParams[key], ...value.override]), 94 | ); 95 | forcedKeys.push(key); 96 | } 97 | } 98 | 99 | return { params, forcedKeys }; 100 | } 101 | 102 | export function getSamplerDefault(key: string) { 103 | const defaultValue = overridesContainer.overrides[key]?.override; 104 | 105 | if (defaultValue === undefined || defaultValue === null) { 106 | return defaultValue; 107 | } 108 | 109 | return defaultValue; 110 | } 111 | 112 | // Link resolver to Zod 113 | z.registerSamplerOverrideResolver(getSamplerDefault); 114 | -------------------------------------------------------------------------------- /common/sampling.ts: -------------------------------------------------------------------------------- 1 | import * as z from "@/common/myZod.ts"; 2 | import { forcedSamplerOverrides } from "@/common/samplerOverrides.ts"; 3 | 4 | // Sampling schemas 5 | const GenerationOptionsSchema = z.aliasedObject( 6 | z.object({ 7 | max_tokens: z.number().gte(0).nullish() 8 | .samplerOverride("max_tokens") 9 | .coalesce(0) 10 | .describe("Aliases: max_length"), 11 | min_tokens: z.number().gte(0).nullish() 12 | .samplerOverride("min_tokens") 13 | .coalesce(0) 14 | .describe("Aliases: min_length"), 15 | stop: z.union([ 16 | z.string().transform((str) => [str]), 17 | z.array(z.union([z.string(), z.number()])), 18 | ]) 19 | .nullish() 20 | .samplerOverride("stop") 21 | .coalesce([]) 22 | .describe("Aliases: stop_sequence"), 23 | add_bos_token: z.boolean().nullish() 24 | .samplerOverride("add_bos_token"), 25 | ban_eos_token: z.boolean().nullish() 26 | .samplerOverride("ban_eos_token") 27 | .coalesce(false) 28 | .describe("Aliases: ignore_eos"), 29 | seed: z.number().nullish() 30 | .samplerOverride("seed"), 31 | logit_bias: z.record(z.string(), z.number()).nullish() 32 | .samplerOverride("logit_bias") 33 | .coalesce({}), 34 | json_schema: z.record(z.string(), z.unknown()).nullish() 35 | .samplerOverride("json_schema"), 36 | regex_pattern: z.string().nullish() 37 | .samplerOverride("regex_pattern"), 38 | grammar_string: z.string().nullish() 39 | .samplerOverride("grammar_string"), 40 | banned_tokens: z.union([ 41 | z.array(z.number()), 42 | z.string() 43 | .transform((str) => 44 | str.replaceAll(" ", "") 45 | .split(",") 46 | .filter((x) => /^\d+$/.test(x)) 47 | .map((x) => parseInt(x)) 48 | ), 49 | ]) 50 | .nullish() 51 | .samplerOverride("banned_tokens") 52 | .coalesce([]) 53 | .describe("Aliases: custom_token_bans"), 54 | banned_strings: z.union([ 55 | z.string().transform((str) => [str]), 56 | z.array(z.string()), 57 | ]) 58 | .nullish() 59 | .samplerOverride("banned_strings") 60 | .coalesce([]), 61 | }), 62 | [ 63 | { field: "max_tokens", aliases: ["max_length"] }, 64 | { field: "min_tokens", aliases: ["min_length"] }, 65 | { field: "ban_eos_token", aliases: ["ignore_eos"] }, 66 | { field: "stop", aliases: ["stop_sequence"] }, 67 | { field: "banned_tokens", aliases: ["custom_token_bans"] }, 68 | ], 69 | ) 70 | .describe("Generation options"); 71 | 72 | const TemperatureSamplerSchema = z.object({ 73 | temperature: z.number().gte(0).nullish() 74 | .samplerOverride("temperature") 75 | .coalesce(1), 76 | temperature_last: z.boolean().nullish() 77 | .samplerOverride("temperature_last") 78 | .coalesce(false), 79 | }) 80 | .describe("Temperature options"); 81 | 82 | const AlphabetSamplerSchema = z.aliasedObject( 83 | z.object({ 84 | top_k: z.number().gte(-1) 85 | .transform((top_k) => top_k == -1 ? 0 : top_k) 86 | .nullish() 87 | .samplerOverride("top_k") 88 | .coalesce(0), 89 | top_p: z.number().gte(0).lte(1).nullish() 90 | .samplerOverride("top_p") 91 | .coalesce(1), 92 | min_p: z.number().gte(0).lte(1).nullish() 93 | .samplerOverride("min_p") 94 | .coalesce(0), 95 | typical: z.number().gt(0).lte(1).nullish() 96 | .samplerOverride("typical") 97 | .coalesce(1), 98 | nsigma: z.number().gte(0).nullish() 99 | .samplerOverride("nsigma") 100 | .coalesce(0), 101 | }), 102 | [{ field: "typical", aliases: ["typical_p"] }], 103 | ) 104 | .describe("Alphabet samplers"); 105 | 106 | const PenaltySamplerSchema = z.aliasedObject( 107 | z.object({ 108 | frequency_penalty: z.number().gte(0).nullish() 109 | .samplerOverride("frequency_penalty") 110 | .coalesce(0), 111 | presence_penalty: z.number().gte(0).nullish() 112 | .samplerOverride("presence_penalty") 113 | .coalesce(0), 114 | repetition_penalty: z.number().gt(0).nullish() 115 | .samplerOverride("repetition_penalty") 116 | .coalesce(1) 117 | .describe("Aliases: rep_pen"), 118 | penalty_range: z.number().nullish() 119 | .samplerOverride("penalty_range") 120 | .coalesce(-1) 121 | .describe( 122 | "Aliases: repetition_range, repetition_penalty_range, rep_pen_range", 123 | ), 124 | }), 125 | [ 126 | { field: "repetition_penalty", aliases: ["rep_pen"] }, 127 | { 128 | field: "penalty_range", 129 | aliases: [ 130 | "repetition_range", 131 | "repetition_penalty_range", 132 | "rep_pen_range", 133 | ], 134 | }, 135 | ], 136 | ) 137 | .describe("Penalty samplers"); 138 | 139 | const DrySchema = z.aliasedObject( 140 | z.object({ 141 | dry_multiplier: z.number().nullish() 142 | .samplerOverride("dry_multiplier") 143 | .coalesce(0), 144 | dry_base: z.number().nullish() 145 | .samplerOverride("dry_base") 146 | .coalesce(0), 147 | dry_allowed_length: z.number().nullish() 148 | .samplerOverride("dry_allowed_length") 149 | .coalesce(0), 150 | dry_sequence_breakers: z.union([ 151 | z.string() 152 | .transform((str) => { 153 | if (!str.startsWith("[")) { 154 | str = `[${str}]`; 155 | } 156 | 157 | // Parse can fail, so return a default value if it does 158 | try { 159 | return JSON.parse(str); 160 | } catch { 161 | return []; 162 | } 163 | }), 164 | z.array(z.string()), 165 | ]) 166 | .nullish() 167 | .samplerOverride("dry_sequence_breakers") 168 | .coalesce([]), 169 | dry_range: z.number().nullish() 170 | .samplerOverride("dry_range") 171 | .coalesce(0) 172 | .describe("Aliases: dry_penalty_last_n"), 173 | }), 174 | [{ field: "dry_range", aliases: ["dry_penalty_last_n"] }], 175 | ) 176 | .describe("DRY options"); 177 | 178 | const XtcSchema = z.object({ 179 | xtc_probability: z.number().nullish() 180 | .samplerOverride("xtc_probability") 181 | .coalesce(0), 182 | xtc_threshold: z.number().nullish() 183 | .samplerOverride("xtc_threshold") 184 | .coalesce(0.1), 185 | }) 186 | .describe("XTC options"); 187 | 188 | const DynatempSchema = z.aliasedObject( 189 | z.object({ 190 | max_temp: z.number().gte(0).nullish() 191 | .samplerOverride("max_temp") 192 | .coalesce(1) 193 | .describe("Aliases: dynatemp_high"), 194 | min_temp: z.number().gte(0).nullish() 195 | .samplerOverride("min_temp") 196 | .coalesce(1) 197 | .describe("Aliases: dynatemp_low"), 198 | temp_exponent: z.number().gte(0).nullish() 199 | .samplerOverride("temp_exponent") 200 | .coalesce(1) 201 | .describe("Aliases: dynatemp_exponent"), 202 | }), 203 | [ 204 | { field: "max_temp", aliases: ["dynatemp_high"] }, 205 | { field: "min_temp", aliases: ["dynatemp_low"] }, 206 | { field: "temp_exponent", aliases: ["dynatemp_exponent"] }, 207 | ], 208 | ) 209 | .describe("DynaTemp options"); 210 | 211 | const MirostatSchema = z.object({ 212 | mirostat_mode: z.number().nullish() 213 | .samplerOverride("mirostat_mode") 214 | .coalesce(0), 215 | mirostat_tau: z.number().nullish() 216 | .samplerOverride("mirostat_tau") 217 | .coalesce(1), 218 | mirostat_eta: z.number().nullish() 219 | .samplerOverride("mirostat_eta") 220 | .coalesce(0), 221 | }) 222 | .describe("Mirostat options"); 223 | 224 | // Define the schema 225 | const BaseSamplerRequestSchema = GenerationOptionsSchema 226 | .and(TemperatureSamplerSchema) 227 | .and(AlphabetSamplerSchema) 228 | .and(PenaltySamplerSchema) 229 | .and(DrySchema) 230 | .and(XtcSchema) 231 | .and(DynatempSchema) 232 | .and(MirostatSchema); 233 | 234 | // Define the type from the schema 235 | export type BaseSamplerRequest = z.infer; 236 | 237 | // Apply transforms and expose the type 238 | export const BaseSamplerRequest = BaseSamplerRequestSchema 239 | .transform((obj, ctx) => { 240 | const { params, forcedKeys } = forcedSamplerOverrides(obj); 241 | 242 | if (forcedKeys.length > 0) { 243 | const result = BaseSamplerRequestSchema.safeParse(params); 244 | if (!result.success) { 245 | ctx.addIssue({ 246 | code: "custom", 247 | message: 248 | `Forced sampler overrides must match the input type`, 249 | path: ["forcedSamplerOverrides"], 250 | params: { 251 | details: result.error.issues, 252 | }, 253 | }); 254 | } 255 | } 256 | 257 | return params; 258 | }); 259 | -------------------------------------------------------------------------------- /common/templating.ts: -------------------------------------------------------------------------------- 1 | // @ts-types="@/types/jinja.d.ts" 2 | import { 3 | ArrayLiteral, 4 | Environment, 5 | Identifier, 6 | Interpreter, 7 | Literal, 8 | SetStatement, 9 | Template, 10 | } from "@huggingface/jinja"; 11 | import * as z from "@/common/myZod.ts"; 12 | import * as Path from "@std/path"; 13 | 14 | // From @huggingface/jinja 15 | export function range(start: number, stop?: number, step = 1): number[] { 16 | if (stop === undefined) { 17 | stop = start; 18 | start = 0; 19 | } 20 | 21 | const result: number[] = []; 22 | for (let i = start; i < stop; i += step) { 23 | result.push(i); 24 | } 25 | return result; 26 | } 27 | 28 | const TemplateMetadataSchema = z.object({ 29 | stop_strings: z.array(z.string()).default([]), 30 | tool_start: z.string().optional(), 31 | tool_start_token: z.number().optional(), 32 | }); 33 | 34 | type TemplateMetadata = z.infer; 35 | 36 | export class PromptTemplate { 37 | name: string; 38 | rawTemplate: string; 39 | template: Template; 40 | metadata: TemplateMetadata; 41 | 42 | public constructor( 43 | name: string, 44 | rawTemplate: string, 45 | ) { 46 | this.name = name; 47 | this.rawTemplate = rawTemplate; 48 | this.template = new Template(rawTemplate); 49 | this.metadata = this.extractMetadata(this.template); 50 | } 51 | 52 | // Overrides the template's render function to expose the env 53 | public render(context: Record = {}): string { 54 | const env = new Environment(); 55 | 56 | // Environment vars 57 | env.set("false", false); 58 | env.set("true", true); 59 | 60 | // Function vars 61 | env.set("raise_exception", (args: string) => { 62 | throw new Error(args); 63 | }); 64 | env.set("range", range); 65 | 66 | // Add custom template vars 67 | for (const [key, value] of Object.entries(context)) { 68 | env.set(key, value); 69 | } 70 | 71 | // Run the template 72 | const interpreter = new Interpreter(env); 73 | const response = interpreter.run(this.template.parsed); 74 | 75 | // Value is always a string here 76 | return response.value as string; 77 | } 78 | 79 | private assignMetadataValue( 80 | metadata: TemplateMetadata, 81 | key: K, 82 | value: unknown, 83 | ) { 84 | metadata[key] = value as TemplateMetadata[K]; 85 | } 86 | 87 | private extractMetadata(template: Template) { 88 | const metadata: TemplateMetadata = TemplateMetadataSchema.parse({}); 89 | 90 | template.parsed.body.forEach((statement) => { 91 | if (statement.type === "Set") { 92 | const setStatement = statement as SetStatement; 93 | 94 | const assignee = setStatement.assignee as Identifier; 95 | const foundMetaKey = Object.keys(TemplateMetadataSchema.shape) 96 | .find( 97 | (key) => key === assignee.value, 98 | ) as keyof TemplateMetadata; 99 | 100 | if (foundMetaKey) { 101 | const fieldSchema = 102 | TemplateMetadataSchema.shape[foundMetaKey]; 103 | 104 | let result: unknown; 105 | if (setStatement.value.type === "ArrayLiteral") { 106 | const arrayValue = setStatement.value as ArrayLiteral; 107 | result = arrayValue.value.map((e) => { 108 | const literalValue = e as Literal; 109 | return literalValue.value; 110 | }); 111 | } else if (setStatement.value.type.endsWith("Literal")) { 112 | const literalValue = setStatement.value as Literal< 113 | unknown 114 | >; 115 | result = literalValue.value; 116 | } 117 | 118 | const parsedValue = fieldSchema.safeParse(result); 119 | if (parsedValue.success) { 120 | this.assignMetadataValue( 121 | metadata, 122 | foundMetaKey, 123 | parsedValue.data, 124 | ); 125 | } 126 | } 127 | } 128 | }); 129 | 130 | return metadata; 131 | } 132 | 133 | static async fromFile(templatePath: string) { 134 | const parsedPath = Path.parse(templatePath); 135 | parsedPath.ext = ".jinja"; 136 | const formattedPath = Path.format({ ...parsedPath, base: undefined }); 137 | const rawTemplate = await Deno.readTextFile(formattedPath); 138 | return new PromptTemplate(parsedPath.name, rawTemplate); 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /common/utils.ts: -------------------------------------------------------------------------------- 1 | import os from "node:os"; 2 | import { logger } from "./logging.ts"; 3 | 4 | export function defer(callback: () => void): Disposable { 5 | return { 6 | [Symbol.dispose]: () => callback(), 7 | }; 8 | } 9 | 10 | export function asyncDefer(callback: () => Promise): AsyncDisposable { 11 | return { 12 | [Symbol.asyncDispose]: async () => await callback(), 13 | }; 14 | } 15 | 16 | export async function getCommitSha() { 17 | const cmd = new Deno.Command("git", { 18 | args: ["rev-parse", "--short", "HEAD"], 19 | }); 20 | try { 21 | const { stdout } = await cmd.output(); 22 | const sha = new TextDecoder().decode(stdout).trim(); 23 | 24 | return sha; 25 | } catch (error) { 26 | console.error(`Failed to get commit SHA: ${error}`); 27 | return undefined; 28 | } 29 | } 30 | 31 | export async function getYalsVersion(root?: string) { 32 | const shaPath = root ? `${root}/gitSha.txt` : "gitSha.txt"; 33 | 34 | try { 35 | const cachedSha = await Deno.readTextFile(shaPath); 36 | return cachedSha.trim(); 37 | } catch { 38 | return await getCommitSha(); 39 | } 40 | } 41 | 42 | export function generateUuidHex() { 43 | const buffer = new Uint8Array(16); 44 | crypto.getRandomValues(buffer); 45 | 46 | // To hex string 47 | const token = Array.from(buffer) 48 | .map((b) => b.toString(16).padStart(2, "0")) 49 | .join(""); 50 | 51 | return token; 52 | } 53 | 54 | // Sets the process priority to realtime 55 | export function elevateProcessPriority() { 56 | try { 57 | os.setPriority(os.constants.priority.PRIORITY_HIGHEST); 58 | logger.warn("EXPERIMENTAL: Process priority set to Realtime."); 59 | 60 | if (Deno.build.os === "windows") { 61 | logger.warn( 62 | "If you're not running YALS as administrator," + 63 | "the priority is set to high.", 64 | ); 65 | } 66 | } catch { 67 | logger.warn( 68 | "Cannot set the process priority to realtime. " + 69 | "Restart the program with sudo permissions.", 70 | ); 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /config_sample.yml: -------------------------------------------------------------------------------- 1 | # Options for networking 2 | network: 3 | # The IP to host on (default: 127.0.0.1). 4 | # Use 0.0.0.0 to expose on all network adapters. 5 | host: 127.0.0.1 6 | 7 | # The port to host on (default: 5000). 8 | # Note: Recommended to use 5001 on MacOS because AirServer runs on port 5000 9 | port: 5000 10 | 11 | # Disable HTTP token authentication with requests. 12 | # WARNING: This will make your instance vulnerable! 13 | # Turn on this option if you are ONLY connecting from localhost. 14 | disable_auth: false 15 | 16 | # Options for logging 17 | logging: 18 | # Enable prompt logging (default: False) 19 | log_prompt: false 20 | 21 | # Enable generation parameter logging (default: False) 22 | log_generation_params: false 23 | 24 | # Enable request logging (default: False). 25 | # NOTE: Only use this for debugging! 26 | log_requests: false 27 | 28 | # Options for model overrides and loading 29 | model: 30 | # Directory to look for models (default: models). 31 | # Windows users, do NOT put this path in quotes! 32 | model_dir: models 33 | 34 | # An initial model to load. 35 | # Make sure the model is located in the model directory! 36 | # REQUIRED: This must be filled out to load a model on startup. 37 | model_name: 38 | 39 | # Names of args to use as a fallback for API load requests (default: []). 40 | # For example, if you always want cache_mode to be Q4 instead of on the inital model load, add "cache_mode" to this array. 41 | # Example: ['max_seq_len', 'num_gpu_layers']. 42 | use_as_default: [] 43 | 44 | # Max sequence length (default: Empty). 45 | # Fetched from the model's base sequence length by default. 46 | max_seq_len: 47 | 48 | # Number of slots for continuous batching (default: 1) 49 | num_slots: 1 50 | 51 | # Size (in tokens) of the KV cache (default: max_seq_len). 52 | # At maximum, should be the max_seq_len * num_slots. 53 | cache_size: 54 | 55 | # Chunk size for prompt ingestion (default: 512). 56 | # A lower value reduces VRAM usage but decreases ingestion speed. 57 | # NOTE: Effects vary depending on the model. 58 | # An ideal value is between 512 and 4096. 59 | chunk_size: 512 60 | 61 | # Number of model layers to offload on the GPU (default: 0) 62 | # Set this to 999 to offload all layers to the GPU 63 | num_gpu_layers: 0 64 | 65 | # An integer array of GBs of VRAM to split between GPUs (default: []). 66 | # Going over the max amount of GPUs will crash when loading the model 67 | gpu_split: [] 68 | 69 | # Enables row tensor split mode (default: false) 70 | # This is referenced as "tensor parallelism" in lcpp, so mark the arg as such 71 | # for clarity 72 | tensor_parallel: false 73 | 74 | # Number of CPU threads to use during processing/generation (default: -1) 75 | # NOTE: Does not apply if model is fully offloaded to GPU 76 | num_threads: -1 77 | 78 | # Prompt template to use for chat completions (default: None) 79 | prompt_template: 80 | 81 | # Enable flash attention (default: true) 82 | # Disable if problems arise with the model's architecture 83 | flash_attention: true 84 | 85 | # Rope freq base. 0 = model default (default: 0) 86 | # Adjust this value for NTK scaling 87 | rope_freq_base: 0 88 | 89 | # Enable YaRN scaling. All other parameters inherited from the model (default: 0) 90 | # Turning this on disables linear/NTK RoPE scaling 91 | enable_yarn: false 92 | 93 | # K cache quantization type (default: F16) 94 | # Possible values - f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0 95 | cache_mode_k: f16 96 | 97 | # V cache quantization type (default: F16) 98 | # Possible values - f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0 99 | cache_mode_v: f16 100 | 101 | # Override tensors to different devices (default: None) 102 | # Takes in a regex string. Recommended to set num_threads 103 | override_tensor: 104 | 105 | # Lazily load the model into virtual memory. This is fast and efficient (default: true) 106 | # Turning mmap off will take longer to load, but will reduce the risk of pageouts 107 | # WARNING: Do not adjust this parameter unless you know what you're doing! 108 | mmap: true 109 | 110 | # Options for Sampling 111 | sampling: 112 | # Select a sampler override preset (default: None). 113 | # Find this in the sampler_overrides folder. 114 | # This overrides default fallbacks for sampler values that are passed to the API. 115 | override_preset: 116 | 117 | developer: 118 | # Set process to use a higher priority. 119 | # For realtime process priority, run as administrator or sudo. 120 | # Otherwise, the priority will be set to high. 121 | realtime_process_priority: false 122 | -------------------------------------------------------------------------------- /deno.json: -------------------------------------------------------------------------------- 1 | { 2 | "tasks": { 3 | "dev": "deno run -A --watch main.ts", 4 | "start": "deno run --allow-read --allow-write=api_tokens.yml --allow-env --allow-sys --allow-net --allow-ffi --allow-run main.ts", 5 | "bindings": "cd bindings && ./bindings.sh", 6 | "generate-sha": "deno run --allow-run --allow-write=gitSha.txt --allow-env generateGitSha.ts", 7 | "compile": "deno compile --allow-read --allow-write=api_tokens.yml --allow-env --allow-sys --allow-net --allow-ffi --allow-run --include gitSha.txt main.ts", 8 | "build": "deno task generate-sha && deno task compile", 9 | "bindings-win": "cd bindings && powershell -ExecutionPolicy Bypass -File bindings.ps1", 10 | "compile-win": "deno compile --allow-read --allow-write=api_tokens.yml --allow-env --allow-sys --allow-net --allow-ffi --allow-run --include gitSha.txt --icon assets/icon.ico main.ts", 11 | "build-win": "deno task generate-sha && deno task compile-win" 12 | }, 13 | "imports": { 14 | "@/": "./", 15 | "@core/asyncutil": "jsr:@core/asyncutil@^1.2.0", 16 | "@hono/standard-validator": "npm:@hono/standard-validator@^0.1.2", 17 | "@kingbri1/standard-json": "npm:@kingbri1/standard-json@^0.2.1-pre8", 18 | "hono-openapi": "npm:@kingbri1/hono-openapi@^0.5.0-pre6", 19 | "@kingbri1/standard-openapi": "npm:@kingbri1/standard-openapi@^0.1.3-pre1", 20 | "command-line-args": "npm:command-line-args@^6.0.1", 21 | "command-line-usage": "npm:command-line-usage@^7.0.3", 22 | "hono": "npm:hono@^4.7.9", 23 | "@huggingface/jinja": "npm:@huggingface/jinja@^0.5.0", 24 | "@scalar/hono-api-reference": "npm:@scalar/hono-api-reference@0.5.172", 25 | "@std/async": "jsr:@std/async@^1.0.13", 26 | "@std/path": "jsr:@std/path@^1.0.9", 27 | "@std/yaml": "jsr:@std/yaml@^1.0.6", 28 | "json-schema-walker": "npm:json-schema-walker@^3.0.0", 29 | "winston": "npm:winston@^3.17.0", 30 | "yoctocolors": "npm:yoctocolors@^2.1.1", 31 | "zod": "npm:zod@^3.25.41", 32 | "@types/command-line-args": "npm:@types/command-line-args@^5.2.3", 33 | "@types/command-line-usage": "npm:@types/command-line-usage@^5.0.4" 34 | }, 35 | "fmt": { 36 | "indentWidth": 4, 37 | "semiColons": true, 38 | "include": ["**/*.ts"], 39 | "lineWidth": 80 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /generateGitSha.ts: -------------------------------------------------------------------------------- 1 | import { getCommitSha } from "@/common/utils.ts"; 2 | 3 | if (import.meta.main) { 4 | const sha = await getCommitSha(); 5 | 6 | if (sha) { 7 | await Deno.writeTextFile("gitSha.txt", sha); 8 | console.log(`Successfully wrote Git SHA (${sha}) to gitSha.txt.`); 9 | } else { 10 | console.log("Failed to write Git SHA due to the errors above."); 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /lib/place_libs_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theroyallab/YALS/60286959be95d577e05efdf33ba6733395d60020/lib/place_libs_here.txt -------------------------------------------------------------------------------- /main.ts: -------------------------------------------------------------------------------- 1 | import { createApi } from "@/api/server.ts"; 2 | import { loadAuthKeys } from "@/common/auth.ts"; 3 | import { parseArgs } from "@/common/args.ts"; 4 | import { config, loadConfig } from "@/common/config.ts"; 5 | import { logger } from "@/common/logging.ts"; 6 | import { loadModel } from "@/common/modelContainer.ts"; 7 | import { elevateProcessPriority, getYalsVersion } from "@/common/utils.ts"; 8 | import { overridesFromFile } from "@/common/samplerOverrides.ts"; 9 | import { loadYalsBindings } from "@/bindings/lib.ts"; 10 | 11 | if (import.meta.main) { 12 | // Use Promise resolution to avoid nested try/catch 13 | const version = await getYalsVersion(import.meta.dirname); 14 | 15 | if (version) { 16 | logger.info(`Using YALS commit ${version}`); 17 | } else { 18 | logger.info("Could not find YALS commit version. Launching anyway."); 19 | } 20 | 21 | // Load bindings 22 | loadYalsBindings(); 23 | 24 | //Parse CLI args 25 | const { args, usage } = parseArgs(); 26 | 27 | // Display help message if needed 28 | if (args.support.help) { 29 | console.log(usage); 30 | Deno.exit(); 31 | } 32 | 33 | await loadConfig(args); 34 | 35 | // console.log(config.model) 36 | // Load model if present 37 | if (config.model.model_name) { 38 | // Load model in bindings 39 | await loadModel(config.model); 40 | } 41 | 42 | // Attempt to set RT process priority 43 | if (config.developer.realtime_process_priority) { 44 | elevateProcessPriority(); 45 | } 46 | 47 | // Set sampler overrides 48 | if (config.sampling.override_preset) { 49 | await overridesFromFile(config.sampling.override_preset); 50 | } 51 | 52 | await loadAuthKeys(); 53 | createApi(); 54 | } 55 | -------------------------------------------------------------------------------- /minimal_test_setup.ts: -------------------------------------------------------------------------------- 1 | import { loadModel, model } from "./common/modelContainer.ts"; 2 | import { ModelConfig } from "./common/configModels.ts"; 3 | import { BaseSamplerRequest } from "./common/sampling.ts"; 4 | 5 | // Create the model configuration matching the Model.init expectations 6 | const modelConfig = ModelConfig.parse({ 7 | model_dir: "/home/blackroot/Desktop/YALS/bindings/gguf/", 8 | model_name: "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf", 9 | num_gpu_layers: 999, 10 | max_seq_len: undefined, 11 | }); 12 | 13 | // Load the model with the new configuration 14 | await loadModel(modelConfig); 15 | 16 | const samplerRequest = BaseSamplerRequest.parse({ 17 | temperature: 0, 18 | max_tokens: 200, 19 | }); 20 | 21 | let abort = new AbortController(); 22 | const encoder = new TextEncoder(); 23 | let buffer = ""; 24 | 25 | const prompt = 26 | '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a robot<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nRespond litrally with "Hi nice to meet you"<|eot_id|><|start_header_id|>assistant<|end_header_id|>'; 27 | for (let i = 0; i < 4; i++) { 28 | console.log(); 29 | console.log("NEXT"); 30 | console.log(); 31 | 32 | for await ( 33 | const chunk of model!.generateGen(prompt, samplerRequest, abort.signal) 34 | ) { 35 | if (chunk.kind === "data") { 36 | await Deno.stdout.write(encoder.encode(chunk.text)); 37 | await Deno.stdout.write(encoder.encode(chunk.token)); 38 | buffer += chunk.text; 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /models/place_your_models_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theroyallab/YALS/60286959be95d577e05efdf33ba6733395d60020/models/place_your_models_here.txt -------------------------------------------------------------------------------- /sampler_overrides/sample_preset.yml: -------------------------------------------------------------------------------- 1 | # Sample YAML file for override presets. 2 | # Each block corresponds to a sampler fallback override. Remove ones that you don't need. 3 | # "force" always overrides the sampler to the specified value. 4 | # For example, a top-p override of 1.5 with force = true will make every API request have a top_p value of 1.5 5 | 6 | # You can use https://www.yamllint.com/ if you want to check your YAML formatting. 7 | 8 | # TODO: Improve documentation for each field 9 | 10 | # MARK: Misc generation parameters 11 | max_tokens: 12 | override: 150 13 | force: false 14 | stop: 15 | override: [] 16 | force: false 17 | additive: false 18 | seed: 19 | override: 20 | force: false 21 | banned_strings: 22 | override: [] 23 | force: false 24 | additive: false 25 | 26 | # MARK: Temperature 27 | temperature: 28 | override: 1.0 29 | force: false 30 | temperature_last: 31 | override: false 32 | force: false 33 | min_temp: 34 | override: 1.0 35 | force: false 36 | max_temp: 37 | override: 1.0 38 | force: false 39 | temp_exponent: 40 | override: 1.0 41 | force: false 42 | 43 | # MARK: Alphabet soup 44 | top_k: 45 | override: 0 46 | force: false 47 | top_p: 48 | override: 1.0 49 | force: false 50 | top_a: 51 | override: 0.0 52 | force: false 53 | min_p: 54 | override: 0.0 55 | force: false 56 | typical: 57 | override: 1.0 58 | force: false 59 | nsigma: 60 | override: 0 61 | force: false 62 | xtc_probability: 63 | override: 0.0 64 | force: false 65 | xtc_threshold: 66 | override: 0.1 67 | force: false 68 | 69 | # MARK: Penalty settings 70 | frequency_penalty: 71 | override: 0.0 72 | force: false 73 | presence_penalty: 74 | override: 0.0 75 | force: false 76 | repetition_penalty: 77 | override: 1.0 78 | force: false 79 | penalty_range: 80 | override: -1 81 | force: false 82 | 83 | # MARK: DRY 84 | dry_multiplier: 85 | override: 0.0 86 | force: false 87 | dry_base: 88 | override: 0.0 89 | force: false 90 | dry_allowed_length: 91 | override: 0 92 | force: false 93 | dry_range: 94 | override: 0 95 | force: false 96 | dry_sequence_breakers: 97 | override: [] 98 | force: false 99 | additive: false 100 | 101 | # MARK: Token options 102 | add_bos_token: 103 | override: 104 | force: false 105 | ban_eos_token: 106 | override: false 107 | force: false 108 | logit_bias: 109 | override: 110 | force: false 111 | additive: false 112 | banned_tokens: 113 | override: [] 114 | force: false 115 | additive: false 116 | -------------------------------------------------------------------------------- /templates/alpaca.jinja: -------------------------------------------------------------------------------- 1 | {# Metadata #} 2 | {%- set stop_strings = ["### Instruction:", "### Input:", "### Response:"] -%} 3 | {# Template #} 4 | {{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} 5 | 6 | {% for message in messages %} 7 | {% if message['role'] == 'user' %} 8 | ### Instruction: 9 | {{ message['content']|trim -}} 10 | {% if not loop.last %} 11 | 12 | 13 | {% endif %} 14 | {% elif message['role'] == 'assistant' %} 15 | ### Response: 16 | {{ message['content']|trim -}} 17 | {% if not loop.last %} 18 | 19 | 20 | {% endif %} 21 | {% elif message['role'] == 'user_context' %} 22 | ### Input: 23 | {{ message['content']|trim -}} 24 | {% if not loop.last %} 25 | 26 | 27 | {% endif %} 28 | {% endif %} 29 | {% endfor %} 30 | {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} 31 | ### Response: 32 | {% endif %} 33 | -------------------------------------------------------------------------------- /templates/chatml.jinja: -------------------------------------------------------------------------------- 1 | {# Metadata #} 2 | {%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%} 3 | 4 | {# Template #} 5 | {% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} 6 | {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %} 7 | -------------------------------------------------------------------------------- /templates/place_your_templates_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theroyallab/YALS/60286959be95d577e05efdf33ba6733395d60020/templates/place_your_templates_here.txt -------------------------------------------------------------------------------- /types/jinja.d.ts: -------------------------------------------------------------------------------- 1 | export * from "@huggingface/jinja"; 2 | 3 | declare module "@huggingface/jinja" { 4 | export class Statement { 5 | type: string; 6 | } 7 | 8 | export class Expression extends Statement { 9 | type: string; 10 | } 11 | 12 | abstract class Literal extends Expression { 13 | value: T; 14 | type: string; 15 | constructor(value: T); 16 | } 17 | 18 | export class Identifier extends Expression { 19 | value: string; 20 | type: string; 21 | 22 | /** 23 | * @param {string} value The name of the identifier 24 | */ 25 | constructor(value: string); 26 | } 27 | 28 | export class NumericLiteral extends Literal { 29 | type: string; 30 | } 31 | 32 | export class StringLiteral extends Literal { 33 | type: string; 34 | } 35 | 36 | export class BooleanLiteral extends Literal { 37 | type: string; 38 | } 39 | 40 | export class NullLiteral extends Literal { 41 | type: string; 42 | } 43 | 44 | export class ArrayLiteral extends Literal { 45 | type: string; 46 | } 47 | 48 | export class TupleLiteral extends Literal { 49 | type: string; 50 | } 51 | 52 | export class ObjectLiteral extends Literal> { 53 | type: string; 54 | } 55 | 56 | export class SetStatement extends Statement { 57 | assignee: Expression; 58 | value: Expression; 59 | type: string; 60 | constructor(assignee: Expression, value: Expression); 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /types/utils.ts: -------------------------------------------------------------------------------- 1 | export type MaybePromise = (() => T) | (() => Promise); 2 | --------------------------------------------------------------------------------