├── .coderabbit.yaml ├── .github ├── ISSUE_TEMPLATE │ └── bug-report.md └── workflows │ ├── closeinactive.yml │ ├── containers.yml │ ├── go-ci-windows.yml │ ├── go-ci.yml │ └── release.yml ├── .gitignore ├── .goreleaser.yaml ├── LICENSE.md ├── Makefile ├── README.md ├── config.example.yaml ├── docker ├── build-container.sh ├── config.example.yaml └── llama-swap.Containerfile ├── examples ├── README.md ├── aider-qwq-coder │ ├── README.md │ ├── aider.model.settings.dualgpu.yml │ ├── aider.model.settings.yml │ └── llama-swap.yaml ├── benchmark-snakegame │ ├── README.md │ └── run-benchmark.sh ├── restart-on-config-change │ └── README.md └── speculative-decoding │ └── README.md ├── go.mod ├── go.sum ├── header.jpeg ├── header2.png ├── llama-swap.go ├── misc ├── assets │ └── favicon-raw.png ├── simple-responder │ └── simple-responder.go └── test-rerank │ ├── README.md │ └── reranker-test.json ├── models ├── .gitignore └── README.md ├── proxy ├── config.go ├── config_posix_test.go ├── config_test.go ├── config_windows_test.go ├── helpers_test.go ├── html │ ├── favicon.ico │ ├── index.html │ └── logs.html ├── html_files.go ├── logMonitor.go ├── logMonitor_test.go ├── process.go ├── process_test.go ├── processgroup.go ├── processgroup_test.go ├── proxymanager.go ├── proxymanager_loghandlers.go ├── proxymanager_test.go ├── sanitize_cors.go └── sanitize_cors_test.go └── scripts ├── install.sh └── uninstall.sh /.coderabbit.yaml: -------------------------------------------------------------------------------- 1 | # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json 2 | language: "en-US" 3 | early_access: false 4 | reviews: 5 | profile: "chill" 6 | request_changes_workflow: false 7 | high_level_summary: true 8 | poem: false 9 | review_status: true 10 | collapse_walkthrough: false 11 | auto_review: 12 | enabled: true 13 | drafts: false 14 | chat: 15 | auto_reply: true 16 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Something is not working as expected... 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Expected behaviour** 14 | A clear and concise description of what you expected to happen. 15 | 16 | **Operating system and version** 17 | 18 | - OS: (linux, osx, windows, freebsd, etc) 19 | - GPUs: (list architecture) 20 | 21 | **My Configuration** 22 | 23 | ```yaml 24 | # copy / paste your configuration here 25 | ``` 26 | 27 | **Proxy Logs** 28 | 29 | ``` 30 | # copy / paste from /logs 31 | ``` 32 | 33 | **Upstream Logs** 34 | 35 | ``` 36 | # copy/paste from /logs 37 | ``` 38 | -------------------------------------------------------------------------------- /.github/workflows/closeinactive.yml: -------------------------------------------------------------------------------- 1 | # https://docs.github.com/en/actions/use-cases-and-examples/project-management/closing-inactive-issues 2 | name: Close inactive issues 3 | on: 4 | schedule: 5 | - cron: "32 1 * * *" 6 | 7 | jobs: 8 | close-issues: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | issues: write 12 | pull-requests: write 13 | steps: 14 | - uses: actions/stale@v9 15 | with: 16 | days-before-issue-stale: 14 17 | days-before-issue-close: 14 18 | stale-issue-label: "stale" 19 | stale-issue-message: "This issue is stale because it has been open for 2 weeks with no activity." 20 | close-issue-message: "This issue was closed because it has been inactive for 2 weeks since being marked as stale." 21 | days-before-pr-stale: -1 22 | days-before-pr-close: -1 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | -------------------------------------------------------------------------------- /.github/workflows/containers.yml: -------------------------------------------------------------------------------- 1 | name: Build Containers 2 | 3 | on: 4 | # time has no specific meaning, trying to time it after 5 | # the llama.cpp daily packages are published 6 | # https://github.com/ggml-org/llama.cpp/blob/master/.github/workflows/docker.yml 7 | schedule: 8 | - cron: "37 5 * * *" 9 | 10 | # Allows manual triggering of the workflow 11 | workflow_dispatch: 12 | 13 | jobs: 14 | build-and-push: 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | #platform: [intel, cuda, vulkan, cpu, musa] 19 | platform: [cuda, vulkan, cpu, musa] 20 | fail-fast: false 21 | steps: 22 | - name: Checkout code 23 | uses: actions/checkout@v4 24 | 25 | - name: Log in to GitHub Container Registry 26 | uses: docker/login-action@v2 27 | with: 28 | registry: ghcr.io 29 | username: ${{ github.actor }} 30 | password: ${{ secrets.GITHUB_TOKEN }} 31 | 32 | - name: Run build-container 33 | env: 34 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 35 | run: ./docker/build-container.sh ${{ matrix.platform }} true 36 | 37 | # note make sure mostlygeek/llama-swap has admin rights to the llama-swap package 38 | # see: https://github.com/actions/delete-package-versions/issues/74 39 | delete-untagged-containers: 40 | needs: build-and-push 41 | runs-on: ubuntu-latest 42 | steps: 43 | - uses: actions/delete-package-versions@v5 44 | with: 45 | package-name: 'llama-swap' 46 | package-type: 'container' 47 | delete-only-untagged-versions: 'true' 48 | -------------------------------------------------------------------------------- /.github/workflows/go-ci-windows.yml: -------------------------------------------------------------------------------- 1 | name: Windows CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | 7 | pull_request: 8 | branches: [ "main" ] 9 | 10 | # Allows manual triggering of the workflow 11 | workflow_dispatch: 12 | 13 | jobs: 14 | 15 | run-tests: 16 | runs-on: windows-latest 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Set up Go 21 | uses: actions/setup-go@v4 22 | with: 23 | go-version: '1.23' 24 | 25 | # cache simple-responder to save the build time 26 | - name: Restore Simple Responder 27 | id: restore-simple-responder 28 | uses: actions/cache/restore@v4 29 | with: 30 | path: ./build 31 | key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }} 32 | 33 | # necessary for testing proxy/Process swapping 34 | - name: Create simple-responder 35 | if: steps.restore-simple-responder.outputs.cache-hit != 'true' 36 | shell: bash 37 | run: make simple-responder-windows 38 | 39 | - name: Save Simple Responder 40 | # nothing new to save ... skip this step 41 | if: steps.restore-simple-responder.outputs.cache-hit != 'true' 42 | id: save-simple-responder 43 | uses: actions/cache/save@v4 44 | with: 45 | path: ./build 46 | key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }} 47 | 48 | - name: Test all 49 | shell: bash 50 | run: make test-all -------------------------------------------------------------------------------- /.github/workflows/go-ci.yml: -------------------------------------------------------------------------------- 1 | name: Linux CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | 7 | pull_request: 8 | branches: [ "main" ] 9 | 10 | # Allows manual triggering of the workflow 11 | workflow_dispatch: 12 | 13 | jobs: 14 | 15 | run-tests: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: actions/checkout@v4 19 | 20 | - name: Set up Go 21 | uses: actions/setup-go@v4 22 | with: 23 | go-version: '1.23' 24 | 25 | # cache simple-responder to save the build time 26 | - name: Restore Simple Responder 27 | id: restore-simple-responder 28 | uses: actions/cache/restore@v4 29 | with: 30 | path: ./build 31 | key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }} 32 | 33 | # necessary for testing proxy/Process swapping 34 | - name: Create simple-responder 35 | run: make simple-responder 36 | 37 | - name: Save Simple Responder 38 | # nothing new to save ... skip this step 39 | if: steps.restore-simple-responder.outputs.cache-hit != 'true' 40 | id: save-simple-responder 41 | uses: actions/cache/save@v4 42 | with: 43 | path: ./build 44 | key: ${{ runner.os }}-simple-responder-${{ hashFiles('misc/simple-responder/simple-responder.go') }} 45 | 46 | - name: Test all 47 | run: make test-all -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: goreleaser 2 | 3 | on: 4 | push: 5 | tags: 6 | - '*' 7 | 8 | # Allows manual triggering of the workflow 9 | workflow_dispatch: 10 | 11 | permissions: 12 | contents: write 13 | 14 | jobs: 15 | goreleaser: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - 19 | name: Checkout 20 | uses: actions/checkout@v4 21 | with: 22 | fetch-depth: 0 23 | - 24 | name: Set up Go 25 | uses: actions/setup-go@v5 26 | - 27 | name: Run GoReleaser 28 | uses: goreleaser/goreleaser-action@v6 29 | with: 30 | # either 'goreleaser' (default) or 'goreleaser-pro' 31 | distribution: goreleaser 32 | # 'latest', 'nightly', or a semver 33 | version: '~> v2' 34 | args: release --clean 35 | env: 36 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .aider* 2 | .env 3 | build/ 4 | dist/ 5 | .vscode 6 | .DS_Store 7 | -------------------------------------------------------------------------------- /.goreleaser.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | builds: 4 | - env: 5 | - CGO_ENABLED=0 6 | goos: 7 | - linux 8 | - darwin 9 | - freebsd 10 | - windows 11 | goarch: 12 | - amd64 13 | - arm64 14 | ignore: 15 | - goos: freebsd 16 | goarch: arm64 17 | - goos: windows 18 | goarch: arm64 19 | 20 | # use zip format for windows 21 | archives: 22 | - id: default 23 | format: tar.gz 24 | name_template: "{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}" 25 | builds_info: 26 | group: root 27 | owner: root 28 | format_overrides: 29 | - goos: windows 30 | format: zip -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Benson Wong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Define variables for the application 2 | APP_NAME = llama-swap 3 | BUILD_DIR = build 4 | 5 | # Get the current Git hash 6 | GIT_HASH := $(shell git rev-parse --short HEAD) 7 | ifneq ($(shell git status --porcelain),) 8 | # There are untracked changes 9 | GIT_HASH := $(GIT_HASH)+ 10 | endif 11 | 12 | # Capture the current build date in RFC3339 format 13 | BUILD_DATE := $(shell date -u +"%Y-%m-%dT%H:%M:%SZ") 14 | 15 | # Default target: Builds binaries for both OSX and Linux 16 | all: mac linux simple-responder 17 | 18 | # Clean build directory 19 | clean: 20 | rm -rf $(BUILD_DIR) 21 | 22 | test: 23 | go test -short -v -count=1 ./proxy 24 | 25 | test-all: 26 | go test -v -count=1 ./proxy 27 | 28 | # Build OSX binary 29 | mac: 30 | @echo "Building Mac binary..." 31 | GOOS=darwin GOARCH=arm64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-darwin-arm64 32 | 33 | # Build Linux binary 34 | linux: 35 | @echo "Building Linux binary..." 36 | GOOS=linux GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-linux-amd64 37 | 38 | # Build Windows binary 39 | windows: 40 | @echo "Building Windows binary..." 41 | GOOS=windows GOARCH=amd64 go build -ldflags="-X main.commit=${GIT_HASH} -X main.version=local_${GIT_HASH} -X main.date=${BUILD_DATE}" -o $(BUILD_DIR)/$(APP_NAME)-windows-amd64.exe 42 | 43 | # for testing proxy.Process 44 | simple-responder: 45 | @echo "Building simple responder" 46 | GOOS=darwin GOARCH=arm64 go build -o $(BUILD_DIR)/simple-responder_darwin_arm64 misc/simple-responder/simple-responder.go 47 | GOOS=linux GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder_linux_amd64 misc/simple-responder/simple-responder.go 48 | 49 | simple-responder-windows: 50 | @echo "Building simple responder for windows" 51 | GOOS=windows GOARCH=amd64 go build -o $(BUILD_DIR)/simple-responder.exe misc/simple-responder/simple-responder.go 52 | 53 | # Ensure build directory exists 54 | $(BUILD_DIR): 55 | mkdir -p $(BUILD_DIR) 56 | 57 | # Create a new release tag 58 | release: 59 | @echo "Checking for unstaged changes..." 60 | @if [ -n "$(shell git status --porcelain)" ]; then \ 61 | echo "Error: There are unstaged changes. Please commit or stash your changes before creating a release tag." >&2; \ 62 | exit 1; \ 63 | fi 64 | 65 | # Get the highest tag in v{number} format, increment it, and create a new tag 66 | @highest_tag=$$(git tag --sort=-v:refname | grep -E '^v[0-9]+$$' | head -n 1 || echo "v0"); \ 67 | new_tag="v$$(( $${highest_tag#v} + 1 ))"; \ 68 | echo "tagging new version: $$new_tag"; \ 69 | git tag "$$new_tag"; 70 | 71 | # Phony targets 72 | .PHONY: all clean mac linux windows simple-responder 73 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![llama-swap header image](header2.png) 2 | ![GitHub Downloads (all assets, all releases)](https://img.shields.io/github/downloads/mostlygeek/llama-swap/total) 3 | ![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/mostlygeek/llama-swap/go-ci.yml) 4 | ![GitHub Repo stars](https://img.shields.io/github/stars/mostlygeek/llama-swap) 5 | 6 | # llama-swap 7 | 8 | llama-swap is a light weight, transparent proxy server that provides automatic model swapping to llama.cpp's server. 9 | 10 | Written in golang, it is very easy to install (single binary with no dependencies) and configure (single yaml file). To get started, download a pre-built binary or use the provided docker images. 11 | 12 | ## Features: 13 | 14 | - ✅ Easy to deploy: single binary with no dependencies 15 | - ✅ Easy to config: single yaml file 16 | - ✅ On-demand model switching 17 | - ✅ OpenAI API supported endpoints: 18 | - `v1/completions` 19 | - `v1/chat/completions` 20 | - `v1/embeddings` 21 | - `v1/rerank` 22 | - `v1/audio/speech` ([#36](https://github.com/mostlygeek/llama-swap/issues/36)) 23 | - `v1/audio/transcriptions` ([docs](https://github.com/mostlygeek/llama-swap/issues/41#issuecomment-2722637867)) 24 | - ✅ llama-swap custom API endpoints 25 | - `/log` - remote log monitoring 26 | - `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31)) 27 | - `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58)) 28 | - `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61)) 29 | - ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) 30 | - ✅ Automatic unloading of models after timeout by setting a `ttl` 31 | - ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc) 32 | - ✅ Docker and Podman support 33 | - ✅ Full control over server settings per model 34 | 35 | ## How does llama-swap work? 36 | 37 | When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request. 38 | 39 | In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used. 40 | 41 | ## config.yaml 42 | 43 | llama-swap's configuration is purposefully simple: 44 | 45 | ```yaml 46 | models: 47 | "qwen2.5": 48 | cmd: | 49 | /app/llama-server 50 | -hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M 51 | --port ${PORT} 52 | 53 | "smollm2": 54 | cmd: | 55 | /app/llama-server 56 | -hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M 57 | --port ${PORT} 58 | ``` 59 | 60 | .. but also supports many advanced features: 61 | 62 | - `groups` to run multiple models at once 63 | - `macros` for reusable snippets 64 | - `ttl` to automatically unload models 65 | - `aliases` to use familiar model names (e.g., "gpt-4o-mini") 66 | - `env` variables to pass custom environment to inference servers 67 | - `useModelName` to override model names sent to upstream servers 68 | - `healthCheckTimeout` to control model startup wait times 69 | - `${PORT}` automatic port variables for dynamic port assignment 70 | - `cmdStop` for to gracefully stop Docker/Podman containers 71 | 72 | Check the [configuration documentation](https://github.com/mostlygeek/llama-swap/wiki/Configuration) in the wiki for all options. 73 | 74 | ## Docker Install ([download images](https://github.com/mostlygeek/llama-swap/pkgs/container/llama-swap)) 75 | 76 | Docker is the quickest way to try out llama-swap: 77 | 78 | ```shell 79 | # use CPU inference comes with the example config above 80 | $ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu 81 | 82 | # qwen2.5 0.5B 83 | $ curl -s http://localhost:9292/v1/chat/completions \ 84 | -H "Content-Type: application/json" \ 85 | -H "Authorization: Bearer no-key" \ 86 | -d '{"model":"qwen2.5","messages": [{"role": "user","content": "tell me a joke"}]}' | \ 87 | jq -r '.choices[0].message.content' 88 | 89 | # SmolLM2 135M 90 | $ curl -s http://localhost:9292/v1/chat/completions \ 91 | -H "Content-Type: application/json" \ 92 | -H "Authorization: Bearer no-key" \ 93 | -d '{"model":"smollm2","messages": [{"role": "user","content": "tell me a joke"}]}' | \ 94 | jq -r '.choices[0].message.content' 95 | ``` 96 | 97 |
98 | Docker images are built nightly for cuda, intel, vulcan, etc ... 99 | 100 | They include: 101 | 102 | - `ghcr.io/mostlygeek/llama-swap:cpu` 103 | - `ghcr.io/mostlygeek/llama-swap:cuda` 104 | - `ghcr.io/mostlygeek/llama-swap:intel` 105 | - `ghcr.io/mostlygeek/llama-swap:vulkan` 106 | - ROCm disabled until fixed in llama.cpp container 107 | 108 | Specific versions are also available and are tagged with the llama-swap, architecture and llama.cpp versions. For example: `ghcr.io/mostlygeek/llama-swap:v89-cuda-b4716` 109 | 110 | Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration. 111 | 112 | ```shell 113 | $ docker run -it --rm --runtime nvidia -p 9292:8080 \ 114 | -v /path/to/models:/models \ 115 | -v /path/to/custom/config.yaml:/app/config.yaml \ 116 | ghcr.io/mostlygeek/llama-swap:cuda 117 | ``` 118 | 119 |
120 | 121 | ## Bare metal Install ([download](https://github.com/mostlygeek/llama-swap/releases)) 122 | 123 | Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are automatically published and are likely a few hours ahead of the docker releases. The baremetal install works with any OpenAI compatible server, not just llama-server. 124 | 125 | 1. Create a configuration file, see [config.example.yaml](config.example.yaml) 126 | 1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture. 127 | 1. Run the binary with `llama-swap --config path/to/config.yaml`. 128 | Available flags: 129 | - `--config`: Path to the configuration file (default: `config.yaml`). 130 | - `--listen`: Address and port to listen on (default: `:8080`). 131 | - `--version`: Show version information and exit. 132 | - `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`). 133 | 134 | ### Building from source 135 | 136 | 1. Install golang for your system 137 | 1. `git clone git@github.com:mostlygeek/llama-swap.git` 138 | 1. `make clean all` 139 | 1. Binaries will be in `build/` subdirectory 140 | 141 | ## Monitoring Logs 142 | 143 | Open the `http:///logs` with your browser to get a web interface with streaming logs. 144 | 145 | Of course, CLI access is also supported: 146 | 147 | ```shell 148 | # sends up to the last 10KB of logs 149 | curl http://host/logs' 150 | 151 | # streams combined logs 152 | curl -Ns 'http://host/logs/stream' 153 | 154 | # just llama-swap's logs 155 | curl -Ns 'http://host/logs/stream/proxy' 156 | 157 | # just upstream's logs 158 | curl -Ns 'http://host/logs/stream/upstream' 159 | 160 | # stream and filter logs with linux pipes 161 | curl -Ns http://host/logs/stream | grep 'eval time' 162 | 163 | # skips history and just streams new log entries 164 | curl -Ns 'http://host/logs/stream?no-history' 165 | ``` 166 | 167 | ## Do I need to use llama.cpp's server (llama-server)? 168 | 169 | Any OpenAI compatible server would work. llama-swap was originally designed for llama-server and it is the best supported. 170 | 171 | For Python based inference servers like vllm or tabbyAPI it is recommended to run them via podman or docker. This provides clean environment isolation as well as responding correctly to `SIGTERM` signals to shutdown. 172 | 173 | ## Star History 174 | 175 | [![Star History Chart](https://api.star-history.com/svg?repos=mostlygeek/llama-swap&type=Date)](https://www.star-history.com/#mostlygeek/llama-swap&Date) 176 | -------------------------------------------------------------------------------- /config.example.yaml: -------------------------------------------------------------------------------- 1 | # Seconds to wait for llama.cpp to be available to serve requests 2 | # Default (and minimum): 15 seconds 3 | healthCheckTimeout: 90 4 | 5 | # valid log levels: debug, info (default), warn, error 6 | logLevel: debug 7 | 8 | # creating a coding profile with models for code generation and general questions 9 | groups: 10 | coding: 11 | swap: false 12 | members: 13 | - "qwen" 14 | - "llama" 15 | 16 | models: 17 | "llama": 18 | cmd: | 19 | models/llama-server-osx 20 | --port ${PORT} 21 | -m models/Llama-3.2-1B-Instruct-Q4_0.gguf 22 | 23 | # list of model name aliases this llama.cpp instance can serve 24 | aliases: 25 | - gpt-4o-mini 26 | 27 | # check this path for a HTTP 200 response for the server to be ready 28 | checkEndpoint: /health 29 | 30 | # unload model after 5 seconds 31 | ttl: 5 32 | 33 | "qwen": 34 | cmd: models/llama-server-osx --port ${PORT} -m models/qwen2.5-0.5b-instruct-q8_0.gguf 35 | aliases: 36 | - gpt-3.5-turbo 37 | 38 | # Embedding example with Nomic 39 | # https://huggingface.co/nomic-ai/nomic-embed-text-v1.5-GGUF 40 | "nomic": 41 | cmd: | 42 | models/llama-server-osx --port ${PORT} 43 | -m models/nomic-embed-text-v1.5.Q8_0.gguf 44 | --ctx-size 8192 45 | --batch-size 8192 46 | --rope-scaling yarn 47 | --rope-freq-scale 0.75 48 | -ngl 99 49 | --embeddings 50 | 51 | # Reranking example with bge-reranker 52 | # https://huggingface.co/gpustack/bge-reranker-v2-m3-GGUF 53 | "bge-reranker": 54 | cmd: | 55 | models/llama-server-osx --port ${PORT} 56 | -m models/bge-reranker-v2-m3-Q4_K_M.gguf 57 | --ctx-size 8192 58 | --reranking 59 | 60 | # Docker Support (v26.1.4+ required!) 61 | "dockertest": 62 | cmd: | 63 | docker run --name dockertest 64 | --init --rm -p ${PORT}:8080 -v /mnt/nvme/models:/models 65 | ghcr.io/ggerganov/llama.cpp:server 66 | --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' 67 | 68 | "simple": 69 | # example of setting environment variables 70 | env: 71 | - CUDA_VISIBLE_DEVICES=0,1 72 | - env1=hello 73 | cmd: build/simple-responder --port ${PORT} 74 | unlisted: true 75 | 76 | # use "none" to skip check. Caution this may cause some requests to fail 77 | # until the upstream server is ready for traffic 78 | checkEndpoint: none 79 | 80 | # don't use these, just for testing if things are broken 81 | "broken": 82 | cmd: models/llama-server-osx --port 8999 -m models/doesnotexist.gguf 83 | proxy: http://127.0.0.1:8999 84 | unlisted: true 85 | "broken_timeout": 86 | cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf 87 | proxy: http://127.0.0.1:9000 88 | unlisted: true -------------------------------------------------------------------------------- /docker/build-container.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd $(dirname "$0") 4 | 5 | ARCH=$1 6 | PUSH_IMAGES=${2:-false} 7 | 8 | # List of allowed architectures 9 | ALLOWED_ARCHS=("intel" "vulkan" "musa" "cuda" "cpu") 10 | 11 | # Check if ARCH is in the allowed list 12 | if [[ ! " ${ALLOWED_ARCHS[@]} " =~ " ${ARCH} " ]]; then 13 | echo "Error: ARCH must be one of the following: ${ALLOWED_ARCHS[@]}" 14 | exit 1 15 | fi 16 | 17 | # Check if GITHUB_TOKEN is set and not empty 18 | if [[ -z "$GITHUB_TOKEN" ]]; then 19 | echo "Error: GITHUB_TOKEN is not set or is empty." 20 | exit 1 21 | fi 22 | 23 | # the most recent llama-swap tag 24 | # have to strip out the 'v' due to .tar.gz file naming 25 | LS_VER=$(curl -s https://api.github.com/repos/mostlygeek/llama-swap/releases/latest | jq -r .tag_name | sed 's/v//') 26 | 27 | if [ "$ARCH" == "cpu" ]; then 28 | # cpu only containers just use the latest available 29 | CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:cpu" 30 | echo "Building ${CONTAINER_LATEST} $LS_VER" 31 | docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server --build-arg LS_VER=${LS_VER} -t ${CONTAINER_LATEST} . 32 | if [ "$PUSH_IMAGES" == "true" ]; then 33 | docker push ${CONTAINER_LATEST} 34 | fi 35 | else 36 | LCPP_TAG=$(curl -s -H "Authorization: Bearer $GITHUB_TOKEN" \ 37 | "https://api.github.com/users/ggml-org/packages/container/llama.cpp/versions" \ 38 | | jq -r --arg arch "$ARCH" '.[] | select(.metadata.container.tags[] | startswith("server-\($arch)")) | .metadata.container.tags[]' \ 39 | | sort -r | head -n1 | awk -F '-' '{print $3}') 40 | 41 | # Abort if LCPP_TAG is empty. 42 | if [[ -z "$LCPP_TAG" ]]; then 43 | echo "Abort: Could not find llama-server container for arch: $ARCH" 44 | exit 1 45 | fi 46 | 47 | CONTAINER_TAG="ghcr.io/mostlygeek/llama-swap:v${LS_VER}-${ARCH}-${LCPP_TAG}" 48 | CONTAINER_LATEST="ghcr.io/mostlygeek/llama-swap:${ARCH}" 49 | echo "Building ${CONTAINER_TAG} $LS_VER" 50 | docker build -f llama-swap.Containerfile --build-arg BASE_TAG=server-${ARCH}-${LCPP_TAG} --build-arg LS_VER=${LS_VER} -t ${CONTAINER_TAG} -t ${CONTAINER_LATEST} . 51 | if [ "$PUSH_IMAGES" == "true" ]; then 52 | docker push ${CONTAINER_TAG} 53 | docker push ${CONTAINER_LATEST} 54 | fi 55 | fi -------------------------------------------------------------------------------- /docker/config.example.yaml: -------------------------------------------------------------------------------- 1 | healthCheckTimeout: 300 2 | logRequests: true 3 | 4 | models: 5 | "qwen2.5": 6 | proxy: "http://127.0.0.1:9999" 7 | cmd: > 8 | /app/llama-server 9 | -hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M 10 | --port 9999 11 | 12 | "smollm2": 13 | proxy: "http://127.0.0.1:9999" 14 | cmd: > 15 | /app/llama-server 16 | -hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M 17 | --port 9999 -------------------------------------------------------------------------------- /docker/llama-swap.Containerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_TAG=server-cuda 2 | FROM ghcr.io/ggml-org/llama.cpp:${BASE_TAG} 3 | 4 | # has to be after the FROM 5 | ARG LS_VER=89 6 | 7 | WORKDIR /app 8 | RUN \ 9 | curl -LO https://github.com/mostlygeek/llama-swap/releases/download/v"${LS_VER}"/llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \ 10 | tar -zxf llama-swap_"${LS_VER}"_linux_amd64.tar.gz && \ 11 | rm llama-swap_"${LS_VER}"_linux_amd64.tar.gz 12 | 13 | COPY config.example.yaml /app/config.yaml 14 | 15 | HEALTHCHECK CMD curl -f http://localhost:8080/ || exit 1 16 | ENTRYPOINT [ "/app/llama-swap", "-config", "/app/config.yaml" ] -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Example Configs and Use Cases 2 | 3 | A collections of usecases and examples for getting the most out of llama-swap. 4 | 5 | * [Speculative Decoding](speculative-decoding/README.md) - using a small draft model can increase inference speeds from 20% to 40%. This example includes a configurations Qwen2.5-Coder-32B (2.5x increase) and Llama-3.1-70B (1.4x increase) in the best cases. 6 | * [Optimizing Code Generation](benchmark-snakegame/README.md) - find the optimal settings for your machine. This example demonstrates defining multiple configurations and testing which one is fastest. -------------------------------------------------------------------------------- /examples/aider-qwq-coder/README.md: -------------------------------------------------------------------------------- 1 | # aider, QwQ, Qwen-Coder 2.5 and llama-swap 2 | 3 | This guide show how to use aider and llama-swap to get a 100% local coding co-pilot setup. The focus is on the trickest part which is configuring aider, llama-swap and llama-server to work together. 4 | 5 | ## Here's what you you need: 6 | 7 | - aider - [installation docs](https://aider.chat/docs/install.html) 8 | - llama-server - [download latest release](https://github.com/ggml-org/llama.cpp/releases) 9 | - llama-swap - [download latest release](https://github.com/mostlygeek/llama-swap/releases) 10 | - [QwQ 32B](https://huggingface.co/bartowski/Qwen_QwQ-32B-GGUF) and [Qwen Coder 2.5 32B](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) models 11 | - 24GB VRAM video card 12 | 13 | ## Running aider 14 | 15 | The goal is getting this command line to work: 16 | 17 | ```sh 18 | aider --architect \ 19 | --no-show-model-warnings \ 20 | --model openai/QwQ \ 21 | --editor-model openai/qwen-coder-32B \ 22 | --model-settings-file aider.model.settings.yml \ 23 | --openai-api-key "sk-na" \ 24 | --openai-api-base "http://10.0.1.24:8080/v1" \ 25 | ``` 26 | 27 | Set `--openai-api-base` to the IP and port where your llama-swap is running. 28 | 29 | ## Create an aider model settings file 30 | 31 | ```yaml 32 | # aider.model.settings.yml 33 | 34 | # 35 | # !!! important: model names must match llama-swap configuration names !!! 36 | # 37 | 38 | - name: "openai/QwQ" 39 | edit_format: diff 40 | extra_params: 41 | max_tokens: 16384 42 | top_p: 0.95 43 | top_k: 40 44 | presence_penalty: 0.1 45 | repetition_penalty: 1 46 | num_ctx: 16384 47 | use_temperature: 0.6 48 | reasoning_tag: think 49 | weak_model_name: "openai/qwen-coder-32B" 50 | editor_model_name: "openai/qwen-coder-32B" 51 | 52 | - name: "openai/qwen-coder-32B" 53 | edit_format: diff 54 | extra_params: 55 | max_tokens: 16384 56 | top_p: 0.8 57 | top_k: 20 58 | repetition_penalty: 1.05 59 | use_temperature: 0.6 60 | reasoning_tag: think 61 | editor_edit_format: editor-diff 62 | editor_model_name: "openai/qwen-coder-32B" 63 | ``` 64 | 65 | ## llama-swap configuration 66 | 67 | ```yaml 68 | # config.yaml 69 | 70 | # The parameters are tweaked to fit model+context into 24GB VRAM GPUs 71 | models: 72 | "qwen-coder-32B": 73 | proxy: "http://127.0.0.1:8999" 74 | cmd: > 75 | /path/to/llama-server 76 | --host 127.0.0.1 --port 8999 --flash-attn --slots 77 | --ctx-size 16000 78 | --cache-type-k q8_0 --cache-type-v q8_0 79 | -ngl 99 80 | --model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf 81 | 82 | "QwQ": 83 | proxy: "http://127.0.0.1:9503" 84 | cmd: > 85 | /path/to/llama-server 86 | --host 127.0.0.1 --port 9503 --flash-attn --metrics--slots 87 | --cache-type-k q8_0 --cache-type-v q8_0 88 | --ctx-size 32000 89 | --samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc" 90 | --temp 0.6 --repeat-penalty 1.1 --dry-multiplier 0.5 91 | --min-p 0.01 --top-k 40 --top-p 0.95 92 | -ngl 99 93 | --model /mnt/nvme/models/bartowski/Qwen_QwQ-32B-Q4_K_M.gguf 94 | ``` 95 | 96 | ## Advanced, Dual GPU Configuration 97 | 98 | If you have _dual 24GB GPUs_ you can use llama-swap profiles to avoid swapping between QwQ and Qwen Coder. 99 | 100 | In llama-swap's configuration file: 101 | 102 | 1. add a `profiles` section with `aider` as the profile name 103 | 2. using the `env` field to specify the GPU IDs for each model 104 | 105 | ```yaml 106 | # config.yaml 107 | 108 | # Add a profile for aider 109 | profiles: 110 | aider: 111 | - qwen-coder-32B 112 | - QwQ 113 | 114 | models: 115 | "qwen-coder-32B": 116 | # manually set the GPU to run on 117 | env: 118 | - "CUDA_VISIBLE_DEVICES=0" 119 | proxy: "http://127.0.0.1:8999" 120 | cmd: /path/to/llama-server ... 121 | 122 | "QwQ": 123 | # manually set the GPU to run on 124 | env: 125 | - "CUDA_VISIBLE_DEVICES=1" 126 | proxy: "http://127.0.0.1:9503" 127 | cmd: /path/to/llama-server ... 128 | ``` 129 | 130 | Append the profile tag, `aider:`, to the model names in the model settings file 131 | 132 | ```yaml 133 | # aider.model.settings.yml 134 | - name: "openai/aider:QwQ" 135 | weak_model_name: "openai/aider:qwen-coder-32B-aider" 136 | editor_model_name: "openai/aider:qwen-coder-32B-aider" 137 | 138 | - name: "openai/aider:qwen-coder-32B" 139 | editor_model_name: "openai/aider:qwen-coder-32B-aider" 140 | ``` 141 | 142 | Run aider with: 143 | 144 | ```sh 145 | $ aider --architect \ 146 | --no-show-model-warnings \ 147 | --model openai/aider:QwQ \ 148 | --editor-model openai/aider:qwen-coder-32B \ 149 | --config aider.conf.yml \ 150 | --model-settings-file aider.model.settings.yml 151 | --openai-api-key "sk-na" \ 152 | --openai-api-base "http://10.0.1.24:8080/v1" 153 | ``` 154 | -------------------------------------------------------------------------------- /examples/aider-qwq-coder/aider.model.settings.dualgpu.yml: -------------------------------------------------------------------------------- 1 | # this makes use of llama-swap's profile feature to 2 | # keep the architect and editor models in VRAM on different GPUs 3 | 4 | - name: "openai/aider:QwQ" 5 | edit_format: diff 6 | extra_params: 7 | max_tokens: 16384 8 | top_p: 0.95 9 | top_k: 40 10 | presence_penalty: 0.1 11 | repetition_penalty: 1 12 | num_ctx: 16384 13 | use_temperature: 0.6 14 | reasoning_tag: think 15 | weak_model_name: "openai/aider:qwen-coder-32B" 16 | editor_model_name: "openai/aider:qwen-coder-32B" 17 | 18 | - name: "openai/aider:qwen-coder-32B" 19 | edit_format: diff 20 | extra_params: 21 | max_tokens: 16384 22 | top_p: 0.8 23 | top_k: 20 24 | repetition_penalty: 1.05 25 | use_temperature: 0.6 26 | reasoning_tag: think 27 | editor_edit_format: editor-diff 28 | editor_model_name: "openai/aider:qwen-coder-32B" -------------------------------------------------------------------------------- /examples/aider-qwq-coder/aider.model.settings.yml: -------------------------------------------------------------------------------- 1 | - name: "openai/QwQ" 2 | edit_format: diff 3 | extra_params: 4 | max_tokens: 16384 5 | top_p: 0.95 6 | top_k: 40 7 | presence_penalty: 0.1 8 | repetition_penalty: 1 9 | num_ctx: 16384 10 | use_temperature: 0.6 11 | reasoning_tag: think 12 | weak_model_name: "openai/qwen-coder-32B" 13 | editor_model_name: "openai/qwen-coder-32B" 14 | 15 | - name: "openai/qwen-coder-32B" 16 | edit_format: diff 17 | extra_params: 18 | max_tokens: 16384 19 | top_p: 0.8 20 | top_k: 20 21 | repetition_penalty: 1.05 22 | use_temperature: 0.6 23 | reasoning_tag: think 24 | editor_edit_format: editor-diff 25 | editor_model_name: "openai/qwen-coder-32B" 26 | 27 | -------------------------------------------------------------------------------- /examples/aider-qwq-coder/llama-swap.yaml: -------------------------------------------------------------------------------- 1 | healthCheckTimeout: 300 2 | logLevel: debug 3 | 4 | profiles: 5 | aider: 6 | - qwen-coder-32B 7 | - QwQ 8 | 9 | models: 10 | "qwen-coder-32B": 11 | env: 12 | - "CUDA_VISIBLE_DEVICES=0" 13 | aliases: 14 | - coder 15 | proxy: "http://127.0.0.1:8999" 16 | 17 | # set appropriate paths for your environment 18 | cmd: > 19 | /path/to/llama-server 20 | --host 127.0.0.1 --port 8999 --flash-attn --slots 21 | --ctx-size 16000 22 | --ctx-size-draft 16000 23 | --model /path/to/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf 24 | --model-draft /path/to/Qwen2.5-Coder-1.5B-Instruct-Q8_0.gguf 25 | -ngl 99 -ngld 99 26 | --draft-max 16 --draft-min 4 --draft-p-min 0.4 27 | --cache-type-k q8_0 --cache-type-v q8_0 28 | "QwQ": 29 | env: 30 | - "CUDA_VISIBLE_DEVICES=1" 31 | proxy: "http://127.0.0.1:9503" 32 | 33 | # set appropriate paths for your environment 34 | cmd: > 35 | /path/to/llama-server 36 | --host 127.0.0.1 --port 9503 37 | --flash-attn --metrics 38 | --slots 39 | --model /path/to/Qwen_QwQ-32B-Q4_K_M.gguf 40 | --cache-type-k q8_0 --cache-type-v q8_0 41 | --ctx-size 32000 42 | --samplers "top_k;top_p;min_p;temperature;dry;typ_p;xtc" 43 | --temp 0.6 44 | --repeat-penalty 1.1 45 | --dry-multiplier 0.5 46 | --min-p 0.01 47 | --top-k 40 48 | --top-p 0.95 49 | -ngl 99 -ngld 99 -------------------------------------------------------------------------------- /examples/benchmark-snakegame/README.md: -------------------------------------------------------------------------------- 1 | # Optimizing Code Generation with llama-swap 2 | 3 | Finding the best mix of settings for your hardware can be time consuming. This example demonstrates using a custom configuration file to automate testing different scenarios to find the an optimal configuration. 4 | 5 | The benchmark writes a snake game in Python, TypeScript, and Swift using the Qwen 2.5 Coder models. The experiments were done using a 3090 and a P40. 6 | 7 | **Benchmark Scenarios** 8 | 9 | Three scenarios are tested: 10 | 11 | - 3090-only: Just the main model on the 3090 12 | - 3090-with-draft: the main and draft models on the 3090 13 | - 3090-P40-draft: the main model on the 3090 with the draft model offloaded to the P40 14 | 15 | **Available Devices** 16 | 17 | Use the following command to list available devices IDs for the configuration: 18 | 19 | ``` 20 | $ /mnt/nvme/llama-server/llama-server-f3252055 --list-devices 21 | ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no 22 | ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no 23 | ggml_cuda_init: found 4 CUDA devices: 24 | Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes 25 | Device 1: Tesla P40, compute capability 6.1, VMM: yes 26 | Device 2: Tesla P40, compute capability 6.1, VMM: yes 27 | Device 3: Tesla P40, compute capability 6.1, VMM: yes 28 | Available devices: 29 | CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 406 MiB free) 30 | CUDA1: Tesla P40 (24438 MiB, 22942 MiB free) 31 | CUDA2: Tesla P40 (24438 MiB, 24144 MiB free) 32 | CUDA3: Tesla P40 (24438 MiB, 24144 MiB free) 33 | ``` 34 | 35 | **Configuration** 36 | 37 | The configuration file, `benchmark-config.yaml`, defines the three scenarios: 38 | 39 | ```yaml 40 | models: 41 | "3090-only": 42 | proxy: "http://127.0.0.1:9503" 43 | cmd: > 44 | /mnt/nvme/llama-server/llama-server-f3252055 45 | --host 127.0.0.1 --port 9503 46 | --flash-attn 47 | --slots 48 | 49 | --model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf 50 | -ngl 99 51 | --device CUDA0 52 | 53 | --ctx-size 32768 54 | --cache-type-k q8_0 --cache-type-v q8_0 55 | 56 | "3090-with-draft": 57 | proxy: "http://127.0.0.1:9503" 58 | # --ctx-size 28500 max that can fit on 3090 after draft model 59 | cmd: > 60 | /mnt/nvme/llama-server/llama-server-f3252055 61 | --host 127.0.0.1 --port 9503 62 | --flash-attn 63 | --slots 64 | 65 | --model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf 66 | -ngl 99 67 | --device CUDA0 68 | 69 | --model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf 70 | -ngld 99 71 | --draft-max 16 72 | --draft-min 4 73 | --draft-p-min 0.4 74 | --device-draft CUDA0 75 | 76 | --ctx-size 28500 77 | --cache-type-k q8_0 --cache-type-v q8_0 78 | 79 | "3090-P40-draft": 80 | proxy: "http://127.0.0.1:9503" 81 | cmd: > 82 | /mnt/nvme/llama-server/llama-server-f3252055 83 | --host 127.0.0.1 --port 9503 84 | --flash-attn --metrics 85 | --slots 86 | --model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf 87 | -ngl 99 88 | --device CUDA0 89 | 90 | --model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf 91 | -ngld 99 92 | --draft-max 16 93 | --draft-min 4 94 | --draft-p-min 0.4 95 | --device-draft CUDA1 96 | 97 | --ctx-size 32768 98 | --cache-type-k q8_0 --cache-type-v q8_0 99 | ``` 100 | 101 | > Note in the `3090-with-draft` scenario the `--ctx-size` had to be reduced from 32768 to to accommodate the draft model. 102 | 103 | 104 | **Running the Benchmark** 105 | 106 | To run the benchmark, execute the following commands: 107 | 108 | 1. `llama-swap -config benchmark-config.yaml` 109 | 1. `./run-benchmark.sh http://localhost:8080 "3090-only" "3090-with-draft" "3090-P40-draft"` 110 | 111 | The [benchmark script](run-benchmark.sh) generates a CSV output of the results, which can be converted to a Markdown table for readability. 112 | 113 | **Results (tokens/second)** 114 | 115 | | model | python | typescript | swift | 116 | |-----------------|--------|------------|-------| 117 | | 3090-only | 34.03 | 34.01 | 34.01 | 118 | | 3090-with-draft | 106.65 | 70.48 | 57.89 | 119 | | 3090-P40-draft | 81.54 | 60.35 | 46.50 | 120 | 121 | Many different factors, like the programming language, can have big impacts on the performance gains. However, with a custom configuration file for benchmarking it is easy to test the different variations to discover what's best for your hardware. 122 | 123 | Happy coding! -------------------------------------------------------------------------------- /examples/benchmark-snakegame/run-benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script generates a CSV file showing the token/second for generating a Snake Game in python, typescript and swift 4 | # It was created to test the effects of speculative decoding and the various draft settings on performance. 5 | # 6 | # Writing code with a low temperature seems to provide fairly consistent logic. 7 | # 8 | # Usage: ./benchmark.sh [model2 ...] 9 | # Example: ./benchmark.sh http://localhost:8080 model1 model2 10 | 11 | if [ "$#" -lt 2 ]; then 12 | echo "Usage: $0 [model2 ...]" 13 | exit 1 14 | fi 15 | 16 | url=$1; shift 17 | 18 | echo "model,python,typescript,swift" 19 | 20 | for model in "$@"; do 21 | 22 | echo -n "$model," 23 | 24 | for lang in "python" "typescript" "swift"; do 25 | # expects a llama.cpp after PR https://github.com/ggerganov/llama.cpp/pull/10548 26 | # (Dec 3rd/2024) 27 | time=$(curl -s --url "$url/v1/chat/completions" -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"top_k\": 1, \"timings_per_token\":true, \"model\":\"$model\"}" | jq -r .timings.predicted_per_second) 28 | 29 | if [ $? -ne 0 ]; then 30 | time="error" 31 | exit 1 32 | fi 33 | 34 | if [ "$lang" != "swift" ]; then 35 | printf "%0.2f tps," $time 36 | else 37 | printf "%0.2f tps\n" $time 38 | fi 39 | done 40 | done -------------------------------------------------------------------------------- /examples/restart-on-config-change/README.md: -------------------------------------------------------------------------------- 1 | # Restart llama-swap on config change 2 | 3 | Sometimes editing the configuration file can take a bit of trail and error to get a model configuration tuned just right. The `watch-and-restart.sh` script can be used to watch `config.yaml` for changes and restart `llama-swap` when it detects a change. 4 | 5 | ```bash 6 | #!/bin/bash 7 | # 8 | # A simple watch and restart llama-swap when its configuration 9 | # file changes. Useful for trying out configuration changes 10 | # without manually restarting the server each time. 11 | if [ -z "$1" ]; then 12 | echo "Usage: $0 " 13 | exit 1 14 | fi 15 | 16 | while true; do 17 | # Start the process again 18 | ./llama-swap-linux-amd64 -config $1 -listen :1867 & 19 | PID=$! 20 | echo "Started llama-swap with PID $PID" 21 | 22 | # Wait for modifications in the specified directory or file 23 | inotifywait -e modify "$1" 24 | 25 | # Check if process exists before sending signal 26 | if kill -0 $PID 2>/dev/null; then 27 | echo "Sending SIGTERM to $PID" 28 | kill -SIGTERM $PID 29 | wait $PID 30 | else 31 | echo "Process $PID no longer exists" 32 | fi 33 | sleep 1 34 | done 35 | ``` 36 | 37 | ## Usage and output example 38 | 39 | ```bash 40 | $ ./watch-and-restart.sh config.yaml 41 | Started llama-swap with PID 495455 42 | Setting up watches. 43 | Watches established. 44 | llama-swap listening on :1867 45 | Sending SIGTERM to 495455 46 | Shutting down llama-swap 47 | Started llama-swap with PID 495486 48 | Setting up watches. 49 | Watches established. 50 | llama-swap listening on :1867 51 | ``` 52 | -------------------------------------------------------------------------------- /examples/speculative-decoding/README.md: -------------------------------------------------------------------------------- 1 | # Speculative Decoding 2 | 3 | Speculative decoding can significantly improve the tokens per second. However, this comes at the cost of increased VRAM usage for the draft model. The examples provided are based on a server with three P40s and one 3090. 4 | 5 | ## Coding Use Case 6 | 7 | This example uses Qwen2.5 Coder 32B with the 0.5B model as a draft. A quantization of Q8_0 was chosen for the draft model, as quantization has a greater impact on smaller models. 8 | 9 | The models used are: 10 | 11 | * [Bartowski Qwen2.5-Coder-32B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-32B-Instruct-GGUF) 12 | * [Bartowski Qwen2.5-Coder-0.5B-Instruct](https://huggingface.co/bartowski/Qwen2.5-Coder-0.5B-Instruct-GGUF) 13 | 14 | The llama-swap configuration is as follows: 15 | 16 | ```yaml 17 | models: 18 | "qwen-coder-32b-q4": 19 | # main model on 3090, draft on P40 #1 20 | cmd: > 21 | /mnt/nvme/llama-server/llama-server-be0e35 22 | --host 127.0.0.1 --port 9503 23 | --flash-attn --metrics 24 | --slots 25 | --model /mnt/nvme/models/Qwen2.5-Coder-32B-Instruct-Q4_K_M.gguf 26 | -ngl 99 27 | --ctx-size 19000 28 | --model-draft /mnt/nvme/models/Qwen2.5-Coder-0.5B-Instruct-Q8_0.gguf 29 | -ngld 99 30 | --draft-max 16 31 | --draft-min 4 32 | --draft-p-min 0.4 33 | --device CUDA0 34 | --device-draft CUDA1 35 | proxy: "http://127.0.0.1:9503" 36 | ``` 37 | 38 | In this configuration, two GPUs are used: a 3090 (CUDA0) for the main model and a P40 (CUDA1) for the draft model. Although both models can fit on the 3090, relocating the draft model to the P40 freed up space for a larger context size. Despite the P40 being about 1/3rd the speed of the 3090, the small model still improved tokens per second. 39 | 40 | Multiple tests were run with various parameters, and the fastest result was chosen for the configuration. In all tests, the 0.5B model produced the largest improvements to tokens per second. 41 | 42 | Baseline: 33.92 tokens/second on 3090 without a draft model. 43 | 44 | | draft-max | draft-min | draft-p-min | python | TS | swift | 45 | |-----------|-----------|-------------|--------|----|-------| 46 | | 16 | 1 | 0.9 | 71.64 | 55.55 | 48.06 | 47 | | 16 | 1 | 0.4 | 83.21 | 58.55 | 45.50 | 48 | | 16 | 1 | 0.1 | 79.72 | 55.66 | 43.94 | 49 | | 16 | 2 | 0.9 | 68.47 | 55.13 | 43.12 | 50 | | 16 | 2 | 0.4 | 82.82 | 57.42 | 48.83 | 51 | | 16 | 2 | 0.1 | 81.68 | 51.37 | 45.72 | 52 | | 16 | 4 | 0.9 | 66.44 | 48.49 | 42.40 | 53 | | 16 | 4 | 0.4 | _83.62_ (fastest)| _58.29_ | _50.17_ | 54 | | 16 | 4 | 0.1 | 82.46 | 51.45 | 40.71 | 55 | | 8 | 1 | 0.4 | 67.07 | 55.17 | 48.46 | 56 | | 4 | 1 | 0.4 | 50.13 | 44.96 | 40.79 | 57 | 58 | The test script can be found in this [gist](https://gist.github.com/mostlygeek/da429769796ac8a111142e75660820f1). It is a simple curl script that prompts generating a snake game in Python, TypeScript, or Swift. Evaluation metrics were pulled from llama.cpp's logs. 59 | 60 | ```bash 61 | for lang in "python" "typescript" "swift"; do 62 | echo "Generating Snake Game in $lang using $model" 63 | curl -s --url http://localhost:8080/v1/chat/completions -d "{\"messages\": [{\"role\": \"system\", \"content\": \"you only write code.\"}, {\"role\": \"user\", \"content\": \"write snake game in $lang\"}], \"temperature\": 0.1, \"model\":\"$model\"}" > /dev/null 64 | done 65 | ``` 66 | 67 | Python consistently outperformed Swift in all tests, likely due to the 0.5B draft model being more proficient in generating Python code accepted by the larger 32B model. 68 | 69 | ## Chat 70 | 71 | This configuration is for a regular chat use case. It produces approximately 13 tokens/second in typical use, up from ~9 tokens/second with only the 3xP40s. This is great news for P40 owners. 72 | 73 | The models used are: 74 | 75 | * [Bartowski Meta-Llama-3.1-70B-Instruct-GGUF](https://huggingface.co/bartowski/Meta-Llama-3.1-70B-Instruct-GGUF) 76 | * [Bartowski Llama-3.2-3B-Instruct-GGUF](https://huggingface.co/bartowski/Llama-3.2-3B-Instruct-GGUF) 77 | 78 | ```yaml 79 | models: 80 | "llama-70B": 81 | cmd: > 82 | /mnt/nvme/llama-server/llama-server-be0e35 83 | --host 127.0.0.1 --port 9602 84 | --flash-attn --metrics 85 | --split-mode row 86 | --ctx-size 80000 87 | --model /mnt/nvme/models/Meta-Llama-3.1-70B-Instruct-Q4_K_L.gguf 88 | -ngl 99 89 | --model-draft /mnt/nvme/models/Llama-3.2-3B-Instruct-Q4_K_M.gguf 90 | -ngld 99 91 | --draft-max 16 92 | --draft-min 1 93 | --draft-p-min 0.4 94 | --device-draft CUDA0 95 | --tensor-split 0,1,1,1 96 | ``` 97 | 98 | In this configuration, Llama-3.1-70B is split across three P40s, and Llama-3.2-3B is on the 3090. 99 | 100 | Some flags deserve further explanation: 101 | 102 | * `--split-mode row` - increases inference speeds using multiple P40s by about 30%. This is a P40-specific feature. 103 | * `--tensor-split 0,1,1,1` - controls how the main model is split across the GPUs. This means 0% on the 3090 and an even split across the P40s. A value of `--tensor-split 0,5,4,1` would mean 0% on the 3090, 50%, 40%, and 10% respectively across the other P40s. However, this would exceed the available VRAM. 104 | * `--ctx-size 80000` - the maximum context size that can fit in the remaining VRAM. 105 | 106 | ## What is CUDA0, CUDA1, CUDA2, CUDA3? 107 | 108 | These devices are the IDs used by llama.cpp. 109 | 110 | ```bash 111 | $ ./llama-server --list-devices 112 | ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no 113 | ggml_cuda_init: GGML_CUDA_FORCE_CUBLAS: no 114 | ggml_cuda_init: found 4 CUDA devices: 115 | Device 0: NVIDIA GeForce RTX 3090, compute capability 8.6, VMM: yes 116 | Device 1: Tesla P40, compute capability 6.1, VMM: yes 117 | Device 2: Tesla P40, compute capability 6.1, VMM: yes 118 | Device 3: Tesla P40, compute capability 6.1, VMM: yes 119 | Available devices: 120 | CUDA0: NVIDIA GeForce RTX 3090 (24154 MiB, 23892 MiB free) 121 | CUDA1: Tesla P40 (24438 MiB, 24290 MiB free) 122 | CUDA2: Tesla P40 (24438 MiB, 24290 MiB free) 123 | CUDA3: Tesla P40 (24438 MiB, 24290 MiB free) 124 | ``` -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/mostlygeek/llama-swap 2 | 3 | go 1.23.0 4 | 5 | require ( 6 | github.com/fsnotify/fsnotify v1.9.0 7 | github.com/gin-gonic/gin v1.10.0 8 | github.com/stretchr/testify v1.9.0 9 | github.com/tidwall/gjson v1.18.0 10 | github.com/tidwall/sjson v1.2.5 11 | gopkg.in/yaml.v3 v3.0.1 12 | ) 13 | 14 | require ( 15 | github.com/billziss-gh/golib v0.2.0 // indirect 16 | github.com/bytedance/sonic v1.11.6 // indirect 17 | github.com/bytedance/sonic/loader v0.1.1 // indirect 18 | github.com/cloudwego/base64x v0.1.4 // indirect 19 | github.com/cloudwego/iasm v0.2.0 // indirect 20 | github.com/davecgh/go-spew v1.1.1 // indirect 21 | github.com/gabriel-vasile/mimetype v1.4.3 // indirect 22 | github.com/gin-contrib/sse v0.1.0 // indirect 23 | github.com/go-playground/locales v0.14.1 // indirect 24 | github.com/go-playground/universal-translator v0.18.1 // indirect 25 | github.com/go-playground/validator/v10 v10.20.0 // indirect 26 | github.com/goccy/go-json v0.10.2 // indirect 27 | github.com/json-iterator/go v1.1.12 // indirect 28 | github.com/klauspost/cpuid/v2 v2.2.7 // indirect 29 | github.com/leodido/go-urn v1.4.0 // indirect 30 | github.com/mattn/go-isatty v0.0.20 // indirect 31 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 32 | github.com/modern-go/reflect2 v1.0.2 // indirect 33 | github.com/pelletier/go-toml/v2 v2.2.2 // indirect 34 | github.com/pmezard/go-difflib v1.0.0 // indirect 35 | github.com/tidwall/match v1.1.1 // indirect 36 | github.com/tidwall/pretty v1.2.1 // indirect 37 | github.com/twitchyliquid64/golang-asm v0.15.1 // indirect 38 | github.com/ugorji/go/codec v1.2.12 // indirect 39 | golang.org/x/arch v0.8.0 // indirect 40 | golang.org/x/crypto v0.36.0 // indirect 41 | golang.org/x/net v0.38.0 // indirect 42 | golang.org/x/sys v0.31.0 // indirect 43 | golang.org/x/text v0.23.0 // indirect 44 | google.golang.org/protobuf v1.34.1 // indirect 45 | ) 46 | -------------------------------------------------------------------------------- /go.sum: -------------------------------------------------------------------------------- 1 | github.com/billziss-gh/golib v0.2.0 h1:NyvcAQdfvM8xokKkKotiligKjKXzuQD4PPykg1nKc/8= 2 | github.com/billziss-gh/golib v0.2.0/go.mod h1:mZpUYANXZkDKSnyYbX9gfnyxwe0ddRhUtfXcsD5r8dw= 3 | github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= 4 | github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= 5 | github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= 6 | github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= 7 | github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= 8 | github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= 9 | github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= 10 | github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= 11 | github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 12 | github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= 13 | github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= 14 | github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= 15 | github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= 16 | github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= 17 | github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= 18 | github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= 19 | github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= 20 | github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= 21 | github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= 22 | github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= 23 | github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= 24 | github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= 25 | github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= 26 | github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= 27 | github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= 28 | github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= 29 | github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= 30 | github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= 31 | github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= 32 | github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= 33 | github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= 34 | github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= 35 | github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= 36 | github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= 37 | github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= 38 | github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= 39 | github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= 40 | github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= 41 | github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= 42 | github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= 43 | github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= 44 | github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= 45 | github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= 46 | github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= 47 | github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 48 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= 49 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= 50 | github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= 51 | github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= 52 | github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= 53 | github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= 54 | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= 55 | github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 56 | github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 57 | github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 58 | github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 59 | github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= 60 | github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 61 | github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 62 | github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 63 | github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 64 | github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= 65 | github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 66 | github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= 67 | github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 68 | github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 69 | github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= 70 | github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= 71 | github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= 72 | github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= 73 | github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 74 | github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= 75 | github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= 76 | github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= 77 | github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= 78 | github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= 79 | github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= 80 | github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= 81 | github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= 82 | golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= 83 | golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= 84 | golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= 85 | golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= 86 | golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= 87 | golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= 88 | golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= 89 | golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 90 | golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 91 | golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= 92 | golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 93 | golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= 94 | golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= 95 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= 96 | golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 97 | google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= 98 | google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= 99 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= 100 | gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 101 | gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 102 | gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= 103 | gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= 104 | nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= 105 | rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= 106 | -------------------------------------------------------------------------------- /header.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostlygeek/llama-swap/a84098d3b43b5453f0600d942da7cf9bf762612a/header.jpeg -------------------------------------------------------------------------------- /header2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostlygeek/llama-swap/a84098d3b43b5453f0600d942da7cf9bf762612a/header2.png -------------------------------------------------------------------------------- /llama-swap.go: -------------------------------------------------------------------------------- 1 | package main 2 | 3 | import ( 4 | "context" 5 | "flag" 6 | "fmt" 7 | "log" 8 | "net/http" 9 | "os" 10 | "os/signal" 11 | "path/filepath" 12 | "syscall" 13 | "time" 14 | 15 | "github.com/fsnotify/fsnotify" 16 | "github.com/gin-gonic/gin" 17 | "github.com/mostlygeek/llama-swap/proxy" 18 | ) 19 | 20 | var ( 21 | version string = "0" 22 | commit string = "abcd1234" 23 | date string = "unknown" 24 | ) 25 | 26 | func main() { 27 | // Define a command-line flag for the port 28 | configPath := flag.String("config", "config.yaml", "config file name") 29 | listenStr := flag.String("listen", ":8080", "listen ip/port") 30 | showVersion := flag.Bool("version", false, "show version of build") 31 | watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change") 32 | 33 | flag.Parse() // Parse the command-line flags 34 | 35 | if *showVersion { 36 | fmt.Printf("version: %s (%s), built at %s\n", version, commit, date) 37 | os.Exit(0) 38 | } 39 | 40 | config, err := proxy.LoadConfig(*configPath) 41 | if err != nil { 42 | fmt.Printf("Error loading config: %v\n", err) 43 | os.Exit(1) 44 | } 45 | 46 | if len(config.Profiles) > 0 { 47 | fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.") 48 | } 49 | 50 | if mode := os.Getenv("GIN_MODE"); mode != "" { 51 | gin.SetMode(mode) 52 | } else { 53 | gin.SetMode(gin.ReleaseMode) 54 | } 55 | 56 | proxyManager := proxy.New(config) 57 | 58 | // Setup channels for server management 59 | reloadChan := make(chan *proxy.ProxyManager) 60 | exitChan := make(chan struct{}) 61 | sigChan := make(chan os.Signal, 1) 62 | signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) 63 | 64 | // Create server with initial handler 65 | srv := &http.Server{ 66 | Addr: *listenStr, 67 | Handler: proxyManager, 68 | } 69 | 70 | // Start server 71 | fmt.Printf("llama-swap listening on %s\n", *listenStr) 72 | go func() { 73 | if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { 74 | fmt.Printf("Fatal server error: %v\n", err) 75 | close(exitChan) 76 | } 77 | }() 78 | 79 | // Handle config reloads and signals 80 | go func() { 81 | currentManager := proxyManager 82 | for { 83 | select { 84 | case newManager := <-reloadChan: 85 | log.Println("Config change detected, waiting for in-flight requests to complete...") 86 | // Stop old manager processes gracefully (this waits for in-flight requests) 87 | currentManager.StopProcesses(proxy.StopWaitForInflightRequest) 88 | // Now do a full shutdown to clear the process map 89 | currentManager.Shutdown() 90 | currentManager = newManager 91 | srv.Handler = newManager 92 | log.Println("Server handler updated with new config") 93 | case sig := <-sigChan: 94 | fmt.Printf("Received signal %v, shutting down...\n", sig) 95 | ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 96 | defer cancel() 97 | currentManager.Shutdown() 98 | if err := srv.Shutdown(ctx); err != nil { 99 | fmt.Printf("Server shutdown error: %v\n", err) 100 | } 101 | close(exitChan) 102 | return 103 | } 104 | } 105 | }() 106 | 107 | // Start file watcher if requested 108 | if *watchConfig { 109 | absConfigPath, err := filepath.Abs(*configPath) 110 | if err != nil { 111 | log.Printf("Error getting absolute path for config: %v. File watching disabled.", err) 112 | } else { 113 | go watchConfigFileWithReload(absConfigPath, reloadChan) 114 | } 115 | } 116 | 117 | // Wait for exit signal 118 | <-exitChan 119 | } 120 | 121 | // watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan. 122 | func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) { 123 | watcher, err := fsnotify.NewWatcher() 124 | if err != nil { 125 | log.Printf("Error creating file watcher: %v. File watching disabled.", err) 126 | return 127 | } 128 | defer watcher.Close() 129 | 130 | err = watcher.Add(configPath) 131 | if err != nil { 132 | log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err) 133 | return 134 | } 135 | 136 | log.Printf("Watching config file for changes: %s", configPath) 137 | 138 | var debounceTimer *time.Timer 139 | debounceDuration := 2 * time.Second 140 | 141 | for { 142 | select { 143 | case event, ok := <-watcher.Events: 144 | if !ok { 145 | return 146 | } 147 | // We only care about writes to the specific config file 148 | if event.Name == configPath && event.Has(fsnotify.Write) { 149 | // Reset or start the debounce timer 150 | if debounceTimer != nil { 151 | debounceTimer.Stop() 152 | } 153 | debounceTimer = time.AfterFunc(debounceDuration, func() { 154 | log.Printf("Config file modified: %s, reloading...", event.Name) 155 | 156 | // Try up to 3 times with exponential backoff 157 | var newConfig proxy.Config 158 | var err error 159 | for retries := 0; retries < 3; retries++ { 160 | // Load new configuration 161 | newConfig, err = proxy.LoadConfig(configPath) 162 | if err == nil { 163 | break 164 | } 165 | log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err) 166 | if retries < 2 { 167 | time.Sleep(time.Duration(1< 1 { 224 | break runloop 225 | } else { 226 | log.Println("Received SIGINT, send another SIGINT to shutdown") 227 | } 228 | case syscall.SIGTERM: 229 | if *ignoreSigTerm { 230 | log.Println("Ignoring SIGTERM") 231 | } else { 232 | log.Println("Received SIGTERM, shutting down") 233 | break runloop 234 | } 235 | default: 236 | break runloop 237 | } 238 | } 239 | 240 | log.Println("simple-responder shutting down") 241 | } 242 | -------------------------------------------------------------------------------- /misc/test-rerank/README.md: -------------------------------------------------------------------------------- 1 | The rerank-test.json data is from https://github.com/ggerganov/llama.cpp/pull/9510 2 | 3 | To run it: 4 | > curl http://127.0.0.1:8080/v1/rerank -H "Content-Type: application/json" -d @reranker-test.json -v | jq . -------------------------------------------------------------------------------- /misc/test-rerank/reranker-test.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "bge-reranker", 3 | "query": "Organic skincare products for sensitive skin", 4 | "top_n": 3, 5 | "documents": [ 6 | "Organic skincare for sensitive skin with aloe vera and chamomile: Imagine the soothing embrace of nature with our organic skincare range, crafted specifically for sensitive skin. Infused with the calming properties of aloe vera and chamomile, each product provides gentle nourishment and protection. Say goodbye to irritation and hello to a glowing, healthy complexion.", 7 | "New makeup trends focus on bold colors and innovative techniques: Step into the world of cutting-edge beauty with this seasons makeup trends. Bold, vibrant colors and groundbreaking techniques are redefining the art of makeup. From neon eyeliners to holographic highlighters, unleash your creativity and make a statement with every look.", 8 | "Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille: Erleben Sie die wohltuende Wirkung unserer Bio-Hautpflege, speziell für empfindliche Haut entwickelt. Mit den beruhigenden Eigenschaften von Aloe Vera und Kamille pflegen und schützen unsere Produkte Ihre Haut auf natürliche Weise. Verabschieden Sie sich von Hautirritationen und genießen Sie einen strahlenden Teint.", 9 | "Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken: Tauchen Sie ein in die Welt der modernen Schönheit mit den neuesten Make-up-Trends. Kräftige, lebendige Farben und innovative Techniken setzen neue Maßstäbe. Von auffälligen Eyelinern bis hin zu holografischen Highlightern – lassen Sie Ihrer Kreativität freien Lauf und setzen Sie jedes Mal ein Statement.", 10 | "Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla: Descubre el poder de la naturaleza con nuestra línea de cuidado de la piel orgánico, diseñada especialmente para pieles sensibles. Enriquecidos con aloe vera y manzanilla, estos productos ofrecen una hidratación y protección suave. Despídete de las irritaciones y saluda a una piel radiante y saludable.", 11 | "Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras: Entra en el fascinante mundo del maquillaje con las tendencias más actuales. Colores vivos y técnicas innovadoras están revolucionando el arte del maquillaje. Desde delineadores neón hasta iluminadores holográficos, desata tu creatividad y destaca en cada look.", 12 | "针对敏感肌专门设计的天然有机护肤产品:体验由芦荟和洋甘菊提取物带来的自然呵护。我们的护肤产品特别为敏感肌设计,温和滋润,保护您的肌肤不受刺激。让您的肌肤告别不适,迎来健康光彩。", 13 | "新的化妆趋势注重鲜艳的颜色和创新的技巧:进入化妆艺术的新纪元,本季的化妆趋势以大胆的颜色和创新的技巧为主。无论是霓虹眼线还是全息高光,每一款妆容都能让您脱颖而出,展现独特魅力。", 14 | "敏感肌のために特別に設計された天然有機スキンケア製品: アロエベラとカモミールのやさしい力で、自然の抱擁を感じてください。敏感肌用に特別に設計された私たちのスキンケア製品は、肌に優しく栄養を与え、保護します。肌トラブルにさようなら、輝く健康な肌にこんにちは。", 15 | "新しいメイクのトレンドは鮮やかな色と革新的な技術に焦点を当てています: 今シーズンのメイクアップトレンドは、大胆な色彩と革新的な技術に注目しています。ネオンアイライナーからホログラフィックハイライターまで、クリエイティビティを解き放ち、毎回ユニークなルックを演出しましょう。" 16 | ] 17 | } 18 | -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !README.md -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | TODO improve these docs 2 | 3 | 1. Download a llama-server suitable for your architecture 4 | 1. Fetch some small models for testing / swapping between 5 | - `huggingface-cli download bartowski/Qwen2.5-1.5B-Instruct-GGUF --include "Qwen2.5-1.5B-Instruct-Q4_K_M.gguf" --local-dir ./` 6 | - `huggingface-cli download bartowski/Llama-3.2-1B-Instruct-GGUF --include "Llama-3.2-1B-Instruct-Q4_K_M.gguf" --local-dir ./` 7 | 1. Create a new config.yaml file (see `config.example.yaml`) pointing to the models -------------------------------------------------------------------------------- /proxy/config.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "io" 6 | "os" 7 | "regexp" 8 | "runtime" 9 | "sort" 10 | "strconv" 11 | "strings" 12 | 13 | "github.com/billziss-gh/golib/shlex" 14 | "gopkg.in/yaml.v3" 15 | ) 16 | 17 | const DEFAULT_GROUP_ID = "(default)" 18 | 19 | type ModelConfig struct { 20 | Cmd string `yaml:"cmd"` 21 | CmdStop string `yaml:"cmdStop"` 22 | Proxy string `yaml:"proxy"` 23 | Aliases []string `yaml:"aliases"` 24 | Env []string `yaml:"env"` 25 | CheckEndpoint string `yaml:"checkEndpoint"` 26 | UnloadAfter int `yaml:"ttl"` 27 | Unlisted bool `yaml:"unlisted"` 28 | UseModelName string `yaml:"useModelName"` 29 | 30 | // Limit concurrency of HTTP requests to process 31 | ConcurrencyLimit int `yaml:"concurrencyLimit"` 32 | } 33 | 34 | func (m *ModelConfig) SanitizedCommand() ([]string, error) { 35 | return SanitizeCommand(m.Cmd) 36 | } 37 | 38 | type GroupConfig struct { 39 | Swap bool `yaml:"swap"` 40 | Exclusive bool `yaml:"exclusive"` 41 | Persistent bool `yaml:"persistent"` 42 | Members []string `yaml:"members"` 43 | } 44 | 45 | // set default values for GroupConfig 46 | func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { 47 | type rawGroupConfig GroupConfig 48 | defaults := rawGroupConfig{ 49 | Swap: true, 50 | Exclusive: true, 51 | Persistent: false, 52 | Members: []string{}, 53 | } 54 | 55 | if err := unmarshal(&defaults); err != nil { 56 | return err 57 | } 58 | 59 | *c = GroupConfig(defaults) 60 | return nil 61 | } 62 | 63 | type Config struct { 64 | HealthCheckTimeout int `yaml:"healthCheckTimeout"` 65 | LogRequests bool `yaml:"logRequests"` 66 | LogLevel string `yaml:"logLevel"` 67 | Models map[string]ModelConfig `yaml:"models"` /* key is model ID */ 68 | Profiles map[string][]string `yaml:"profiles"` 69 | Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */ 70 | 71 | // for key/value replacements in model's cmd, cmdStop, proxy, checkEndPoint 72 | Macros map[string]string `yaml:"macros"` 73 | 74 | // map aliases to actual model IDs 75 | aliases map[string]string 76 | 77 | // automatic port assignments 78 | StartPort int `yaml:"startPort"` 79 | } 80 | 81 | func (c *Config) RealModelName(search string) (string, bool) { 82 | if _, found := c.Models[search]; found { 83 | return search, true 84 | } else if name, found := c.aliases[search]; found { 85 | return name, found 86 | } else { 87 | return "", false 88 | } 89 | } 90 | 91 | func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) { 92 | if realName, found := c.RealModelName(modelName); !found { 93 | return ModelConfig{}, "", false 94 | } else { 95 | return c.Models[realName], realName, true 96 | } 97 | } 98 | 99 | func LoadConfig(path string) (Config, error) { 100 | file, err := os.Open(path) 101 | if err != nil { 102 | return Config{}, err 103 | } 104 | defer file.Close() 105 | return LoadConfigFromReader(file) 106 | } 107 | 108 | func LoadConfigFromReader(r io.Reader) (Config, error) { 109 | data, err := io.ReadAll(r) 110 | if err != nil { 111 | return Config{}, err 112 | } 113 | 114 | var config Config 115 | err = yaml.Unmarshal(data, &config) 116 | if err != nil { 117 | return Config{}, err 118 | } 119 | 120 | if config.HealthCheckTimeout == 0 { 121 | // this high default timeout helps avoid failing health checks 122 | // for configurations that wait for docker or have slower startup 123 | config.HealthCheckTimeout = 120 124 | } else if config.HealthCheckTimeout < 15 { 125 | // set a minimum of 15 seconds 126 | config.HealthCheckTimeout = 15 127 | } 128 | 129 | // set default port ranges 130 | if config.StartPort == 0 { 131 | // default to 5800 132 | config.StartPort = 5800 133 | } else if config.StartPort < 1 { 134 | return Config{}, fmt.Errorf("startPort must be greater than 1") 135 | } 136 | 137 | // Populate the aliases map 138 | config.aliases = make(map[string]string) 139 | for modelName, modelConfig := range config.Models { 140 | for _, alias := range modelConfig.Aliases { 141 | if _, found := config.aliases[alias]; found { 142 | return Config{}, fmt.Errorf("duplicate alias %s found in model: %s", alias, modelName) 143 | } 144 | config.aliases[alias] = modelName 145 | } 146 | } 147 | 148 | /* check macro constraint rules: 149 | 150 | - name must fit the regex ^[a-zA-Z0-9_-]+$ 151 | - names must be less than 64 characters (no reason, just cause) 152 | - name can not be any reserved macros: PORT 153 | - macro values must be less than 1024 characters 154 | */ 155 | macroNameRegex := regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) 156 | for macroName, macroValue := range config.Macros { 157 | if len(macroName) >= 64 { 158 | return Config{}, fmt.Errorf("macro name '%s' exceeds maximum length of 63 characters", macroName) 159 | } 160 | if !macroNameRegex.MatchString(macroName) { 161 | return Config{}, fmt.Errorf("macro name '%s' contains invalid characters, must match pattern ^[a-zA-Z0-9_-]+$", macroName) 162 | } 163 | if len(macroValue) >= 1024 { 164 | return Config{}, fmt.Errorf("macro value for '%s' exceeds maximum length of 1024 characters", macroName) 165 | } 166 | switch macroName { 167 | case "PORT": 168 | return Config{}, fmt.Errorf("macro name '%s' is reserved and cannot be used", macroName) 169 | } 170 | } 171 | 172 | // Get and sort all model IDs first, makes testing more consistent 173 | modelIds := make([]string, 0, len(config.Models)) 174 | for modelId := range config.Models { 175 | modelIds = append(modelIds, modelId) 176 | } 177 | sort.Strings(modelIds) // This guarantees stable iteration order 178 | 179 | nextPort := config.StartPort 180 | for _, modelId := range modelIds { 181 | modelConfig := config.Models[modelId] 182 | 183 | // go through model config fields: cmd, cmdStop, proxy, checkEndPoint and replace macros with macro values 184 | for macroName, macroValue := range config.Macros { 185 | macroSlug := fmt.Sprintf("${%s}", macroName) 186 | modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, macroSlug, macroValue) 187 | modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, macroSlug, macroValue) 188 | modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, macroSlug, macroValue) 189 | modelConfig.CheckEndpoint = strings.ReplaceAll(modelConfig.CheckEndpoint, macroSlug, macroValue) 190 | } 191 | 192 | // only iterate over models that use ${PORT} to keep port numbers from increasing unnecessarily 193 | if strings.Contains(modelConfig.Cmd, "${PORT}") || strings.Contains(modelConfig.Proxy, "${PORT}") || strings.Contains(modelConfig.CmdStop, "${PORT}") { 194 | if modelConfig.Proxy == "" { 195 | modelConfig.Proxy = "http://localhost:${PORT}" 196 | } 197 | 198 | nextPortStr := strconv.Itoa(nextPort) 199 | modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", nextPortStr) 200 | modelConfig.CmdStop = strings.ReplaceAll(modelConfig.CmdStop, "${PORT}", nextPortStr) 201 | modelConfig.Proxy = strings.ReplaceAll(modelConfig.Proxy, "${PORT}", nextPortStr) 202 | nextPort++ 203 | } else if modelConfig.Proxy == "" { 204 | return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId) 205 | } 206 | 207 | // make sure there are no unknown macros that have not been replaced 208 | macroPattern := regexp.MustCompile(`\$\{([a-zA-Z0-9_-]+)\}`) 209 | fieldMap := map[string]string{ 210 | "cmd": modelConfig.Cmd, 211 | "cmdStop": modelConfig.CmdStop, 212 | "proxy": modelConfig.Proxy, 213 | "checkEndpoint": modelConfig.CheckEndpoint, 214 | } 215 | 216 | for fieldName, fieldValue := range fieldMap { 217 | matches := macroPattern.FindAllStringSubmatch(fieldValue, -1) 218 | for _, match := range matches { 219 | macroName := match[1] 220 | if _, exists := config.Macros[macroName]; !exists { 221 | return Config{}, fmt.Errorf("unknown macro '${%s}' found in %s.%s", macroName, modelId, fieldName) 222 | } 223 | } 224 | } 225 | 226 | config.Models[modelId] = modelConfig 227 | } 228 | 229 | config = AddDefaultGroupToConfig(config) 230 | // check that members are all unique in the groups 231 | memberUsage := make(map[string]string) // maps member to group it appears in 232 | for groupID, groupConfig := range config.Groups { 233 | prevSet := make(map[string]bool) 234 | for _, member := range groupConfig.Members { 235 | // Check for duplicates within this group 236 | if _, found := prevSet[member]; found { 237 | return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID) 238 | } 239 | prevSet[member] = true 240 | 241 | // Check if member is used in another group 242 | if existingGroup, exists := memberUsage[member]; exists { 243 | return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID) 244 | } 245 | memberUsage[member] = groupID 246 | } 247 | } 248 | 249 | return config, nil 250 | } 251 | 252 | // rewrites the yaml to include a default group with any orphaned models 253 | func AddDefaultGroupToConfig(config Config) Config { 254 | 255 | if config.Groups == nil { 256 | config.Groups = make(map[string]GroupConfig) 257 | } 258 | 259 | defaultGroup := GroupConfig{ 260 | Swap: true, 261 | Exclusive: true, 262 | Members: []string{}, 263 | } 264 | // if groups is empty, create a default group and put 265 | // all models into it 266 | if len(config.Groups) == 0 { 267 | for modelName := range config.Models { 268 | defaultGroup.Members = append(defaultGroup.Members, modelName) 269 | } 270 | } else { 271 | // iterate over existing group members and add non-grouped models into the default group 272 | for modelName, _ := range config.Models { 273 | foundModel := false 274 | found: 275 | // search for the model in existing groups 276 | for _, groupConfig := range config.Groups { 277 | for _, member := range groupConfig.Members { 278 | if member == modelName { 279 | foundModel = true 280 | break found 281 | } 282 | } 283 | } 284 | 285 | if !foundModel { 286 | defaultGroup.Members = append(defaultGroup.Members, modelName) 287 | } 288 | } 289 | } 290 | 291 | sort.Strings(defaultGroup.Members) // make consistent ordering for testing 292 | config.Groups[DEFAULT_GROUP_ID] = defaultGroup 293 | 294 | return config 295 | } 296 | 297 | func SanitizeCommand(cmdStr string) ([]string, error) { 298 | var cleanedLines []string 299 | for _, line := range strings.Split(cmdStr, "\n") { 300 | trimmed := strings.TrimSpace(line) 301 | // Skip comment lines 302 | if strings.HasPrefix(trimmed, "#") { 303 | continue 304 | } 305 | // Handle trailing backslashes by replacing with space 306 | if strings.HasSuffix(trimmed, "\\") { 307 | cleanedLines = append(cleanedLines, strings.TrimSuffix(trimmed, "\\")+" ") 308 | } else { 309 | cleanedLines = append(cleanedLines, line) 310 | } 311 | } 312 | 313 | // put it back together 314 | cmdStr = strings.Join(cleanedLines, "\n") 315 | 316 | // Split the command into arguments 317 | var args []string 318 | if runtime.GOOS == "windows" { 319 | args = shlex.Windows.Split(cmdStr) 320 | } else { 321 | args = shlex.Posix.Split(cmdStr) 322 | } 323 | 324 | // Ensure the command is not empty 325 | if len(args) == 0 { 326 | return nil, fmt.Errorf("empty command") 327 | } 328 | 329 | return args, nil 330 | } 331 | -------------------------------------------------------------------------------- /proxy/config_posix_test.go: -------------------------------------------------------------------------------- 1 | //go:build !windows 2 | 3 | package proxy 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestConfig_SanitizeCommand(t *testing.T) { 12 | // Test a command with spaces and newlines 13 | args, err := SanitizeCommand(`python model1.py \ 14 | -a "double quotes" \ 15 | --arg2 'single quotes' 16 | -s 17 | # comment 1 18 | --arg3 123 \ 19 | 20 | # comment 2 21 | --arg4 '"string in string"' 22 | 23 | 24 | # this will get stripped out as well as the white space above 25 | -c "'single quoted'" 26 | `) 27 | assert.NoError(t, err) 28 | assert.Equal(t, []string{ 29 | "python", "model1.py", 30 | "-a", "double quotes", 31 | "--arg2", "single quotes", 32 | "-s", 33 | "--arg3", "123", 34 | "--arg4", `"string in string"`, 35 | "-c", `'single quoted'`, 36 | }, args) 37 | 38 | // Test an empty command 39 | args, err = SanitizeCommand("") 40 | assert.Error(t, err) 41 | assert.Nil(t, args) 42 | } 43 | -------------------------------------------------------------------------------- /proxy/config_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "os" 5 | "path/filepath" 6 | "strings" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | func TestConfig_Load(t *testing.T) { 13 | // Create a temporary YAML file for testing 14 | tempDir, err := os.MkdirTemp("", "test-config") 15 | if err != nil { 16 | t.Fatalf("Failed to create temporary directory: %v", err) 17 | } 18 | defer os.RemoveAll(tempDir) 19 | 20 | tempFile := filepath.Join(tempDir, "config.yaml") 21 | content := ` 22 | macros: 23 | svr-path: "path/to/server" 24 | models: 25 | model1: 26 | cmd: path/to/cmd --arg1 one 27 | proxy: "http://localhost:8080" 28 | aliases: 29 | - "m1" 30 | - "model-one" 31 | env: 32 | - "VAR1=value1" 33 | - "VAR2=value2" 34 | checkEndpoint: "/health" 35 | model2: 36 | cmd: ${svr-path} --arg1 one 37 | proxy: "http://localhost:8081" 38 | aliases: 39 | - "m2" 40 | checkEndpoint: "/" 41 | model3: 42 | cmd: path/to/cmd --arg1 one 43 | proxy: "http://localhost:8081" 44 | aliases: 45 | - "mthree" 46 | checkEndpoint: "/" 47 | model4: 48 | cmd: path/to/cmd --arg1 one 49 | proxy: "http://localhost:8082" 50 | checkEndpoint: "/" 51 | 52 | healthCheckTimeout: 15 53 | profiles: 54 | test: 55 | - model1 56 | - model2 57 | groups: 58 | group1: 59 | swap: true 60 | exclusive: false 61 | members: ["model2"] 62 | forever: 63 | exclusive: false 64 | persistent: true 65 | members: 66 | - "model4" 67 | ` 68 | 69 | if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil { 70 | t.Fatalf("Failed to write temporary file: %v", err) 71 | } 72 | 73 | // Load the config and verify 74 | config, err := LoadConfig(tempFile) 75 | if err != nil { 76 | t.Fatalf("Failed to load config: %v", err) 77 | } 78 | 79 | expected := Config{ 80 | StartPort: 5800, 81 | Macros: map[string]string{ 82 | "svr-path": "path/to/server", 83 | }, 84 | Models: map[string]ModelConfig{ 85 | "model1": { 86 | Cmd: "path/to/cmd --arg1 one", 87 | Proxy: "http://localhost:8080", 88 | Aliases: []string{"m1", "model-one"}, 89 | Env: []string{"VAR1=value1", "VAR2=value2"}, 90 | CheckEndpoint: "/health", 91 | }, 92 | "model2": { 93 | Cmd: "path/to/server --arg1 one", 94 | Proxy: "http://localhost:8081", 95 | Aliases: []string{"m2"}, 96 | Env: nil, 97 | CheckEndpoint: "/", 98 | }, 99 | "model3": { 100 | Cmd: "path/to/cmd --arg1 one", 101 | Proxy: "http://localhost:8081", 102 | Aliases: []string{"mthree"}, 103 | Env: nil, 104 | CheckEndpoint: "/", 105 | }, 106 | "model4": { 107 | Cmd: "path/to/cmd --arg1 one", 108 | Proxy: "http://localhost:8082", 109 | CheckEndpoint: "/", 110 | }, 111 | }, 112 | HealthCheckTimeout: 15, 113 | Profiles: map[string][]string{ 114 | "test": {"model1", "model2"}, 115 | }, 116 | aliases: map[string]string{ 117 | "m1": "model1", 118 | "model-one": "model1", 119 | "m2": "model2", 120 | "mthree": "model3", 121 | }, 122 | Groups: map[string]GroupConfig{ 123 | DEFAULT_GROUP_ID: { 124 | Swap: true, 125 | Exclusive: true, 126 | Members: []string{"model1", "model3"}, 127 | }, 128 | "group1": { 129 | Swap: true, 130 | Exclusive: false, 131 | Members: []string{"model2"}, 132 | }, 133 | "forever": { 134 | Swap: true, 135 | Exclusive: false, 136 | Persistent: true, 137 | Members: []string{"model4"}, 138 | }, 139 | }, 140 | } 141 | 142 | assert.Equal(t, expected, config) 143 | 144 | realname, found := config.RealModelName("m1") 145 | assert.True(t, found) 146 | assert.Equal(t, "model1", realname) 147 | } 148 | 149 | func TestConfig_GroupMemberIsUnique(t *testing.T) { 150 | content := ` 151 | models: 152 | model1: 153 | cmd: path/to/cmd --arg1 one 154 | proxy: "http://localhost:8080" 155 | model2: 156 | cmd: path/to/cmd --arg1 one 157 | proxy: "http://localhost:8081" 158 | checkEndpoint: "/" 159 | model3: 160 | cmd: path/to/cmd --arg1 one 161 | proxy: "http://localhost:8081" 162 | checkEndpoint: "/" 163 | 164 | healthCheckTimeout: 15 165 | groups: 166 | group1: 167 | swap: true 168 | exclusive: false 169 | members: ["model2"] 170 | group2: 171 | swap: true 172 | exclusive: false 173 | members: ["model2"] 174 | ` 175 | // Load the config and verify 176 | _, err := LoadConfigFromReader(strings.NewReader(content)) 177 | 178 | // a Contains as order of the map is not guaranteed 179 | assert.Contains(t, err.Error(), "model member model2 is used in multiple groups:") 180 | } 181 | 182 | func TestConfig_ModelAliasesAreUnique(t *testing.T) { 183 | content := ` 184 | models: 185 | model1: 186 | cmd: path/to/cmd --arg1 one 187 | proxy: "http://localhost:8080" 188 | aliases: 189 | - m1 190 | model2: 191 | cmd: path/to/cmd --arg1 one 192 | proxy: "http://localhost:8081" 193 | checkEndpoint: "/" 194 | aliases: 195 | - m1 196 | - m2 197 | ` 198 | // Load the config and verify 199 | _, err := LoadConfigFromReader(strings.NewReader(content)) 200 | 201 | // this is a contains because it could be `model1` or `model2` depending on the order 202 | // go decided on the order of the map 203 | assert.Contains(t, err.Error(), "duplicate alias m1 found in model: model") 204 | } 205 | 206 | func TestConfig_ModelConfigSanitizedCommand(t *testing.T) { 207 | config := &ModelConfig{ 208 | Cmd: `python model1.py \ 209 | --arg1 value1 \ 210 | --arg2 value2`, 211 | } 212 | 213 | args, err := config.SanitizedCommand() 214 | assert.NoError(t, err) 215 | assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args) 216 | } 217 | 218 | func TestConfig_FindConfig(t *testing.T) { 219 | 220 | // TODO? 221 | // make make this shared between the different tests 222 | config := &Config{ 223 | Models: map[string]ModelConfig{ 224 | "model1": { 225 | Cmd: "python model1.py", 226 | Proxy: "http://localhost:8080", 227 | Aliases: []string{"m1", "model-one"}, 228 | Env: []string{"VAR1=value1", "VAR2=value2"}, 229 | CheckEndpoint: "/health", 230 | }, 231 | "model2": { 232 | Cmd: "python model2.py", 233 | Proxy: "http://localhost:8081", 234 | Aliases: []string{"m2", "model-two"}, 235 | Env: []string{"VAR3=value3", "VAR4=value4"}, 236 | CheckEndpoint: "/status", 237 | }, 238 | }, 239 | HealthCheckTimeout: 10, 240 | aliases: map[string]string{ 241 | "m1": "model1", 242 | "model-one": "model1", 243 | "m2": "model2", 244 | }, 245 | } 246 | 247 | // Test finding a model by its name 248 | modelConfig, modelId, found := config.FindConfig("model1") 249 | assert.True(t, found) 250 | assert.Equal(t, "model1", modelId) 251 | assert.Equal(t, config.Models["model1"], modelConfig) 252 | 253 | // Test finding a model by its alias 254 | modelConfig, modelId, found = config.FindConfig("m1") 255 | assert.True(t, found) 256 | assert.Equal(t, "model1", modelId) 257 | assert.Equal(t, config.Models["model1"], modelConfig) 258 | 259 | // Test finding a model that does not exist 260 | modelConfig, modelId, found = config.FindConfig("model3") 261 | assert.False(t, found) 262 | assert.Equal(t, "", modelId) 263 | assert.Equal(t, ModelConfig{}, modelConfig) 264 | } 265 | 266 | func TestConfig_AutomaticPortAssignments(t *testing.T) { 267 | 268 | t.Run("Default Port Ranges", func(t *testing.T) { 269 | content := `` 270 | config, err := LoadConfigFromReader(strings.NewReader(content)) 271 | if !assert.NoError(t, err) { 272 | t.Fatalf("Failed to load config: %v", err) 273 | } 274 | 275 | assert.Equal(t, 5800, config.StartPort) 276 | }) 277 | t.Run("User specific port ranges", func(t *testing.T) { 278 | content := `startPort: 1000` 279 | config, err := LoadConfigFromReader(strings.NewReader(content)) 280 | if !assert.NoError(t, err) { 281 | t.Fatalf("Failed to load config: %v", err) 282 | } 283 | 284 | assert.Equal(t, 1000, config.StartPort) 285 | }) 286 | 287 | t.Run("Invalid start port", func(t *testing.T) { 288 | content := `startPort: abcd` 289 | _, err := LoadConfigFromReader(strings.NewReader(content)) 290 | assert.NotNil(t, err) 291 | }) 292 | 293 | t.Run("start port must be greater than 1", func(t *testing.T) { 294 | content := `startPort: -99` 295 | _, err := LoadConfigFromReader(strings.NewReader(content)) 296 | assert.NotNil(t, err) 297 | }) 298 | 299 | t.Run("Automatic port assignments", func(t *testing.T) { 300 | content := ` 301 | startPort: 5800 302 | models: 303 | model1: 304 | cmd: svr --port ${PORT} 305 | model2: 306 | cmd: svr --port ${PORT} 307 | proxy: "http://172.11.22.33:${PORT}" 308 | model3: 309 | cmd: svr --port 1999 310 | proxy: "http://1.2.3.4:1999" 311 | ` 312 | config, err := LoadConfigFromReader(strings.NewReader(content)) 313 | if !assert.NoError(t, err) { 314 | t.Fatalf("Failed to load config: %v", err) 315 | } 316 | 317 | assert.Equal(t, 5800, config.StartPort) 318 | assert.Equal(t, "svr --port 5800", config.Models["model1"].Cmd) 319 | assert.Equal(t, "http://localhost:5800", config.Models["model1"].Proxy) 320 | 321 | assert.Equal(t, "svr --port 5801", config.Models["model2"].Cmd) 322 | assert.Equal(t, "http://172.11.22.33:5801", config.Models["model2"].Proxy) 323 | 324 | assert.Equal(t, "svr --port 1999", config.Models["model3"].Cmd) 325 | assert.Equal(t, "http://1.2.3.4:1999", config.Models["model3"].Proxy) 326 | 327 | }) 328 | 329 | t.Run("Proxy value required if no ${PORT} in cmd", func(t *testing.T) { 330 | content := ` 331 | models: 332 | model1: 333 | cmd: svr --port 111 334 | ` 335 | _, err := LoadConfigFromReader(strings.NewReader(content)) 336 | assert.Equal(t, "model model1 requires a proxy value when not using automatic ${PORT}", err.Error()) 337 | }) 338 | } 339 | 340 | func TestConfig_MacroReplacement(t *testing.T) { 341 | content := ` 342 | startPort: 9990 343 | macros: 344 | svr-path: "path/to/server" 345 | argOne: "--arg1" 346 | argTwo: "--arg2" 347 | autoPort: "--port ${PORT}" 348 | 349 | models: 350 | model1: 351 | cmd: | 352 | ${svr-path} ${argTwo} 353 | # the automatic ${PORT} is replaced 354 | ${autoPort} 355 | ${argOne} 356 | --arg3 three 357 | cmdStop: | 358 | /path/to/stop.sh --port ${PORT} ${argTwo} 359 | ` 360 | 361 | config, err := LoadConfigFromReader(strings.NewReader(content)) 362 | assert.NoError(t, err) 363 | sanitizedCmd, err := SanitizeCommand(config.Models["model1"].Cmd) 364 | assert.NoError(t, err) 365 | assert.Equal(t, "path/to/server --arg2 --port 9990 --arg1 --arg3 three", strings.Join(sanitizedCmd, " ")) 366 | 367 | sanitizedCmdStop, err := SanitizeCommand(config.Models["model1"].CmdStop) 368 | assert.NoError(t, err) 369 | assert.Equal(t, "/path/to/stop.sh --port 9990 --arg2", strings.Join(sanitizedCmdStop, " ")) 370 | } 371 | 372 | func TestConfig_MacroErrorOnUnknownMacros(t *testing.T) { 373 | tests := []struct { 374 | name string 375 | field string 376 | content string 377 | }{ 378 | { 379 | name: "unknown macro in cmd", 380 | field: "cmd", 381 | content: ` 382 | startPort: 9990 383 | macros: 384 | svr-path: "path/to/server" 385 | models: 386 | model1: 387 | cmd: | 388 | ${svr-path} --port ${PORT} 389 | ${unknownMacro} 390 | `, 391 | }, 392 | { 393 | name: "unknown macro in cmdStop", 394 | field: "cmdStop", 395 | content: ` 396 | startPort: 9990 397 | macros: 398 | svr-path: "path/to/server" 399 | models: 400 | model1: 401 | cmd: "${svr-path} --port ${PORT}" 402 | cmdStop: "kill ${unknownMacro}" 403 | `, 404 | }, 405 | { 406 | name: "unknown macro in proxy", 407 | field: "proxy", 408 | content: ` 409 | startPort: 9990 410 | macros: 411 | svr-path: "path/to/server" 412 | models: 413 | model1: 414 | cmd: "${svr-path} --port ${PORT}" 415 | proxy: "http://localhost:${unknownMacro}" 416 | `, 417 | }, 418 | { 419 | name: "unknown macro in checkEndpoint", 420 | field: "checkEndpoint", 421 | content: ` 422 | startPort: 9990 423 | macros: 424 | svr-path: "path/to/server" 425 | models: 426 | model1: 427 | cmd: "${svr-path} --port ${PORT}" 428 | checkEndpoint: "http://localhost:${unknownMacro}/health" 429 | `, 430 | }, 431 | } 432 | 433 | for _, tt := range tests { 434 | t.Run(tt.name, func(t *testing.T) { 435 | _, err := LoadConfigFromReader(strings.NewReader(tt.content)) 436 | assert.Error(t, err) 437 | assert.Contains(t, err.Error(), "unknown macro '${unknownMacro}' found in model1."+tt.field) 438 | //t.Log(err) 439 | }) 440 | } 441 | } 442 | -------------------------------------------------------------------------------- /proxy/config_windows_test.go: -------------------------------------------------------------------------------- 1 | //go:build windows 2 | 3 | package proxy 4 | 5 | import ( 6 | "testing" 7 | 8 | "github.com/stretchr/testify/assert" 9 | ) 10 | 11 | func TestConfig_SanitizeCommand(t *testing.T) { 12 | // does not support single quoted strings like in config_posix_test.go 13 | args, err := SanitizeCommand(`python model1.py \ 14 | 15 | -a "double quotes" \ 16 | -s 17 | --arg3 123 \ 18 | 19 | # comment 2 20 | --arg4 '"string in string"' 21 | 22 | 23 | 24 | # this will get stripped out as well as the white space above 25 | -c "'single quoted'" 26 | `) 27 | assert.NoError(t, err) 28 | assert.Equal(t, []string{ 29 | "python", "model1.py", 30 | "-a", "double quotes", 31 | "-s", 32 | "--arg3", "123", 33 | "--arg4", "'string in string'", // this is a little weird but the lexer says so...? 34 | "-c", `'single quoted'`, 35 | }, args) 36 | 37 | // Test an empty command 38 | args, err = SanitizeCommand("") 39 | assert.Error(t, err) 40 | assert.Nil(t, args) 41 | } 42 | -------------------------------------------------------------------------------- /proxy/helpers_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | "path/filepath" 7 | "runtime" 8 | "sync" 9 | "testing" 10 | 11 | "github.com/gin-gonic/gin" 12 | ) 13 | 14 | var ( 15 | nextTestPort int = 12000 16 | portMutex sync.Mutex 17 | testLogger = NewLogMonitorWriter(os.Stdout) 18 | ) 19 | 20 | // Check if the binary exists 21 | func TestMain(m *testing.M) { 22 | binaryPath := getSimpleResponderPath() 23 | if _, err := os.Stat(binaryPath); os.IsNotExist(err) { 24 | fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath) 25 | os.Exit(1) 26 | } 27 | 28 | gin.SetMode(gin.TestMode) 29 | 30 | switch os.Getenv("LOG_LEVEL") { 31 | case "debug": 32 | testLogger.SetLogLevel(LevelDebug) 33 | case "warn": 34 | testLogger.SetLogLevel(LevelWarn) 35 | case "info": 36 | testLogger.SetLogLevel(LevelInfo) 37 | default: 38 | testLogger.SetLogLevel(LevelWarn) 39 | } 40 | 41 | m.Run() 42 | } 43 | 44 | // Helper function to get the binary path 45 | func getSimpleResponderPath() string { 46 | goos := runtime.GOOS 47 | goarch := runtime.GOARCH 48 | 49 | if goos == "windows" { 50 | return filepath.Join("..", "build", "simple-responder.exe") 51 | } else { 52 | return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch)) 53 | } 54 | } 55 | 56 | func getTestPort() int { 57 | portMutex.Lock() 58 | defer portMutex.Unlock() 59 | 60 | port := nextTestPort 61 | nextTestPort++ 62 | 63 | return port 64 | } 65 | 66 | func getTestSimpleResponderConfig(expectedMessage string) ModelConfig { 67 | return getTestSimpleResponderConfigPort(expectedMessage, getTestPort()) 68 | } 69 | 70 | func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig { 71 | binaryPath := getSimpleResponderPath() 72 | 73 | // Create a process configuration 74 | return ModelConfig{ 75 | Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage), 76 | Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), 77 | CheckEndpoint: "/health", 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /proxy/html/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mostlygeek/llama-swap/a84098d3b43b5453f0600d942da7cf9bf762612a/proxy/html/favicon.ico -------------------------------------------------------------------------------- /proxy/html/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | llama-swap 7 | 8 | 9 |

llama-swap

10 |

11 | view logs | configured models | github 12 |

13 | 14 | 15 | -------------------------------------------------------------------------------- /proxy/html/logs.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Logs 7 | 117 | 118 | 119 |
120 |
121 |

Proxy Logs

122 |
123 | 124 | 125 |
126 |
Waiting for proxy logs...
127 |
128 |
129 |

Upstream Logs

130 |
131 | 132 | 133 |
134 |
Waiting for upstream logs...
135 |
136 |
137 | 258 | 259 | -------------------------------------------------------------------------------- /proxy/html_files.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import "embed" 4 | 5 | //go:embed html 6 | var htmlFiles embed.FS 7 | 8 | func getHTMLFile(path string) ([]byte, error) { 9 | return htmlFiles.ReadFile("html/" + path) 10 | } 11 | -------------------------------------------------------------------------------- /proxy/logMonitor.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "container/ring" 5 | "fmt" 6 | "io" 7 | "os" 8 | "sync" 9 | ) 10 | 11 | type LogLevel int 12 | 13 | const ( 14 | LevelDebug LogLevel = iota 15 | LevelInfo 16 | LevelWarn 17 | LevelError 18 | ) 19 | 20 | type LogMonitor struct { 21 | clients map[chan []byte]bool 22 | mu sync.RWMutex 23 | buffer *ring.Ring 24 | bufferMu sync.RWMutex 25 | 26 | // typically this can be os.Stdout 27 | stdout io.Writer 28 | 29 | // logging levels 30 | level LogLevel 31 | prefix string 32 | } 33 | 34 | func NewLogMonitor() *LogMonitor { 35 | return NewLogMonitorWriter(os.Stdout) 36 | } 37 | 38 | func NewLogMonitorWriter(stdout io.Writer) *LogMonitor { 39 | return &LogMonitor{ 40 | clients: make(map[chan []byte]bool), 41 | buffer: ring.New(10 * 1024), // keep 10KB of buffered logs 42 | stdout: stdout, 43 | level: LevelInfo, 44 | prefix: "", 45 | } 46 | } 47 | 48 | func (w *LogMonitor) Write(p []byte) (n int, err error) { 49 | if len(p) == 0 { 50 | return 0, nil 51 | } 52 | 53 | n, err = w.stdout.Write(p) 54 | if err != nil { 55 | return n, err 56 | } 57 | 58 | w.bufferMu.Lock() 59 | bufferCopy := make([]byte, len(p)) 60 | copy(bufferCopy, p) 61 | w.buffer.Value = bufferCopy 62 | w.buffer = w.buffer.Next() 63 | w.bufferMu.Unlock() 64 | 65 | w.broadcast(bufferCopy) 66 | return n, nil 67 | } 68 | 69 | func (w *LogMonitor) GetHistory() []byte { 70 | w.bufferMu.RLock() 71 | defer w.bufferMu.RUnlock() 72 | 73 | var history []byte 74 | w.buffer.Do(func(p any) { 75 | if p != nil { 76 | if content, ok := p.([]byte); ok { 77 | history = append(history, content...) 78 | } 79 | } 80 | }) 81 | return history 82 | } 83 | 84 | func (w *LogMonitor) Subscribe() chan []byte { 85 | w.mu.Lock() 86 | defer w.mu.Unlock() 87 | 88 | ch := make(chan []byte, 100) 89 | w.clients[ch] = true 90 | return ch 91 | } 92 | 93 | func (w *LogMonitor) Unsubscribe(ch chan []byte) { 94 | w.mu.Lock() 95 | defer w.mu.Unlock() 96 | 97 | delete(w.clients, ch) 98 | close(ch) 99 | } 100 | 101 | func (w *LogMonitor) broadcast(msg []byte) { 102 | w.mu.RLock() 103 | defer w.mu.RUnlock() 104 | 105 | for client := range w.clients { 106 | select { 107 | case client <- msg: 108 | default: 109 | // If client buffer is full, skip 110 | } 111 | } 112 | } 113 | 114 | func (w *LogMonitor) SetPrefix(prefix string) { 115 | w.mu.Lock() 116 | defer w.mu.Unlock() 117 | w.prefix = prefix 118 | } 119 | 120 | func (w *LogMonitor) SetLogLevel(level LogLevel) { 121 | w.mu.Lock() 122 | defer w.mu.Unlock() 123 | w.level = level 124 | } 125 | 126 | func (w *LogMonitor) formatMessage(level string, msg string) []byte { 127 | prefix := "" 128 | if w.prefix != "" { 129 | prefix = fmt.Sprintf("[%s] ", w.prefix) 130 | } 131 | return []byte(fmt.Sprintf("%s[%s] %s\n", prefix, level, msg)) 132 | } 133 | 134 | func (w *LogMonitor) log(level LogLevel, msg string) { 135 | if level < w.level { 136 | return 137 | } 138 | w.Write(w.formatMessage(level.String(), msg)) 139 | } 140 | 141 | func (w *LogMonitor) Debug(msg string) { 142 | w.log(LevelDebug, msg) 143 | } 144 | 145 | func (w *LogMonitor) Info(msg string) { 146 | w.log(LevelInfo, msg) 147 | } 148 | 149 | func (w *LogMonitor) Warn(msg string) { 150 | w.log(LevelWarn, msg) 151 | } 152 | 153 | func (w *LogMonitor) Error(msg string) { 154 | w.log(LevelError, msg) 155 | } 156 | 157 | func (w *LogMonitor) Debugf(format string, args ...interface{}) { 158 | w.log(LevelDebug, fmt.Sprintf(format, args...)) 159 | } 160 | 161 | func (w *LogMonitor) Infof(format string, args ...interface{}) { 162 | w.log(LevelInfo, fmt.Sprintf(format, args...)) 163 | } 164 | 165 | func (w *LogMonitor) Warnf(format string, args ...interface{}) { 166 | w.log(LevelWarn, fmt.Sprintf(format, args...)) 167 | } 168 | 169 | func (w *LogMonitor) Errorf(format string, args ...interface{}) { 170 | w.log(LevelError, fmt.Sprintf(format, args...)) 171 | } 172 | 173 | func (l LogLevel) String() string { 174 | switch l { 175 | case LevelDebug: 176 | return "DEBUG" 177 | case LevelInfo: 178 | return "INFO" 179 | case LevelWarn: 180 | return "WARN" 181 | case LevelError: 182 | return "ERROR" 183 | default: 184 | return "UNKNOWN" 185 | } 186 | } 187 | -------------------------------------------------------------------------------- /proxy/logMonitor_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "io" 6 | "sync" 7 | "testing" 8 | ) 9 | 10 | func TestLogMonitor(t *testing.T) { 11 | logMonitor := NewLogMonitorWriter(io.Discard) 12 | 13 | // Test subscription 14 | client1 := logMonitor.Subscribe() 15 | client2 := logMonitor.Subscribe() 16 | 17 | defer logMonitor.Unsubscribe(client1) 18 | defer logMonitor.Unsubscribe(client2) 19 | 20 | client1Messages := make([]byte, 0) 21 | client2Messages := make([]byte, 0) 22 | 23 | var wg sync.WaitGroup 24 | wg.Add(1) 25 | 26 | go func() { 27 | defer wg.Done() 28 | for { 29 | select { 30 | case data := <-client1: 31 | client1Messages = append(client1Messages, data...) 32 | case data := <-client2: 33 | client2Messages = append(client2Messages, data...) 34 | default: 35 | return 36 | } 37 | } 38 | }() 39 | 40 | logMonitor.Write([]byte("1")) 41 | logMonitor.Write([]byte("2")) 42 | logMonitor.Write([]byte("3")) 43 | 44 | // Wait for the goroutine to finish 45 | wg.Wait() 46 | 47 | // Check the buffer 48 | expectedHistory := "123" 49 | history := string(logMonitor.GetHistory()) 50 | 51 | if history != expectedHistory { 52 | t.Errorf("Expected history: %s, got: %s", expectedHistory, history) 53 | } 54 | 55 | c1Data := string(client1Messages) 56 | if c1Data != expectedHistory { 57 | t.Errorf("Client1 expected %s, got: %s", expectedHistory, c1Data) 58 | } 59 | 60 | c2Data := string(client2Messages) 61 | if c2Data != expectedHistory { 62 | t.Errorf("Client2 expected %s, got: %s", expectedHistory, c2Data) 63 | } 64 | } 65 | 66 | func TestWrite_ImmutableBuffer(t *testing.T) { 67 | // Create a new LogMonitor instance 68 | lm := NewLogMonitorWriter(io.Discard) 69 | 70 | // Prepare a message to write 71 | msg := []byte("Hello, World!") 72 | lenmsg := len(msg) 73 | 74 | // Write the message to the LogMonitor 75 | n, err := lm.Write(msg) 76 | if err != nil { 77 | t.Fatalf("Write failed: %v", err) 78 | } 79 | 80 | if n != lenmsg { 81 | t.Errorf("Expected %d bytes written but got %d", lenmsg, n) 82 | } 83 | 84 | // Change the original message 85 | msg[0] = 'B' // This should not affect the buffer 86 | 87 | // Get the history from the LogMonitor 88 | history := lm.GetHistory() 89 | 90 | // Check that the history contains the original message, not the modified one 91 | expected := []byte("Hello, World!") 92 | if !bytes.Equal(history, expected) { 93 | t.Errorf("Expected history to be %q, got %q", expected, history) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /proxy/process.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "context" 5 | "errors" 6 | "fmt" 7 | "io" 8 | "net/http" 9 | "net/url" 10 | "os/exec" 11 | "runtime" 12 | "strconv" 13 | "strings" 14 | "sync" 15 | "syscall" 16 | "time" 17 | ) 18 | 19 | type ProcessState string 20 | 21 | const ( 22 | StateStopped ProcessState = ProcessState("stopped") 23 | StateStarting ProcessState = ProcessState("starting") 24 | StateReady ProcessState = ProcessState("ready") 25 | StateStopping ProcessState = ProcessState("stopping") 26 | 27 | // failed a health check on start and will not be recovered 28 | StateFailed ProcessState = ProcessState("failed") 29 | 30 | // process is shutdown and will not be restarted 31 | StateShutdown ProcessState = ProcessState("shutdown") 32 | ) 33 | 34 | type StopStrategy int 35 | 36 | const ( 37 | StopImmediately StopStrategy = iota 38 | StopWaitForInflightRequest 39 | ) 40 | 41 | type Process struct { 42 | ID string 43 | config ModelConfig 44 | cmd *exec.Cmd 45 | 46 | // for p.cmd.Wait() select { ... } 47 | cmdWaitChan chan error 48 | 49 | processLogger *LogMonitor 50 | proxyLogger *LogMonitor 51 | 52 | healthCheckTimeout int 53 | healthCheckLoopInterval time.Duration 54 | 55 | lastRequestHandled time.Time 56 | 57 | stateMutex sync.RWMutex 58 | state ProcessState 59 | 60 | inFlightRequests sync.WaitGroup 61 | 62 | // used to block on multiple start() calls 63 | waitStarting sync.WaitGroup 64 | 65 | // for managing shutdown state 66 | shutdownCtx context.Context 67 | shutdownCancel context.CancelFunc 68 | 69 | // for managing concurrency limits 70 | concurrencyLimitSemaphore chan struct{} 71 | 72 | // stop timeout waiting for graceful shutdown 73 | gracefulStopTimeout time.Duration 74 | 75 | // track that this happened 76 | upstreamWasStoppedWithKill bool 77 | } 78 | 79 | func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { 80 | ctx, cancel := context.WithCancel(context.Background()) 81 | concurrentLimit := 10 82 | if config.ConcurrencyLimit > 0 { 83 | concurrentLimit = config.ConcurrencyLimit 84 | } 85 | 86 | return &Process{ 87 | ID: ID, 88 | config: config, 89 | cmd: nil, 90 | cmdWaitChan: make(chan error, 1), 91 | processLogger: processLogger, 92 | proxyLogger: proxyLogger, 93 | healthCheckTimeout: healthCheckTimeout, 94 | healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */ 95 | state: StateStopped, 96 | shutdownCtx: ctx, 97 | shutdownCancel: cancel, 98 | 99 | // concurrency limit 100 | concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit), 101 | 102 | // stop timeout 103 | gracefulStopTimeout: 10 * time.Second, 104 | upstreamWasStoppedWithKill: false, 105 | } 106 | } 107 | 108 | // LogMonitor returns the log monitor associated with the process. 109 | func (p *Process) LogMonitor() *LogMonitor { 110 | return p.processLogger 111 | } 112 | 113 | // custom error types for swapping state 114 | var ( 115 | ErrExpectedStateMismatch = errors.New("expected state mismatch") 116 | ErrInvalidStateTransition = errors.New("invalid state transition") 117 | ) 118 | 119 | // swapState performs a compare and swap of the state atomically. It returns the current state 120 | // and an error if the swap failed. 121 | func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) { 122 | p.stateMutex.Lock() 123 | defer p.stateMutex.Unlock() 124 | 125 | if p.state != expectedState { 126 | p.proxyLogger.Warnf("<%s> swapState() Unexpected current state %s, expected %s", p.ID, p.state, expectedState) 127 | return p.state, ErrExpectedStateMismatch 128 | } 129 | 130 | if !isValidTransition(p.state, newState) { 131 | p.proxyLogger.Warnf("<%s> swapState() Invalid state transition from %s to %s", p.ID, p.state, newState) 132 | return p.state, ErrInvalidStateTransition 133 | } 134 | 135 | p.state = newState 136 | p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState) 137 | return p.state, nil 138 | } 139 | 140 | // Helper function to encapsulate transition rules 141 | func isValidTransition(from, to ProcessState) bool { 142 | switch from { 143 | case StateStopped: 144 | return to == StateStarting 145 | case StateStarting: 146 | return to == StateReady || to == StateFailed || to == StateStopping 147 | case StateReady: 148 | return to == StateStopping 149 | case StateStopping: 150 | return to == StateStopped || to == StateShutdown 151 | case StateFailed: 152 | return to == StateStopping 153 | case StateShutdown: 154 | return false // No transitions allowed from these states 155 | } 156 | return false 157 | } 158 | 159 | func (p *Process) CurrentState() ProcessState { 160 | p.stateMutex.RLock() 161 | defer p.stateMutex.RUnlock() 162 | return p.state 163 | } 164 | 165 | // start starts the upstream command, checks the health endpoint, and sets the state to Ready 166 | // it is a private method because starting is automatic but stopping can be called 167 | // at any time. 168 | func (p *Process) start() error { 169 | 170 | if p.config.Proxy == "" { 171 | return fmt.Errorf("can not start(), upstream proxy missing") 172 | } 173 | 174 | args, err := p.config.SanitizedCommand() 175 | if err != nil { 176 | return fmt.Errorf("unable to get sanitized command: %v", err) 177 | } 178 | 179 | if curState, err := p.swapState(StateStopped, StateStarting); err != nil { 180 | if err == ErrExpectedStateMismatch { 181 | // already starting, just wait for it to complete and expect 182 | // it to be be in the Ready start after. If not, return an error 183 | if curState == StateStarting { 184 | p.waitStarting.Wait() 185 | if state := p.CurrentState(); state == StateReady { 186 | return nil 187 | } else { 188 | return fmt.Errorf("process was already starting but wound up in state %v", state) 189 | } 190 | } else { 191 | return fmt.Errorf("processes was in state %v when start() was called", curState) 192 | } 193 | } else { 194 | return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err) 195 | } 196 | } 197 | 198 | p.waitStarting.Add(1) 199 | defer p.waitStarting.Done() 200 | 201 | p.cmd = exec.Command(args[0], args[1:]...) 202 | p.cmd.Stdout = p.processLogger 203 | p.cmd.Stderr = p.processLogger 204 | p.cmd.Env = p.config.Env 205 | 206 | err = p.cmd.Start() 207 | 208 | // Set process state to failed 209 | if err != nil { 210 | if curState, swapErr := p.swapState(StateStarting, StateFailed); swapErr != nil { 211 | return fmt.Errorf( 212 | "failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v", 213 | err, curState, swapErr, 214 | ) 215 | } 216 | return fmt.Errorf("start() failed: %v", err) 217 | } 218 | 219 | // Capture the exit error for later signalling 220 | go func() { 221 | exitErr := p.cmd.Wait() 222 | p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr) 223 | 224 | // there is a race condition when SIGKILL is used, p.cmd.Wait() returns, and then 225 | // the code below fires, putting an error into cmdWaitChan. This code is to prevent this 226 | if p.upstreamWasStoppedWithKill { 227 | p.proxyLogger.Debugf("<%s> process was killed, NOT sending exitErr: %v", p.ID, exitErr) 228 | p.upstreamWasStoppedWithKill = false 229 | return 230 | } 231 | 232 | p.cmdWaitChan <- exitErr 233 | }() 234 | 235 | // One of three things can happen at this stage: 236 | // 1. The command exits unexpectedly 237 | // 2. The health check fails 238 | // 3. The health check passes 239 | // 240 | // only in the third case will the process be considered Ready to accept 241 | <-time.After(250 * time.Millisecond) // give process a bit of time to start 242 | 243 | checkStartTime := time.Now() 244 | maxDuration := time.Second * time.Duration(p.healthCheckTimeout) 245 | checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint) 246 | 247 | // a "none" means don't check for health ... I could have picked a better word :facepalm: 248 | if checkEndpoint != "none" { 249 | // keep default behaviour 250 | if checkEndpoint == "" { 251 | checkEndpoint = "/health" 252 | } 253 | 254 | proxyTo := p.config.Proxy 255 | healthURL, err := url.JoinPath(proxyTo, checkEndpoint) 256 | if err != nil { 257 | return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint) 258 | } 259 | 260 | checkDeadline, cancelHealthCheck := context.WithDeadline( 261 | context.Background(), 262 | checkStartTime.Add(maxDuration), 263 | ) 264 | defer cancelHealthCheck() 265 | 266 | loop: 267 | // Ready Check loop 268 | for { 269 | select { 270 | case <-checkDeadline.Done(): 271 | if curState, err := p.swapState(StateStarting, StateFailed); err != nil { 272 | return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState) 273 | } else { 274 | return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds()) 275 | } 276 | case <-p.shutdownCtx.Done(): 277 | return errors.New("health check interrupted due to shutdown") 278 | case exitErr := <-p.cmdWaitChan: 279 | if exitErr != nil { 280 | p.proxyLogger.Warnf("<%s> upstream command exited prematurely with error: %v", p.ID, exitErr) 281 | if curState, err := p.swapState(StateStarting, StateFailed); err != nil { 282 | return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState) 283 | } else { 284 | return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error()) 285 | } 286 | } else { 287 | p.proxyLogger.Warnf("<%s> upstream command exited prematurely but successfully", p.ID) 288 | if curState, err := p.swapState(StateStarting, StateFailed); err != nil { 289 | return fmt.Errorf("upstream command exited prematurely but successfully AND state swap failed: %v, current state: %v", err, curState) 290 | } else { 291 | return fmt.Errorf("upstream command exited prematurely but successfully") 292 | } 293 | } 294 | default: 295 | if err := p.checkHealthEndpoint(healthURL); err == nil { 296 | p.proxyLogger.Infof("<%s> Health check passed on %s", p.ID, healthURL) 297 | cancelHealthCheck() 298 | break loop 299 | } else { 300 | if strings.Contains(err.Error(), "connection refused") { 301 | endTime, _ := checkDeadline.Deadline() 302 | ttl := time.Until(endTime) 303 | p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds()) 304 | } else { 305 | p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err) 306 | } 307 | } 308 | } 309 | 310 | <-time.After(p.healthCheckLoopInterval) 311 | } 312 | } 313 | 314 | if p.config.UnloadAfter > 0 { 315 | // start a goroutine to check every second if 316 | // the process should be stopped 317 | go func() { 318 | maxDuration := time.Duration(p.config.UnloadAfter) * time.Second 319 | 320 | for range time.Tick(time.Second) { 321 | if p.CurrentState() != StateReady { 322 | return 323 | } 324 | 325 | // wait for all inflight requests to complete and ticker 326 | p.inFlightRequests.Wait() 327 | 328 | if time.Since(p.lastRequestHandled) > maxDuration { 329 | p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter) 330 | p.Stop() 331 | return 332 | } 333 | } 334 | }() 335 | } 336 | 337 | if curState, err := p.swapState(StateStarting, StateReady); err != nil { 338 | return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err) 339 | } else { 340 | return nil 341 | } 342 | } 343 | 344 | // Stop will wait for inflight requests to complete before stopping the process. 345 | func (p *Process) Stop() { 346 | if !isValidTransition(p.CurrentState(), StateStopping) { 347 | return 348 | } 349 | 350 | // wait for any inflight requests before proceeding 351 | p.proxyLogger.Debugf("<%s> Stop(): Waiting for inflight requests to complete", p.ID) 352 | p.inFlightRequests.Wait() 353 | p.StopImmediately() 354 | } 355 | 356 | // StopImmediately will transition the process to the stopping state and stop the process with a SIGTERM. 357 | // If the process does not stop within the specified timeout, it will be forcefully stopped with a SIGKILL. 358 | func (p *Process) StopImmediately() { 359 | if !isValidTransition(p.CurrentState(), StateStopping) { 360 | return 361 | } 362 | 363 | p.proxyLogger.Debugf("<%s> Stopping process, current state: %s", p.ID, p.CurrentState()) 364 | currentState := p.CurrentState() 365 | 366 | if currentState == StateFailed { 367 | if curState, err := p.swapState(StateFailed, StateStopping); err != nil { 368 | p.proxyLogger.Infof("<%s> Stop() Failed -> StateStopping err: %v, current state: %v", p.ID, err, curState) 369 | return 370 | } 371 | } else { 372 | if curState, err := p.swapState(StateReady, StateStopping); err != nil { 373 | p.proxyLogger.Infof("<%s> Stop() Ready -> StateStopping err: %v, current state: %v", p.ID, err, curState) 374 | return 375 | } 376 | } 377 | 378 | // stop the process with a graceful exit timeout 379 | p.stopCommand(p.gracefulStopTimeout) 380 | 381 | if curState, err := p.swapState(StateStopping, StateStopped); err != nil { 382 | p.proxyLogger.Infof("<%s> Stop() StateStopping -> StateStopped err: %v, current state: %v", p.ID, err, curState) 383 | } 384 | } 385 | 386 | // Shutdown is called when llama-swap is shutting down. It will give a little bit 387 | // of time for any inflight requests to complete before shutting down. If the Process 388 | // is in the state of starting, it will cancel it and shut it down. Once a process is in 389 | // the StateShutdown state, it can not be started again. 390 | func (p *Process) Shutdown() { 391 | if !isValidTransition(p.CurrentState(), StateStopping) { 392 | return 393 | } 394 | 395 | p.shutdownCancel() 396 | p.stopCommand(p.gracefulStopTimeout) 397 | 398 | // just force it to this state since there is no recovery from shutdown 399 | p.state = StateShutdown 400 | } 401 | 402 | // stopCommand will send a SIGTERM to the process and wait for it to exit. 403 | // If it does not exit within 5 seconds, it will send a SIGKILL. 404 | func (p *Process) stopCommand(sigtermTTL time.Duration) { 405 | stopStartTime := time.Now() 406 | defer func() { 407 | p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime)) 408 | }() 409 | 410 | sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL) 411 | defer cancelTimeout() 412 | 413 | if p.cmd == nil || p.cmd.Process == nil { 414 | p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID) 415 | return 416 | } 417 | 418 | // if err := p.terminateProcess(); err != nil { 419 | // p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err) 420 | // } 421 | // the default cmdStop to taskkill /f /t /pid ${PID} 422 | if runtime.GOOS == "windows" && strings.TrimSpace(p.config.CmdStop) == "" { 423 | p.config.CmdStop = "taskkill /f /t /pid ${PID}" 424 | } 425 | 426 | if p.config.CmdStop != "" { 427 | // replace ${PID} with the pid of the process 428 | stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid))) 429 | if err != nil { 430 | p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err) 431 | return 432 | } 433 | 434 | p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " ")) 435 | 436 | stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...) 437 | stopCmd.Stdout = p.processLogger 438 | stopCmd.Stderr = p.processLogger 439 | stopCmd.Env = p.config.Env 440 | 441 | if err := stopCmd.Run(); err != nil { 442 | p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err) 443 | return 444 | } 445 | } else { 446 | if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil { 447 | p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err) 448 | return 449 | } 450 | } 451 | 452 | select { 453 | case <-sigtermTimeout.Done(): 454 | p.proxyLogger.Debugf("<%s> Process timed out waiting to stop, sending KILL signal (normal during shutdown)", p.ID) 455 | p.upstreamWasStoppedWithKill = true 456 | if err := p.cmd.Process.Kill(); err != nil { 457 | p.proxyLogger.Errorf("<%s> Failed to kill process: %v", p.ID, err) 458 | } 459 | case err := <-p.cmdWaitChan: 460 | // Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK 461 | // because if we make it here then the cmd has been successfully running and made it 462 | // through the health check. There is a possibility that the cmd crashed after the health check 463 | // succeeded but that's not a case llama-swap is handling for now. 464 | if err != nil { 465 | if errno, ok := err.(syscall.Errno); ok { 466 | p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno) 467 | } else if exitError, ok := err.(*exec.ExitError); ok { 468 | if strings.Contains(exitError.String(), "signal: terminated") { 469 | p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID) 470 | } else if strings.Contains(exitError.String(), "signal: interrupt") { 471 | p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID) 472 | } else { 473 | p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode()) 474 | } 475 | } else { 476 | p.proxyLogger.Errorf("<%s> Process exited >> %v", p.ID, err) 477 | } 478 | } 479 | } 480 | } 481 | 482 | func (p *Process) checkHealthEndpoint(healthURL string) error { 483 | client := &http.Client{ 484 | Timeout: 500 * time.Millisecond, 485 | } 486 | 487 | req, err := http.NewRequest("GET", healthURL, nil) 488 | if err != nil { 489 | return err 490 | } 491 | 492 | resp, err := client.Do(req) 493 | if err != nil { 494 | return err 495 | } 496 | defer resp.Body.Close() 497 | 498 | // got a response but it was not an OK 499 | if resp.StatusCode != http.StatusOK { 500 | return fmt.Errorf("status code: %d", resp.StatusCode) 501 | } 502 | 503 | return nil 504 | } 505 | 506 | func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { 507 | requestBeginTime := time.Now() 508 | var startDuration time.Duration 509 | 510 | // prevent new requests from being made while stopping or irrecoverable 511 | currentState := p.CurrentState() 512 | if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping { 513 | http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable) 514 | return 515 | } 516 | 517 | select { 518 | case p.concurrencyLimitSemaphore <- struct{}{}: 519 | defer func() { <-p.concurrencyLimitSemaphore }() 520 | default: 521 | http.Error(w, "Too many requests", http.StatusTooManyRequests) 522 | return 523 | } 524 | 525 | p.inFlightRequests.Add(1) 526 | defer func() { 527 | p.lastRequestHandled = time.Now() 528 | p.inFlightRequests.Done() 529 | }() 530 | 531 | // start the process on demand 532 | if p.CurrentState() != StateReady { 533 | beginStartTime := time.Now() 534 | if err := p.start(); err != nil { 535 | errstr := fmt.Sprintf("unable to start process: %s", err) 536 | http.Error(w, errstr, http.StatusBadGateway) 537 | return 538 | } 539 | startDuration = time.Since(beginStartTime) 540 | } 541 | 542 | proxyTo := p.config.Proxy 543 | client := &http.Client{} 544 | req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body) 545 | if err != nil { 546 | http.Error(w, err.Error(), http.StatusInternalServerError) 547 | return 548 | } 549 | req.Header = r.Header.Clone() 550 | 551 | contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64) 552 | if err == nil { 553 | req.ContentLength = contentLength 554 | } 555 | 556 | resp, err := client.Do(req) 557 | if err != nil { 558 | http.Error(w, err.Error(), http.StatusBadGateway) 559 | return 560 | } 561 | defer resp.Body.Close() 562 | for k, vv := range resp.Header { 563 | for _, v := range vv { 564 | w.Header().Add(k, v) 565 | } 566 | } 567 | w.WriteHeader(resp.StatusCode) 568 | 569 | // faster than io.Copy when streaming 570 | buf := make([]byte, 32*1024) 571 | for { 572 | n, err := resp.Body.Read(buf) 573 | if n > 0 { 574 | if _, writeErr := w.Write(buf[:n]); writeErr != nil { 575 | return 576 | } 577 | if flusher, ok := w.(http.Flusher); ok { 578 | flusher.Flush() 579 | } 580 | } 581 | if err == io.EOF { 582 | break 583 | } 584 | if err != nil { 585 | http.Error(w, err.Error(), http.StatusBadGateway) 586 | return 587 | } 588 | } 589 | 590 | totalTime := time.Since(requestBeginTime) 591 | p.proxyLogger.Debugf("<%s> request %s - start: %v, total: %v", 592 | p.ID, r.RequestURI, startDuration, totalTime) 593 | } 594 | -------------------------------------------------------------------------------- /proxy/process_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "net/http/httptest" 7 | "os" 8 | "runtime" 9 | "sync" 10 | "testing" 11 | "time" 12 | 13 | "github.com/stretchr/testify/assert" 14 | ) 15 | 16 | var ( 17 | debugLogger = NewLogMonitorWriter(os.Stdout) 18 | ) 19 | 20 | func init() { 21 | // flip to help with debugging tests 22 | if false { 23 | debugLogger.SetLogLevel(LevelDebug) 24 | } else { 25 | debugLogger.SetLogLevel(LevelError) 26 | } 27 | } 28 | 29 | func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { 30 | 31 | expectedMessage := "testing91931" 32 | config := getTestSimpleResponderConfig(expectedMessage) 33 | 34 | // Create a process 35 | process := NewProcess("test-process", 5, config, debugLogger, debugLogger) 36 | defer process.Stop() 37 | 38 | req := httptest.NewRequest("GET", "/test", nil) 39 | w := httptest.NewRecorder() 40 | 41 | // process is automatically started 42 | assert.Equal(t, StateStopped, process.CurrentState()) 43 | process.ProxyRequest(w, req) 44 | assert.Equal(t, StateReady, process.CurrentState()) 45 | 46 | assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) 47 | assert.Contains(t, w.Body.String(), expectedMessage) 48 | 49 | // Stop the process 50 | process.Stop() 51 | 52 | req = httptest.NewRequest("GET", "/", nil) 53 | w = httptest.NewRecorder() 54 | 55 | // Proxy the request 56 | process.ProxyRequest(w, req) 57 | 58 | // should have automatically started the process again 59 | if w.Code != http.StatusOK { 60 | t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) 61 | } 62 | } 63 | 64 | // TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests 65 | // are all handled successfully, even though they all may ask for the process to .start() 66 | func TestProcess_WaitOnMultipleStarts(t *testing.T) { 67 | 68 | expectedMessage := "testing91931" 69 | config := getTestSimpleResponderConfig(expectedMessage) 70 | 71 | process := NewProcess("test-process", 5, config, debugLogger, debugLogger) 72 | defer process.Stop() 73 | 74 | var wg sync.WaitGroup 75 | for i := 0; i < 5; i++ { 76 | wg.Add(1) 77 | go func(reqID int) { 78 | defer wg.Done() 79 | req := httptest.NewRequest("GET", "/test", nil) 80 | w := httptest.NewRecorder() 81 | process.ProxyRequest(w, req) 82 | assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID) 83 | assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID) 84 | }(i) 85 | } 86 | wg.Wait() 87 | assert.Equal(t, StateReady, process.CurrentState()) 88 | } 89 | 90 | // test that the automatic start returns the expected error type 91 | func TestProcess_BrokenModelConfig(t *testing.T) { 92 | // Create a process configuration 93 | config := ModelConfig{ 94 | Cmd: "nonexistent-command", 95 | Proxy: "http://127.0.0.1:9913", 96 | CheckEndpoint: "/health", 97 | } 98 | 99 | process := NewProcess("broken", 1, config, debugLogger, debugLogger) 100 | 101 | req := httptest.NewRequest("GET", "/", nil) 102 | w := httptest.NewRecorder() 103 | process.ProxyRequest(w, req) 104 | assert.Equal(t, http.StatusBadGateway, w.Code) 105 | assert.Contains(t, w.Body.String(), "unable to start process") 106 | 107 | w = httptest.NewRecorder() 108 | process.ProxyRequest(w, req) 109 | assert.Equal(t, http.StatusServiceUnavailable, w.Code) 110 | assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed") 111 | } 112 | 113 | func TestProcess_UnloadAfterTTL(t *testing.T) { 114 | if testing.Short() { 115 | t.Skip("skipping long auto unload TTL test") 116 | } 117 | 118 | expectedMessage := "I_sense_imminent_danger" 119 | config := getTestSimpleResponderConfig(expectedMessage) 120 | assert.Equal(t, 0, config.UnloadAfter) 121 | config.UnloadAfter = 3 // seconds 122 | assert.Equal(t, 3, config.UnloadAfter) 123 | 124 | process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) 125 | defer process.Stop() 126 | 127 | // this should take 4 seconds 128 | req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil) 129 | req2 := httptest.NewRequest("GET", "/test", nil) 130 | 131 | w := httptest.NewRecorder() 132 | 133 | // Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter 134 | process.ProxyRequest(w, req1) 135 | 136 | t.Log("sending slow first request (4 seconds)") 137 | assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) 138 | assert.Contains(t, w.Body.String(), "1234") 139 | assert.Equal(t, StateReady, process.CurrentState()) 140 | 141 | // ensure the TTL timeout does not race slow requests (see issue #25) 142 | t.Log("sending second request (1 second)") 143 | time.Sleep(time.Second) 144 | w = httptest.NewRecorder() 145 | process.ProxyRequest(w, req2) 146 | assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) 147 | assert.Contains(t, w.Body.String(), expectedMessage) 148 | assert.Equal(t, StateReady, process.CurrentState()) 149 | 150 | // wait 5 seconds 151 | t.Log("sleep 5 seconds and check if unloaded") 152 | time.Sleep(5 * time.Second) 153 | assert.Equal(t, StateStopped, process.CurrentState()) 154 | } 155 | 156 | func TestProcess_LowTTLValue(t *testing.T) { 157 | if true { // change this code to run this ... 158 | t.Skip("skipping test, edit process_test.go to run it ") 159 | } 160 | 161 | config := getTestSimpleResponderConfig("fast_ttl") 162 | assert.Equal(t, 0, config.UnloadAfter) 163 | config.UnloadAfter = 1 // second 164 | assert.Equal(t, 1, config.UnloadAfter) 165 | 166 | process := NewProcess("ttl", 2, config, debugLogger, debugLogger) 167 | defer process.Stop() 168 | 169 | for i := 0; i < 100; i++ { 170 | t.Logf("Waiting before sending request %d", i) 171 | time.Sleep(1500 * time.Millisecond) 172 | 173 | expected := fmt.Sprintf("echo=test_%d", i) 174 | req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil) 175 | w := httptest.NewRecorder() 176 | process.ProxyRequest(w, req) 177 | assert.Equal(t, http.StatusOK, w.Code) 178 | assert.Contains(t, w.Body.String(), expected) 179 | } 180 | 181 | } 182 | 183 | // issue #19 184 | // This test makes sure using Process.Stop() does not affect pending HTTP 185 | // requests. All HTTP requests in this test should complete successfully. 186 | func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { 187 | if testing.Short() { 188 | t.Skip("skipping slow test") 189 | } 190 | 191 | expectedMessage := "12345" 192 | config := getTestSimpleResponderConfig(expectedMessage) 193 | process := NewProcess("t", 10, config, debugLogger, debugLogger) 194 | defer process.Stop() 195 | 196 | results := map[string]string{ 197 | "12345": "", 198 | "abcde": "", 199 | "fghij": "", 200 | } 201 | 202 | var wg sync.WaitGroup 203 | var mu sync.Mutex 204 | 205 | for key := range results { 206 | wg.Add(1) 207 | go func(key string) { 208 | defer wg.Done() 209 | // send a request where simple-responder is will wait 300ms before responding 210 | // this will simulate an in-progress request. 211 | req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=300ms", key), nil) 212 | w := httptest.NewRecorder() 213 | 214 | process.ProxyRequest(w, req) 215 | 216 | if w.Code != http.StatusOK { 217 | t.Errorf("Expected status OK, got %d for key %s", w.Code, key) 218 | } 219 | 220 | mu.Lock() 221 | results[key] = w.Body.String() 222 | mu.Unlock() 223 | 224 | }(key) 225 | } 226 | 227 | // Stop the process while requests are still being processed 228 | go func() { 229 | <-time.After(150 * time.Millisecond) 230 | process.Stop() 231 | }() 232 | 233 | wg.Wait() 234 | 235 | for key, result := range results { 236 | assert.Equal(t, key, result) 237 | } 238 | } 239 | 240 | func TestProcess_SwapState(t *testing.T) { 241 | tests := []struct { 242 | name string 243 | currentState ProcessState 244 | expectedState ProcessState 245 | newState ProcessState 246 | expectedError error 247 | expectedResult ProcessState 248 | }{ 249 | {"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting}, 250 | {"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady}, 251 | {"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed}, 252 | {"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping}, 253 | {"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping}, 254 | {"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped}, 255 | {"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown}, 256 | {"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped}, 257 | {"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting}, 258 | {"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady}, 259 | {"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady}, 260 | {"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping}, 261 | {"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed}, 262 | {"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed}, 263 | {"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown}, 264 | {"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown}, 265 | {"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped}, 266 | } 267 | 268 | for _, test := range tests { 269 | t.Run(test.name, func(t *testing.T) { 270 | p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger) 271 | p.state = test.currentState 272 | 273 | resultState, err := p.swapState(test.expectedState, test.newState) 274 | if err != nil && test.expectedError == nil { 275 | t.Errorf("Unexpected error: %v", err) 276 | } else if err == nil && test.expectedError != nil { 277 | t.Errorf("Expected error: %v, but got none", test.expectedError) 278 | } else if err != nil && test.expectedError != nil { 279 | if err.Error() != test.expectedError.Error() { 280 | t.Errorf("Expected error: %v, got: %v", test.expectedError, err) 281 | } 282 | } 283 | 284 | if resultState != test.expectedResult { 285 | t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState) 286 | } 287 | }) 288 | } 289 | } 290 | 291 | func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) { 292 | if testing.Short() { 293 | t.Skip("skipping long shutdown test") 294 | } 295 | 296 | expectedMessage := "testing91931" 297 | 298 | // make a config where the healthcheck will always fail because port is wrong 299 | config := getTestSimpleResponderConfigPort(expectedMessage, 9999) 300 | config.Proxy = "http://localhost:9998/test" 301 | 302 | healthCheckTTLSeconds := 30 303 | process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger) 304 | 305 | // make it a lot faster 306 | process.healthCheckLoopInterval = time.Second 307 | 308 | // start a goroutine to simulate a shutdown 309 | var wg sync.WaitGroup 310 | go func() { 311 | defer wg.Done() 312 | <-time.After(time.Millisecond * 500) 313 | process.Shutdown() 314 | }() 315 | wg.Add(1) 316 | 317 | // start the process, this is a blocking call 318 | err := process.start() 319 | 320 | wg.Wait() 321 | assert.ErrorContains(t, err, "health check interrupted due to shutdown") 322 | assert.Equal(t, StateShutdown, process.CurrentState()) 323 | } 324 | 325 | func TestProcess_ExitInterruptsHealthCheck(t *testing.T) { 326 | if testing.Short() { 327 | t.Skip("skipping Exit Interrupts Health Check test") 328 | } 329 | 330 | // should run and exit but interrupt the long checkHealthTimeout 331 | checkHealthTimeout := 5 332 | config := ModelConfig{ 333 | Cmd: "sleep 1", 334 | Proxy: "http://127.0.0.1:9913", 335 | CheckEndpoint: "/health", 336 | } 337 | 338 | process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger) 339 | process.healthCheckLoopInterval = time.Second // make it faster 340 | err := process.start() 341 | assert.Equal(t, "upstream command exited prematurely but successfully", err.Error()) 342 | assert.Equal(t, process.CurrentState(), StateFailed) 343 | } 344 | 345 | func TestProcess_ConcurrencyLimit(t *testing.T) { 346 | if testing.Short() { 347 | t.Skip("skipping long concurrency limit test") 348 | } 349 | 350 | expectedMessage := "concurrency_limit_test" 351 | config := getTestSimpleResponderConfig(expectedMessage) 352 | 353 | // only allow 1 concurrent request at a time 354 | config.ConcurrencyLimit = 1 355 | 356 | process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) 357 | assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore)) 358 | defer process.Stop() 359 | 360 | // launch a goroutine first to take up the semaphore 361 | go func() { 362 | req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil) 363 | w := httptest.NewRecorder() 364 | process.ProxyRequest(w, req1) 365 | assert.Equal(t, http.StatusOK, w.Code) 366 | }() 367 | 368 | // let the goroutine start 369 | <-time.After(time.Millisecond * 25) 370 | 371 | denied := httptest.NewRequest("GET", "/test", nil) 372 | w := httptest.NewRecorder() 373 | process.ProxyRequest(w, denied) 374 | assert.Equal(t, http.StatusTooManyRequests, w.Code) 375 | } 376 | 377 | func TestProcess_StopImmediately(t *testing.T) { 378 | expectedMessage := "test_stop_immediate" 379 | config := getTestSimpleResponderConfig(expectedMessage) 380 | 381 | process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger) 382 | defer process.Stop() 383 | 384 | err := process.start() 385 | assert.Nil(t, err) 386 | assert.Equal(t, process.CurrentState(), StateReady) 387 | go func() { 388 | // slow, but will get killed by StopImmediate 389 | req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=1s", nil) 390 | w := httptest.NewRecorder() 391 | process.ProxyRequest(w, req) 392 | }() 393 | <-time.After(time.Millisecond) 394 | process.StopImmediately() 395 | assert.Equal(t, process.CurrentState(), StateStopped) 396 | } 397 | 398 | // Test that SIGKILL is sent when gracefulStopTimeout is reached and properly terminates 399 | // the upstream command 400 | func TestProcess_ForceStopWithKill(t *testing.T) { 401 | 402 | expectedMessage := "test_sigkill" 403 | binaryPath := getSimpleResponderPath() 404 | port := getTestPort() 405 | 406 | config := ModelConfig{ 407 | // note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent 408 | // to force the process to exit 409 | Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage), 410 | Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), 411 | CheckEndpoint: "/health", 412 | } 413 | 414 | process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger) 415 | defer process.Stop() 416 | 417 | // reduce to make testing go faster 418 | process.gracefulStopTimeout = time.Second 419 | 420 | err := process.start() 421 | assert.Nil(t, err) 422 | assert.Equal(t, process.CurrentState(), StateReady) 423 | 424 | waitChan := make(chan struct{}) 425 | go func() { 426 | // slow, but will get killed by StopImmediate 427 | req := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=2s", nil) 428 | w := httptest.NewRecorder() 429 | process.ProxyRequest(w, req) 430 | 431 | // StatusOK because that was already sent before the kill 432 | assert.Equal(t, http.StatusOK, w.Code) 433 | 434 | // unexpected EOF because the kill happened, the "1" is sent before the kill 435 | // then the unexpected EOF is sent after the kill 436 | if runtime.GOOS == "windows" { 437 | assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host") 438 | } else { 439 | assert.Contains(t, w.Body.String(), "unexpected EOF") 440 | } 441 | 442 | close(waitChan) 443 | }() 444 | 445 | <-time.After(time.Millisecond) 446 | process.StopImmediately() 447 | assert.Equal(t, process.CurrentState(), StateStopped) 448 | 449 | // the request should have been interrupted by SIGKILL 450 | <-waitChan 451 | } 452 | 453 | func TestProcess_StopCmd(t *testing.T) { 454 | config := getTestSimpleResponderConfig("test_stop_cmd") 455 | 456 | if runtime.GOOS == "windows" { 457 | config.CmdStop = "taskkill /f /t /pid ${PID}" 458 | } else { 459 | config.CmdStop = "kill -TERM ${PID}" 460 | } 461 | 462 | process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger) 463 | defer process.Stop() 464 | 465 | err := process.start() 466 | assert.Nil(t, err) 467 | assert.Equal(t, process.CurrentState(), StateReady) 468 | process.StopImmediately() 469 | assert.Equal(t, process.CurrentState(), StateStopped) 470 | } 471 | -------------------------------------------------------------------------------- /proxy/processgroup.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "slices" 7 | "sync" 8 | ) 9 | 10 | type ProcessGroup struct { 11 | sync.Mutex 12 | 13 | config Config 14 | id string 15 | swap bool 16 | exclusive bool 17 | persistent bool 18 | 19 | proxyLogger *LogMonitor 20 | upstreamLogger *LogMonitor 21 | 22 | // map of current processes 23 | processes map[string]*Process 24 | lastUsedProcess string 25 | } 26 | 27 | func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup { 28 | groupConfig, ok := config.Groups[id] 29 | if !ok { 30 | panic("Unable to find configuration for group id: " + id) 31 | } 32 | 33 | pg := &ProcessGroup{ 34 | id: id, 35 | config: config, 36 | swap: groupConfig.Swap, 37 | exclusive: groupConfig.Exclusive, 38 | persistent: groupConfig.Persistent, 39 | proxyLogger: proxyLogger, 40 | upstreamLogger: upstreamLogger, 41 | processes: make(map[string]*Process), 42 | } 43 | 44 | // Create a Process for each member in the group 45 | for _, modelID := range groupConfig.Members { 46 | modelConfig, modelID, _ := pg.config.FindConfig(modelID) 47 | process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger) 48 | pg.processes[modelID] = process 49 | } 50 | 51 | return pg 52 | } 53 | 54 | // ProxyRequest proxies a request to the specified model 55 | func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error { 56 | if !pg.HasMember(modelID) { 57 | return fmt.Errorf("model %s not part of group %s", modelID, pg.id) 58 | } 59 | 60 | if pg.swap { 61 | pg.Lock() 62 | if pg.lastUsedProcess != modelID { 63 | if pg.lastUsedProcess != "" { 64 | pg.processes[pg.lastUsedProcess].Stop() 65 | } 66 | pg.lastUsedProcess = modelID 67 | } 68 | pg.Unlock() 69 | } 70 | 71 | pg.processes[modelID].ProxyRequest(writer, request) 72 | return nil 73 | } 74 | 75 | func (pg *ProcessGroup) HasMember(modelName string) bool { 76 | return slices.Contains(pg.config.Groups[pg.id].Members, modelName) 77 | } 78 | 79 | func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) { 80 | pg.Lock() 81 | defer pg.Unlock() 82 | 83 | if len(pg.processes) == 0 { 84 | return 85 | } 86 | 87 | // stop Processes in parallel 88 | var wg sync.WaitGroup 89 | for _, process := range pg.processes { 90 | wg.Add(1) 91 | go func(process *Process) { 92 | defer wg.Done() 93 | switch strategy { 94 | case StopImmediately: 95 | process.StopImmediately() 96 | default: 97 | process.Stop() 98 | } 99 | }(process) 100 | } 101 | wg.Wait() 102 | } 103 | 104 | func (pg *ProcessGroup) Shutdown() { 105 | var wg sync.WaitGroup 106 | for _, process := range pg.processes { 107 | wg.Add(1) 108 | go func(process *Process) { 109 | defer wg.Done() 110 | process.Shutdown() 111 | }(process) 112 | } 113 | wg.Wait() 114 | } 115 | -------------------------------------------------------------------------------- /proxy/processgroup_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "net/http" 6 | "net/http/httptest" 7 | "testing" 8 | 9 | "github.com/stretchr/testify/assert" 10 | ) 11 | 12 | var processGroupTestConfig = AddDefaultGroupToConfig(Config{ 13 | HealthCheckTimeout: 15, 14 | Models: map[string]ModelConfig{ 15 | "model1": getTestSimpleResponderConfig("model1"), 16 | "model2": getTestSimpleResponderConfig("model2"), 17 | "model3": getTestSimpleResponderConfig("model3"), 18 | "model4": getTestSimpleResponderConfig("model4"), 19 | "model5": getTestSimpleResponderConfig("model5"), 20 | }, 21 | Groups: map[string]GroupConfig{ 22 | "G1": { 23 | Swap: true, 24 | Exclusive: true, 25 | Members: []string{"model1", "model2"}, 26 | }, 27 | "G2": { 28 | Swap: false, 29 | Exclusive: true, 30 | Members: []string{"model3", "model4"}, 31 | }, 32 | }, 33 | }) 34 | 35 | func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) { 36 | pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger) 37 | assert.True(t, pg.HasMember("model5")) 38 | } 39 | 40 | func TestProcessGroup_HasMember(t *testing.T) { 41 | pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) 42 | assert.True(t, pg.HasMember("model1")) 43 | assert.True(t, pg.HasMember("model2")) 44 | assert.False(t, pg.HasMember("model3")) 45 | } 46 | 47 | func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) { 48 | pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) 49 | defer pg.StopProcesses(StopWaitForInflightRequest) 50 | 51 | tests := []string{"model1", "model2"} 52 | 53 | for _, modelName := range tests { 54 | t.Run(modelName, func(t *testing.T) { 55 | reqBody := `{"x", "y"}` 56 | req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) 57 | w := httptest.NewRecorder() 58 | 59 | assert.NoError(t, pg.ProxyRequest(modelName, w, req)) 60 | assert.Equal(t, http.StatusOK, w.Code) 61 | assert.Contains(t, w.Body.String(), modelName) 62 | 63 | // make sure only one process is in the running state 64 | count := 0 65 | for _, process := range pg.processes { 66 | if process.CurrentState() == StateReady { 67 | count++ 68 | } 69 | } 70 | assert.Equal(t, 1, count) 71 | }) 72 | } 73 | } 74 | 75 | func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { 76 | pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger) 77 | defer pg.StopProcesses(StopWaitForInflightRequest) 78 | 79 | tests := []string{"model3", "model4"} 80 | 81 | for _, modelName := range tests { 82 | t.Run(modelName, func(t *testing.T) { 83 | reqBody := `{"x", "y"}` 84 | req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) 85 | w := httptest.NewRecorder() 86 | assert.NoError(t, pg.ProxyRequest(modelName, w, req)) 87 | assert.Equal(t, http.StatusOK, w.Code) 88 | assert.Contains(t, w.Body.String(), modelName) 89 | }) 90 | } 91 | 92 | // make sure all the processes are running 93 | for _, process := range pg.processes { 94 | assert.Equal(t, StateReady, process.CurrentState()) 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /proxy/proxymanager.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "io" 8 | "mime/multipart" 9 | "net/http" 10 | "os" 11 | "sort" 12 | "strconv" 13 | "strings" 14 | "sync" 15 | "time" 16 | 17 | "github.com/gin-gonic/gin" 18 | "github.com/tidwall/gjson" 19 | "github.com/tidwall/sjson" 20 | ) 21 | 22 | const ( 23 | PROFILE_SPLIT_CHAR = ":" 24 | ) 25 | 26 | type ProxyManager struct { 27 | sync.Mutex 28 | 29 | config Config 30 | ginEngine *gin.Engine 31 | 32 | // logging 33 | proxyLogger *LogMonitor 34 | upstreamLogger *LogMonitor 35 | muxLogger *LogMonitor 36 | 37 | processGroups map[string]*ProcessGroup 38 | } 39 | 40 | func New(config Config) *ProxyManager { 41 | // set up loggers 42 | stdoutLogger := NewLogMonitorWriter(os.Stdout) 43 | upstreamLogger := NewLogMonitorWriter(stdoutLogger) 44 | proxyLogger := NewLogMonitorWriter(stdoutLogger) 45 | 46 | if config.LogRequests { 47 | proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.") 48 | } 49 | 50 | switch strings.ToLower(strings.TrimSpace(config.LogLevel)) { 51 | case "debug": 52 | proxyLogger.SetLogLevel(LevelDebug) 53 | upstreamLogger.SetLogLevel(LevelDebug) 54 | case "info": 55 | proxyLogger.SetLogLevel(LevelInfo) 56 | upstreamLogger.SetLogLevel(LevelInfo) 57 | case "warn": 58 | proxyLogger.SetLogLevel(LevelWarn) 59 | upstreamLogger.SetLogLevel(LevelWarn) 60 | case "error": 61 | proxyLogger.SetLogLevel(LevelError) 62 | upstreamLogger.SetLogLevel(LevelError) 63 | default: 64 | proxyLogger.SetLogLevel(LevelInfo) 65 | upstreamLogger.SetLogLevel(LevelInfo) 66 | } 67 | 68 | pm := &ProxyManager{ 69 | config: config, 70 | ginEngine: gin.New(), 71 | 72 | proxyLogger: proxyLogger, 73 | muxLogger: stdoutLogger, 74 | upstreamLogger: upstreamLogger, 75 | 76 | processGroups: make(map[string]*ProcessGroup), 77 | } 78 | 79 | // create the process groups 80 | for groupID := range config.Groups { 81 | processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger) 82 | pm.processGroups[groupID] = processGroup 83 | } 84 | 85 | pm.setupGinEngine() 86 | return pm 87 | } 88 | 89 | func (pm *ProxyManager) setupGinEngine() { 90 | pm.ginEngine.Use(func(c *gin.Context) { 91 | // Start timer 92 | start := time.Now() 93 | 94 | // capture these because /upstream/:model rewrites them in c.Next() 95 | clientIP := c.ClientIP() 96 | method := c.Request.Method 97 | path := c.Request.URL.Path 98 | 99 | // Process request 100 | c.Next() 101 | 102 | // Stop timer 103 | duration := time.Since(start) 104 | 105 | statusCode := c.Writer.Status() 106 | bodySize := c.Writer.Size() 107 | 108 | pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v", 109 | clientIP, 110 | method, 111 | path, 112 | c.Request.Proto, 113 | statusCode, 114 | bodySize, 115 | c.Request.UserAgent(), 116 | duration, 117 | ) 118 | }) 119 | 120 | // see: issue: #81, #77 and #42 for CORS issues 121 | // respond with permissive OPTIONS for any endpoint 122 | pm.ginEngine.Use(func(c *gin.Context) { 123 | if c.Request.Method == "OPTIONS" { 124 | c.Header("Access-Control-Allow-Origin", "*") 125 | c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") 126 | 127 | // allow whatever the client requested by default 128 | if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" { 129 | sanitized := SanitizeAccessControlRequestHeaderValues(headers) 130 | c.Header("Access-Control-Allow-Headers", sanitized) 131 | } else { 132 | c.Header( 133 | "Access-Control-Allow-Headers", 134 | "Content-Type, Authorization, Accept, X-Requested-With", 135 | ) 136 | } 137 | c.Header("Access-Control-Max-Age", "86400") 138 | c.AbortWithStatus(http.StatusNoContent) 139 | return 140 | } 141 | c.Next() 142 | }) 143 | 144 | // Set up routes using the Gin engine 145 | pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler) 146 | // Support legacy /v1/completions api, see issue #12 147 | pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler) 148 | 149 | // Support embeddings 150 | pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler) 151 | pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler) 152 | 153 | // Support audio/speech endpoint 154 | pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) 155 | pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler) 156 | 157 | pm.ginEngine.GET("/v1/models", pm.listModelsHandler) 158 | 159 | // in proxymanager_loghandlers.go 160 | pm.ginEngine.GET("/logs", pm.sendLogsHandlers) 161 | pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) 162 | pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) 163 | pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler) 164 | pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE) 165 | 166 | pm.ginEngine.GET("/upstream", pm.upstreamIndex) 167 | pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream) 168 | 169 | pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) 170 | 171 | pm.ginEngine.GET("/running", pm.listRunningProcessesHandler) 172 | 173 | pm.ginEngine.GET("/", func(c *gin.Context) { 174 | // Set the Content-Type header to text/html 175 | c.Header("Content-Type", "text/html") 176 | 177 | // Write the embedded HTML content to the response 178 | htmlData, err := getHTMLFile("index.html") 179 | if err != nil { 180 | c.String(http.StatusInternalServerError, err.Error()) 181 | return 182 | } 183 | _, err = c.Writer.Write(htmlData) 184 | if err != nil { 185 | c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err)) 186 | return 187 | } 188 | }) 189 | 190 | pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) { 191 | if data, err := getHTMLFile("favicon.ico"); err == nil { 192 | c.Data(http.StatusOK, "image/x-icon", data) 193 | } else { 194 | c.String(http.StatusInternalServerError, err.Error()) 195 | } 196 | }) 197 | 198 | // Disable console color for testing 199 | gin.DisableConsoleColor() 200 | } 201 | 202 | // ServeHTTP implements http.Handler interface 203 | func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) { 204 | pm.ginEngine.ServeHTTP(w, r) 205 | } 206 | 207 | // StopProcesses acquires a lock and stops all running upstream processes. 208 | // This is the public method safe for concurrent calls. 209 | // Unlike Shutdown, this method only stops the processes but doesn't perform 210 | // a complete shutdown, allowing for process replacement without full termination. 211 | func (pm *ProxyManager) StopProcesses(strategy StopStrategy) { 212 | pm.Lock() 213 | defer pm.Unlock() 214 | 215 | // stop Processes in parallel 216 | var wg sync.WaitGroup 217 | for _, processGroup := range pm.processGroups { 218 | wg.Add(1) 219 | go func(processGroup *ProcessGroup) { 220 | defer wg.Done() 221 | processGroup.StopProcesses(strategy) 222 | }(processGroup) 223 | } 224 | 225 | wg.Wait() 226 | } 227 | 228 | // Shutdown stops all processes managed by this ProxyManager 229 | func (pm *ProxyManager) Shutdown() { 230 | pm.Lock() 231 | defer pm.Unlock() 232 | 233 | pm.proxyLogger.Debug("Shutdown() called in proxy manager") 234 | 235 | var wg sync.WaitGroup 236 | // Send shutdown signal to all process in groups 237 | for _, processGroup := range pm.processGroups { 238 | wg.Add(1) 239 | go func(processGroup *ProcessGroup) { 240 | defer wg.Done() 241 | processGroup.Shutdown() 242 | }(processGroup) 243 | } 244 | wg.Wait() 245 | } 246 | 247 | func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) { 248 | // de-alias the real model name and get a real one 249 | realModelName, found := pm.config.RealModelName(requestedModel) 250 | if !found { 251 | return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel) 252 | } 253 | 254 | processGroup := pm.findGroupByModelName(realModelName) 255 | if processGroup == nil { 256 | return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel) 257 | } 258 | 259 | if processGroup.exclusive { 260 | pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id) 261 | for groupId, otherGroup := range pm.processGroups { 262 | if groupId != processGroup.id && !otherGroup.persistent { 263 | otherGroup.StopProcesses(StopWaitForInflightRequest) 264 | } 265 | } 266 | } 267 | 268 | return processGroup, realModelName, nil 269 | } 270 | 271 | func (pm *ProxyManager) listModelsHandler(c *gin.Context) { 272 | data := []interface{}{} 273 | for id, modelConfig := range pm.config.Models { 274 | if modelConfig.Unlisted { 275 | continue 276 | } 277 | 278 | data = append(data, map[string]interface{}{ 279 | "id": id, 280 | "object": "model", 281 | "created": time.Now().Unix(), 282 | "owned_by": "llama-swap", 283 | }) 284 | } 285 | 286 | // Set the Content-Type header to application/json 287 | c.Header("Content-Type", "application/json") 288 | 289 | if origin := c.Request.Header.Get("Origin"); origin != "" { 290 | c.Header("Access-Control-Allow-Origin", origin) 291 | } 292 | 293 | // Encode the data as JSON and write it to the response writer 294 | if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"object": "list", "data": data}); err != nil { 295 | pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error())) 296 | return 297 | } 298 | } 299 | 300 | func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { 301 | requestedModel := c.Param("model_id") 302 | 303 | if requestedModel == "" { 304 | pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") 305 | return 306 | } 307 | 308 | processGroup, _, err := pm.swapProcessGroup(requestedModel) 309 | if err != nil { 310 | pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) 311 | return 312 | } 313 | 314 | // rewrite the path 315 | c.Request.URL.Path = c.Param("upstreamPath") 316 | processGroup.ProxyRequest(requestedModel, c.Writer, c.Request) 317 | } 318 | 319 | func (pm *ProxyManager) upstreamIndex(c *gin.Context) { 320 | var html strings.Builder 321 | 322 | html.WriteString("\n

Available Models

Unload all models
    ") 323 | 324 | // Extract keys and sort them 325 | var modelIDs []string 326 | for modelID, modelConfig := range pm.config.Models { 327 | if modelConfig.Unlisted { 328 | continue 329 | } 330 | 331 | modelIDs = append(modelIDs, modelID) 332 | } 333 | sort.Strings(modelIDs) 334 | 335 | // Iterate over sorted keys 336 | for _, modelID := range modelIDs { 337 | // Get process state 338 | processGroup := pm.findGroupByModelName(modelID) 339 | var state string 340 | if processGroup != nil { 341 | process := processGroup.processes[modelID] 342 | if process != nil { 343 | var stateStr string 344 | switch process.CurrentState() { 345 | case StateReady: 346 | stateStr = "Ready" 347 | case StateStarting: 348 | stateStr = "Starting" 349 | case StateStopping: 350 | stateStr = "Stopping" 351 | case StateFailed: 352 | stateStr = "Failed" 353 | case StateShutdown: 354 | stateStr = "Shutdown" 355 | case StateStopped: 356 | stateStr = "Stopped" 357 | default: 358 | stateStr = "Unknown" 359 | } 360 | state = stateStr 361 | } 362 | } 363 | html.WriteString(fmt.Sprintf("
  • %s - %s
  • ", modelID, modelID, state)) 364 | } 365 | html.WriteString("
") 366 | c.Header("Content-Type", "text/html") 367 | c.String(http.StatusOK, html.String()) 368 | } 369 | 370 | func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { 371 | bodyBytes, err := io.ReadAll(c.Request.Body) 372 | if err != nil { 373 | pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body") 374 | return 375 | } 376 | 377 | requestedModel := gjson.GetBytes(bodyBytes, "model").String() 378 | if requestedModel == "" { 379 | pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") 380 | return 381 | } 382 | 383 | processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) 384 | if err != nil { 385 | pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) 386 | return 387 | } 388 | 389 | // issue #69 allow custom model names to be sent to upstream 390 | useModelName := pm.config.Models[realModelName].UseModelName 391 | if useModelName != "" { 392 | bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName) 393 | if err != nil { 394 | pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error())) 395 | return 396 | } 397 | } 398 | 399 | c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) 400 | 401 | // dechunk it as we already have all the body bytes see issue #11 402 | c.Request.Header.Del("transfer-encoding") 403 | c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes))) 404 | c.Request.ContentLength = int64(len(bodyBytes)) 405 | 406 | if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { 407 | pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) 408 | pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) 409 | return 410 | } 411 | } 412 | 413 | func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { 414 | // Parse multipart form 415 | if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk 416 | pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) 417 | return 418 | } 419 | 420 | // Get model parameter from the form 421 | requestedModel := c.Request.FormValue("model") 422 | if requestedModel == "" { 423 | pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data") 424 | return 425 | } 426 | 427 | processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) 428 | if err != nil { 429 | pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) 430 | return 431 | } 432 | 433 | // We need to reconstruct the multipart form in any case since the body is consumed 434 | // Create a new buffer for the reconstructed request 435 | var requestBuffer bytes.Buffer 436 | multipartWriter := multipart.NewWriter(&requestBuffer) 437 | 438 | // Copy all form values 439 | for key, values := range c.Request.MultipartForm.Value { 440 | for _, value := range values { 441 | fieldValue := value 442 | // If this is the model field and we have a profile, use just the model name 443 | if key == "model" { 444 | // # issue #69 allow custom model names to be sent to upstream 445 | useModelName := pm.config.Models[realModelName].UseModelName 446 | 447 | if useModelName != "" { 448 | fieldValue = useModelName 449 | } else { 450 | fieldValue = requestedModel 451 | } 452 | } 453 | field, err := multipartWriter.CreateFormField(key) 454 | if err != nil { 455 | pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field") 456 | return 457 | } 458 | if _, err = field.Write([]byte(fieldValue)); err != nil { 459 | pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field") 460 | return 461 | } 462 | } 463 | } 464 | 465 | // Copy all files from the original request 466 | for key, fileHeaders := range c.Request.MultipartForm.File { 467 | for _, fileHeader := range fileHeaders { 468 | formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename) 469 | if err != nil { 470 | pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file") 471 | return 472 | } 473 | 474 | file, err := fileHeader.Open() 475 | if err != nil { 476 | pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file") 477 | return 478 | } 479 | 480 | if _, err = io.Copy(formFile, file); err != nil { 481 | file.Close() 482 | pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data") 483 | return 484 | } 485 | file.Close() 486 | } 487 | } 488 | 489 | // Close the multipart writer to finalize the form 490 | if err := multipartWriter.Close(); err != nil { 491 | pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form") 492 | return 493 | } 494 | 495 | // Create a new request with the reconstructed form data 496 | modifiedReq, err := http.NewRequestWithContext( 497 | c.Request.Context(), 498 | c.Request.Method, 499 | c.Request.URL.String(), 500 | &requestBuffer, 501 | ) 502 | if err != nil { 503 | pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request") 504 | return 505 | } 506 | 507 | // Copy the headers from the original request 508 | modifiedReq.Header = c.Request.Header.Clone() 509 | modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) 510 | 511 | // set the content length of the body 512 | modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len())) 513 | modifiedReq.ContentLength = int64(requestBuffer.Len()) 514 | 515 | // Use the modified request for proxying 516 | if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil { 517 | pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) 518 | pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) 519 | return 520 | } 521 | } 522 | 523 | func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) { 524 | acceptHeader := c.GetHeader("Accept") 525 | 526 | if strings.Contains(acceptHeader, "application/json") { 527 | c.JSON(statusCode, gin.H{"error": message}) 528 | } else { 529 | c.String(statusCode, message) 530 | } 531 | } 532 | 533 | func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { 534 | pm.StopProcesses(StopImmediately) 535 | c.String(http.StatusOK, "OK") 536 | } 537 | 538 | func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) { 539 | context.Header("Content-Type", "application/json") 540 | runningProcesses := make([]gin.H, 0) // Default to an empty response. 541 | 542 | for _, processGroup := range pm.processGroups { 543 | for _, process := range processGroup.processes { 544 | if process.CurrentState() == StateReady { 545 | runningProcesses = append(runningProcesses, gin.H{ 546 | "model": process.ID, 547 | "state": process.state, 548 | }) 549 | } 550 | } 551 | } 552 | 553 | // Put the results under the `running` key. 554 | response := gin.H{ 555 | "running": runningProcesses, 556 | } 557 | 558 | context.JSON(http.StatusOK, response) // Always return 200 OK 559 | } 560 | 561 | func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup { 562 | for _, group := range pm.processGroups { 563 | if group.HasMember(modelName) { 564 | return group 565 | } 566 | } 567 | return nil 568 | } 569 | -------------------------------------------------------------------------------- /proxy/proxymanager_loghandlers.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "fmt" 5 | "net/http" 6 | "strings" 7 | 8 | "github.com/gin-gonic/gin" 9 | ) 10 | 11 | func (pm *ProxyManager) sendLogsHandlers(c *gin.Context) { 12 | accept := c.GetHeader("Accept") 13 | if strings.Contains(accept, "text/html") { 14 | // Set the Content-Type header to text/html 15 | c.Header("Content-Type", "text/html") 16 | 17 | // Write the embedded HTML content to the response 18 | logsHTML, err := getHTMLFile("logs.html") 19 | if err != nil { 20 | c.String(http.StatusInternalServerError, err.Error()) 21 | return 22 | } 23 | _, err = c.Writer.Write(logsHTML) 24 | if err != nil { 25 | c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err)) 26 | return 27 | } 28 | } else { 29 | c.Header("Content-Type", "text/plain") 30 | history := pm.muxLogger.GetHistory() 31 | _, err := c.Writer.Write(history) 32 | if err != nil { 33 | c.AbortWithError(http.StatusInternalServerError, err) 34 | return 35 | } 36 | } 37 | } 38 | 39 | func (pm *ProxyManager) streamLogsHandler(c *gin.Context) { 40 | c.Header("Content-Type", "text/plain") 41 | c.Header("Transfer-Encoding", "chunked") 42 | c.Header("X-Content-Type-Options", "nosniff") 43 | 44 | logMonitorId := c.Param("logMonitorID") 45 | logger, err := pm.getLogger(logMonitorId) 46 | if err != nil { 47 | c.String(http.StatusBadRequest, err.Error()) 48 | return 49 | } 50 | ch := logger.Subscribe() 51 | defer logger.Unsubscribe(ch) 52 | 53 | notify := c.Request.Context().Done() 54 | flusher, ok := c.Writer.(http.Flusher) 55 | if !ok { 56 | c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported")) 57 | return 58 | } 59 | 60 | _, skipHistory := c.GetQuery("no-history") 61 | // Send history first if not skipped 62 | 63 | if !skipHistory { 64 | history := logger.GetHistory() 65 | if len(history) != 0 { 66 | c.Writer.Write(history) 67 | flusher.Flush() 68 | } 69 | } 70 | 71 | // Stream new logs 72 | for { 73 | select { 74 | case msg := <-ch: 75 | _, err := c.Writer.Write(msg) 76 | if err != nil { 77 | // just break the loop if we can't write for some reason 78 | return 79 | } 80 | flusher.Flush() 81 | case <-notify: 82 | return 83 | } 84 | } 85 | } 86 | 87 | func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) { 88 | c.Header("Content-Type", "text/event-stream") 89 | c.Header("Cache-Control", "no-cache") 90 | c.Header("Connection", "keep-alive") 91 | c.Header("X-Content-Type-Options", "nosniff") 92 | 93 | logMonitorId := c.Param("logMonitorID") 94 | logger, err := pm.getLogger(logMonitorId) 95 | if err != nil { 96 | c.String(http.StatusBadRequest, err.Error()) 97 | return 98 | } 99 | ch := logger.Subscribe() 100 | defer logger.Unsubscribe(ch) 101 | 102 | notify := c.Request.Context().Done() 103 | 104 | // Send history first if not skipped 105 | _, skipHistory := c.GetQuery("no-history") 106 | if !skipHistory { 107 | history := logger.GetHistory() 108 | if len(history) != 0 { 109 | c.SSEvent("message", string(history)) 110 | c.Writer.Flush() 111 | } 112 | } 113 | 114 | // Stream new logs 115 | for { 116 | select { 117 | case msg := <-ch: 118 | c.SSEvent("message", string(msg)) 119 | c.Writer.Flush() 120 | case <-notify: 121 | return 122 | } 123 | } 124 | } 125 | 126 | // getLogger searches for the appropriate logger based on the logMonitorId 127 | func (pm *ProxyManager) getLogger(logMonitorId string) (*LogMonitor, error) { 128 | var logger *LogMonitor 129 | 130 | if logMonitorId == "" { 131 | // maintain the default 132 | logger = pm.muxLogger 133 | } else if logMonitorId == "proxy" { 134 | logger = pm.proxyLogger 135 | } else if logMonitorId == "upstream" { 136 | logger = pm.upstreamLogger 137 | } else { 138 | return nil, fmt.Errorf("invalid logger. Use 'proxy' or 'upstream'") 139 | } 140 | 141 | return logger, nil 142 | } 143 | -------------------------------------------------------------------------------- /proxy/proxymanager_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "bytes" 5 | "encoding/json" 6 | "fmt" 7 | "math/rand" 8 | "mime/multipart" 9 | "net/http" 10 | "net/http/httptest" 11 | "strconv" 12 | "sync" 13 | "testing" 14 | "time" 15 | 16 | "github.com/stretchr/testify/assert" 17 | "github.com/tidwall/gjson" 18 | ) 19 | 20 | func TestProxyManager_SwapProcessCorrectly(t *testing.T) { 21 | config := AddDefaultGroupToConfig(Config{ 22 | HealthCheckTimeout: 15, 23 | Models: map[string]ModelConfig{ 24 | "model1": getTestSimpleResponderConfig("model1"), 25 | "model2": getTestSimpleResponderConfig("model2"), 26 | }, 27 | LogLevel: "error", 28 | }) 29 | 30 | proxy := New(config) 31 | defer proxy.StopProcesses(StopWaitForInflightRequest) 32 | 33 | for _, modelName := range []string{"model1", "model2"} { 34 | reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) 35 | req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) 36 | w := httptest.NewRecorder() 37 | 38 | proxy.ServeHTTP(w, req) 39 | assert.Equal(t, http.StatusOK, w.Code) 40 | assert.Contains(t, w.Body.String(), modelName) 41 | } 42 | } 43 | 44 | func TestProxyManager_SwapMultiProcess(t *testing.T) { 45 | config := AddDefaultGroupToConfig(Config{ 46 | HealthCheckTimeout: 15, 47 | Models: map[string]ModelConfig{ 48 | "model1": getTestSimpleResponderConfig("model1"), 49 | "model2": getTestSimpleResponderConfig("model2"), 50 | }, 51 | LogLevel: "error", 52 | Groups: map[string]GroupConfig{ 53 | "G1": { 54 | Swap: true, 55 | Exclusive: false, 56 | Members: []string{"model1"}, 57 | }, 58 | "G2": { 59 | Swap: true, 60 | Exclusive: false, 61 | Members: []string{"model2"}, 62 | }, 63 | }, 64 | }) 65 | 66 | proxy := New(config) 67 | defer proxy.StopProcesses(StopWaitForInflightRequest) 68 | 69 | tests := []string{"model1", "model2"} 70 | for _, requestedModel := range tests { 71 | t.Run(requestedModel, func(t *testing.T) { 72 | reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) 73 | req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) 74 | w := httptest.NewRecorder() 75 | 76 | proxy.ServeHTTP(w, req) 77 | assert.Equal(t, http.StatusOK, w.Code) 78 | assert.Contains(t, w.Body.String(), requestedModel) 79 | }) 80 | } 81 | 82 | // make sure there's two loaded models 83 | assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady) 84 | assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady) 85 | } 86 | 87 | // Test that a persistent group is not affected by the swapping behaviour of 88 | // other groups. 89 | func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) { 90 | config := AddDefaultGroupToConfig(Config{ 91 | HealthCheckTimeout: 15, 92 | Models: map[string]ModelConfig{ 93 | "model1": getTestSimpleResponderConfig("model1"), // goes into the default group 94 | "model2": getTestSimpleResponderConfig("model2"), 95 | }, 96 | LogLevel: "error", 97 | Groups: map[string]GroupConfig{ 98 | // the forever group is persistent and should not be affected by model1 99 | "forever": { 100 | Swap: true, 101 | Exclusive: false, 102 | Persistent: true, 103 | Members: []string{"model2"}, 104 | }, 105 | }, 106 | }) 107 | 108 | proxy := New(config) 109 | defer proxy.StopProcesses(StopWaitForInflightRequest) 110 | 111 | // make requests to load all models, loading model1 should not affect model2 112 | tests := []string{"model2", "model1"} 113 | for _, requestedModel := range tests { 114 | reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) 115 | req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) 116 | w := httptest.NewRecorder() 117 | 118 | proxy.ServeHTTP(w, req) 119 | assert.Equal(t, http.StatusOK, w.Code) 120 | assert.Contains(t, w.Body.String(), requestedModel) 121 | } 122 | 123 | assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady) 124 | assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady) 125 | } 126 | 127 | // When a request for a different model comes in ProxyManager should wait until 128 | // the first request is complete before swapping. Both requests should complete 129 | func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { 130 | if testing.Short() { 131 | t.Skip("skipping slow test") 132 | } 133 | 134 | config := AddDefaultGroupToConfig(Config{ 135 | HealthCheckTimeout: 15, 136 | Models: map[string]ModelConfig{ 137 | "model1": getTestSimpleResponderConfig("model1"), 138 | "model2": getTestSimpleResponderConfig("model2"), 139 | "model3": getTestSimpleResponderConfig("model3"), 140 | }, 141 | LogLevel: "error", 142 | }) 143 | 144 | proxy := New(config) 145 | defer proxy.StopProcesses(StopWaitForInflightRequest) 146 | 147 | results := map[string]string{} 148 | 149 | var wg sync.WaitGroup 150 | var mu sync.Mutex 151 | 152 | for key := range config.Models { 153 | wg.Add(1) 154 | go func(key string) { 155 | defer wg.Done() 156 | 157 | reqBody := fmt.Sprintf(`{"model":"%s"}`, key) 158 | req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody)) 159 | w := httptest.NewRecorder() 160 | 161 | proxy.ServeHTTP(w, req) 162 | 163 | if w.Code != http.StatusOK { 164 | t.Errorf("Expected status OK, got %d for key %s", w.Code, key) 165 | } 166 | 167 | mu.Lock() 168 | var response map[string]string 169 | assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) 170 | results[key] = response["responseMessage"] 171 | mu.Unlock() 172 | }(key) 173 | 174 | <-time.After(time.Millisecond) 175 | } 176 | 177 | wg.Wait() 178 | assert.Len(t, results, len(config.Models)) 179 | 180 | for key, result := range results { 181 | assert.Equal(t, key, result) 182 | } 183 | } 184 | 185 | func TestProxyManager_ListModelsHandler(t *testing.T) { 186 | config := Config{ 187 | HealthCheckTimeout: 15, 188 | Models: map[string]ModelConfig{ 189 | "model1": getTestSimpleResponderConfig("model1"), 190 | "model2": getTestSimpleResponderConfig("model2"), 191 | "model3": getTestSimpleResponderConfig("model3"), 192 | }, 193 | LogLevel: "error", 194 | } 195 | 196 | proxy := New(config) 197 | 198 | // Create a test request 199 | req := httptest.NewRequest("GET", "/v1/models", nil) 200 | req.Header.Add("Origin", "i-am-the-origin") 201 | w := httptest.NewRecorder() 202 | 203 | // Call the listModelsHandler 204 | proxy.ServeHTTP(w, req) 205 | 206 | // Check the response status code 207 | assert.Equal(t, http.StatusOK, w.Code) 208 | 209 | // Check for Access-Control-Allow-Origin 210 | assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin")) 211 | 212 | // Parse the JSON response 213 | var response struct { 214 | Data []map[string]interface{} `json:"data"` 215 | } 216 | if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { 217 | t.Fatalf("Failed to parse JSON response: %v", err) 218 | } 219 | 220 | // Check the number of models returned 221 | assert.Len(t, response.Data, 3) 222 | 223 | // Check the details of each model 224 | expectedModels := map[string]struct{}{ 225 | "model1": {}, 226 | "model2": {}, 227 | "model3": {}, 228 | } 229 | 230 | for _, model := range response.Data { 231 | modelID, ok := model["id"].(string) 232 | assert.True(t, ok, "model ID should be a string") 233 | _, exists := expectedModels[modelID] 234 | assert.True(t, exists, "unexpected model ID: %s", modelID) 235 | delete(expectedModels, modelID) 236 | 237 | object, ok := model["object"].(string) 238 | assert.True(t, ok, "object should be a string") 239 | assert.Equal(t, "model", object) 240 | 241 | created, ok := model["created"].(float64) 242 | assert.True(t, ok, "created should be a number") 243 | assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive 244 | 245 | ownedBy, ok := model["owned_by"].(string) 246 | assert.True(t, ok, "owned_by should be a string") 247 | assert.Equal(t, "llama-swap", ownedBy) 248 | } 249 | 250 | // Ensure all expected models were returned 251 | assert.Empty(t, expectedModels, "not all expected models were returned") 252 | } 253 | 254 | func TestProxyManager_Shutdown(t *testing.T) { 255 | // make broken model configurations 256 | model1Config := getTestSimpleResponderConfigPort("model1", 9991) 257 | model1Config.Proxy = "http://localhost:10001/" 258 | 259 | model2Config := getTestSimpleResponderConfigPort("model2", 9992) 260 | model2Config.Proxy = "http://localhost:10002/" 261 | 262 | model3Config := getTestSimpleResponderConfigPort("model3", 9993) 263 | model3Config.Proxy = "http://localhost:10003/" 264 | 265 | config := AddDefaultGroupToConfig(Config{ 266 | HealthCheckTimeout: 15, 267 | Models: map[string]ModelConfig{ 268 | "model1": model1Config, 269 | "model2": model2Config, 270 | "model3": model3Config, 271 | }, 272 | LogLevel: "error", 273 | Groups: map[string]GroupConfig{ 274 | "test": { 275 | Swap: false, 276 | Members: []string{"model1", "model2", "model3"}, 277 | }, 278 | }, 279 | }) 280 | 281 | proxy := New(config) 282 | 283 | // Start all the processes 284 | var wg sync.WaitGroup 285 | for _, modelName := range []string{"model1", "model2", "model3"} { 286 | wg.Add(1) 287 | go func(modelName string) { 288 | defer wg.Done() 289 | reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) 290 | req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) 291 | w := httptest.NewRecorder() 292 | 293 | // send a request to trigger the proxy to load ... this should hang waiting for start up 294 | proxy.ServeHTTP(w, req) 295 | assert.Equal(t, http.StatusBadGateway, w.Code) 296 | assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown") 297 | }(modelName) 298 | } 299 | 300 | go func() { 301 | <-time.After(time.Second) 302 | proxy.Shutdown() 303 | }() 304 | wg.Wait() 305 | } 306 | 307 | func TestProxyManager_Unload(t *testing.T) { 308 | config := AddDefaultGroupToConfig(Config{ 309 | HealthCheckTimeout: 15, 310 | Models: map[string]ModelConfig{ 311 | "model1": getTestSimpleResponderConfig("model1"), 312 | }, 313 | LogLevel: "error", 314 | }) 315 | 316 | proxy := New(config) 317 | reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") 318 | req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) 319 | w := httptest.NewRecorder() 320 | proxy.ServeHTTP(w, req) 321 | 322 | assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) 323 | req = httptest.NewRequest("GET", "/unload", nil) 324 | w = httptest.NewRecorder() 325 | proxy.ServeHTTP(w, req) 326 | assert.Equal(t, http.StatusOK, w.Code) 327 | assert.Equal(t, w.Body.String(), "OK") 328 | 329 | // give it a bit of time to stop 330 | <-time.After(time.Millisecond * 250) 331 | assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped) 332 | } 333 | 334 | // Test issue #61 `Listing the current list of models and the loaded model.` 335 | func TestProxyManager_RunningEndpoint(t *testing.T) { 336 | // Shared configuration 337 | config := AddDefaultGroupToConfig(Config{ 338 | HealthCheckTimeout: 15, 339 | Models: map[string]ModelConfig{ 340 | "model1": getTestSimpleResponderConfig("model1"), 341 | "model2": getTestSimpleResponderConfig("model2"), 342 | }, 343 | LogLevel: "warn", 344 | }) 345 | 346 | // Define a helper struct to parse the JSON response. 347 | type RunningResponse struct { 348 | Running []struct { 349 | Model string `json:"model"` 350 | State string `json:"state"` 351 | } `json:"running"` 352 | } 353 | 354 | // Create proxy once for all tests 355 | proxy := New(config) 356 | defer proxy.StopProcesses(StopWaitForInflightRequest) 357 | 358 | t.Run("no models loaded", func(t *testing.T) { 359 | req := httptest.NewRequest("GET", "/running", nil) 360 | w := httptest.NewRecorder() 361 | proxy.ServeHTTP(w, req) 362 | 363 | assert.Equal(t, http.StatusOK, w.Code) 364 | 365 | var response RunningResponse 366 | 367 | // Check if this is a valid JSON object. 368 | assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) 369 | 370 | // We should have an empty running array here. 371 | assert.Empty(t, response.Running, "expected no running models") 372 | }) 373 | 374 | t.Run("single model loaded", func(t *testing.T) { 375 | // Load just a model. 376 | reqBody := `{"model":"model1"}` 377 | req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) 378 | w := httptest.NewRecorder() 379 | proxy.ServeHTTP(w, req) 380 | assert.Equal(t, http.StatusOK, w.Code) 381 | 382 | // Simulate browser call for the `/running` endpoint. 383 | req = httptest.NewRequest("GET", "/running", nil) 384 | w = httptest.NewRecorder() 385 | proxy.ServeHTTP(w, req) 386 | 387 | var response RunningResponse 388 | assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) 389 | 390 | // Check if we have a single array element. 391 | assert.Len(t, response.Running, 1) 392 | 393 | // Is this the right model? 394 | assert.Equal(t, "model1", response.Running[0].Model) 395 | 396 | // Is the model loaded? 397 | assert.Equal(t, "ready", response.Running[0].State) 398 | }) 399 | } 400 | 401 | func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { 402 | config := AddDefaultGroupToConfig(Config{ 403 | HealthCheckTimeout: 15, 404 | Models: map[string]ModelConfig{ 405 | "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), 406 | }, 407 | LogLevel: "error", 408 | }) 409 | 410 | proxy := New(config) 411 | defer proxy.StopProcesses(StopWaitForInflightRequest) 412 | 413 | // Create a buffer with multipart form data 414 | var b bytes.Buffer 415 | w := multipart.NewWriter(&b) 416 | 417 | // Add the model field 418 | fw, err := w.CreateFormField("model") 419 | assert.NoError(t, err) 420 | _, err = fw.Write([]byte("TheExpectedModel")) 421 | assert.NoError(t, err) 422 | 423 | // Add a file field 424 | fw, err = w.CreateFormFile("file", "test.mp3") 425 | assert.NoError(t, err) 426 | // Generate random content length between 10 and 20 427 | contentLength := rand.Intn(11) + 10 // 10 to 20 428 | content := make([]byte, contentLength) 429 | _, err = fw.Write(content) 430 | assert.NoError(t, err) 431 | w.Close() 432 | 433 | // Create the request with the multipart form data 434 | req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) 435 | req.Header.Set("Content-Type", w.FormDataContentType()) 436 | rec := httptest.NewRecorder() 437 | proxy.ServeHTTP(rec, req) 438 | 439 | // Verify the response 440 | assert.Equal(t, http.StatusOK, rec.Code) 441 | var response map[string]string 442 | err = json.Unmarshal(rec.Body.Bytes(), &response) 443 | assert.NoError(t, err) 444 | assert.Equal(t, "TheExpectedModel", response["model"]) 445 | assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder 446 | assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"]) 447 | } 448 | 449 | // Test useModelName in configuration sends overrides what is sent to upstream 450 | func TestProxyManager_UseModelName(t *testing.T) { 451 | upstreamModelName := "upstreamModel" 452 | modelConfig := getTestSimpleResponderConfig(upstreamModelName) 453 | modelConfig.UseModelName = upstreamModelName 454 | 455 | config := AddDefaultGroupToConfig(Config{ 456 | HealthCheckTimeout: 15, 457 | Models: map[string]ModelConfig{ 458 | "model1": modelConfig, 459 | }, 460 | LogLevel: "error", 461 | }) 462 | 463 | proxy := New(config) 464 | defer proxy.StopProcesses(StopWaitForInflightRequest) 465 | 466 | requestedModel := "model1" 467 | 468 | t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) { 469 | reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) 470 | req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) 471 | w := httptest.NewRecorder() 472 | 473 | proxy.ServeHTTP(w, req) 474 | assert.Equal(t, http.StatusOK, w.Code) 475 | assert.Contains(t, w.Body.String(), upstreamModelName) 476 | 477 | // make sure the content length was set correctly 478 | // simple-responder will return the content length it got in the response 479 | body := w.Body.Bytes() 480 | contentLength := int(gjson.GetBytes(body, "h_content_length").Int()) 481 | assert.Equal(t, len(fmt.Sprintf(`{"model":"%s"}`, upstreamModelName)), contentLength) 482 | }) 483 | 484 | t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) { 485 | // Create a buffer with multipart form data 486 | var b bytes.Buffer 487 | w := multipart.NewWriter(&b) 488 | 489 | // Add the model field 490 | fw, err := w.CreateFormField("model") 491 | assert.NoError(t, err) 492 | _, err = fw.Write([]byte(requestedModel)) 493 | assert.NoError(t, err) 494 | 495 | // Add a file field 496 | fw, err = w.CreateFormFile("file", "test.mp3") 497 | assert.NoError(t, err) 498 | _, err = fw.Write([]byte("test")) 499 | assert.NoError(t, err) 500 | w.Close() 501 | 502 | // Create the request with the multipart form data 503 | req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) 504 | req.Header.Set("Content-Type", w.FormDataContentType()) 505 | rec := httptest.NewRecorder() 506 | proxy.ServeHTTP(rec, req) 507 | 508 | // Verify the response 509 | assert.Equal(t, http.StatusOK, rec.Code) 510 | var response map[string]string 511 | err = json.Unmarshal(rec.Body.Bytes(), &response) 512 | assert.NoError(t, err) 513 | assert.Equal(t, upstreamModelName, response["model"]) 514 | }) 515 | } 516 | 517 | func TestProxyManager_CORSOptionsHandler(t *testing.T) { 518 | config := AddDefaultGroupToConfig(Config{ 519 | HealthCheckTimeout: 15, 520 | Models: map[string]ModelConfig{ 521 | "model1": getTestSimpleResponderConfig("model1"), 522 | }, 523 | LogLevel: "error", 524 | }) 525 | 526 | tests := []struct { 527 | name string 528 | method string 529 | requestHeaders map[string]string 530 | expectedStatus int 531 | expectedHeaders map[string]string 532 | }{ 533 | { 534 | name: "OPTIONS with no headers", 535 | method: "OPTIONS", 536 | expectedStatus: http.StatusNoContent, 537 | expectedHeaders: map[string]string{ 538 | "Access-Control-Allow-Origin": "*", 539 | "Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS", 540 | "Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With", 541 | }, 542 | }, 543 | { 544 | name: "OPTIONS with specific headers", 545 | method: "OPTIONS", 546 | requestHeaders: map[string]string{ 547 | "Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header", 548 | }, 549 | expectedStatus: http.StatusNoContent, 550 | expectedHeaders: map[string]string{ 551 | "Access-Control-Allow-Origin": "*", 552 | "Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS", 553 | "Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header", 554 | }, 555 | }, 556 | { 557 | name: "Non-OPTIONS request", 558 | method: "GET", 559 | expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined 560 | }, 561 | } 562 | 563 | for _, tt := range tests { 564 | t.Run(tt.name, func(t *testing.T) { 565 | proxy := New(config) 566 | defer proxy.StopProcesses(StopWaitForInflightRequest) 567 | 568 | req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil) 569 | for k, v := range tt.requestHeaders { 570 | req.Header.Set(k, v) 571 | } 572 | 573 | w := httptest.NewRecorder() 574 | proxy.ServeHTTP(w, req) 575 | 576 | assert.Equal(t, tt.expectedStatus, w.Code) 577 | 578 | for header, expectedValue := range tt.expectedHeaders { 579 | assert.Equal(t, expectedValue, w.Header().Get(header)) 580 | } 581 | }) 582 | } 583 | } 584 | 585 | func TestProxyManager_Upstream(t *testing.T) { 586 | config := AddDefaultGroupToConfig(Config{ 587 | HealthCheckTimeout: 15, 588 | Models: map[string]ModelConfig{ 589 | "model1": getTestSimpleResponderConfig("model1"), 590 | }, 591 | LogLevel: "error", 592 | }) 593 | 594 | proxy := New(config) 595 | defer proxy.StopProcesses(StopWaitForInflightRequest) 596 | req := httptest.NewRequest("GET", "/upstream/model1/test", nil) 597 | rec := httptest.NewRecorder() 598 | proxy.ServeHTTP(rec, req) 599 | assert.Equal(t, http.StatusOK, rec.Code) 600 | assert.Equal(t, "model1", rec.Body.String()) 601 | } 602 | 603 | func TestProxyManager_ChatContentLength(t *testing.T) { 604 | config := AddDefaultGroupToConfig(Config{ 605 | HealthCheckTimeout: 15, 606 | Models: map[string]ModelConfig{ 607 | "model1": getTestSimpleResponderConfig("model1"), 608 | }, 609 | LogLevel: "error", 610 | }) 611 | 612 | proxy := New(config) 613 | defer proxy.StopProcesses(StopWaitForInflightRequest) 614 | 615 | reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1") 616 | req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) 617 | w := httptest.NewRecorder() 618 | 619 | proxy.ServeHTTP(w, req) 620 | assert.Equal(t, http.StatusOK, w.Code) 621 | var response map[string]string 622 | assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) 623 | assert.Equal(t, "81", response["h_content_length"]) 624 | assert.Equal(t, "model1", response["responseMessage"]) 625 | } 626 | -------------------------------------------------------------------------------- /proxy/sanitize_cors.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import ( 4 | "strings" 5 | ) 6 | 7 | func isTokenChar(r rune) bool { 8 | switch { 9 | case r >= 'a' && r <= 'z': 10 | case r >= 'A' && r <= 'Z': 11 | case r >= '0' && r <= '9': 12 | case strings.ContainsRune("!#$%&'*+-.^_`|~", r): 13 | default: 14 | return false 15 | } 16 | return true 17 | } 18 | 19 | func SanitizeAccessControlRequestHeaderValues(headerValues string) string { 20 | parts := strings.Split(headerValues, ",") 21 | valid := make([]string, 0, len(parts)) 22 | 23 | for _, p := range parts { 24 | v := strings.TrimSpace(p) 25 | if v == "" { 26 | continue 27 | } 28 | 29 | validPart := true 30 | for _, c := range v { 31 | if !isTokenChar(c) { 32 | validPart = false 33 | break 34 | } 35 | } 36 | 37 | if validPart { 38 | valid = append(valid, v) 39 | } 40 | } 41 | 42 | return strings.Join(valid, ", ") 43 | } 44 | -------------------------------------------------------------------------------- /proxy/sanitize_cors_test.go: -------------------------------------------------------------------------------- 1 | package proxy 2 | 3 | import "testing" 4 | 5 | func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) { 6 | tests := []struct { 7 | name string 8 | input string 9 | expected string 10 | }{ 11 | { 12 | name: "empty string", 13 | input: "", 14 | expected: "", 15 | }, 16 | { 17 | name: "whitespace only", 18 | input: " ", 19 | expected: "", 20 | }, 21 | { 22 | name: "single valid value", 23 | input: "content-type", 24 | expected: "content-type", 25 | }, 26 | { 27 | name: "multiple valid values", 28 | input: "content-type, authorization, x-requested-with", 29 | expected: "content-type, authorization, x-requested-with", 30 | }, 31 | { 32 | name: "values with extra spaces", 33 | input: " content-type , authorization ", 34 | expected: "content-type, authorization", 35 | }, 36 | { 37 | name: "values with tabs", 38 | input: "content-type,\tauthorization", 39 | expected: "content-type, authorization", 40 | }, 41 | { 42 | name: "values with invalid characters", 43 | input: "content-type, auth\n, x-requested-with\r", 44 | expected: "content-type, auth, x-requested-with", 45 | }, 46 | { 47 | name: "empty values in list", 48 | input: "content-type,,authorization", 49 | expected: "content-type, authorization", 50 | }, 51 | { 52 | name: "leading and trailing commas", 53 | input: ",content-type,authorization,", 54 | expected: "content-type, authorization", 55 | }, 56 | { 57 | name: "mixed valid and invalid values", 58 | input: "content-type, \x00invalid, x-requested-with", 59 | expected: "content-type, x-requested-with", 60 | }, 61 | { 62 | name: "mixed case values", 63 | input: "Content-Type, my-Valid-Header, Another-hEader", 64 | expected: "Content-Type, my-Valid-Header, Another-hEader", 65 | }, 66 | } 67 | 68 | for _, tt := range tests { 69 | t.Run(tt.name, func(t *testing.T) { 70 | got := SanitizeAccessControlRequestHeaderValues(tt.input) 71 | if got != tt.expected { 72 | t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q", 73 | tt.input, got, tt.expected) 74 | } 75 | }) 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /scripts/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # This script installs llama-swap on Linux. 3 | # It detects the current operating system architecture and installs the appropriate version of llama-swap. 4 | 5 | set -eu 6 | 7 | LLAMA_SWAP_DEFAULT_ADDRESS=${LLAMA_SWAP_DEFAULT_ADDRESS:-"127.0.0.1:8080"} 8 | 9 | red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)" 10 | plain="$( (/usr/bin/tput sgr0 || :) 2>&-)" 11 | 12 | status() { echo ">>> $*" >&2; } 13 | error() { echo "${red}ERROR:${plain} $*"; exit 1; } 14 | warning() { echo "${red}WARNING:${plain} $*"; } 15 | 16 | available() { command -v "$1" >/dev/null; } 17 | require() { 18 | _MISSING='' 19 | for TOOL in "$@"; do 20 | if ! available "$TOOL"; then 21 | _MISSING="$_MISSING $TOOL" 22 | fi 23 | done 24 | 25 | echo "$_MISSING" 26 | } 27 | 28 | SUDO= 29 | if [ "$(id -u)" -ne 0 ]; then 30 | if ! available sudo; then 31 | error "This script requires superuser permissions. Please re-run as root." 32 | fi 33 | 34 | SUDO="sudo" 35 | fi 36 | 37 | NEEDS=$(require tee tar python3 mktemp) 38 | if [ -n "$NEEDS" ]; then 39 | status "ERROR: The following tools are required but missing:" 40 | for NEED in $NEEDS; do 41 | echo " - $NEED" 42 | done 43 | exit 1 44 | fi 45 | 46 | [ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.' 47 | 48 | ARCH=$(uname -m) 49 | case "$ARCH" in 50 | x86_64) ARCH="amd64" ;; 51 | aarch64|arm64) ARCH="arm64" ;; 52 | *) error "Unsupported architecture: $ARCH" ;; 53 | esac 54 | 55 | IS_WSL2=false 56 | 57 | KERN=$(uname -r) 58 | case "$KERN" in 59 | *icrosoft*WSL2 | *icrosoft*wsl2) IS_WSL2=true;; 60 | *icrosoft) error "Microsoft WSL1 is not currently supported. Please use WSL2 with 'wsl --set-version 2'" ;; 61 | *) ;; 62 | esac 63 | 64 | download_binary() { 65 | ASSET_NAME="linux_$ARCH" 66 | 67 | TMPDIR=$(mktemp -d) 68 | trap 'rm -rf "${TMPDIR}"' EXIT INT TERM HUP 69 | PYTHON_SCRIPT=$(cat </dev/null 2>&1; then 106 | status "Creating llama-swap user..." 107 | $SUDO useradd -r -s /bin/false -U -m -d /usr/share/llama-swap llama-swap 108 | fi 109 | if getent group render >/dev/null 2>&1; then 110 | status "Adding llama-swap user to render group..." 111 | $SUDO usermod -a -G render llama-swap 112 | fi 113 | if getent group video >/dev/null 2>&1; then 114 | status "Adding llama-swap user to video group..." 115 | $SUDO usermod -a -G video llama-swap 116 | fi 117 | if getent group docker >/dev/null 2>&1; then 118 | status "Adding llama-swap user to docker group..." 119 | $SUDO usermod -a -G docker llama-swap 120 | fi 121 | 122 | status "Adding current user to llama-swap group..." 123 | $SUDO usermod -a -G llama-swap "$(whoami)" 124 | 125 | if [ ! -f "/usr/share/llama-swap/config.yaml" ]; then 126 | status "Creating default config.yaml..." 127 | cat </dev/null 128 | # default 15s likely to fail for default models due to downloading models 129 | healthCheckTimeout: 60 130 | 131 | models: 132 | "qwen2.5": 133 | cmd: | 134 | docker run 135 | --rm 136 | -p \${PORT}:8080 137 | --name qwen2.5 138 | ghcr.io/ggml-org/llama.cpp:server 139 | -hf bartowski/Qwen2.5-0.5B-Instruct-GGUF:Q4_K_M 140 | cmdStop: docker stop qwen2.5 141 | 142 | "smollm2": 143 | cmd: | 144 | docker run 145 | --rm 146 | -p \${PORT}:8080 147 | --name smollm2 148 | ghcr.io/ggml-org/llama.cpp:server 149 | -hf bartowski/SmolLM2-135M-Instruct-GGUF:Q4_K_M 150 | cmdStop: docker stop smollm2 151 | EOF 152 | fi 153 | 154 | status "Creating llama-swap systemd service..." 155 | cat </dev/null 156 | [Unit] 157 | Description=llama-swap 158 | After=network.target 159 | 160 | [Service] 161 | User=llama-swap 162 | Group=llama-swap 163 | 164 | # set this to match your environment 165 | ExecStart=/usr/local/bin/llama-swap --config /usr/share/llama-swap/config.yaml --watch-config -listen ${LLAMA_SWAP_DEFAULT_ADDRESS} 166 | 167 | Restart=on-failure 168 | RestartSec=3 169 | StartLimitBurst=3 170 | StartLimitInterval=30 171 | 172 | [Install] 173 | WantedBy=multi-user.target 174 | EOF 175 | SYSTEMCTL_RUNNING="$(systemctl is-system-running || true)" 176 | case $SYSTEMCTL_RUNNING in 177 | running|degraded) 178 | status "Enabling and starting llama-swap service..." 179 | $SUDO systemctl daemon-reload 180 | $SUDO systemctl enable llama-swap 181 | 182 | start_service() { $SUDO systemctl restart llama-swap; } 183 | trap start_service EXIT 184 | ;; 185 | *) 186 | warning "systemd is not running" 187 | if [ "$IS_WSL2" = true ]; then 188 | warning "see https://learn.microsoft.com/en-us/windows/wsl/systemd#how-to-enable-systemd to enable it" 189 | fi 190 | ;; 191 | esac 192 | } 193 | 194 | if available systemctl; then 195 | configure_systemd 196 | fi 197 | 198 | install_success() { 199 | status "The llama-swap API is now available at http://${LLAMA_SWAP_DEFAULT_ADDRESS}" 200 | status 'Customize the config file at /usr/share/llama-swap/config.yaml.' 201 | status 'Install complete.' 202 | } 203 | 204 | # WSL2 only supports GPUs via nvidia passthrough 205 | # so check for nvidia-smi to determine if GPU is available 206 | if [ "$IS_WSL2" = true ]; then 207 | if available nvidia-smi && [ -n "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\.[0-9]*")" ]; then 208 | status "Nvidia GPU detected." 209 | fi 210 | exit 0 211 | fi 212 | 213 | install_success 214 | -------------------------------------------------------------------------------- /scripts/uninstall.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # This script uninstalls llama-swap on Linux. 3 | # It removes the binary, systemd service, config.yaml (optional), and llama-swap user and group. 4 | 5 | set -eu 6 | 7 | red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)" 8 | plain="$( (/usr/bin/tput sgr0 || :) 2>&-)" 9 | 10 | status() { echo ">>> $*" >&2; } 11 | error() { echo "${red}ERROR:${plain} $*"; exit 1; } 12 | warning() { echo "${red}WARNING:${plain} $*"; } 13 | 14 | available() { command -v $1 >/dev/null; } 15 | 16 | SUDO= 17 | if [ "$(id -u)" -ne 0 ]; then 18 | if ! available sudo; then 19 | error "This script requires superuser permissions. Please re-run as root." 20 | fi 21 | 22 | SUDO="sudo" 23 | fi 24 | 25 | configure_systemd() { 26 | status "Stopping llama-swap service..." 27 | $SUDO systemctl stop llama-swap 28 | 29 | status "Disabling llama-swap service..." 30 | $SUDO systemctl disable llama-swap 31 | } 32 | if available systemctl; then 33 | configure_systemd 34 | fi 35 | 36 | if available llama-swap; then 37 | status "Removing llama-swap binary..." 38 | $SUDO rm $(which llama-swap) 39 | fi 40 | 41 | if [ -f "/usr/share/llama-swap/config.yaml" ]; then 42 | while true; do 43 | printf "Delete config.yaml (/usr/share/llama-swap/config.yaml)? [y/N] " >&2 44 | read answer 45 | case "$answer" in 46 | [Yy]* ) 47 | $SUDO rm -r /usr/share/llama-swap 48 | break 49 | ;; 50 | [Nn]* | "" ) 51 | break 52 | ;; 53 | * ) 54 | echo "Invalid input. Please enter y or n." 55 | ;; 56 | esac 57 | done 58 | fi 59 | 60 | if id llama-swap >/dev/null 2>&1; then 61 | status "Removing llama-swap user..." 62 | $SUDO userdel llama-swap 63 | fi 64 | 65 | if getent group llama-swap >/dev/null 2>&1; then 66 | status "Removing llama-swap group..." 67 | $SUDO groupdel llama-swap 68 | fi 69 | --------------------------------------------------------------------------------