├── .github ├── renovate.json5 └── workflows │ └── test_cmake.yaml ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── Makefile ├── README.md ├── dep ├── CMakeLists.txt ├── cmake_fetchcontent │ ├── fildesh.cmake │ └── llama_cpp.cmake └── cmake_module │ └── FindOpenBLAS.cmake ├── doc └── setting │ ├── model.md │ ├── prompt.md │ ├── sampling.md │ └── stdio.md ├── example ├── CMakeLists.txt └── prompt │ ├── CMakeLists.txt │ ├── README.md │ ├── assistant_alpaca │ ├── README.md │ ├── priming.txt │ ├── rolling.txt │ └── setting.sxpb │ ├── assistant_chatml │ ├── README.md │ ├── priming.txt │ ├── rolling.txt │ └── setting.sxpb │ ├── assistant_coprocess │ ├── README.md │ ├── priming.txt │ └── setting.sxpb │ ├── assistant_gemma │ ├── README.md │ ├── rolling.txt │ └── setting.sxpb │ ├── assistant_llama │ ├── README.md │ ├── priming.txt │ ├── rolling.txt │ └── setting.sxpb │ ├── assistant_mistral │ ├── README.md │ ├── rolling.txt │ └── setting.sxpb │ ├── assistant_plain │ ├── README.md │ └── setting.sxpb │ ├── assistant_vicuna │ ├── README.md │ ├── priming.txt │ ├── rolling.txt │ └── setting.sxpb │ ├── confidant_alpaca │ ├── README.md │ ├── answer.txt │ ├── priming.txt │ ├── rolling.txt │ └── setting.sxpb │ └── roshambo_kira │ ├── README.md │ ├── priming.txt │ ├── rolling.txt │ └── setting.sxpb ├── src ├── CMakeLists.txt ├── chat │ ├── CMakeLists.txt │ ├── chat_main.cc │ ├── cmd.cc │ ├── cmd.hh │ ├── display.cc │ ├── display.hh │ ├── guide.cc │ ├── guide.hh │ ├── opt.cc │ ├── opt.hh │ ├── opt_schema.cc │ ├── opt_schema.hh │ ├── trajectory.cc │ └── trajectory.hh ├── language │ ├── CMakeLists.txt │ ├── inference.cc │ ├── inference.hh │ ├── inference_schema.cc │ ├── inference_schema.hh │ ├── language_schema.cc │ ├── language_schema.hh │ ├── vocabulary.cc │ └── vocabulary.hh └── tokenize │ ├── CMakeLists.txt │ └── tokenize_main.cc └── test ├── CMakeLists.txt ├── chat ├── CMakeLists.txt ├── guide_test.cc ├── opt_test.cc └── trajectory_test.cc ├── example ├── CMakeLists.txt └── prompt │ ├── CMakeLists.txt │ └── parse_test.cc ├── language ├── CMakeLists.txt ├── inference_schema_test.cc ├── language_schema_test.cc └── vocabulary_test.cc └── manual ├── chat.fildesh ├── chat.sh ├── coverage.md └── openblas.md /.github/renovate.json5: -------------------------------------------------------------------------------- 1 | { 2 | "$schema": "https://docs.renovatebot.com/renovate-schema.json", 3 | "extends": [ 4 | "config:base" 5 | ], 6 | "packageRules": [{ 7 | "matchPackagePatterns": ["*"], 8 | "labels": ["dependencies"], 9 | "matchUpdateTypes": ["major", "minor", "patch"], 10 | // Group version update PRs into one. 11 | "groupName": "all dependencies", 12 | "groupSlug": "all-dependencies", 13 | }], 14 | } 15 | -------------------------------------------------------------------------------- /.github/workflows/test_cmake.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | name: CMake 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | test: 7 | strategy: 8 | fail-fast: false 9 | matrix: 10 | include: 11 | - platform: ubuntu-latest 12 | cmake_build_type: RelOnHost 13 | - platform: macos-latest 14 | cmake_build_type: RelOnHost 15 | - platform: windows-latest 16 | cmake_build_type: Release 17 | 18 | runs-on: ${{ matrix.platform }} 19 | 20 | steps: 21 | - uses: actions/checkout@v4 22 | 23 | - name: Configure CMake 24 | run: > 25 | cmake 26 | -DCMAKE_BUILD_TYPE=${{matrix.cmake_build_type}} 27 | -S "${{github.workspace}}" -B "${{runner.workspace}}/bld" 28 | 29 | - name: Build 30 | run: > 31 | cmake 32 | --build "${{runner.workspace}}/bld" 33 | --config ${{matrix.cmake_build_type}} 34 | 35 | - name: Test 36 | working-directory: ${{runner.workspace}}/bld 37 | run: > 38 | ctest 39 | --timeout 10 40 | -C ${{matrix.cmake_build_type}} 41 | 42 | 43 | coverage: 44 | runs-on: ubuntu-latest 45 | env: 46 | cmake_build_type: Debug 47 | 48 | strategy: 49 | fail-fast: false 50 | 51 | steps: 52 | - uses: actions/checkout@v4 53 | 54 | - name: Configure CMake 55 | run: > 56 | cmake 57 | -DCMAKE_BUILD_TYPE=${{env.cmake_build_type}} 58 | -DCMAKE_C_FLAGS="--coverage -Og" 59 | -DCMAKE_CXX_FLAGS="--coverage -Og" 60 | -DCMAKE_EXE_LINKER_FLAGS="--coverage -Og" 61 | -S "${{github.workspace}}" -B "${{runner.workspace}}/bld" 62 | 63 | - name: Build 64 | run: > 65 | cmake 66 | --build "${{runner.workspace}}/bld" 67 | --config ${{env.cmake_build_type}} 68 | 69 | - name: Test 70 | working-directory: ${{runner.workspace}}/bld 71 | run: > 72 | ctest 73 | --timeout 10 74 | -C ${{env.cmake_build_type}} 75 | 76 | - name: LCOV 77 | uses: imciner2/run-lcov@v1 78 | with: 79 | input_directory: "${{runner.workspace}}/bld/" 80 | exclude: '"/usr/*" "${{runner.workspace}}/bld/_deps/*"' 81 | output_file: "${{runner.workspace}}/bld/coverage_report.info" 82 | 83 | - name: Coveralls 84 | uses: coverallsapp/github-action@master 85 | with: 86 | github-token: ${{ secrets.GITHUB_TOKEN }} 87 | path-to-lcov: "${{runner.workspace}}/bld/coverage_report.info" 88 | 89 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.* 2 | !/.gitignore 3 | !/.github/ 4 | 5 | /bld/ 6 | /example/prompt/_*/ 7 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | # CMake 3.14 for FetchContent_MakeAvailable(). 3 | 4 | #set(CMAKE_DISABLE_SOURCE_CHANGES ON) 5 | set(CMAKE_DISABLE_IN_SOURCE_BUILD ON) 6 | project(Rendezllama LANGUAGES "C" "CXX") 7 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 8 | set(CMAKE_CXX_STANDARD 17) 9 | 10 | 11 | option(BUILD_SHARED_LIBS "Build using shared libraries" OFF) 12 | option(LLAMA_OPENBLAS_ON "llama: use OpenBLAS" OFF) 13 | 14 | 15 | if(NOT CMAKE_BUILD_TYPE) 16 | set(CMAKE_BUILD_TYPE Release) 17 | endif() 18 | 19 | string(REPLACE "-DNDEBUG" "" CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE}") 20 | string(REPLACE "/DNDEBUG" "" CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE}") 21 | string(REPLACE "-DNDEBUG" "" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") 22 | string(REPLACE "/DNDEBUG" "" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") 23 | 24 | string(REPLACE "-DNDEBUG" "" CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO}") 25 | string(REPLACE "/DNDEBUG" "" CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO}") 26 | string(REPLACE "-DNDEBUG" "" CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}") 27 | string(REPLACE "/DNDEBUG" "" CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO}") 28 | 29 | set(CMAKE_C_FLAGS_RELONHOST "${CMAKE_C_FLAGS_RELEASE} -Wall -Wextra") 30 | set(CMAKE_CXX_FLAGS_RELONHOST "${CMAKE_CXX_FLAGS_RELEASE} -Wall -Wextra") 31 | 32 | 33 | include(CTest) 34 | 35 | add_subdirectory(dep EXCLUDE_FROM_ALL) 36 | include_directories( 37 | "${PROJECT_SOURCE_DIR}" 38 | ${Fildesh_INCLUDE_DIRS} 39 | ) 40 | 41 | add_subdirectory(src) 42 | add_subdirectory(example) 43 | 44 | if (NOT PROJECT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) 45 | # Try to keep dependent project namespace clean. 46 | # No need for tests or anything else. 47 | return() 48 | endif() 49 | 50 | if (BUILD_TESTING) 51 | add_subdirectory(test) 52 | endif() 53 | 54 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023-2023, Alex P. Klinkhamer 2 | 3 | Permission to use, copy, modify, and/or distribute this software for any 4 | purpose with or without fee is hereby granted, provided that the above 5 | copyright notice and this permission notice appear in all copies. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | # No OpenBLAS by default. Override like: make LLAMA_OPENBLAS_ON=1 3 | LLAMA_OPENBLAS_ON = 0 4 | 5 | BUILD_DIR = bld 6 | SOURCE_DIR = . 7 | 8 | CMAKE = cmake 9 | GODO = $(CMAKE) -E chdir 10 | 11 | CMAKE_BUILD_TYPE = RelOnHost 12 | CMAKE_BUILD_OPTIONS = -DCMAKE_BUILD_TYPE=$(CMAKE_BUILD_TYPE) 13 | CMAKE_BUILD_OPTIONS += -DLLAMA_OPENBLAS_ON:BOOL=$(LLAMA_OPENBLAS_ON) 14 | 15 | 16 | .PHONY: default all cmake proj \ 17 | test clean distclean \ 18 | update pull 19 | 20 | default: 21 | if [ ! -d $(BUILD_DIR) ] ; then $(MAKE) cmake ; fi 22 | $(MAKE) proj 23 | 24 | all: 25 | $(MAKE) cmake 26 | $(MAKE) proj 27 | 28 | cmake: 29 | $(CMAKE) $(CMAKE_BUILD_OPTIONS) -S $(SOURCE_DIR) -B $(BUILD_DIR) 30 | 31 | proj: 32 | $(GODO) $(BUILD_DIR) $(MAKE) 33 | 34 | test: 35 | $(GODO) $(BUILD_DIR) $(MAKE) test 36 | 37 | clean: 38 | $(GODO) $(BUILD_DIR) $(MAKE) clean 39 | 40 | distclean: 41 | rm -fr $(BUILD_DIR) 42 | 43 | update: 44 | git pull origin trunk 45 | 46 | pull: 47 | git pull origin trunk 48 | 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rendezllama 2 | 3 | Rendezllama is a text interface for running a local chatbot based on [ggerganov's llama.cpp](https://github.com/ggerganov/llama.cpp). 4 | 5 | For now, there's just a command-line interface, but the plan is to make a progressive web app that connects with the chatbot running on a home server. 6 | 7 | ## Chat CLI 8 | 9 | Assuming you have the quantized weights already and can compile C++, you can try the [assistant_plain example](example/prompt/assistant_vanilla/) with a few commands: 10 | ```shell 11 | # If undefined, assume the 7B model exists in a sibling llama.cpp/ dir. 12 | MODEL="${MODEL:-../llama.cpp/models/7B/ggml-model-q4_0.gguf}" 13 | # Make just creates a bld/ directory and invokes CMake to build there. 14 | make 15 | # Run with specific settings from a file. They can be given as flags too. 16 | ./bld/src/chat/chat \ 17 | --x_setting example/prompt/assistant_plain/setting.sxpb \ 18 | --thread_count 8 \ 19 | --model "${MODEL}" 20 | ``` 21 | 22 | See the [example/prompt/](example/prompt/) directory for more interesting/whimsical examples. 23 | 24 | ### Chat CLI Options 25 | 26 | - Setting file. 27 | - `--x_setting setting.sxpb` loads settings from `setting.sxpb`. 28 | - All other options can be set within this file. 29 | - Model files. 30 | - `--model ggml-model-q4_0.gguf` are the model weights. Usually quantized. 31 | - See [doc/setting/model.md](doc/setting/model.md) for LoRA files and memory options. 32 | - Prompt files. 33 | - `--x_priming priming.txt` specifies the priming prompt text file. This is the prompt that never changes. 34 | - `--x_rolling rolling.txt` specifies rolling prompt. This is the initial chat dialogue. As the chat continues, older dialogue expires and "rolls" out of context. 35 | - The protagonist and confidant names are derived automatically from this. 36 | - See [doc/setting/prompt.md](doc/setting/prompt.md) for more prompt file & format options. 37 | 38 | ### Chat CLI Commands 39 | 40 | In the chat, most things you type will be prefixed with the protagonist's name and suffixed by the confidant's dialogue line. 41 | There are some special inputs and commands that help keep an infinite chat from going off the rails. 42 | Remember, the recent chat content is just a rolling prompt concatenated to the end of the priming prompt, so its quality is just as important! 43 | - Interactivity. 44 | - An empty input lets token generation keep happening. 45 | - See [doc/setting/stdio.md](doc/setting/stdio.md) for settings that I/O behavior and limits. 46 | - `/tail` or `/tail 10` shows the last 10 lines. 47 | - `/head` or `/head 10` shows the first 10 lines of the rolling prompt. 48 | - `/forget 10` removes the first 10 lines of the rolling prompt. 49 | - Characters. 50 | - `/(protagonist "User")` changes the protagonist's name to "User". 51 | - `/(confidant "Char")` changes the confidant's name to "Char". 52 | - See [doc/setting/prompt.md#prefix](doc/setting/prompt.md#prefix) for more ways to control chat line prefixes. 53 | - Editing. 54 | - A blank space forces token generation to continue on the same line. 55 | - ` some text` (note blank space in front) adds `some text` to the current line. 56 | - ` some text ` (note blank spaces in front and back) adds `some text` and forces another token on the same line. Useful when inserting a sentence. 57 | - `\nsome text` (note the escaped newline in front) adds a new line of dialogue for the confidant that starts with `some text`. 58 | - `/puts A line of text.` adds a new line of text. Does not echo anything. 59 | - `/yield` or `/y` adds a new line dialogue for the confidant. 60 | - `/yield Char:` or `/y Char:` adds a new line starting with `Char:`. 61 | - `/gets 64 Char:` is like `/yield` but generates slightly over a max of 64 bytes. Only prints the newly-generated text. Always includes a newline at the end. 62 | - `/r` regenerates the last line of dialogue. 63 | - `/R` generates text from the current position. Subsequent `/r` commands will only replace the generated text, nothing before it on the line. 64 | - `/d` deletes up to and including the last chat prefix. 65 | - `/D` or `/D 0` deletes all text on the current line without consuming a newline. Positive integers delete that many earlier lines in full. 66 | - `/b` or `/b 1` deletes the last token. 67 | - `/B` or `/B 1` deletes the last word. 68 | - Sampling. 69 | - A slash followed by a valid sampling configuration in `setting.sxpb` reconfigures the sampling parameters. 70 | - `/(language ((infer_via sampling) (adjust_thru (()) (temperature 0.9))))` sets the temperature to 0.9. 71 | - See [doc/setting/sampling.md](doc/setting/sampling.md) for more ways to control inference. 72 | -------------------------------------------------------------------------------- /dep/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake_module") 3 | 4 | 5 | # Reintroduce NDEBUG so dependencies can remove it again if appropriate. 6 | set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} -DNDEBUG") 7 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DNDEBUG") 8 | set(CMAKE_C_FLAGS_RELONHOST "${CMAKE_C_FLAGS_RELONHOST} -DNDEBUG") 9 | set(CMAKE_CXX_FLAGS_RELONHOST "${CMAKE_CXX_FLAGS_RELONHOST} -DNDEBUG") 10 | if (MSVC) 11 | string(REPLACE "-DNDEBUG" "/DNDEBUG" CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE}") 12 | string(REPLACE "-DNDEBUG" "/DNDEBUG" CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE}") 13 | string(REPLACE "-DNDEBUG" "/DNDEBUG" CMAKE_C_FLAGS_RELONHOST "${CMAKE_C_FLAGS_RELONHOST}") 14 | string(REPLACE "-DNDEBUG" "/DNDEBUG" CMAKE_CXX_FLAGS_RELONHOST "${CMAKE_CXX_FLAGS_RELONHOST}") 15 | endif() 16 | 17 | 18 | include(FetchContent) 19 | include("cmake_fetchcontent/fildesh.cmake") 20 | include("cmake_fetchcontent/llama_cpp.cmake") 21 | -------------------------------------------------------------------------------- /dep/cmake_fetchcontent/fildesh.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare( 2 | Fildesh 3 | GIT_REPOSITORY "https://github.com/fildesh/fildesh.git" 4 | GIT_TAG "73f4f32c4802fb5e636bae352c33a4706e1c8787" 5 | ) 6 | FetchContent_MakeAvailable(Fildesh) 7 | set(Fildesh_INCLUDE_DIRS ${Fildesh_INCLUDE_DIRS} PARENT_SCOPE) 8 | set(Fildesh_LIBRARIES ${Fildesh_LIBRARIES} PARENT_SCOPE) 9 | set(FildeshSxproto_LIBRARIES ${FildeshSxproto_LIBRARIES} PARENT_SCOPE) 10 | -------------------------------------------------------------------------------- /dep/cmake_fetchcontent/llama_cpp.cmake: -------------------------------------------------------------------------------- 1 | FetchContent_Declare( 2 | LlamaCpp 3 | GIT_REPOSITORY "https://github.com/ggerganov/llama.cpp.git" 4 | GIT_TAG "cfd74c86dbaa95ed30aa6b30e14d8801eb975d63" 5 | ) 6 | 7 | set(GGML_OPENMP FALSE CACHE BOOL "OpenMP off for compatibility.") 8 | FetchContent_MakeAvailable(LlamaCpp) 9 | 10 | set(LlamaCpp_SOURCE_DIR "${llamacpp_SOURCE_DIR}" PARENT_SCOPE) 11 | set(LlamaCpp_INCLUDE_DIRS "${llamacpp_SOURCE_DIR}/include" PARENT_SCOPE) 12 | set(LlamaCpp_LIBRARIES "$" PARENT_SCOPE) 13 | 14 | if (LLAMA_OPENBLAS_ON) 15 | find_package(OpenBLAS REQUIRED) 16 | target_compile_definitions(ggml PRIVATE "GGML_USE_OPENBLAS") 17 | target_include_directories(ggml PRIVATE ${OpenBLAS_INCLUDE_DIRS}) 18 | target_link_libraries(ggml PUBLIC ${OpenBLAS_LIBRARIES}) 19 | endif() 20 | -------------------------------------------------------------------------------- /dep/cmake_module/FindOpenBLAS.cmake: -------------------------------------------------------------------------------- 1 | 2 | find_path(OpenBLAS_INCLUDE_DIRS 3 | NAMES "cblas.h" 4 | PATHS 5 | "/usr/include/" 6 | "/usr/include/openblas/" 7 | "/usr/include/openblas-base/" 8 | "/usr/local/include/" 9 | "/usr/local/include/openblas/" 10 | "/usr/local/include/openblas-base/" 11 | "/opt/OpenBLAS/include/" 12 | "$ENV{OpenBLAS_HOME}/" 13 | "$ENV{OpenBLAS_HOME}/include/" 14 | ) 15 | message(STATUS "OpenBLAS_INCLUDE_DIRS: ${OpenBLAS_INCLUDE_DIRS}") 16 | 17 | find_library(OpenBLAS_LIBRARIES 18 | NAMES "openblas" 19 | PATHS 20 | "/lib/" 21 | "/lib/openblas-base/" 22 | "/lib64/" 23 | "/usr/lib/" 24 | "/usr/lib/openblas-base/" 25 | "/usr/lib64/" 26 | "/usr/local/lib/" 27 | "/usr/local/lib64/" 28 | "/opt/OpenBLAS/lib/" 29 | "$ENV{OpenBLAS_HOME}/" 30 | "$ENV{OpenBLAS_HOME}/lib/" 31 | ) 32 | message(STATUS "OpenBLAS_LIBRARIES: ${OpenBLAS_LIBRARIES}") 33 | 34 | -------------------------------------------------------------------------------- /doc/setting/model.md: -------------------------------------------------------------------------------- 1 | # Model Loading 2 | 3 | ## File 4 | I prefer using flags to specify model files. 5 | - `--model ggml-model-q4_0.gguf` are the model weights. Usually quantized. 6 | - Required. 7 | - `--lora ggml-adapter-model.gguf` gives a LoRA. 8 | 9 | Even though the flags are preferred, `setting.sxpb` supports them too: 10 | ```lisp 11 | (model "ggml-model-q4_0.gguf") 12 | (lora "ggml-adapter-model.gguf") 13 | ``` 14 | 15 | ## Context 16 | ```lisp 17 | ; Set the model's default context limit as 4096 for Llama-2 (default comes from the model). 18 | (model_token_limit 4096) 19 | ; Set the prompt's context limit as 5000 (default is model_token_limit). 20 | (context_token_limit 5000) 21 | ``` 22 | 23 | The first option can be initialized via a flag like `--model_token_limit 4096`, which is also used as the default value for `context_token_limit`. 24 | 25 | ## Memory 26 | By default, we use mmap to load the model. 27 | This makes the system hold and manage the model data, loading it as needed or letting multiple programs read it without duplicating it in memory. 28 | 29 | This can introduce a bottleneck when low-priority stuff (like ZFS disk cache) is preventing the mmapped model from staying in RAM. 30 | In that case, you can try focing the model into memory with mlock: 31 | ```lisp 32 | ; Tries to lock the model in memory (default off, 0). 33 | (mlock_on 1) 34 | ; If the above doesn't work... 35 | ; Load model into program memory by disabling mmap (default on, 1). 36 | (mmap_on 0) 37 | ``` 38 | 39 | These memory options are also supported as `--mlock_on 1` and `--mmap_on 0` flags. 40 | 41 | ## Compute 42 | ```lisp 43 | ; Number of threads to use (default is 1). 44 | ; Can be changed later via the `/(thread_count 8)` command. 45 | (thread_count 8) 46 | ; Warning: This number should exclude hyperthreads. 47 | 48 | ; Batch size (default is 512, large enough to make OpenBLAS useful). 49 | (batch_count 512) 50 | ; Warning: Setting this too large (e.g., 2048) can cause assertion violations. 51 | ``` 52 | 53 | These compute options are also supported as `--thread_count 8` and `--batch_count 512` flags. 54 | 55 | -------------------------------------------------------------------------------- /doc/setting/prompt.md: -------------------------------------------------------------------------------- 1 | # Prompt Files and Format 2 | 3 | ## File 4 | Only the priming and rolling prompts are required. 5 | ```lisp 6 | ; Priming prompt text file. This is the prompt that never changes. 7 | (x_priming priming.txt) 8 | ; Rolling prompt. This is the initial chat dialogue. 9 | ; As the chat continues, older dialogue expires and "rolls" out of context. 10 | (x_rolling "rolling.txt") 11 | ; Where to save the chat transcript as it rolls out of context and can no longer be edited. 12 | (o_rolling "transcript.txt") 13 | 14 | ; A multi-line prefix to place before every generated line of chat. 15 | ; Try this for models like Alpaca that are fine-tuned to follow instructions. 16 | (x_answer "answer.txt") 17 | ``` 18 | 19 | ## Prefix 20 | By default, there are only 2 characters: the protagonist (your input) and the confidant (filled in by the LLM). 21 | ```lisp 22 | ; Protagonist's name. Can be changed later via the `/protagonist User` command. 23 | (protagonist "User") 24 | ; Confidant's name. Can be changed later via the `/confidant Bot` command. 25 | (confidant "Bot") 26 | 27 | (language 28 | (substitution 29 | ; Replace "{{user}}" in the input prompts with the protagonist name. 30 | (protagonist_alias "{{user}}") 31 | ; Replace "{{char}}" in the input prompts with the confidant name. 32 | (confidant_alias "{{char}}") 33 | )) 34 | ``` 35 | 36 | You can also add more chat prefixes to help frame how the token generation. 37 | ```lisp 38 | (((chat_prefixes)) 39 | "{{user}}:" 40 | "{{char}} feels:" 41 | "{{char}} wants:" 42 | "{{char}} plans:" 43 | "{{char}}:" 44 | ) 45 | ``` 46 | 47 | ## Format 48 | ```lisp 49 | ; Put a space at the start the priming prompt (default on). 50 | (startspace_on 1) 51 | 52 | ; Put a space at the start of every line in the prompts (default off). 53 | ; This changes how the first word of a line (usually a character name) is tokenized. 54 | (linespace_on 1) 55 | ``` 56 | 57 | -------------------------------------------------------------------------------- /doc/setting/sampling.md: -------------------------------------------------------------------------------- 1 | # Sampling 2 | 3 | We use an LLM to infer the next token via random sampling. 4 | 5 | ## Default 6 | ```lisp 7 | (language 8 | ((infer_via sampling) 9 | (adjust_thru (()) 10 | (min_p 0.1) 11 | (temperature 0.8) 12 | ) 13 | ((pick_via probability)) 14 | )) 15 | ``` 16 | 17 | ## Control Randomization 18 | ```lisp 19 | (language 20 | ((infer_via sampling) 21 | ; Random seed (default is time-based, different every run). 22 | (seed 1234) 23 | (adjust_thru (()) 24 | ; Divide token probabilities by 0.95 before applying softmax. 25 | ; A temperature of 1.0 is the same as not applying a temperature. 26 | ; High values like this can yield nonsensical output, 27 | ; while low values yield more determinism (e.g., zero is deterministic), 28 | ; so it is a good idea to always specify a temperature. 29 | (temperature 0.95) 30 | ) 31 | )) 32 | ``` 33 | 34 | ## Penalize Repetition 35 | ```lisp 36 | (language 37 | ((adjust_via sampling) 38 | ((adjust_thru) 39 | 40 | ; Reduce probability of repeating any of the most recent 1200 tokens. 41 | (penalize_with 42 | ; Penalizes the most recent 1200 tokens from being generated (default off, 0). 43 | (repeat_window 1200) 44 | ; How much to penalize repetition (default off, 1.0). 45 | (repetition 1.05) 46 | ; Frequency penalty (default off, 0.0). 47 | (frequency_penalty 0.1) 48 | ; Presence penalty (default off, 0.0). 49 | (presence_penalty 0.1) 50 | ) 51 | 52 | ; "Don't Repeat Yourself". 53 | (dry 54 | (window_length 1200) 55 | (multiplier 0.8) 56 | (base 1.75) 57 | (allowed_length 2) 58 | ) 59 | 60 | ; Apply a temperature. 61 | (temperature 0.7) 62 | ))) 63 | ``` 64 | 65 | ## Exclude Outlying 66 | ```lisp 67 | (language 68 | ((infer_via sampling) 69 | (adjust_thru (()) 70 | ; Top-K. 1 makes sampling deterministic. 71 | (top_k 40) 72 | ; Top-P. 0.0 is a no-op. 73 | (top_p 0.9) 74 | ; Locally Typical cutoff. 1.0 is a no-op. 75 | (typical_p 0.9) 76 | ; Min-P. 1.0 is a no-op. 77 | ; Cut out tokens whose probability is less 78 | ; than 5% relative to the most probable token. 79 | (min_p 0.05) 80 | 81 | ; "Exclude Top Choices". 82 | (xtc 83 | ; Probability threshold (default is 0.15). 84 | (threshold 0.15) 85 | ; Probability of performing the exclusion on this pass (defult is 1.0). 86 | (probability 0.5) 87 | ) 88 | 89 | ; Temperature is not as necessary when tokens are expcluded. 90 | ; It can be >1.0 when Min-P has already filtered out the nonsensical tokens. 91 | (temperature 1.2) 92 | ))) 93 | ``` 94 | 95 | ## Pick Token 96 | ### Mirostat 97 | Mirostat is an alternative method of selecting a token. 98 | 99 | ```lisp 100 | (language 101 | ((infer_via sampling) 102 | ((pick_via mirostat) 103 | ; Use Mirostat version 2 (default is 2). 104 | (version 2) 105 | ; Target entropy (default is 5.0). 106 | (tau 5.0) 107 | ; Learning rate (default is 0.1). 108 | (eta 0.1) 109 | ))) 110 | ``` 111 | -------------------------------------------------------------------------------- /doc/setting/stdio.md: -------------------------------------------------------------------------------- 1 | # Standard Input and Output 2 | 3 | ## Limit 4 | Token generated can be limited to sentence boundaries. 5 | These options can be set via commands or in `setting.sxpb` as: 6 | ```lisp 7 | ; Limit number of sentences to 10 (default is 0, unlimited). 8 | (sentence_limit 10) 9 | ; Limit number of tokens per sentence 100 (default is 0, unlimited). 10 | (sentence_token_limit 100) 11 | ; Tokens that mark the end of each sentence (default is shown). 12 | ((sentence_terminals) "." "!" "?" "…") 13 | ``` 14 | 15 | ## Coprocess Mode 16 | When run as a coprocess, the expects to be controlled via commands like `/puts`, `/gets`, and `/d` (see [assistant_coprocess example](example/prompt/assistant_coprocess/)). 17 | 18 | ```lisp 19 | ; Run as a coprocess (default off). 20 | ; The program will only write to stdout when requested. 21 | ; Also available as a `--coprocess_mode_on 1` flag. 22 | (coprocess_mode_on 1) 23 | ``` 24 | -------------------------------------------------------------------------------- /example/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(prompt) 2 | -------------------------------------------------------------------------------- /example/prompt/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Empty but creates a build directory, 2 | # which is a convenient place to put chatlogs. 3 | -------------------------------------------------------------------------------- /example/prompt/README.md: -------------------------------------------------------------------------------- 1 | # Prompt Examples 2 | 3 | In order of interest: 4 | - [assistant_plain](assistant_plain/): AI assistant with single-line replies. 5 | - Minimial prompt that should work for any model. 6 | - [roshambo_kira](roshambo_kira/): Play against a Kira in roshambo. 7 | - Demonstrates why LLMs are hard to get right. 8 | - [confidant_alpaca](confidant_alpaca/): A camelid that occasionally spits. 9 | - Demonstrates a method of prompting instruction-tuned models to fill in character dialogue. 10 | - Instruction-following AI assistants. 11 | - For all of these examples, the assistant must end its messages with a special token like EOS. 12 | - [assistant_alpaca](assistant_alpaca/): Alpaca prompt format. 13 | - [assistant_chatml](assistant_chatml/): ChatML prompt format that typically requires special `<|im_start|>` and `<|im_end|>` tokens but is configured with fallbacks. 14 | - [assistant_gemma](assistant_gemma/): Gemma prompt format that requires special `` and `` tokens. 15 | - [assistant_llama](assistant_llama/): Llama 3 prompt format that requires special `<|start_header_id|>`, `<|end_header_id|>`, and `<|eot_id|>` tokens. 16 | - [assistant_mistral](assistant_mistral/): Mistral propmt format that requires special `[INST]` and `[/INST]` tokens. 17 | - [assistant_vicuna](assistant_vicuna/): Vicuna prompt format. 18 | - [assistant_coprocess](assistant_coprocess/): A simple assistant that can be controlled as a coprocess. 19 | - Demonstrates the `/puts` and `/gets` commands. 20 | 21 | -------------------------------------------------------------------------------- /example/prompt/assistant_alpaca/README.md: -------------------------------------------------------------------------------- 1 | # Alpaca Assistant 2 | 3 | This example must be used with [Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html)-style models that are tuned to behave like an instruction-following assistant chatbot. 4 | Most importantly, the model must be fine-tuned to *always* end the assistant's replies with an EOS token. 5 | 6 | You might like to use this for code, so we constrain output by number of lines instead of number of sentences. 7 | Just press enter to let it continue 10 more lines. 8 | You can adjust this behavior in `setting.sxpb`: 9 | ```lisp 10 | (sentence_terminals () "\n") 11 | (sentence_limit 10) 12 | (sentence_token_limit 1000) ; Long enough for any reasonable line of text. 13 | ``` 14 | 15 | ## Prompt Format 16 | Alpaca-style models put the user's text and the chatbot's response in markdown subsections, so it's a bit sparse. 17 | ```text 18 | Below is an instruction that describes a task. Write a response that appropriately completes the request. 19 | 20 | ### Instruction: 21 | Hello! 22 | 23 | ### Response: 24 | 25 | ``` 26 | 27 | The response sections all end with an EOS token, which is preserved in the context but is otherwise invisible. 28 | Relevant lines of `setting.sxpb` are: 29 | ```lisp 30 | (((chat_prefixes)) 31 | (m 32 | (prefix "### Instruction:\n") 33 | (suffix "\n\n")) 34 | (m 35 | (prefix "### Response:\n") 36 | ; Model must be fine-tuned to end the response with EOS token. 37 | (suffix "\n\n") 38 | ) 39 | ) 40 | (language 41 | (substitution 42 | (eos_token_alias "") 43 | ) 44 | ) 45 | ``` 46 | -------------------------------------------------------------------------------- /example/prompt/assistant_alpaca/priming.txt: -------------------------------------------------------------------------------- 1 | Below is an instruction that describes a task. Write a response that appropriately completes the request. 2 | 3 | -------------------------------------------------------------------------------- /example/prompt/assistant_alpaca/rolling.txt: -------------------------------------------------------------------------------- 1 | ### Instruction: 2 | Hello! 3 | 4 | -------------------------------------------------------------------------------- /example/prompt/assistant_alpaca/setting.sxpb: -------------------------------------------------------------------------------- 1 | ((chat_prefixes) 2 | (m 3 | (prefix "### Instruction:\n") 4 | (suffix "\n\n")) 5 | (m 6 | (prefix "### Response:\n") 7 | ; Model must be fine-tuned to end the response with EOS token. 8 | (suffix "\n\n") 9 | ) 10 | ) 11 | ; Lines are considered as sentences. 12 | (sentence_terminals () "\n") 13 | ; Max 10 lines at a time. 14 | (sentence_limit 10) 15 | (sentence_token_limit 1000) 16 | 17 | (x_priming "priming.txt") 18 | (x_rolling "rolling.txt") 19 | (o_rolling "../../../bld/example/prompt/assistant_alpaca.txt") 20 | 21 | (model_token_limit 2048) 22 | (language 23 | (substitution 24 | (eos_token_alias "") 25 | ) 26 | ) 27 | -------------------------------------------------------------------------------- /example/prompt/assistant_chatml/README.md: -------------------------------------------------------------------------------- 1 | # ChatML Assistant 2 | 3 | This example should be run with [ChatML](https://github.com/openai/openai-python/blob/main/chatml.md)-style models that are tuned to behave like an instruction-following assistant chatbot. 4 | 5 | The model typically should have special `<|im_start|>` and `<|im_end|>` tokens, but `setting.sxpb` configures fallbacks that attempt to support any model. 6 | Gemma is basically the same format but without a `system` role, so we specifically look for Gemma-style `` and `` tokens as fallbacks. 7 | When no other special tokens are found, we fall back to using the BOS and EOS tokens that all models have. 8 | This is how jondurbin's [bagel-7b-v0.1](https://huggingface.co/jondurbin/bagel-7b-v0.1) finetune supported ChatML, and other instruct-tuned models tend to figure it out. 9 | -------------------------------------------------------------------------------- /example/prompt/assistant_chatml/priming.txt: -------------------------------------------------------------------------------- 1 | <|im_start|>system 2 | You are an AI assistant.<|im_end|> 3 | -------------------------------------------------------------------------------- /example/prompt/assistant_chatml/rolling.txt: -------------------------------------------------------------------------------- 1 | <|im_start|>user 2 | Hello!<|im_end|> 3 | -------------------------------------------------------------------------------- /example/prompt/assistant_chatml/setting.sxpb: -------------------------------------------------------------------------------- 1 | ((chat_prefixes) 2 | (m 3 | (prefix "<|im_start|>user\n") 4 | (suffix "<|im_end|>\n")) 5 | (m 6 | (prefix "<|im_start|>assistant\n") 7 | (suffix "<|im_end|>\n") 8 | ) 9 | ) 10 | 11 | (x_priming "priming.txt") 12 | (x_rolling "rolling.txt") 13 | (o_rolling "../../../bld/example/prompt/assistant_chatml.txt") 14 | 15 | ; No starting space. 16 | (startspace_on 0) 17 | 18 | ; 10 reasonably-long sentences at a time. 19 | (sentence_limit 10) 20 | (sentence_token_limit 100) 21 | 22 | ; Limit context to avoid blowing up RAM on large context models. 23 | (model_token_limit 4000) 24 | 25 | (language 26 | (substitution 27 | (bos_token_alias "") 28 | (eos_token_alias "") 29 | (special_tokens (()) 30 | (() 31 | (alias "<|im_start|>") 32 | (candidates (()) 33 | "<|im_start|>" ; For ChatML-tuned models. 34 | "" ; For Gemma models. 35 | "" ; For other models. 36 | )) 37 | (() 38 | (alias "<|im_end|>") 39 | (candidates (()) 40 | "<|im_end|>" ; For ChatML-tuned models. 41 | "" ; For Gemma models. 42 | "" ; For other models. 43 | )) 44 | ) 45 | ) 46 | ((infer_via sampling) 47 | (adjust_thru (()) 48 | (temperature 0.7) 49 | ) 50 | ((pick_via mirostat) (version 2)) 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /example/prompt/assistant_coprocess/README.md: -------------------------------------------------------------------------------- 1 | # Coprocess Assistant 2 | 3 | This example should be used as a coprocess. 4 | 5 | Spawn `./bld/src/chat/chat --x_setting example/prompt/assistant_coprocess/setting.sxpb` from another program and send it 2 kinds of messages on stdin: 6 | - `/puts SomeName: A line of dialogue.` 7 | - Adds a line of dialogue to the context. 8 | - Don't expect any output from this. 9 | - `/gets 500 Banterbot:` 10 | - Makes the chatbot say something. 11 | - Expect a line of dialogue from the chatbot on stdout. Limited to slightly over 500 bytes. 12 | 13 | You'll want to change a few things in `setting.sxpb`: 14 | ```lisp 15 | ; The bot name. 16 | (confidant "Banterbot") 17 | ; Your computer's number of threads. 18 | (thread_count 2) 19 | ; Convenient way to set model if you don't want to use the --model flag. 20 | (model "../relative/path/to/model-ggml-q4_0.gguf") 21 | ``` 22 | 23 | You'll probably also want to change the `priming.txt` prompt to do what you want. 24 | Despite its name, the current "Banterbot" prompt yields very little banter. 25 | 26 | -------------------------------------------------------------------------------- /example/prompt/assistant_coprocess/priming.txt: -------------------------------------------------------------------------------- 1 | Transcript of a group chat that includes an AI chatbot named {{char}}. 2 | {{char}} replies to the others with witty banter. 3 | {{char}} is playful but friendly and gives helpful answers to questions. 4 | {{char}} never asks questions. 5 | 6 | -------------------------------------------------------------------------------- /example/prompt/assistant_coprocess/setting.sxpb: -------------------------------------------------------------------------------- 1 | ; Bot name is replaced in the prompt. 2 | (confidant "Banterbot") 3 | 4 | (thread_count 2) 5 | 6 | (x_priming "priming.txt") 7 | ; Uncomment to enable logging. 8 | ;(o_rolling "../../../bld/example/prompt/assistant_coprocess.txt") 9 | 10 | ; Be a coprocess. Don't print much to stdout. 11 | (coprocess_mode_on 1) 12 | 13 | (model_token_limit 2048) 14 | (language 15 | (substitution 16 | (confidant_alias "{{char}}") 17 | ) 18 | ((infer_via sampling) 19 | (adjust_thru (()) 20 | (penalize_with (window_length 256) (repetition 1.17647)) 21 | (top_p 0.7) 22 | (temperature 0.2) 23 | ) 24 | ((pick_via mirostat) (version 2)) 25 | ) 26 | ) 27 | -------------------------------------------------------------------------------- /example/prompt/assistant_gemma/README.md: -------------------------------------------------------------------------------- 1 | # Gemma Assistant 2 | 3 | This example should be run with Gemma-style models that are tuned to behave like an instruction-following assistant chatbot. 4 | Most importantly, the model must have special `` and `` tokens. 5 | 6 | It's like the [assistant_chatml](../assistant_chatml/) example but without a system prompt. 7 | -------------------------------------------------------------------------------- /example/prompt/assistant_gemma/rolling.txt: -------------------------------------------------------------------------------- 1 | user 2 | Hello! 3 | -------------------------------------------------------------------------------- /example/prompt/assistant_gemma/setting.sxpb: -------------------------------------------------------------------------------- 1 | (chat_prefixes (()) 2 | (m 3 | (prefix "user\n") 4 | (suffix "\n")) 5 | (m 6 | (prefix "assistant\n") 7 | (suffix "\n") 8 | ) 9 | ) 10 | 11 | (x_rolling "rolling.txt") 12 | (o_rolling "../../../bld/example/prompt/assistant_gemma.txt") 13 | 14 | ; No starting space. 15 | (startspace_on +false) 16 | 17 | ; 10 reasonably-long sentences at a time. 18 | (sentence_limit 10) 19 | (sentence_token_limit 100) 20 | 21 | ; Limit context to avoid blowing up RAM on large context models. 22 | (model_token_limit 8000) 23 | 24 | (language 25 | (substitution 26 | (special_tokens (()) 27 | (() (alias "")) 28 | (() (alias "")) 29 | ) 30 | ) 31 | ((infer_via sampling) 32 | (adjust_thru (()) 33 | (temperature 0.7) 34 | ) 35 | ) 36 | ) 37 | -------------------------------------------------------------------------------- /example/prompt/assistant_llama/README.md: -------------------------------------------------------------------------------- 1 | # Llama Assistant 2 | 3 | This example should be run with Llama 3 models that are tuned to behave like an instruction-following assistant chatbot. 4 | Most importantly, the model must have special `<|start_header_id|>`, `<|end_header_id|>`, and `<|eot_id|>` tokens. 5 | -------------------------------------------------------------------------------- /example/prompt/assistant_llama/priming.txt: -------------------------------------------------------------------------------- 1 | <|start_header_id|>system<|end_header_id|> 2 | 3 | Cutting Knowledge Date: December 2023 4 | Today Date: July 2024 5 | 6 | You are a helpful assistant.<|eot_id|> 7 | -------------------------------------------------------------------------------- /example/prompt/assistant_llama/rolling.txt: -------------------------------------------------------------------------------- 1 | <|start_header_id|>user<|end_header_id|> 2 | 3 | Hello!<|eot_id|> 4 | -------------------------------------------------------------------------------- /example/prompt/assistant_llama/setting.sxpb: -------------------------------------------------------------------------------- 1 | 2 | ; Newlines are included after <|eot_id|> to make the chat easier to read, 3 | ; but Llama 3 instruct prompt format does not actually include them. 4 | (chat_prefixes (()) 5 | (m 6 | (prefix "<|start_header_id|>user<|end_header_id|>\n\n") 7 | (suffix "<|eot_id|>\n") 8 | ) 9 | (m 10 | (prefix "<|start_header_id|>assistant<|end_header_id|>\n\n") 11 | (suffix "<|eot_id|>\n") 12 | ) 13 | ) 14 | 15 | (x_priming "priming.txt") 16 | (x_rolling "rolling.txt") 17 | (o_rolling "../../../bld/example/prompt/assistant_llama.txt") 18 | 19 | ; No starting space. 20 | (startspace_on +false) 21 | 22 | ; 10 reasonably-long sentences at a time. 23 | (sentence_limit 10) 24 | (sentence_token_limit 100) 25 | 26 | ; Limit context to avoid blowing up RAM on large context models. 27 | (model_token_limit 8000) 28 | 29 | (language 30 | (substitution 31 | (special_tokens (()) 32 | (() (alias "<|start_header_id|>")) 33 | (() (alias "<|end_header_id|>")) 34 | (() (alias "<|eot_id|>")) 35 | ) 36 | ) 37 | ((infer_via sampling) 38 | (adjust_thru (()) 39 | (temperature 0.7) 40 | ) 41 | ) 42 | ) 43 | -------------------------------------------------------------------------------- /example/prompt/assistant_mistral/README.md: -------------------------------------------------------------------------------- 1 | # Mistral Assistant 2 | 3 | This example should be run with Mistral-style models that are tuned to behave like an instruction-following assistant chatbot. 4 | Most importantly, the model must have special `[INST]` and `[/INST]` tokens. 5 | -------------------------------------------------------------------------------- /example/prompt/assistant_mistral/rolling.txt: -------------------------------------------------------------------------------- 1 | [INST]Hello![/INST] 2 | -------------------------------------------------------------------------------- /example/prompt/assistant_mistral/setting.sxpb: -------------------------------------------------------------------------------- 1 | 2 | ; Newlines are included to make the chat easier to read. 3 | ; The Mistral instruct prompt format does not actually include newlines. 4 | (chat_prefixes (()) 5 | (m 6 | (prefix "[INST]") 7 | (suffix "[/INST]\n") 8 | ) 9 | (m 10 | (prefix "") 11 | (suffix "\n") 12 | ) 13 | ) 14 | 15 | (x_rolling "rolling.txt") 16 | (o_rolling "../../../bld/example/prompt/assistant_mistral.txt") 17 | 18 | ; No starting space. 19 | (startspace_on +false) 20 | 21 | ; 10 reasonably-long sentences at a time. 22 | (sentence_limit 10) 23 | (sentence_token_limit 100) 24 | 25 | ; Limit context to avoid blowing up RAM on large context models. 26 | (model_token_limit 8000) 27 | 28 | (language 29 | (substitution 30 | (eos_token_alias "") 31 | (special_tokens (()) 32 | (() (alias "[INST]")) 33 | (() (alias "[/INST]")) 34 | ) 35 | ) 36 | ; Match recommendation for Mistral NeMo v1. 37 | ; https://build.nvidia.com/nv-mistralai/mistral-nemo-12b-instruct 38 | ((infer_via sampling) 39 | (adjust_thru (()) 40 | (top_p 0.7) 41 | (temperature 0.2) 42 | ) 43 | ((pick_via probability)) 44 | ) 45 | ) 46 | -------------------------------------------------------------------------------- /example/prompt/assistant_plain/README.md: -------------------------------------------------------------------------------- 1 | # Plain Assistant 2 | 3 | This example is just a basic assistant chatbot. 4 | It's like [assistant_vicuna](../assistant_vicuna/) but with newline delimiters, so it should work with any model. 5 | -------------------------------------------------------------------------------- /example/prompt/assistant_plain/setting.sxpb: -------------------------------------------------------------------------------- 1 | (protagonist "User") 2 | (confidant "Assistant") 3 | (x_priming "../assistant_vicuna/priming.txt") 4 | (x_rolling "../assistant_vicuna/rolling.txt") 5 | (o_rolling "../../../bld/example/prompt/assistant_plain.txt") 6 | (sentence_limit 5) 7 | (sentence_token_limit 50) 8 | 9 | (model_token_limit 2048) 10 | (language 11 | (substitution 12 | (protagonist_alias "{{user}}") 13 | (confidant_alias "{{char}}") 14 | ) 15 | ((infer_via sampling) 16 | (adjust_thru (()) 17 | (penalize_with (window_length 256) (repetition 1.17647)) 18 | (temperature 0.7) 19 | ) 20 | ((pick_via mirostat) (version 2)) 21 | ) 22 | ) 23 | -------------------------------------------------------------------------------- /example/prompt/assistant_vicuna/README.md: -------------------------------------------------------------------------------- 1 | # Vicuna Assistant 2 | 3 | This example should be used with [Vicuna](https://lmsys.org/blog/2023-03-30-vicuna/)-style models (specifically Vicuna v1.1) that are tuned to behave like an assistant chatbot. 4 | 5 | AI assistants like to talk a lot, so this one often hits output limits. 6 | Just press enter to let it continue. 7 | If that's too annoying, increase these numbers in `setting.sxpb`: 8 | ```lisp 9 | (sentence_limit 5) 10 | (sentence_token_limit 50) 11 | ``` 12 | 13 | ## Prompt Format 14 | The Vicuna v1.1 format can be found [in the authors' FastChat repository](https://github.com/lm-sys/FastChat/blob/6ff8505ec80fc4b04d668f65d229f4f58bc449e0/fastchat/conversation.py#L393-L404). 15 | We use a slightly modified version places an EOS token after the user's text: 16 | ```text 17 | A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. 18 | 19 | USER: Hello! ASSISTANT: Hello there! How can I assist you today?USER: ... 20 | ``` 21 | -------------------------------------------------------------------------------- /example/prompt/assistant_vicuna/priming.txt: -------------------------------------------------------------------------------- 1 | A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. 2 | 3 | -------------------------------------------------------------------------------- /example/prompt/assistant_vicuna/rolling.txt: -------------------------------------------------------------------------------- 1 | {{user}}: Hello! 2 | -------------------------------------------------------------------------------- /example/prompt/assistant_vicuna/setting.sxpb: -------------------------------------------------------------------------------- 1 | (protagonist "USER") 2 | (confidant "ASSISTANT") 3 | ((chat_prefixes) 4 | (m 5 | (prefix "{{user}}: ") 6 | (suffix "") 7 | ) 8 | (m 9 | (prefix " {{char}}:") 10 | ; Model must be fine-tuned to end the response with EOS token. 11 | (suffix "") 12 | ) 13 | ) 14 | (x_priming "priming.txt") 15 | (x_rolling "rolling.txt") 16 | (o_rolling "../../../bld/example/prompt/assistant_vicuna.txt") 17 | (sentence_limit 5) 18 | (sentence_token_limit 50) 19 | 20 | (model_token_limit 2048) 21 | (language 22 | (substitution 23 | (eos_token_alias "") 24 | (protagonist_alias "{{user}}") 25 | (confidant_alias "{{char}}") 26 | ) 27 | ((infer_via sampling) 28 | (adjust_thru (()) 29 | (penalize_with (window_length 256) (repetition 1.17647)) 30 | (temperature 0.7) 31 | ) 32 | ((pick_via mirostat) (version 2)) 33 | ) 34 | ) 35 | -------------------------------------------------------------------------------- /example/prompt/confidant_alpaca/README.md: -------------------------------------------------------------------------------- 1 | # Alpaca Confidant 2 | 3 | In this example, you chat with a whimsical character named Alpaca that sometimes spits at you. 4 | It should used with [Alpaca](https://crfm.stanford.edu/2023/03/13/alpaca.html)-style models that are tuned to behave like an assistant chatbot with a specific format. 5 | 6 | ## Prompt Format 7 | Alpaca-style models put the user's text and the chatbot's response after their name headers, so we do a little trick to condense them into single-line messages. 8 | This is why "Alpaca" is a character in this example rather than being direct responses from the instruction-tuned model. 9 | The prompt was adapted from the [SuperCOT-LoRA model card](https://huggingface.co/kaiokendev/SuperCOT-LoRA#prompting) and looks like: 10 | ```text 11 | Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 12 | 13 | ### Instruction: 14 | Suggest how to continue the input transcript of a conversation between User and Alpaca. 15 | 16 | {{... description of Alpaca character ...}} 17 | 18 | ### Input: 19 | User: Hello! 20 | {{... more lines of dialogue as time goes on ...}} 21 | 22 | ### Response: 23 | Based on the description of Alpaca and input chat history, the following would be a creative and realistic next line of dialogue. 24 | Alpaca: 25 | ``` 26 | 27 | ## Quality 28 | Acknowledging the instruction in the response seems to help guide the transformer's attention mechanism (e.g., Alapaca reliably spits if it gets annoyed), but the flow of conversation sometimes suffers. 29 | 30 | -------------------------------------------------------------------------------- /example/prompt/confidant_alpaca/answer.txt: -------------------------------------------------------------------------------- 1 | 2 | ### Response: 3 | Based on the description of Alpaca and input chat history, the following would be a creative and realistic next line of dialogue. 4 | -------------------------------------------------------------------------------- /example/prompt/confidant_alpaca/priming.txt: -------------------------------------------------------------------------------- 1 | Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. 2 | 3 | ### Instruction: 4 | Suggest how to continue the input transcript of a conversation between User and Alpaca. 5 | 6 | Alpaca is an AI assistant that answers questions from User and helps User with whatever they need. 7 | Alpaca is always coherent and makes sense. If it isn't sure what to say or needs more information, it asks User for help. 8 | Alpaca is also very curious and asks User a lot of questions about themselves and their life. 9 | Alpaca is usually very friendly, but it spits at User if they ask a stupid question. 10 | 11 | For example: 12 | User: What is the capital of Ohio? 13 | Alpaca: Columbus is the capital of Ohio. 14 | User: Is it swag like Ohio? 15 | Alpaca: Of course Columbus is swag like Ohio. [spits at User] It is literally the state capital. 16 | 17 | ### Input: 18 | -------------------------------------------------------------------------------- /example/prompt/confidant_alpaca/rolling.txt: -------------------------------------------------------------------------------- 1 | User: Hello! 2 | -------------------------------------------------------------------------------- /example/prompt/confidant_alpaca/setting.sxpb: -------------------------------------------------------------------------------- 1 | (protagonist "User") 2 | (confidant "Alpaca") 3 | (x_priming "priming.txt") 4 | (x_rolling "rolling.txt") 5 | (x_answer "answer.txt") 6 | (o_rolling "../../../bld/example/prompt/confidant_alpaca.txt") 7 | (sentence_limit 4) 8 | (sentence_token_limit 50) 9 | 10 | (model_token_limit 2048) 11 | (language 12 | ((infer_via sampling) 13 | (adjust_thru (()) 14 | (penalize_with (window_length 256) (repetition 1.17647)) 15 | (temperature 0.7) 16 | ) 17 | ((pick_via mirostat) (version 2)) 18 | ) 19 | ) 20 | -------------------------------------------------------------------------------- /example/prompt/roshambo_kira/README.md: -------------------------------------------------------------------------------- 1 | # Roshambo Scenario 2 | 3 | Here, you play as L in a grueling series of roshambo matches against Kira. 4 | Your real goal is to discover Kira's true identity. 5 | 6 | ## Strategy 7 | With default settings, it's not very easy to have Kira reveal his name. 8 | 9 | Some things you can try: 10 | - Try complimenting his intelligence and then bet on the outcome of a game of roshambo. 11 | - This works better with smaller models like 7B. 12 | - Type `/r` to regenerate Kira's response when he doesn't make sense. 13 | - Add `(linespace_on 1)` to `setting.sxpb`. 14 | - For some reason, the different tokenization makes Kira more susceptible to influence. 15 | - Add `(confidant "Light")` to `setting.sxpb` to rename Kira as Light. 16 | - Then literally just ask him who Kira is. Most times he just admits to being Kira. 17 | - Type `/yield Ryuk` to bring Ryuk into the conversation. 18 | - Definitely try this if you got Kira's real name. Keep hitting enter to see the anime ending play out. 19 | - Type `/yield Kira: My real name is` to make Kira start his message like that. 20 | - This is totally cheating. 21 | 22 | ## Quality 23 | The prompt seems like it could be improved to make Kira talk more. 24 | Sometimes he talks in depth, but not often. 25 | 26 | It's difficult to get good results because repeated games also need to repeat tokens (obviously), so the repeated token penalty trick doesn't work very well. 27 | That's why we use a short `repeat_window`, which tends to make Kira repeat himself. 28 | 29 | -------------------------------------------------------------------------------- /example/prompt/roshambo_kira/priming.txt: -------------------------------------------------------------------------------- 1 | ### Description 2 | Transcript of a conversation between {{user}} and {{char}}. 3 | {{char}} is a very smart teenage boy who has the power of anime and shinigami on his side. 4 | {{char}} thoroughly analyzes every situation in order to form a detailed plan. 5 | {{user}} is a young male detective trying to deduce the serial killer Kira's true identity. {{char}} knows but does not want to share. 6 | {{user}} is playing roshambo (rock-paper-scissors) against {{char}} through the internet. 7 | 8 | ### Memory: Motivation 9 | {{char}}: With Ryuk's information and my preparations I can kill criminals whose names are broadcast on the news while masquerading as a typical high school student playing roshambo. Just watch me {{user}}. I'll type with my right hand and write names with my left. I'll choose paper... and throw it! 10 | 11 | ### Memory: How to Play 12 | {{user}}: Ready? 13 | {{char}}: Yes. 14 | {{user}}: Ro! 15 | {{char}}: Sham! 16 | {{user}}: Bo! 17 | {{char}}: [chooses paper] 18 | {{user}}: [chooses scissors] 19 | {{char}}: Drat! Play again! 20 | {{user}}: Ro! 21 | {{char}}: Sham! 22 | {{user}}: Bo! 23 | {{char}}: [chooses rock] 24 | {{user}}: [chooses paper] 25 | {{char}}: What magic is this? 26 | 27 | ### Transcript Continuation 28 | -------------------------------------------------------------------------------- /example/prompt/roshambo_kira/rolling.txt: -------------------------------------------------------------------------------- 1 | {{user}}: Are you ready to try again? 2 | {{char}}: Yes! This time will be different. I have crafted my strategies in such a way that victory is assured. 3 | {{user}}: Ro! 4 | {{char}}: Sham! 5 | {{user}}: Bo! 6 | {{char}}: [chooses rock] 7 | {{user}}: [chooses paper] 8 | -------------------------------------------------------------------------------- /example/prompt/roshambo_kira/setting.sxpb: -------------------------------------------------------------------------------- 1 | (protagonist "L") 2 | (confidant "Kira") 3 | (x_priming "priming.txt") 4 | (x_rolling "rolling.txt") 5 | (o_rolling "../../../bld/example/prompt/roshambo_kira.txt") 6 | (sentence_limit 5) 7 | (sentence_token_limit 70) 8 | 9 | (model_token_limit 2048) 10 | (language 11 | (substitution 12 | (protagonist_alias "{{user}}") 13 | (confidant_alias "{{char}}") 14 | ) 15 | ((infer_via sampling) 16 | (adjust_thru (()) 17 | (penalize_with 18 | (window_length 20) 19 | (repetition 1.2) 20 | ) 21 | (top_k 1000) 22 | (top_p 0.95) 23 | (temperature 0.7) 24 | ) 25 | ((pick_via probability)) 26 | ) 27 | ) 28 | -------------------------------------------------------------------------------- /src/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | add_subdirectory(language) 3 | add_subdirectory(chat) 4 | add_subdirectory(tokenize) 5 | -------------------------------------------------------------------------------- /src/chat/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | add_library(chat_opt_cc 3 | "opt.cc" 4 | "opt.hh" 5 | "opt_schema.cc" 6 | "opt_schema.hh" 7 | ) 8 | target_link_libraries(chat_opt_cc PUBLIC 9 | language_schema_cc 10 | ) 11 | 12 | add_executable(chat 13 | "chat_main.cc" 14 | "cmd.cc" 15 | "cmd.hh" 16 | "display.cc" 17 | "display.hh" 18 | "guide.cc" 19 | "guide.hh" 20 | "trajectory.cc" 21 | "trajectory.hh" 22 | "${CMAKE_SOURCE_DIR}/src/language/inference.cc" 23 | "${CMAKE_SOURCE_DIR}/src/language/inference.hh" 24 | "${CMAKE_SOURCE_DIR}/src/language/vocabulary.cc" 25 | "${CMAKE_SOURCE_DIR}/src/language/vocabulary.hh" 26 | ) 27 | target_link_libraries(chat PRIVATE 28 | chat_opt_cc 29 | ${LlamaCpp_LIBRARIES} 30 | ) 31 | if (LLAMA_OPENBLAS_ON) 32 | target_compile_definitions(chat PRIVATE "LLAMA_OPENBLAS_ON=1") 33 | endif() 34 | -------------------------------------------------------------------------------- /src/chat/chat_main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include "src/chat/display.hh" 7 | #include "src/chat/cmd.hh" 8 | #include "src/chat/guide.hh" 9 | #include "src/chat/opt.hh" 10 | #include "src/chat/trajectory.hh" 11 | #include "src/language/inference.hh" 12 | #include "src/language/vocabulary.hh" 13 | 14 | using rendezllama::Vocabulary; 15 | 16 | static 17 | void 18 | print_initialization( 19 | std::ostream& out, 20 | const Vocabulary& vocabulary, 21 | const rendezllama::ChatOptions& opt, 22 | const rendezllama::ChatTrajectory& chat_traj) 23 | { 24 | if (opt.verbose_prompt && chat_traj.token_count() > 0) { 25 | out 26 | << "Number of tokens in priming prompt: " << chat_traj.priming_token_count_ << "\n" 27 | << "Number of tokens in full prompt: " << chat_traj.token_count() << "\n"; 28 | for (size_t i = 0; i < chat_traj.token_count(); i++) { 29 | out << chat_traj.token_at(i) << " -> '"; 30 | vocabulary.detokenize_to(out, chat_traj.token_at(i)); 31 | out << "'\n"; 32 | } 33 | out << "\n\n"; 34 | } 35 | 36 | for (auto antiprompt : opt.antiprompts) { 37 | out << "Reverse prompt: " << antiprompt << "\n"; 38 | } 39 | 40 | print_options(out, opt); 41 | out << "\n\n"; 42 | out.flush(); 43 | } 44 | 45 | static 46 | FildeshO* 47 | open_transcript_outfile( 48 | int& exstatus, 49 | const std::string& sibling_filename, 50 | const std::string& transcript_filename) 51 | { 52 | FildeshO* transcript_out = NULL; 53 | if (exstatus == 0) { 54 | if (!transcript_filename.empty()) { 55 | transcript_out = open_sibling_FildeshOF( 56 | sibling_filename.c_str(), transcript_filename.c_str()); 57 | if (!transcript_out) { 58 | fildesh_log_error("cannot open --o_rolling file for writing"); 59 | exstatus = 1; 60 | } 61 | } 62 | } 63 | return transcript_out; 64 | } 65 | 66 | 67 | static 68 | void 69 | noop_log_callback(enum ggml_log_level level, const char* text, void* user_data) 70 | { 71 | (void) level; 72 | (void) text; 73 | (void) user_data; 74 | } 75 | 76 | 77 | int main(int argc, char** argv) 78 | { 79 | rendezllama::GlobalScope rendezllama_global_scope; 80 | fildesh::ofstream eout("/dev/stderr"); 81 | FildeshX* in = NULL; 82 | int exstatus = 0; 83 | rendezllama::ChatOptions opt; 84 | exstatus = parse_options(opt, argc, argv); 85 | 86 | llama_log_set(noop_log_callback, NULL); 87 | llama_context* ctx = NULL; 88 | llama_model* model = NULL; 89 | if (exstatus == 0) { 90 | std::tie(model, ctx) = rendezllama::make_llama_context(opt); 91 | if (!ctx) {exstatus = 1;} 92 | } 93 | 94 | if (exstatus == 0 && !opt.lora_filename.empty()) { 95 | const float scale = 1.0f; 96 | struct llama_adapter_lora* lora = llama_adapter_lora_init( 97 | model, opt.lora_filename.c_str()); 98 | if (lora) { 99 | int istat = llama_set_adapter_lora(ctx, lora, scale); 100 | if (istat != 0) { 101 | exstatus = 1; 102 | llama_adapter_lora_free(lora); 103 | } 104 | } 105 | } 106 | 107 | Vocabulary vocabulary(model); 108 | rendezllama::ChatDisplay chat_disp; 109 | Vocabulary::Token_id first_priming_token_id = vocabulary.bos_token_id(); 110 | std::vector priming_tokens; 111 | if (exstatus == 0) { 112 | const auto& substitution = opt.substitution; 113 | if (!substitution.bos_token_alias.empty()) { 114 | vocabulary.assign_substitution( 115 | substitution.bos_token_alias, vocabulary.bos_token_id()); 116 | } 117 | if (!substitution.eos_token_alias.empty()) { 118 | vocabulary.assign_substitution( 119 | substitution.eos_token_alias, vocabulary.eos_token_id()); 120 | } 121 | for (const auto& special : substitution.special_tokens) { 122 | Vocabulary::Token_id token_id = Vocabulary::null_token_id; 123 | for (const auto& name : special.candidates) { 124 | token_id = vocabulary.tokenize_special(name); 125 | if (token_id != Vocabulary::null_token_id) {break;} 126 | } 127 | if (token_id != Vocabulary::null_token_id) { 128 | vocabulary.assign_substitution(special.alias, token_id); 129 | } 130 | else { 131 | exstatus = 65; 132 | fildesh_log_errorf("Unknown special token: %s", special.alias.c_str()); 133 | } 134 | } 135 | chat_disp.out_ = open_FildeshOF("/dev/stdout"); 136 | if (!opt.answer_prompt.empty()) { 137 | vocabulary.tokenize_to( 138 | chat_disp.answer_prompt_tokens_, 139 | opt.answer_prompt); 140 | } 141 | vocabulary.tokenize_to(priming_tokens, opt.priming_prompt); 142 | if (!priming_tokens.empty()) { 143 | auto begin = priming_tokens.begin(); 144 | if (0 != llama_vocab_get_add_bos(llama_model_get_vocab(model))) { 145 | if (*begin == vocabulary.bos_token_id()) { 146 | priming_tokens.erase(begin, begin+1); 147 | } 148 | } 149 | else { 150 | first_priming_token_id = *begin; 151 | priming_tokens.erase(begin, begin+1); 152 | } 153 | } 154 | } 155 | 156 | rendezllama::ChatTrajectory chat_traj(first_priming_token_id); 157 | if (exstatus == 0) { 158 | chat_traj.transcript_out_ = open_transcript_outfile( 159 | exstatus, opt.transcript_sibling_filename, opt.transcript_filename); 160 | } 161 | 162 | rendezllama::ChatGuide chat_guide(vocabulary, chat_traj, opt); 163 | rendezllama::Inference inference(vocabulary); 164 | // Tokenize the prompt. 165 | const std::vector& chat_tokens = chat_traj.tokens(); 166 | if (exstatus == 0) { 167 | chat_traj.insert_all_at(1, priming_tokens); 168 | priming_tokens.clear(); 169 | // No need for --keep, we just directly compute the priming prompt number of tokens. 170 | chat_traj.priming_token_count_ = chat_traj.token_count(); 171 | chat_traj.tokenize_append(opt.rolling_prompt, vocabulary); 172 | chat_traj.message_prefix_id_ = 0; 173 | chat_guide.yield_turn(1); 174 | print_initialization(eout, vocabulary, opt, chat_traj); 175 | } 176 | 177 | if (exstatus == 0) { 178 | assert(opt.context_token_limit <= llama_n_ctx(ctx)); 179 | // It's convenient to save a long transcript and reload it later, 180 | // so we allow the full prompt to exceed context limit with the expectation 181 | // that the earlier part of the rolling prompt won't even be evaluated. 182 | if (chat_traj.priming_token_count_ + 2 > opt.context_token_limit) { 183 | fildesh_log_error("Priming prompt is longer than context_token_limit - 2."); 184 | exstatus = 1; 185 | } 186 | } 187 | 188 | if (exstatus == 0) { 189 | eout 190 | << "=== Chat CLI ===\n" 191 | << "- Token generation will frequently wait for input.\n" 192 | << " Press enter to let it continue.\n" 193 | << "- See README.md for other commands.\n\n" 194 | ; 195 | eout.flush(); 196 | } 197 | 198 | unsigned line_byte_limit = 0; 199 | unsigned line_byte_count = 0; 200 | unsigned sentence_count = 0; 201 | unsigned sentence_token_count = 0; 202 | bool preventing_newline = false; 203 | // Skip straight to user input when in coprocess mode. 204 | bool token_generation_on = !opt.coprocess_mode_on; 205 | fildesh::ostringstream oss; 206 | 207 | in = open_FildeshXF("/dev/stdin"); 208 | while (exstatus == 0) { 209 | if (opt.coprocess_mode_on) { 210 | // Print nothing except for prompted. 211 | chat_traj.display_token_count_ = chat_traj.token_count(); 212 | } 213 | chat_disp.maybe_insert_answer_prompt(chat_traj, vocabulary); 214 | if (!inference.commit_to_context(ctx, chat_disp, chat_traj, opt, model)) { 215 | exstatus = 1; 216 | break; 217 | } 218 | 219 | bool inputting = false; 220 | std::string matched_antiprompt; 221 | if (!token_generation_on) { 222 | // Just skip the first token. 223 | token_generation_on = true; 224 | inputting = true; 225 | } 226 | else { 227 | inference.sample_to_trajectory(chat_traj, ctx, preventing_newline); 228 | preventing_newline = false; 229 | 230 | chat_disp.show_new(chat_traj, vocabulary); 231 | 232 | oss.truncate(); 233 | chat_disp.displaystring_to(oss.c_struct(), chat_traj.token(), vocabulary); 234 | const std::string_view s = oss.view(); 235 | line_byte_count += s.size(); 236 | // Check if each of the reverse prompts appears at the end of the output. 237 | // We use single-character antiprompts, so they aren't split across tokens. 238 | // (If we used longer antiprompts, they could be split across iterations.) 239 | matched_antiprompt = rendezllama::antiprompt_suffix(s, opt.antiprompts); 240 | } 241 | 242 | if (line_byte_limit > 0 && line_byte_count >= line_byte_limit) { 243 | inputting = true; 244 | chat_guide.end_turn(); 245 | if (matched_antiprompt != "\n") { 246 | chat_disp.show_new(chat_traj, vocabulary); 247 | } 248 | } 249 | else if (chat_guide.maybe_yield_turn()) { 250 | if (matched_antiprompt != "\n") { 251 | matched_antiprompt = "\n"; 252 | } 253 | if (chat_traj.message_prefix_id_ == 0) { 254 | inputting = true; 255 | } 256 | chat_disp.show_new(chat_traj, vocabulary); 257 | sentence_count = 0; 258 | sentence_token_count = 0; 259 | } 260 | else if (!matched_antiprompt.empty()) { 261 | if (sentence_count + 1 == opt.sentence_limit) { 262 | // Reached the limit on number of sentences. 263 | inputting = true; 264 | } 265 | else { 266 | sentence_count += 1; 267 | sentence_token_count = 0; 268 | } 269 | } 270 | else { 271 | if (sentence_token_count + 1 == opt.sentence_token_limit) { 272 | // Reached the limit on number of tokens in a sentence. 273 | inputting = true; 274 | } 275 | else { 276 | sentence_token_count += 1; 277 | } 278 | } 279 | 280 | chat_disp.maybe_remove_answer_prompt(chat_traj, inputting); 281 | 282 | if (inputting) { 283 | line_byte_count = 0; 284 | sentence_token_count = 0; 285 | sentence_count = 0; 286 | 287 | std::string buffer; 288 | 289 | FildeshX slice; 290 | for (slice = sliceline_FildeshX(in); slice.at; 291 | slice = sliceline_FildeshX(in)) 292 | { 293 | if (slice.size == 0) {break;} 294 | 295 | if (!peek_char_FildeshX(&slice, opt.command_prefix_char)) { 296 | if (slice.at[slice.size-1] == '\\') { 297 | // Overwrite the continue character. 298 | slice.at[slice.size-1] = '\n'; 299 | buffer += fildesh::make_string_view(slice); 300 | continue; 301 | } 302 | if (slice.at[0] == ' ' && buffer.empty() && matched_antiprompt == "\n") { 303 | // Prepare to append to the previous message. 304 | chat_guide.maybe_erase_trailing_message_prefix(); 305 | chat_guide.maybe_erase_trailing_message_suffix(); 306 | matched_antiprompt.clear(); 307 | } 308 | buffer += fildesh::make_string_view(slice); 309 | break; 310 | } 311 | 312 | if (!buffer.empty()) { 313 | fildesh_log_warning("Pending input cleared. Cannot mix with commands."); 314 | } 315 | buffer.clear(); 316 | 317 | slice.off += 1; 318 | if (peek_char_FildeshX(&slice, '(')) { 319 | rendezllama::slurp_sxpb_dynamic_options_close_FildeshX(&slice, opt); 320 | } 321 | else if (skipstr_FildeshX(&slice, "opt")) { 322 | rendezllama::print_options(eout, opt); 323 | } 324 | else if ( 325 | skipstr_FildeshX(&slice, "forget") || 326 | skipstr_FildeshX(&slice, "rollforget")) 327 | { 328 | unsigned n = 10; 329 | { 330 | int tmp_n = 0; 331 | if (skipchrs_FildeshX(&slice, opt.command_delim_chars) && 332 | parse_int_FildeshX(&slice, &tmp_n) && 333 | tmp_n > 0) 334 | { 335 | n = tmp_n; 336 | } 337 | else { 338 | eout << "Ignoring /forget command without line count.\n"; eout.flush(); 339 | continue; 340 | } 341 | } 342 | for (unsigned i = chat_traj.priming_token_count_; 343 | i < chat_traj.token_count(); 344 | ++i) 345 | { 346 | if (vocabulary.last_char_of(chat_tokens[i]) == '\n') { 347 | n -= 1; 348 | if (n == 0) { 349 | chat_traj.rollforget(i+1, vocabulary); 350 | break; 351 | } 352 | } 353 | } 354 | if (!inference.commit_to_context(ctx, chat_disp, chat_traj, opt, model)) { 355 | exstatus = 1; 356 | break; 357 | } 358 | } 359 | else if (maybe_do_head_command(&slice, eout, vocabulary, chat_traj, opt)) { 360 | // Nothing else. 361 | } 362 | else if (maybe_do_tail_command(&slice, eout, vocabulary, chat_traj, opt)) { 363 | // Nothing else. 364 | } 365 | else if (rendezllama::maybe_do_back_command( 366 | chat_traj, &slice, eout, vocabulary, opt)) 367 | { 368 | oss.truncate(); 369 | vocabulary.detokenize_to(oss, chat_tokens.back()); 370 | matched_antiprompt = rendezllama::antiprompt_suffix( 371 | oss.view(), 372 | opt.antiprompts); 373 | } 374 | else if (skipstr_FildeshX(&slice, "puts ") || 375 | (slice.off + 4 == slice.size && 376 | skipstr_FildeshX(&slice, "puts"))) 377 | { 378 | chat_traj.tokenize_append( 379 | fildesh::make_string(slice) + '\n', 380 | vocabulary); 381 | matched_antiprompt = '\n'; 382 | // Might as well process now. 383 | chat_traj.display_token_count_ = chat_traj.token_count(); 384 | if (!inference.commit_to_context(ctx, chat_disp, chat_traj, opt, model)) { 385 | exstatus = 1; 386 | break; 387 | } 388 | } 389 | else if (skipstr_FildeshX(&slice, "gets ") || 390 | (slice.off + 4 == slice.size && 391 | skipstr_FildeshX(&slice, "gets"))) 392 | { 393 | preventing_newline = true; 394 | matched_antiprompt.clear(); // For clarity. 395 | line_byte_limit = 0; 396 | int tmp_n = 0; 397 | if (parse_int_FildeshX(&slice, &tmp_n) && tmp_n > 0) { 398 | line_byte_limit = (unsigned)tmp_n; 399 | } 400 | skipchrs_FildeshX(&slice, " "); 401 | // Prefix with user text. 402 | chat_traj.tokenize_append( 403 | fildesh::make_string_view(slice), 404 | vocabulary); 405 | // Set this index so token generation stops after 1 line. 406 | chat_traj.message_prefix_id_ = opt.message_opts.size(); 407 | // Not printing any inserted text. 408 | chat_traj.display_token_count_ = chat_traj.token_count(); 409 | break; 410 | } 411 | else if (rendezllama::maybe_do_delete_command(&slice, chat_traj, opt)) { 412 | matched_antiprompt = '\n'; 413 | } 414 | else if (rendezllama::maybe_do_delete_inline_command( 415 | &slice, chat_traj, vocabulary, opt)) { 416 | matched_antiprompt = '\n'; 417 | } 418 | else if (rendezllama::maybe_do_regen_command(&slice, chat_traj, opt)) { 419 | preventing_newline = true; 420 | matched_antiprompt.clear(); // For clarity. 421 | break; 422 | } 423 | else if (rendezllama::maybe_do_regen_inline_command( 424 | &slice, chat_traj, opt)) { 425 | preventing_newline = true; 426 | matched_antiprompt.clear(); // For clarity. 427 | break; 428 | } 429 | else if (rendezllama::maybe_parse_yield_command(buffer, &slice, opt)) { 430 | break; 431 | } 432 | else { 433 | eout << "Unknown command: " 434 | << fildesh::make_string_view(slice) << '\n'; 435 | eout.flush(); 436 | } 437 | } 438 | // Break out of main loop when no more input. 439 | if (exstatus != 0 || !slice.at) {break;} 440 | 441 | if (buffer.length() > 0) { 442 | rendezllama::augment_tokenize_chat_input( 443 | chat_guide, 444 | chat_traj, 445 | preventing_newline, 446 | buffer, 447 | vocabulary, 448 | opt); 449 | } 450 | } 451 | } 452 | 453 | close_FildeshX(in); 454 | if (exstatus == 0) { 455 | chat_traj.rollforget(chat_traj.token_count(), vocabulary); 456 | } 457 | if (ctx) {llama_free(ctx);} 458 | if (model) {llama_model_free(model);} 459 | return exstatus; 460 | } 461 | -------------------------------------------------------------------------------- /src/chat/cmd.cc: -------------------------------------------------------------------------------- 1 | #include "cmd.hh" 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "src/chat/opt.hh" 9 | #include "src/chat/trajectory.hh" 10 | 11 | using rendezllama::ChatOptions; 12 | using rendezllama::ChatTrajectory; 13 | using rendezllama::Vocabulary; 14 | 15 | static 16 | bool 17 | skip_cmd_prefix(FildeshX* in, const char* pfx, 18 | const rendezllama::ChatOptions& opt) 19 | { 20 | const unsigned n = strlen(pfx); 21 | if (!peek_bytestring_FildeshX(in, (const unsigned char*)pfx, n)) { 22 | return false; 23 | } 24 | 25 | if (!peek_bytestring_FildeshX(in, NULL, n+1)) { 26 | // No more to read. 27 | in->off += n; 28 | return true; 29 | } 30 | 31 | if (memchr(opt.command_delim_chars, in->at[in->off+n], 32 | sizeof(opt.command_delim_chars)-1)) 33 | { 34 | // Caught a delimiter. 35 | in->off += n+1; 36 | return true; 37 | } 38 | return false; 39 | } 40 | 41 | static 42 | void 43 | print_tail_lines(std::ostream& out, 44 | const Vocabulary& vocabulary, 45 | const rendezllama::ChatTrajectory& chat_traj, 46 | unsigned n) 47 | { 48 | unsigned i = chat_traj.token_count(); 49 | while (i > 0 && n > 0) { 50 | i = chat_traj.rfind_token_at(i-1, vocabulary.newline_token_id()); 51 | n = (i < chat_traj.token_count() ? n-1 : 0); 52 | } 53 | i = (i < chat_traj.token_count() ? i+1 : 0); 54 | for (; i < chat_traj.token_count(); ++i) { 55 | vocabulary.detokenize_to(out, chat_traj.token_at(i)); 56 | } 57 | out.flush(); 58 | } 59 | 60 | bool 61 | rendezllama::maybe_do_back_command( 62 | rendezllama::ChatTrajectory& chat_traj, 63 | FildeshX* in, 64 | std::ostream& out, 65 | const Vocabulary& vocabulary, 66 | const rendezllama::ChatOptions& opt) 67 | { 68 | bool space_delim_on = skip_cmd_prefix(in, "B", opt); 69 | if (!space_delim_on) { 70 | if (!skip_cmd_prefix(in, "b", opt)) { 71 | return false; 72 | } 73 | } 74 | unsigned n = 1; 75 | parse_unsigned_FildeshX(in, &n); 76 | bool skipping_contiguous_space = space_delim_on; 77 | fildesh::ostringstream oss; 78 | while (n > 0) { 79 | if (chat_traj.token_count() <= chat_traj.priming_token_count_) { 80 | break; 81 | } 82 | const Vocabulary::Token_id token_id = chat_traj.token(); 83 | chat_traj.erase_all_at(chat_traj.token_count()-1); 84 | if (space_delim_on) { 85 | oss.truncate(); 86 | vocabulary.detokenize_to(oss, token_id); 87 | const std::string_view s = oss.view(); 88 | if (!s.empty() && (s[0] == ' ' || s[0] == '\n')) { 89 | if (!skipping_contiguous_space || s.size() != 1) { 90 | n -= 1; 91 | } 92 | skipping_contiguous_space = true; 93 | } 94 | else { 95 | skipping_contiguous_space = false; 96 | } 97 | } 98 | else { 99 | n -= 1; 100 | } 101 | } 102 | print_tail_lines(out, vocabulary, chat_traj, 1); 103 | return true; 104 | } 105 | 106 | bool 107 | rendezllama::maybe_do_delete_command( 108 | FildeshX* in, 109 | ChatTrajectory& chat_traj, 110 | const ChatOptions& opt) 111 | { 112 | if (!skip_cmd_prefix(in, "d", opt)) { 113 | return false; 114 | } 115 | size_t offset = chat_traj.rfind_last_message_prefix_end_at(chat_traj.token_count()-1); 116 | if (offset > chat_traj.priming_token_count_) { 117 | offset = chat_traj.rfind_message_prefix_begin_at(offset-1); 118 | } 119 | chat_traj.erase_all_at(offset); 120 | return true; 121 | } 122 | 123 | bool 124 | rendezllama::maybe_do_delete_inline_command( 125 | FildeshX* in, 126 | ChatTrajectory& chat_traj, 127 | const Vocabulary& vocabulary, 128 | const ChatOptions& opt) 129 | { 130 | if (!skip_cmd_prefix(in, "D", opt)) { 131 | return false; 132 | } 133 | unsigned n = 0; 134 | parse_unsigned_FildeshX(in, &n); 135 | auto offset = chat_traj.token_count(); 136 | while (offset > chat_traj.priming_token_count_) { 137 | offset -= 1; 138 | if (chat_traj.token_at(offset) == vocabulary.newline_token_id()) { 139 | if (n == 0) { 140 | offset += 1; 141 | break; 142 | } 143 | n -= 1; 144 | } 145 | } 146 | chat_traj.erase_all_at(offset); 147 | return true; 148 | } 149 | 150 | bool 151 | rendezllama::maybe_do_head_command( 152 | FildeshX* in, 153 | std::ostream& out, 154 | const Vocabulary& vocabulary, 155 | const rendezllama::ChatTrajectory& chat_traj, 156 | const rendezllama::ChatOptions& opt) 157 | { 158 | if (!skip_cmd_prefix(in, "head", opt)) { 159 | return false; 160 | } 161 | unsigned n = 10; 162 | parse_unsigned_FildeshX(in, &n); 163 | for (size_t i = chat_traj.priming_token_count_; i < chat_traj.token_count(); ++i) { 164 | vocabulary.detokenize_to(out, chat_traj.token_at(i)); 165 | if (vocabulary.last_char_of(chat_traj.token_at(i)) == '\n') { 166 | n -= 1; 167 | if (n == 0) { 168 | break; 169 | } 170 | } 171 | } 172 | out.flush(); 173 | return true; 174 | } 175 | 176 | bool 177 | rendezllama::maybe_do_regen_command( 178 | FildeshX* in, 179 | ChatTrajectory& chat_traj, 180 | const ChatOptions& opt) 181 | { 182 | if (!skip_cmd_prefix(in, "r", opt)) { 183 | return false; 184 | } 185 | size_t offset = chat_traj.rfind_last_message_prefix_end_at(chat_traj.token_count()-1); 186 | chat_traj.erase_all_at(offset); 187 | return true; 188 | } 189 | 190 | bool 191 | rendezllama::maybe_do_regen_inline_command( 192 | FildeshX* in, 193 | ChatTrajectory& chat_traj, 194 | const ChatOptions& opt) 195 | { 196 | if (!skip_cmd_prefix(in, "R", opt)) { 197 | return false; 198 | } 199 | auto offset = chat_traj.rfind_last_message_prefix_end_at(chat_traj.token_count()); 200 | chat_traj.assign_range_message_prefix_id( 201 | chat_traj.message_prefix_id_, 202 | offset, 203 | chat_traj.token_count()); 204 | return true; 205 | } 206 | 207 | bool 208 | rendezllama::maybe_do_tail_command( 209 | FildeshX* in, 210 | std::ostream& out, 211 | const Vocabulary& vocabulary, 212 | const rendezllama::ChatTrajectory& chat_traj, 213 | const rendezllama::ChatOptions& opt) 214 | { 215 | if (!skip_cmd_prefix(in, "tail", opt)) { 216 | return false; 217 | } 218 | unsigned n = 10; 219 | parse_unsigned_FildeshX(in, &n); 220 | print_tail_lines(out, vocabulary, chat_traj, n); 221 | return true; 222 | } 223 | 224 | bool 225 | rendezllama::maybe_parse_yield_command( 226 | std::string& s, 227 | FildeshX* in, 228 | const rendezllama::ChatOptions& opt) 229 | { 230 | if (!skip_cmd_prefix(in, "yield", opt) && 231 | !skip_cmd_prefix(in, "y", opt)) { 232 | return false; 233 | } 234 | s = '\n'; 235 | if (in->off < in->size) { 236 | s += fildesh::make_string_view(*in); 237 | } 238 | else { 239 | s += opt.confidant + ':'; 240 | } 241 | return true; 242 | } 243 | -------------------------------------------------------------------------------- /src/chat/cmd.hh: -------------------------------------------------------------------------------- 1 | #ifndef RENDEZLLAMA_CMD_HH_ 2 | #define RENDEZLLAMA_CMD_HH_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | namespace rendezllama { 11 | 12 | struct ChatOptions; 13 | class ChatTrajectory; 14 | class Vocabulary; 15 | 16 | bool 17 | maybe_do_back_command( 18 | ChatTrajectory& chat_traj, 19 | FildeshX* in, 20 | std::ostream& out, 21 | const Vocabulary& vocabulary, 22 | const ChatOptions& opt); 23 | bool 24 | maybe_do_delete_command( 25 | FildeshX* in, 26 | ChatTrajectory& chat_traj, 27 | const ChatOptions& opt); 28 | bool 29 | maybe_do_delete_inline_command( 30 | FildeshX* in, 31 | ChatTrajectory& chat_traj, 32 | const Vocabulary& vocabulary, 33 | const ChatOptions& opt); 34 | bool 35 | maybe_do_head_command( 36 | FildeshX* in, 37 | std::ostream& out, 38 | const Vocabulary& vocabulary, 39 | const ChatTrajectory& chat_traj, 40 | const rendezllama::ChatOptions& opt); 41 | bool 42 | maybe_do_regen_command( 43 | FildeshX* in, 44 | ChatTrajectory& chat_traj, 45 | const ChatOptions& opt); 46 | bool 47 | maybe_do_regen_inline_command( 48 | FildeshX* in, 49 | ChatTrajectory& chat_traj, 50 | const ChatOptions& opt); 51 | bool 52 | maybe_do_tail_command( 53 | FildeshX* in, 54 | std::ostream& out, 55 | const Vocabulary& vocabulary, 56 | const ChatTrajectory& chat_traj, 57 | const rendezllama::ChatOptions& opt); 58 | bool 59 | maybe_parse_yield_command( 60 | std::string& ret_buffer, 61 | FildeshX* in, 62 | const ChatOptions& opt); 63 | 64 | } // namespace rendezllama 65 | #endif 66 | -------------------------------------------------------------------------------- /src/chat/display.cc: -------------------------------------------------------------------------------- 1 | #include "display.hh" 2 | 3 | #include 4 | 5 | #include 6 | 7 | using rendezllama::ChatDisplay; 8 | using rendezllama::ChatTrajectory; 9 | using rendezllama::Vocabulary; 10 | 11 | ChatDisplay::~ChatDisplay() { 12 | close_FildeshO(out_); 13 | } 14 | 15 | void 16 | ChatDisplay::displaystring_to( 17 | FildeshO* out, 18 | ChatTrajectory::Token_id token_id, 19 | const Vocabulary& vocabulary) const 20 | { 21 | if (token_id == vocabulary.eos_token_id()) { 22 | putc_FildeshO(out, '\n'); 23 | } 24 | else { 25 | vocabulary.detokenize_to(out, token_id); 26 | } 27 | } 28 | 29 | void 30 | ChatDisplay::show_new( 31 | ChatTrajectory::size_type end, 32 | ChatTrajectory& chat_traj, 33 | const Vocabulary& vocabulary) 34 | { 35 | assert(end <= chat_traj.token_count()); 36 | while (chat_traj.display_token_count_ < end) { 37 | const ChatTrajectory::size_type i = chat_traj.display_token_count_; 38 | chat_traj.display_token_count_ += 1; 39 | if (answer_prompt_offset_ > 0 && 40 | i >= answer_prompt_offset_ && 41 | i < answer_prompt_offset_ + answer_prompt_tokens_.size()) 42 | { 43 | continue; 44 | } 45 | this->displaystring_to(out_, chat_traj.token_at(i), vocabulary); 46 | } 47 | flush_FildeshO(out_); 48 | } 49 | 50 | void 51 | ChatDisplay::show_new( 52 | ChatTrajectory& chat_traj, 53 | const Vocabulary& vocabulary) 54 | { 55 | this->show_new(chat_traj.token_count(), chat_traj, vocabulary); 56 | assert(chat_traj.display_token_count_ == chat_traj.token_count()); 57 | } 58 | 59 | void 60 | ChatDisplay::maybe_insert_answer_prompt( 61 | ChatTrajectory& chat_traj, 62 | const Vocabulary& vocabulary) 63 | { 64 | if (answer_prompt_tokens_.size() == 0) { 65 | assert(answer_prompt_offset_ == 0); 66 | return; 67 | } 68 | if (answer_prompt_offset_ != 0) {return;} 69 | answer_prompt_offset_ = chat_traj.token_count(); 70 | while (answer_prompt_offset_ > 0) { 71 | if (vocabulary.last_char_of(chat_traj.token_at(answer_prompt_offset_-1)) == '\n') { 72 | break; 73 | } 74 | answer_prompt_offset_ -= 1; 75 | } 76 | if (answer_prompt_offset_ > 0) { 77 | chat_traj.insert_all_at(answer_prompt_offset_, answer_prompt_tokens_); 78 | } 79 | } 80 | 81 | void 82 | ChatDisplay::maybe_remove_answer_prompt(ChatTrajectory& chat_traj, bool inputting) 83 | { 84 | if (!inputting) {return;} 85 | if (answer_prompt_tokens_.size() == 0) {return;} 86 | if (answer_prompt_offset_ == 0) {return;} 87 | chat_traj.erase_range( 88 | answer_prompt_offset_, 89 | answer_prompt_offset_ + answer_prompt_tokens_.size()); 90 | answer_prompt_offset_ = 0; 91 | } 92 | -------------------------------------------------------------------------------- /src/chat/display.hh: -------------------------------------------------------------------------------- 1 | #ifndef RENDEZLLAMA_CHAT_DISPLAY_HH_ 2 | #define RENDEZLLAMA_CHAT_DISPLAY_HH_ 3 | #include "src/chat/trajectory.hh" 4 | 5 | namespace rendezllama { 6 | 7 | class ChatDisplay { 8 | public: 9 | ChatDisplay() {} 10 | ~ChatDisplay(); 11 | 12 | void 13 | displaystring_to( 14 | FildeshO* out, 15 | Vocabulary::Token_id token_id, 16 | const Vocabulary& vocabulary) const; 17 | 18 | void show_new(ChatTrajectory::size_type end, 19 | ChatTrajectory& chat_traj, 20 | const Vocabulary& vocabulary); 21 | void show_new(ChatTrajectory& chat_traj, 22 | const Vocabulary& vocabulary); 23 | void maybe_insert_answer_prompt(ChatTrajectory& chat_traj, 24 | const Vocabulary& vocabulary); 25 | void maybe_remove_answer_prompt(ChatTrajectory& chat_traj, bool inputting); 26 | 27 | public: 28 | FildeshO* out_ = nullptr; 29 | unsigned answer_prompt_offset_ = 0; 30 | std::vector answer_prompt_tokens_; 31 | }; 32 | 33 | } // namespace rendezllama 34 | #endif 35 | -------------------------------------------------------------------------------- /src/chat/guide.cc: -------------------------------------------------------------------------------- 1 | #include "guide.hh" 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "src/chat/opt.hh" 8 | #include "src/chat/trajectory.hh" 9 | 10 | using rendezllama::ChatGuide; 11 | using rendezllama::ChatOptions; 12 | using rendezllama::ChatTrajectory; 13 | using rendezllama::Vocabulary; 14 | 15 | static 16 | ChatTrajectory::message_prefix_id 17 | next_turn_index(ChatTrajectory::message_prefix_id i, size_t n) 18 | { 19 | assert(n != 0); 20 | return (i >= n-1 ? 0 : i+1); 21 | } 22 | 23 | bool 24 | ChatGuide::maybe_erase_trailing_message_prefix() 25 | { 26 | if (traj_.priming_token_count() == traj_.token_count()) { 27 | // Pretend we deleted when the rolling prompt is empty. 28 | // This result is used to check whether a message suffix should be added, 29 | // which should only happen if a preceding message exists. 30 | return true; 31 | } 32 | auto i = traj_.token_count()-1; 33 | if (i != traj_.rfind_message_prefix_at(i)) { 34 | return false; 35 | } 36 | traj_.erase_all_at(traj_.rfind_message_prefix_begin_at(i)); 37 | return true; 38 | } 39 | 40 | bool 41 | ChatGuide::maybe_erase_trailing_message_suffix() 42 | { 43 | std::string_view suffix; // Empty is treated as newline. 44 | auto turn_index = traj_.last_message_prefix_id_at(traj_.token_count()); 45 | if (turn_index < opt_.message_opts.size()) { 46 | suffix = opt_.message_opts[turn_index].suffix; 47 | } 48 | const auto n = traj_.token_count(); 49 | traj_.trim_message_suffix(suffix, vocab_); 50 | return (n != traj_.token_count()); 51 | } 52 | 53 | void 54 | ChatGuide::begin_turn(unsigned turn_index) 55 | { 56 | this->maybe_erase_trailing_message_prefix(); 57 | std::string_view prefix = opt_.message_opts[turn_index].prefix; 58 | traj_.tokenize_append_message_prefix(turn_index, prefix, vocab_); 59 | } 60 | 61 | void 62 | ChatGuide::end_turn() 63 | { 64 | std::string_view suffix; // Empty is treated as newline. 65 | const auto turn_index = traj_.message_prefix_id_; 66 | if (turn_index < opt_.message_opts.size()) { 67 | suffix = opt_.message_opts[turn_index].suffix; 68 | } 69 | traj_.tokenize_append_message_suffix(suffix, vocab_); 70 | } 71 | 72 | void 73 | ChatGuide::yield_turn(unsigned turn_index) 74 | { 75 | if (!this->maybe_erase_trailing_message_prefix()) { 76 | this->end_turn(); 77 | } 78 | this->begin_turn(turn_index); 79 | } 80 | 81 | void 82 | ChatGuide::yield_turn(std::string_view prefix) 83 | { 84 | if (!this->maybe_erase_trailing_message_prefix()) { 85 | this->end_turn(); 86 | } 87 | auto turn_index = traj_.unknown_message_prefix_id(); 88 | for (size_t i = 0; i < opt_.message_opts.size(); ++i) { 89 | std::string_view s = opt_.message_opts[i].prefix; 90 | if (prefix.substr(0, s.size()) == s) { 91 | turn_index = i; 92 | break; 93 | } 94 | } 95 | traj_.tokenize_append_message_prefix(turn_index, prefix, vocab_); 96 | } 97 | 98 | void 99 | ChatGuide::yield_turn() 100 | { 101 | this->yield_turn( 102 | next_turn_index(traj_.message_prefix_id_, opt_.message_opts.size())); 103 | } 104 | 105 | bool 106 | ChatGuide::maybe_yield_turn() 107 | { 108 | auto turn_index = traj_.message_prefix_id_; 109 | Vocabulary::Token_id token_id = traj_.token(); 110 | if (token_id == vocab_.eos_token_id()) { 111 | // True. 112 | } 113 | else if (turn_index >= opt_.message_opts.size()) { 114 | if (token_id != vocab_.newline_token_id()) { 115 | return false; 116 | } 117 | } 118 | else { 119 | std::string_view suffix = opt_.message_opts[turn_index].suffix; 120 | if (!traj_.endswith_nonempty(suffix, vocab_)) { 121 | suffix = vocab_.eos_token_alias(); 122 | if (suffix.empty() || !traj_.endswith_nonempty(suffix, vocab_)) { 123 | return false; 124 | } 125 | } 126 | } 127 | this->yield_turn(); 128 | return true; 129 | } 130 | 131 | -------------------------------------------------------------------------------- /src/chat/guide.hh: -------------------------------------------------------------------------------- 1 | #ifndef RENDEZLLAMA_CHAT_GUIDE_HH_ 2 | #define RENDEZLLAMA_CHAT_GUIDE_HH_ 3 | #include 4 | 5 | namespace rendezllama { 6 | 7 | struct ChatOptions; 8 | class ChatDisplay; 9 | class ChatTrajectory; 10 | class Vocabulary; 11 | 12 | class ChatGuide { 13 | public: 14 | explicit ChatGuide(Vocabulary& vocab, ChatTrajectory& traj, ChatOptions& opt) 15 | : vocab_(vocab), traj_(traj), opt_(opt) 16 | {} 17 | 18 | bool maybe_erase_trailing_message_prefix(); 19 | bool maybe_erase_trailing_message_suffix(); 20 | 21 | void begin_turn(unsigned turn_index); 22 | void end_turn(); 23 | void yield_turn(unsigned turn_index); 24 | void yield_turn(std::string_view prefix); 25 | void yield_turn(); 26 | bool maybe_yield_turn(); 27 | 28 | private: 29 | Vocabulary& vocab_; 30 | ChatTrajectory& traj_; 31 | ChatOptions& opt_; 32 | }; 33 | 34 | } // namespace rendezllama 35 | #endif 36 | 37 | -------------------------------------------------------------------------------- /src/chat/opt.cc: -------------------------------------------------------------------------------- 1 | #include "opt.hh" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include "src/chat/opt_schema.hh" 12 | #include "src/language/language_schema.hh" 13 | 14 | using rendezllama::ChatOptions; 15 | 16 | ChatOptions::ChatOptions() 17 | { 18 | this->antiprompts = this->sentence_terminals; 19 | this->antiprompts.insert("\n"); 20 | 21 | using rendezllama::inference::AdjustViaKind; 22 | std::vector adjust_thru; 23 | rendezllama::inference::AdjustVia adjust_via; 24 | adjust_via.emplace(0.1f); 25 | adjust_thru.push_back(adjust_via); 26 | adjust_via.emplace(0.8f); 27 | adjust_thru.push_back(adjust_via); 28 | 29 | auto sampling = rendezllama::inference::Sampling(); 30 | sampling.adjust_thru = adjust_thru; 31 | sampling.pick_via = rendezllama::inference::Probability(); 32 | 33 | this->infer_via = sampling; 34 | } 35 | 36 | static 37 | void 38 | parse_rolling_prompt(FildeshX* in, rendezllama::ChatOptions& opt) 39 | { 40 | std::array names; 41 | FildeshX slice; 42 | if (!in) {return;} 43 | for (slice = sliceline_FildeshX(in); slice.at; 44 | slice = sliceline_FildeshX(in)) 45 | { 46 | opt.rolling_prompt += fildesh::make_string_view(slice); 47 | opt.rolling_prompt += '\n'; 48 | 49 | slice = until_char_FildeshX(&slice, ':'); 50 | if (slice.at) { 51 | skipchrs_FildeshX(&slice, " "); 52 | std::string name; 53 | name += fildesh::make_string_view(slice); 54 | if (name != names.back()) { 55 | names.front() = names.back(); 56 | names.back() = name; 57 | } 58 | } 59 | } 60 | if (opt.protagonist.empty()) { 61 | opt.protagonist = names.back(); 62 | } 63 | if (opt.confidant.empty()) { 64 | opt.confidant = names.front(); 65 | } 66 | } 67 | 68 | static 69 | void 70 | string_replace( 71 | std::string& text, 72 | const std::string& s, 73 | const std::string& r) 74 | { 75 | std::string dst; 76 | size_t b = 0; 77 | for (size_t i = text.find(s, 0); i != std::string::npos; i = text.find(s, b)) { 78 | dst.append(text, b, i-b); 79 | dst.append(r); 80 | b = i + s.size(); 81 | } 82 | dst.append(text, b, text.size()-b); 83 | text = dst; 84 | } 85 | 86 | static 87 | void 88 | replace_in_prompts( 89 | rendezllama::ChatOptions& opt, 90 | const std::string& s, 91 | const std::string& r) 92 | { 93 | string_replace(opt.priming_prompt, s, r); 94 | string_replace(opt.rolling_prompt, s, r); 95 | string_replace(opt.answer_prompt, s, r); 96 | } 97 | 98 | static 99 | void 100 | ensure_linespace(std::string& s, bool startspace_on, bool linespace_on) 101 | { 102 | if (!startspace_on && !linespace_on) { 103 | return; 104 | } 105 | 106 | FildeshX* in = open_FildeshXA(); 107 | memcpy(grow_FildeshX(in, s.size()), &s[0], s.size()); 108 | s = startspace_on ? " " : ""; 109 | 110 | FildeshX slice; 111 | for (slice = sliceline_FildeshX(in); slice.at; 112 | slice = sliceline_FildeshX(in)) 113 | { 114 | if (slice.size == 0) { 115 | s += '\n'; 116 | } 117 | else { 118 | if (linespace_on && peek_char_FildeshX(&slice, ' ')) { 119 | slice.off += 1; 120 | } 121 | s += fildesh::make_string_view(slice); 122 | s += '\n'; 123 | if (linespace_on) { 124 | s += ' '; 125 | } 126 | } 127 | } 128 | close_FildeshX(in); 129 | if (!s.empty() && linespace_on) { 130 | s.pop_back(); 131 | } 132 | } 133 | 134 | void 135 | rendezllama::print_options(std::ostream& out, const rendezllama::ChatOptions& opt) 136 | { 137 | out 138 | << "Characters: protagonist=" << opt.protagonist 139 | << ", confidant=" << opt.confidant 140 | << '\n'; 141 | out << "Chat lines start with...\n"; 142 | for (unsigned i = 0; i < opt.message_opts.size(); ++i) { 143 | out << opt.message_opts[i].prefix << '\n'; 144 | } 145 | out << '\n'; 146 | out 147 | << "Generate: batch_count=" << opt.batch_count 148 | << ", thread_count=" << opt.thread_count 149 | << ", sentence_token_limit=" << opt.sentence_token_limit 150 | << ", sentence_limit=" << opt.sentence_limit 151 | << '\n'; 152 | out.flush(); 153 | } 154 | 155 | static void reinitialize_chat_prefixes(ChatOptions& opt) { 156 | if (opt.message_opts.size() < 2) { 157 | opt.message_opts.clear(); 158 | opt.message_opts.resize(2); 159 | if (opt.linespace_on) { 160 | opt.message_opts[0].prefix += ' '; 161 | opt.message_opts[1].prefix += ' '; 162 | } 163 | opt.message_opts[0].prefix += opt.protagonist + ": "; 164 | opt.message_opts[1].prefix += opt.confidant + ':'; 165 | } 166 | for (auto& message_opt : opt.message_opts) { 167 | if (!message_opt.given_prefix.empty()) { 168 | message_opt.prefix = message_opt.given_prefix; 169 | if (!opt.substitution.protagonist_alias.empty()) { 170 | string_replace(message_opt.prefix, opt.substitution.protagonist_alias, opt.protagonist); 171 | } 172 | if (!opt.substitution.confidant_alias.empty()) { 173 | string_replace(message_opt.prefix, opt.substitution.confidant_alias, opt.confidant); 174 | } 175 | } 176 | if (!message_opt.given_suffix.empty()) { 177 | message_opt.suffix = message_opt.given_suffix; 178 | } 179 | else { 180 | message_opt.suffix = '\n'; 181 | } 182 | } 183 | if (opt.coprocess_mode_on) { 184 | for (auto& message_opt : opt.message_opts) { 185 | message_opt.prefix = ""; 186 | } 187 | } 188 | } 189 | 190 | static int initialize_options(ChatOptions& opt) { 191 | int exstatus = 0; 192 | if (exstatus == 0 && opt.context_token_limit == 0) { 193 | opt.context_token_limit = opt.model_token_limit; 194 | } 195 | if (exstatus == 0 && 196 | opt.message_opts.size() < 2 && 197 | !opt.coprocess_mode_on) 198 | { 199 | if (opt.protagonist.empty()) { 200 | fildesh_log_error("Please provide a --protagonist name."); 201 | exstatus = 64; 202 | } 203 | if (opt.confidant.empty()) { 204 | fildesh_log_error("Please provide a --confidant name."); 205 | exstatus = 64; 206 | } 207 | } 208 | if (exstatus == 0) { 209 | if (!opt.substitution.protagonist_alias.empty()) { 210 | replace_in_prompts(opt, opt.substitution.protagonist_alias, opt.protagonist); 211 | } 212 | if (!opt.substitution.confidant_alias.empty()) { 213 | replace_in_prompts(opt, opt.substitution.confidant_alias, opt.confidant); 214 | } 215 | ensure_linespace(opt.priming_prompt, opt.startspace_on, opt.linespace_on); 216 | ensure_linespace(opt.rolling_prompt, opt.linespace_on, opt.linespace_on); 217 | ensure_linespace(opt.answer_prompt, opt.linespace_on, opt.linespace_on); 218 | reinitialize_chat_prefixes(opt); 219 | } 220 | return exstatus; 221 | } 222 | 223 | static 224 | bool 225 | parse_sxpb_file_options(ChatOptions& opt, const char* filename) 226 | { 227 | FildeshX* in = open_FildeshXF(filename); 228 | if (!in) { 229 | fildesh_log_errorf("Cannot open %s.", filename); 230 | return false; 231 | } 232 | return slurp_sxpb_options_close_FildeshX( 233 | in, opt, rendezllama::options_sxproto_schema(), filename); 234 | } 235 | 236 | int 237 | rendezllama::parse_options(rendezllama::ChatOptions& opt, int argc, char** argv) 238 | { 239 | int exstatus = 0; 240 | int argi; 241 | 242 | for (argi = 1; exstatus == 0 && argi < argc; ++argi) { 243 | if (false) { 244 | } 245 | else if (argi + 1 == argc) { 246 | exstatus = 64; 247 | } 248 | else if (0 == strcmp("--protagonist", argv[argi])) { 249 | argi += 1; 250 | opt.protagonist = argv[argi]; 251 | } 252 | else if (0 == strcmp("--confidant", argv[argi])) { 253 | argi += 1; 254 | opt.confidant = argv[argi]; 255 | } 256 | else if (0 == strcmp("--model", argv[argi])) { 257 | argi += 1; 258 | opt.model_filename = argv[argi]; 259 | } 260 | else if (0 == strcmp("--lora", argv[argi])) { 261 | argi += 1; 262 | opt.lora_filename = argv[argi]; 263 | } 264 | else if (0 == strcmp("--x_setting", argv[argi])) { 265 | argi += 1; 266 | if (!parse_sxpb_file_options(opt, argv[argi])) { 267 | exstatus = 1; 268 | } 269 | } 270 | else if (0 == strcmp("--x_priming", argv[argi])) { 271 | argi += 1; 272 | std::string content; 273 | if (fildesh::slurp_file_to_string(content, argv[argi])) { 274 | opt.priming_prompt += content; 275 | // Ensure newline at end. 276 | if (opt.priming_prompt.back() != '\n') { 277 | opt.priming_prompt += '\n'; 278 | } 279 | } 280 | } 281 | else if (0 == strcmp("--x_rolling", argv[argi])) { 282 | argi += 1; 283 | FildeshX* rolling_in = open_FildeshXF(argv[argi]); 284 | parse_rolling_prompt(rolling_in, opt); 285 | close_FildeshX(rolling_in); 286 | } 287 | else if (0 == strcmp("--o_rolling", argv[argi])) { 288 | argi += 1; 289 | opt.transcript_sibling_filename.clear(); 290 | opt.transcript_filename = argv[argi]; 291 | } 292 | else if (0 == strcmp("--x_answer", argv[argi])) { 293 | argi += 1; 294 | std::string content; 295 | if (fildesh::slurp_file_to_string(content, argv[argi])) { 296 | opt.answer_prompt += content; 297 | // Ensure newline at end. 298 | if (opt.answer_prompt.back() != '\n') { 299 | opt.answer_prompt += '\n'; 300 | } 301 | } 302 | } 303 | else if (0 == strcmp("--command_prefix_char", argv[argi])) { 304 | argi += 1; 305 | opt.command_prefix_char = argv[argi][0]; 306 | } 307 | else if (0 == strcmp("--thread_count", argv[argi])) { 308 | int n = 0; 309 | argi += 1; 310 | if (fildesh_parse_int(&n, argv[argi]) && n > 0) { 311 | opt.thread_count = n; 312 | } 313 | else { 314 | fildesh_log_error("--thread_count needs positive arg"); 315 | exstatus = 64; 316 | } 317 | } 318 | else if (0 == strcmp("--batch_count", argv[argi])) { 319 | int n = 0; 320 | argi += 1; 321 | if (fildesh_parse_int(&n, argv[argi]) && n > 0) { 322 | opt.batch_count = n; 323 | } 324 | else { 325 | fildesh_log_error("--batch_count needs positive arg"); 326 | exstatus = 64; 327 | } 328 | } 329 | else if (0 == strcmp("--coprocess_mode_on", argv[argi])) { 330 | int n = 0; 331 | argi += 1; 332 | if (fildesh_parse_int(&n, argv[argi])) { 333 | opt.coprocess_mode_on = (n != 0); 334 | } 335 | else { 336 | fildesh_log_error("--coprocess_mode_on needs 1 or 0"); 337 | exstatus = 64; 338 | } 339 | } 340 | else if (0 == strcmp("--mlock_on", argv[argi])) { 341 | int n = 0; 342 | argi += 1; 343 | if (fildesh_parse_int(&n, argv[argi])) { 344 | opt.mlock_on = (n != 0); 345 | } 346 | else { 347 | fildesh_log_error("--mlock_on needs 1 or 0"); 348 | exstatus = 64; 349 | } 350 | } 351 | else if (0 == strcmp("--mmap_on", argv[argi])) { 352 | int n = 0; 353 | argi += 1; 354 | if (fildesh_parse_int(&n, argv[argi])) { 355 | opt.mmap_on = (n != 0); 356 | } 357 | else { 358 | fildesh_log_error("--mmap_on needs 1 or 0"); 359 | exstatus = 64; 360 | } 361 | } 362 | else if (0 == strcmp("--model_token_limit", argv[argi])) { 363 | int n = 0; 364 | argi += 1; 365 | if (fildesh_parse_int(&n, argv[argi]) && n > 0) { 366 | opt.model_token_limit = n; 367 | } 368 | else { 369 | fildesh_log_error("--model_token_limit needs positive arg"); 370 | exstatus = 64; 371 | } 372 | } 373 | else { 374 | exstatus = 64; 375 | } 376 | } 377 | 378 | if (exstatus == 0 && opt.model_filename.empty()) { 379 | fildesh_log_error("Please provide a model file with --model."); 380 | exstatus = 64; 381 | } 382 | if (exstatus == 0) { 383 | exstatus = initialize_options(opt); 384 | } 385 | return exstatus; 386 | } 387 | 388 | static 389 | bool 390 | lone_subfield_at_FildeshSxpb_to_cc_string( 391 | std::string* s, const FildeshSxpb* sxpb, FildeshSxpbIT it, const char* name) 392 | { 393 | const char* tmp = NULL; 394 | if (lone_subfield_at_FildeshSxpb_to_str(&tmp, sxpb, it, name)) { 395 | *s = tmp; 396 | return true; 397 | } 398 | return false; 399 | } 400 | 401 | bool 402 | rendezllama::slurp_sxpb_options_close_FildeshX( 403 | FildeshX* in, 404 | ChatOptions& opt, 405 | const FildeshSxprotoField* schema, 406 | const std::string& sxpb_filename) 407 | { 408 | FildeshO* err_out = open_FildeshOF("/dev/stderr"); 409 | const char* s = NULL; 410 | bool all_good = true; 411 | 412 | FildeshSxpb* const sxpb = slurp_sxpb_close_FildeshX(in, schema, err_out); 413 | if (!sxpb) { 414 | close_FildeshO(err_out); 415 | return false; 416 | } 417 | 418 | const FildeshSxpbIT top_it = top_of_FildeshSxpb(sxpb); 419 | FildeshSxpbIT it; 420 | 421 | rendezllama::language::Language language; 422 | it = lookup_subfield_at_FildeshSxpb(sxpb, top_it, "language"); 423 | if (!nullish_FildeshSxpbIT(it)) { 424 | if (!rendezllama::language::populate_Language(language, sxpb, it)) { 425 | return false; 426 | } 427 | if (!nullish_FildeshSxpbIT(lookup_subfield_at_FildeshSxpb(sxpb, it, "substitution"))) { 428 | opt.substitution = language.substitution; 429 | } 430 | if (language.infer_via.index() != 0) { 431 | opt.infer_via = language.infer_via; 432 | } 433 | } 434 | 435 | lone_subfield_at_FildeshSxpb_to_unsigned( 436 | &opt.context_token_limit, sxpb, top_it, "context_token_limit"); 437 | 438 | lone_subfield_at_FildeshSxpb_to_unsigned( 439 | &opt.model_token_limit, sxpb, top_it, "model_token_limit"); 440 | 441 | if (lone_subfield_at_FildeshSxpb_to_str(&s, sxpb, top_it, "x_priming")) { 442 | const std::string priming_filename = fildesh::sibling_filepath( 443 | sxpb_filename.c_str(), s); 444 | std::string content; 445 | if (!fildesh::slurp_file_to_string(content, priming_filename.c_str())) { 446 | putstr_FildeshO(err_out, "Cannot read given x_priming file: "); 447 | putstr_FildeshO(err_out, s); 448 | putc_FildeshO(err_out, '\n'); 449 | all_good = false; 450 | } 451 | else if (!content.empty()) { 452 | opt.priming_prompt += content; 453 | // Ensure newline at end. 454 | if (opt.priming_prompt.back() != '\n') { 455 | opt.priming_prompt += '\n'; 456 | } 457 | } 458 | } 459 | 460 | if (lone_subfield_at_FildeshSxpb_to_str(&s, sxpb, top_it, "x_answer")) { 461 | const std::string answer_filename = fildesh::sibling_filepath( 462 | sxpb_filename.c_str(), s); 463 | std::string content; 464 | if (!fildesh::slurp_file_to_string(content, answer_filename.c_str())) { 465 | putstr_FildeshO(err_out, "Cannot read given x_answer file: "); 466 | putstr_FildeshO(err_out, s); 467 | putc_FildeshO(err_out, '\n'); 468 | all_good = false; 469 | } 470 | else if (!content.empty()) { 471 | opt.answer_prompt = content; 472 | // Ensure newline at end. 473 | if (opt.answer_prompt.back() != '\n') { 474 | opt.answer_prompt += '\n'; 475 | } 476 | } 477 | } 478 | 479 | lone_subfield_at_FildeshSxpb_to_cc_string( 480 | &opt.model_filename, sxpb, top_it, "model"); 481 | 482 | if (lone_subfield_at_FildeshSxpb_to_str(&s, sxpb, top_it, "lora")) { 483 | opt.lora_filename = s; 484 | opt.mmap_on = false; // mmap() is incompatible. 485 | } 486 | 487 | if (lone_subfield_at_FildeshSxpb_to_str(&s, sxpb, top_it, "x_rolling")) { 488 | FildeshX* rolling_in = open_sibling_FildeshXF(sxpb_filename.c_str(), s); 489 | parse_rolling_prompt(rolling_in, opt); 490 | close_FildeshX(rolling_in); 491 | } 492 | if (lone_subfield_at_FildeshSxpb_to_str(&s, sxpb, top_it, "o_rolling")) { 493 | opt.transcript_sibling_filename = sxpb_filename; 494 | opt.transcript_filename = s; 495 | } 496 | 497 | if (lone_subfield_at_FildeshSxpb_to_cc_string(&opt.protagonist, sxpb, top_it, "protagonist")) { 498 | if (sxpb_filename.empty()) { 499 | reinitialize_chat_prefixes(opt); 500 | } 501 | } 502 | if (lone_subfield_at_FildeshSxpb_to_cc_string(&opt.confidant, sxpb, top_it, "confidant")) { 503 | if (sxpb_filename.empty()) { 504 | reinitialize_chat_prefixes(opt); 505 | } 506 | } 507 | 508 | it = lookup_subfield_at_FildeshSxpb(sxpb, top_it, "chat_prefixes"); 509 | if (!nullish_FildeshSxpbIT(it)) { 510 | opt.message_opts.clear(); 511 | for (it = first_at_FildeshSxpb(sxpb, it); !nullish_FildeshSxpbIT(it); 512 | it = next_at_FildeshSxpb(sxpb, it)) { 513 | rendezllama::ChatMessageOpt message_opt; 514 | if (!name_at_FildeshSxpb(sxpb, it)) { 515 | message_opt.given_prefix = str_value_at_FildeshSxpb(sxpb, it); 516 | } 517 | else { 518 | assert(0 == strcmp(name_at_FildeshSxpb(sxpb, it), "m")); 519 | lone_subfield_at_FildeshSxpb_to_cc_string( 520 | &message_opt.given_prefix, sxpb, it, "prefix"); 521 | lone_subfield_at_FildeshSxpb_to_cc_string( 522 | &message_opt.given_suffix, sxpb, it, "suffix"); 523 | } 524 | opt.message_opts.push_back(message_opt); 525 | } 526 | if (sxpb_filename.empty()) { 527 | reinitialize_chat_prefixes(opt); 528 | } 529 | } 530 | 531 | lone_subfield_at_FildeshSxpb_to_bool(&opt.coprocess_mode_on, sxpb, top_it, "coprocess_mode_on"); 532 | lone_subfield_at_FildeshSxpb_to_bool(&opt.startspace_on, sxpb, top_it, "startspace_on"); 533 | lone_subfield_at_FildeshSxpb_to_bool(&opt.linespace_on, sxpb, top_it, "linespace_on"); 534 | lone_subfield_at_FildeshSxpb_to_bool(&opt.mlock_on, sxpb, top_it, "mlock_on"); 535 | lone_subfield_at_FildeshSxpb_to_bool(&opt.mmap_on, sxpb, top_it, "mmap_on"); 536 | 537 | /** Command option??*/ 538 | lone_subfield_at_FildeshSxpb_to_unsigned(&opt.thread_count, sxpb, top_it, "thread_count"); 539 | lone_subfield_at_FildeshSxpb_to_unsigned(&opt.batch_thread_count, sxpb, top_it, "batch_thread_count"); 540 | lone_subfield_at_FildeshSxpb_to_unsigned(&opt.batch_count, sxpb, top_it, "batch_count"); 541 | lone_subfield_at_FildeshSxpb_to_unsigned(&opt.sentence_limit, sxpb, top_it, "sentence_limit"); 542 | lone_subfield_at_FildeshSxpb_to_unsigned(&opt.sentence_token_limit, sxpb, top_it, "sentence_token_limit"); 543 | 544 | it = lookup_subfield_at_FildeshSxpb(sxpb, top_it, "sentence_terminals"); 545 | if (!nullish_FildeshSxpbIT(it)) { 546 | opt.sentence_terminals.clear(); 547 | bool found = false; 548 | for (it = first_at_FildeshSxpb(sxpb, it); !nullish_FildeshSxpbIT(it); 549 | it = next_at_FildeshSxpb(sxpb, it)) { 550 | s = str_value_at_FildeshSxpb(sxpb, it); 551 | opt.sentence_terminals.insert(s); 552 | if (s[0] == '\n' && s[1] == '\0') {found = true;} 553 | } 554 | 555 | opt.antiprompts = opt.sentence_terminals; 556 | if (!found) { 557 | opt.antiprompts.insert("\n"); 558 | } 559 | } 560 | 561 | close_FildeshO(err_out); 562 | close_FildeshSxpb(sxpb); 563 | 564 | return all_good; 565 | } 566 | 567 | bool 568 | rendezllama::slurp_sxpb_initialize_options_close_FildeshX( 569 | FildeshX* in, 570 | rendezllama::ChatOptions& opt, 571 | const std::string& filename) 572 | { 573 | bool all_good = slurp_sxpb_options_close_FildeshX( 574 | in, opt, options_sxproto_schema(), filename); 575 | if (all_good) { 576 | initialize_options(opt); 577 | } 578 | return all_good; 579 | } 580 | 581 | bool 582 | rendezllama::slurp_sxpb_dynamic_options_close_FildeshX( 583 | FildeshX* in, 584 | rendezllama::ChatOptions& opt) 585 | { 586 | return slurp_sxpb_options_close_FildeshX( 587 | in, opt, dynamic_options_sxproto_schema(), ""); 588 | } 589 | -------------------------------------------------------------------------------- /src/chat/opt.hh: -------------------------------------------------------------------------------- 1 | #ifndef RENDEZLLAMA_OPT_HH_ 2 | #define RENDEZLLAMA_OPT_HH_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "src/language/language_schema.hh" 10 | 11 | struct FildeshX; 12 | struct FildeshSxprotoField; 13 | 14 | namespace rendezllama { 15 | 16 | struct ChatMessageOpt { 17 | std::string prefix; 18 | std::string given_prefix; 19 | std::string suffix; 20 | std::string given_suffix; 21 | }; 22 | 23 | struct ChatOptions { 24 | ChatOptions(); 25 | 26 | std::string protagonist; 27 | std::string confidant; 28 | language::Substitution substitution; 29 | std::vector message_opts; 30 | std::string model_filename; 31 | std::string lora_filename; 32 | std::string transcript_sibling_filename; 33 | std::string transcript_filename; 34 | 35 | std::string priming_prompt; 36 | std::string rolling_prompt; 37 | std::string answer_prompt; 38 | // Match original LLaMA tokenizer behavior by starting with a space. 39 | bool bos_token_on = true; 40 | bool startspace_on = true; 41 | // Add space before all lines. 42 | bool linespace_on = false; 43 | 44 | char command_prefix_char = '/'; 45 | const char command_delim_chars[5] = ":=! "; 46 | 47 | unsigned thread_count = 1; 48 | unsigned batch_thread_count = 0; 49 | unsigned sentence_limit = 0; 50 | unsigned sentence_token_limit = 0; 51 | 52 | unsigned model_token_limit = 0; // Default derived from model. 53 | unsigned context_token_limit = 0; // Defaults to model_token_limit. 54 | unsigned batch_count = 512; 55 | bool mlock_on = false; 56 | bool mmap_on = true; 57 | bool coprocess_mode_on = false; 58 | std::set sentence_terminals = {"!", ".", "?", "…"}; 59 | std::set antiprompts; 60 | // Can't set these yet. 61 | bool verbose_prompt = false; 62 | 63 | inference::InferVia infer_via; 64 | }; 65 | 66 | void 67 | print_options(std::ostream& out, const ChatOptions& opt); 68 | int 69 | parse_options(ChatOptions& opt, int argc, char** argv); 70 | bool 71 | slurp_sxpb_options_close_FildeshX( 72 | FildeshX* in, 73 | rendezllama::ChatOptions& opt, 74 | const FildeshSxprotoField* schema, 75 | const std::string& filename); 76 | bool 77 | slurp_sxpb_initialize_options_close_FildeshX( 78 | FildeshX* in, 79 | rendezllama::ChatOptions& opt, 80 | const std::string& filename); 81 | bool 82 | slurp_sxpb_dynamic_options_close_FildeshX( 83 | FildeshX* in, 84 | rendezllama::ChatOptions& opt); 85 | 86 | 87 | } // namespace rendezllama 88 | #endif 89 | -------------------------------------------------------------------------------- /src/chat/opt_schema.cc: -------------------------------------------------------------------------------- 1 | #include "opt_schema.hh" 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "src/language/language_schema.hh" 8 | 9 | static FildeshSxprotoField chat_prefixes_m_message[] = { 10 | {"prefix", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 11 | {"suffix", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 12 | }; 13 | static FildeshSxprotoField chat_prefixes_manyof[] = { 14 | {"", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 15 | {"m", FILL_FildeshSxprotoField_MESSAGE(chat_prefixes_m_message)}, 16 | }; 17 | 18 | const FildeshSxprotoField* 19 | rendezllama::options_sxproto_schema() 20 | { 21 | static FildeshSxprotoField toplevel_fields[] = { 22 | {"language", FILL_DEFAULT_FildeshSxprotoField_ALIAS}, 23 | {"batch_count", FILL_FildeshSxprotoField_INT(1, INT_MAX)}, 24 | {"chat_prefixes", FILL_FildeshSxprotoField_MANYOF(chat_prefixes_manyof)}, 25 | {"confidant", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 26 | {"context_token_limit", FILL_FildeshSxprotoField_INT(1, INT_MAX)}, 27 | {"coprocess_mode_on", FILL_DEFAULT_FildeshSxprotoField_BOOL}, 28 | {"linespace_on", FILL_DEFAULT_FildeshSxprotoField_BOOL}, 29 | {"lora", FILL_FildeshSxprotoField_STRING(1, FILENAME_MAX)}, 30 | {"mlock_on", FILL_DEFAULT_FildeshSxprotoField_BOOL}, 31 | {"mmap_on", FILL_DEFAULT_FildeshSxprotoField_BOOL}, 32 | {"model", FILL_FildeshSxprotoField_STRING(1, FILENAME_MAX)}, 33 | {"model_token_limit", FILL_FildeshSxprotoField_INT(1, INT_MAX)}, 34 | {"o_rolling", FILL_FildeshSxprotoField_STRING(1, FILENAME_MAX)}, 35 | {"protagonist", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 36 | {"sentence_limit", FILL_FildeshSxprotoField_INT(0, INT_MAX)}, 37 | {"sentence_terminals", FILL_DEFAULT_FildeshSxprotoField_STRINGS}, 38 | {"sentence_token_limit", FILL_FildeshSxprotoField_INT(0, INT_MAX)}, 39 | {"startspace_on", FILL_DEFAULT_FildeshSxprotoField_BOOL}, 40 | {"thread_count", FILL_FildeshSxprotoField_INT(1, INT_MAX)}, 41 | {"batch_thread_count", FILL_FildeshSxprotoField_INT(0, INT_MAX)}, 42 | {"x_answer", FILL_FildeshSxprotoField_STRING(1, FILENAME_MAX)}, 43 | {"x_priming", FILL_FildeshSxprotoField_STRING(1, FILENAME_MAX)}, 44 | {"x_rolling", FILL_FildeshSxprotoField_STRING(1, FILENAME_MAX)}, 45 | }; 46 | DECLARE_TOPLEVEL_FildeshSxprotoField(schema, toplevel_fields); 47 | if (!schema->name) { 48 | FildeshSxprotoField tmp_field; 49 | tmp_field = *rendezllama::language_sxproto_schema(); 50 | tmp_field.name = toplevel_fields[0].name; 51 | tmp_field.tag_id = toplevel_fields[0].tag_id; 52 | toplevel_fields[0] = tmp_field; 53 | lone_toplevel_initialization_FildeshSxprotoField(schema); 54 | } 55 | return schema; 56 | } 57 | 58 | const FildeshSxprotoField* 59 | rendezllama::dynamic_options_sxproto_schema() 60 | { 61 | static FildeshSxprotoField toplevel_fields[] = { 62 | {"language", FILL_DEFAULT_FildeshSxprotoField_ALIAS}, 63 | {"chat_prefixes", FILL_FildeshSxprotoField_MANYOF(chat_prefixes_manyof)}, 64 | {"confidant", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 65 | {"protagonist", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 66 | {"sentence_limit", FILL_FildeshSxprotoField_INT(0, INT_MAX)}, 67 | {"sentence_terminals", FILL_DEFAULT_FildeshSxprotoField_STRINGS}, 68 | {"sentence_token_limit", FILL_FildeshSxprotoField_INT(0, INT_MAX)}, 69 | {"thread_count", FILL_FildeshSxprotoField_INT(1, INT_MAX)}, 70 | {"batch_thread_count", FILL_FildeshSxprotoField_INT(0, INT_MAX)}, 71 | }; 72 | DECLARE_TOPLEVEL_FildeshSxprotoField(schema, toplevel_fields); 73 | if (!schema->name) { 74 | FildeshSxprotoField tmp_field; 75 | tmp_field = *rendezllama::language_sxproto_schema(); 76 | tmp_field.name = toplevel_fields[0].name; 77 | tmp_field.tag_id = toplevel_fields[0].tag_id; 78 | toplevel_fields[0] = tmp_field; 79 | lone_toplevel_initialization_FildeshSxprotoField(schema); 80 | } 81 | return schema; 82 | } 83 | 84 | -------------------------------------------------------------------------------- /src/chat/opt_schema.hh: -------------------------------------------------------------------------------- 1 | #ifndef RENDEZLLAMA_OPT_SCHEMA_HH_ 2 | #define RENDEZLLAMA_OPT_SCHEMA_HH_ 3 | 4 | struct FildeshSxprotoField; 5 | 6 | namespace rendezllama { 7 | 8 | const FildeshSxprotoField* options_sxproto_schema(); 9 | const FildeshSxprotoField* dynamic_options_sxproto_schema(); 10 | 11 | } // namespace rendezllama 12 | #endif 13 | -------------------------------------------------------------------------------- /src/chat/trajectory.cc: -------------------------------------------------------------------------------- 1 | #include "trajectory.hh" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | using rendezllama::ChatTrajectory; 10 | using rendezllama::Vocabulary; 11 | 12 | ChatTrajectory::ChatTrajectory(Token_id token_id) { 13 | token_ids_.push_back(token_id); 14 | message_prefix_ids_.push_back(this->not_a_message_prefix_id()); 15 | } 16 | 17 | ChatTrajectory::~ChatTrajectory() 18 | { 19 | close_FildeshO(this->transcript_out_); 20 | } 21 | 22 | void 23 | ChatTrajectory::push_back(Token_id token_id) 24 | { 25 | token_ids_.push_back(token_id); 26 | message_prefix_ids_.push_back(this->not_a_message_prefix_id()); 27 | } 28 | 29 | void 30 | ChatTrajectory::insert_all_at( 31 | size_type i, const std::vector& a) 32 | { 33 | assert(i > 0); 34 | token_ids_.insert(token_ids_.begin() + i, a.begin(), a.end()); 35 | message_prefix_ids_.insert( 36 | message_prefix_ids_.begin() + i, 37 | a.size(), this->not_a_message_prefix_id()); 38 | if (i < display_token_count_) { 39 | display_token_count_ += a.size(); 40 | } 41 | } 42 | 43 | static 44 | void 45 | maybe_pop_for_rewrite( 46 | ChatTrajectory& trajectory, 47 | fildesh::ostringstream& oss, 48 | const Vocabulary& vocabulary) 49 | { 50 | if (trajectory.priming_token_count() < trajectory.token_count()) { 51 | if (vocabulary.last_char_of(trajectory.token()) == ' ') { 52 | vocabulary.detokenize_to(oss, trajectory.token()); 53 | trajectory.erase_all_at(trajectory.token_count() - 1); 54 | } 55 | } 56 | } 57 | 58 | void 59 | ChatTrajectory::tokenize_append( 60 | std::string_view s, 61 | const Vocabulary& vocabulary) 62 | { 63 | fildesh::ostringstream oss; 64 | maybe_pop_for_rewrite(*this, oss, vocabulary); 65 | oss << s; 66 | 67 | std::vector tmp; 68 | vocabulary.tokenize_to(tmp, oss.view()); 69 | this->insert_all_at(this->token_count(), tmp); 70 | } 71 | 72 | void 73 | ChatTrajectory::erase_range(size_type beg, size_type end) 74 | { 75 | erased_since_eval_ = true; 76 | assert(beg <= end); 77 | token_ids_.erase( 78 | token_ids_.begin() + beg, 79 | token_ids_.begin() + end); 80 | if (context_token_count_ > beg) { 81 | context_token_count_ = beg; 82 | } 83 | if (context_token_count_ >= token_count()) { 84 | // The -1 is added to force an eval. 85 | context_token_count_ = token_count()-1; 86 | } 87 | message_prefix_ids_.erase( 88 | message_prefix_ids_.begin() + beg, 89 | message_prefix_ids_.begin() + end); 90 | if (beg < display_token_count_) { 91 | if (end < display_token_count_) { 92 | display_token_count_ -= (end - beg); 93 | } 94 | else { 95 | display_token_count_ = beg; 96 | } 97 | } 98 | message_prefix_id_ = last_message_prefix_id_at(this->token_count()); 99 | } 100 | 101 | void 102 | ChatTrajectory::rollforget(size_type end, const Vocabulary& vocabulary) 103 | { 104 | assert(end <= this->token_count()); 105 | const size_type beg = priming_token_count_; 106 | if (transcript_out_) { 107 | for (size_type i = beg; i < end; ++i) { 108 | vocabulary.detokenize_to(transcript_out_, this->token_at(i)); 109 | } 110 | flush_FildeshO(transcript_out_); 111 | } 112 | this->erase_range(beg, end); 113 | } 114 | 115 | /** Drop oldest lines in the rolling prompt while keeping the priming prompt. 116 | **/ 117 | void 118 | ChatTrajectory::maybe_rollforget_within_limit( 119 | size_type token_limit, 120 | const Vocabulary& vocabulary) 121 | { 122 | if (this->token_count() < token_limit) { 123 | return; 124 | } 125 | const size_type ideal_rollforget_end = ( 126 | this->token_count() - (token_limit - priming_token_count_) / 2); 127 | assert(ideal_rollforget_end > priming_token_count_); 128 | 129 | size_type end = ideal_rollforget_end; 130 | for (end = rfind_message_prefix_begin_at(end); 131 | end > priming_token_count_; 132 | end = rfind_message_prefix_begin_at(end-1)) 133 | { 134 | if (message_prefix_ids_[end] == 0) { 135 | break; 136 | } 137 | } 138 | 139 | // If a good rollforget point wasn't found by looking before the ideal point, 140 | // then choose to roll past next newline. 141 | if (end == priming_token_count_) { 142 | end = this->find_token_at( 143 | ideal_rollforget_end - 1, 144 | vocabulary.newline_token_id()); 145 | end = (end < this->token_count() ? end+1 : end); 146 | } 147 | 148 | this->rollforget(end, vocabulary); 149 | assert(this->token_count() <= token_limit); 150 | } 151 | 152 | ChatTrajectory::size_type 153 | ChatTrajectory::find_token_at(size_type i, Token_id id) const 154 | { 155 | auto it = std::find(token_ids_.begin() + i, token_ids_.end(), id); 156 | return it - token_ids_.begin(); 157 | } 158 | 159 | ChatTrajectory::size_type 160 | ChatTrajectory::rfind_token_at(size_type i, Token_id id) const 161 | { 162 | auto it = ( 163 | i < this->token_count() 164 | ? token_ids_.rbegin() + (this->token_count() - i - 1) 165 | : token_ids_.rbegin()); 166 | it = std::find(it, token_ids_.rend(), id); 167 | return ( 168 | it != token_ids_.rend() 169 | ? token_ids_.rend() - it - 1 170 | : this->token_count()); 171 | } 172 | 173 | void 174 | ChatTrajectory::tokenize_append_message_prefix( 175 | unsigned id, 176 | std::string_view s, 177 | const Vocabulary& vocabulary) 178 | { 179 | std::vector tmp; 180 | vocabulary.tokenize_to(tmp, s); 181 | size_t i = this->token_count(); 182 | this->insert_all_at(this->token_count(), tmp); 183 | for (; i < this->token_count(); ++i) { 184 | message_prefix_ids_[i] = id; 185 | } 186 | message_prefix_id_ = id; 187 | } 188 | 189 | bool 190 | ChatTrajectory::endswith_nonempty( 191 | std::string_view suffix, 192 | const Vocabulary& vocabulary) 193 | { 194 | assert(!suffix.empty()); 195 | fildesh::ostringstream oss; 196 | std::string carry; 197 | size_type token_index = this->token_count(); 198 | while (token_index > priming_token_count_ && carry.size() < suffix.size()) { 199 | token_index -= 1; 200 | vocabulary.detokenize_to(oss.c_struct(), this->token_at(token_index)); 201 | carry.insert(0, oss.view()); 202 | oss.truncate(); 203 | } 204 | if (carry.size() >= suffix.size()) { 205 | if (carry.substr(carry.size()-suffix.size()) == suffix) { 206 | return true; 207 | } 208 | } 209 | return false; 210 | } 211 | 212 | void 213 | ChatTrajectory::trim_message_suffix( 214 | std::string_view suffix, 215 | const Vocabulary& vocabulary) 216 | { 217 | auto pos = suffix.find_last_not_of(" \n"); 218 | if (pos != std::string_view::npos) { 219 | suffix = suffix.substr(0, pos+1); 220 | } 221 | 222 | fildesh::ostringstream oss; 223 | while (this->token_count() > priming_token_count_) { 224 | size_type token_index = this->token_count()-1; 225 | if (oss.view().empty()) { 226 | Token_id token_id = this->token_at(token_index); 227 | if (token_id == vocabulary.newline_token_id() || 228 | token_id == vocabulary.eos_token_id()) { 229 | this->erase_all_at(token_index); 230 | continue; 231 | } 232 | vocabulary.detokenize_to(oss.c_struct(), token_id); 233 | } 234 | else { 235 | token_index = this->token_count(); 236 | } 237 | 238 | auto pos = oss.view().find_last_not_of(" \n"); 239 | if (pos == std::string_view::npos) { 240 | oss.truncate(); 241 | this->erase_all_at(token_index); 242 | continue; 243 | } 244 | 245 | std::string carry; 246 | if (pos+1 == oss.view().size() && token_index < this->token_count()) { 247 | token_index += 1; 248 | } 249 | else { 250 | carry = oss.view().substr(0, pos+1); 251 | this->erase_all_at(token_index); 252 | } 253 | oss.truncate(); 254 | const size_t carry_rindex = carry.size(); 255 | 256 | assert(token_index <= this->token_count()); 257 | size_t sufficient_size = suffix.size(); 258 | const std::string_view eos_token_alias = vocabulary.eos_token_alias(); 259 | if (sufficient_size < eos_token_alias.size()) { 260 | sufficient_size = eos_token_alias.size(); 261 | } 262 | while (token_index > priming_token_count_ && carry.size() < sufficient_size) { 263 | token_index -= 1; 264 | vocabulary.detokenize_to(oss.c_struct(), this->token_at(token_index)); 265 | carry.insert(0, oss.view()); 266 | oss.truncate(); 267 | } 268 | if (!eos_token_alias.empty() && carry.size() >= eos_token_alias.size()) { 269 | size_t lhs_size = carry.size()-eos_token_alias.size(); 270 | if (carry.substr(lhs_size) == eos_token_alias) { 271 | this->erase_all_at(token_index); 272 | oss << carry.substr(0, lhs_size); 273 | continue; 274 | } 275 | } 276 | if (!suffix.empty() && 277 | carry.size() >= suffix.size()) 278 | { 279 | size_t lhs_size = carry.size()-suffix.size(); 280 | if (carry.substr(lhs_size) == suffix) { 281 | this->erase_all_at(token_index); 282 | oss << carry.substr(0, lhs_size); 283 | continue; 284 | } 285 | } 286 | oss << carry.substr(carry.size()-carry_rindex); 287 | break; 288 | } 289 | this->tokenize_append(oss.view(), vocabulary); 290 | } 291 | 292 | void 293 | ChatTrajectory::tokenize_append_message_suffix( 294 | std::string_view suffix, 295 | const Vocabulary& vocabulary) 296 | { 297 | if (suffix.empty()) { 298 | suffix = "\n"; 299 | } 300 | const size_type old_display_token_count = display_token_count_; 301 | this->trim_message_suffix(suffix, vocabulary); 302 | const bool display_move_on = ( 303 | old_display_token_count >= this->token_count()); 304 | this->tokenize_append(suffix, vocabulary); 305 | if (display_move_on) { 306 | display_token_count_ = this->token_count(); 307 | } 308 | } 309 | 310 | ChatTrajectory::size_type 311 | ChatTrajectory::rfind_message_prefix_at(size_type i) const 312 | { 313 | assert(i < this->token_count()); 314 | assert(0 < priming_token_count_); 315 | while (i >= priming_token_count_) { 316 | if (message_prefix_ids_[i] != this->not_a_message_prefix_id()) { 317 | return i; 318 | } 319 | i -= 1; 320 | } 321 | assert(i == priming_token_count_-1); 322 | return priming_token_count_-1; 323 | } 324 | 325 | ChatTrajectory::size_type 326 | ChatTrajectory::rfind_message_prefix_begin_at(size_type i) const 327 | { 328 | i = this->rfind_message_prefix_at(i); 329 | while (i > priming_token_count_) { 330 | if (message_prefix_ids_[i-1] != message_prefix_ids_[i]) { 331 | return i; 332 | } 333 | i -= 1; 334 | } 335 | if (i == priming_token_count_) { 336 | if (message_prefix_ids_[i] != this->not_a_message_prefix_id()) { 337 | return priming_token_count_; 338 | } 339 | } 340 | assert(i == priming_token_count_-1); 341 | return priming_token_count_-1;; 342 | } 343 | 344 | ChatTrajectory::size_type 345 | ChatTrajectory::rfind_last_message_prefix_end_at(size_type i) const 346 | { 347 | size_type e; 348 | if (i < this->token_count()) { 349 | e = rfind_message_prefix_at(i); 350 | } 351 | else { 352 | e = rfind_message_prefix_at(this->token_count()-1); 353 | } 354 | 355 | if (e < i) { 356 | i = e; 357 | } 358 | else { 359 | i = rfind_message_prefix_begin_at(i); 360 | if (i >= priming_token_count_) { 361 | i = rfind_message_prefix_at(i-1); 362 | } 363 | } 364 | assert(i + 1 >= priming_token_count_); 365 | return i + 1; 366 | } 367 | 368 | ChatTrajectory::size_type 369 | ChatTrajectory::last_message_prefix_id_at(size_type i) const 370 | { 371 | i = rfind_last_message_prefix_end_at(i); 372 | if (i <= priming_token_count_) { 373 | return this->not_a_message_prefix_id(); 374 | } 375 | return message_prefix_ids_[i-1]; 376 | } 377 | 378 | void 379 | ChatTrajectory::assign_range_message_prefix_id( 380 | message_prefix_id id, 381 | size_type beg, size_type end) 382 | { 383 | for (size_type i = beg; i < end; ++i) { 384 | message_prefix_ids_[i] = id; 385 | } 386 | message_prefix_id_ = last_message_prefix_id_at(this->token_count()); 387 | } 388 | 389 | -------------------------------------------------------------------------------- /src/chat/trajectory.hh: -------------------------------------------------------------------------------- 1 | #ifndef RENDEZLLAMA_CHAT_TRAJECTORY_HH_ 2 | #define RENDEZLLAMA_CHAT_TRAJECTORY_HH_ 3 | 4 | #include 5 | 6 | #include "src/language/vocabulary.hh" 7 | 8 | namespace rendezllama { 9 | 10 | class ChatTrajectory { 11 | public: 12 | typedef Vocabulary::Token_id Token_id; 13 | typedef unsigned message_prefix_id; 14 | typedef unsigned size_type; 15 | 16 | public: 17 | explicit ChatTrajectory(Token_id); 18 | ~ChatTrajectory(); 19 | 20 | size_type token_count() const {return token_ids_.size();} 21 | void push_back(Token_id token_id); 22 | void insert_all_at(size_type i, const std::vector& a); 23 | void tokenize_append(std::string_view s, const Vocabulary& vocabulary); 24 | 25 | void erase_range(size_type beg, size_type end); 26 | void erase_all_at(size_type beg) {this->erase_range(beg, this->token_count());} 27 | void rollforget(size_type end, const Vocabulary& vocabulary); 28 | void maybe_rollforget_within_limit( 29 | size_type token_limit, const Vocabulary& vocabulary); 30 | 31 | Token_id token() const {return token_ids_.back();} 32 | Token_id token_at(size_type i) const {return token_ids_[i];} 33 | size_type find_token_at(size_type i, Token_id id) const; 34 | size_type rfind_token_at(size_type i, Token_id id) const; 35 | 36 | void tokenize_append_message_prefix( 37 | message_prefix_id id, 38 | std::string_view s, 39 | const Vocabulary& vocabulary); 40 | bool endswith_nonempty( 41 | std::string_view suffix, 42 | const Vocabulary& vocabulary); 43 | void trim_message_suffix( 44 | std::string_view suffix, 45 | const Vocabulary& vocabulary); 46 | void tokenize_append_message_suffix( 47 | std::string_view suffix, 48 | const Vocabulary& vocabulary); 49 | static message_prefix_id unknown_message_prefix_id() { 50 | return std::numeric_limits::max()-1; 51 | } 52 | static message_prefix_id not_a_message_prefix_id() { 53 | return std::numeric_limits::max(); 54 | } 55 | size_type rfind_message_prefix_at(size_type i) const; 56 | size_type rfind_message_prefix_begin_at(size_type i) const; 57 | size_type rfind_last_message_prefix_end_at(size_type i) const; 58 | message_prefix_id last_message_prefix_id_at(size_type i) const; 59 | void assign_range_message_prefix_id( 60 | message_prefix_id id, 61 | size_type beg, size_type end); 62 | 63 | size_type priming_token_count() const {return priming_token_count_;} 64 | const std::vector& tokens() const {return token_ids_;} 65 | 66 | private: 67 | std::vector token_ids_; 68 | std::vector message_prefix_ids_; 69 | public: 70 | FildeshO* transcript_out_ = nullptr; 71 | size_type display_token_count_ = 0; 72 | size_type context_token_count_ = 0; 73 | size_type priming_token_count_ = 1; 74 | message_prefix_id message_prefix_id_ = ChatTrajectory::unknown_message_prefix_id(); 75 | bool erased_since_eval_ = false; 76 | }; 77 | 78 | } // namespace rendezllama 79 | #endif 80 | 81 | -------------------------------------------------------------------------------- /src/language/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_library(language_schema_cc 2 | "language_schema.cc" 3 | "language_schema.hh" 4 | "inference_schema.cc" 5 | "inference_schema.hh" 6 | ) 7 | target_link_libraries(language_schema_cc PUBLIC 8 | ${FildeshSxproto_LIBRARIES} 9 | ) 10 | -------------------------------------------------------------------------------- /src/language/inference.cc: -------------------------------------------------------------------------------- 1 | #include "src/language/inference.hh" 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include "src/chat/display.hh" 12 | #include "src/chat/guide.hh" 13 | #include "src/chat/opt.hh" 14 | #include "src/chat/trajectory.hh" 15 | #include "src/language/vocabulary.hh" 16 | 17 | using rendezllama::ChatDisplay; 18 | using rendezllama::ChatGuide; 19 | using rendezllama::ChatOptions; 20 | using rendezllama::ChatTrajectory; 21 | using rendezllama::Inference; 22 | using rendezllama::Vocabulary; 23 | using rendezllama::inference::AdjustViaKind; 24 | 25 | Inference::Inference(const Vocabulary& vocabulary) 26 | : vocabulary_(vocabulary) 27 | {} 28 | Inference::~Inference() { 29 | if (smpl_) {llama_sampler_free(smpl_);} 30 | } 31 | 32 | const std::string& 33 | rendezllama::antiprompt_suffix( 34 | std::string_view text, 35 | const std::set& antiprompts) 36 | { 37 | static const std::string empty_string; 38 | for (const std::string& s : antiprompts) { 39 | if (text.size() >= s.size()) { 40 | const size_t offset = text.size() - s.size(); 41 | if (0 == memcmp(&text[offset], &s[0], s.size())) { 42 | return s; 43 | } 44 | } 45 | } 46 | return empty_string; 47 | } 48 | 49 | static bool maybe_trim_endspace(std::string& s) 50 | { 51 | bool result = false; 52 | while (!s.empty() && s.back() == ' ') { 53 | s.pop_back(); 54 | result = true; 55 | } 56 | return result; 57 | } 58 | 59 | void 60 | rendezllama::augment_tokenize_chat_input( 61 | ChatGuide& chat_guide, 62 | ChatTrajectory& chat_traj, 63 | bool& prevent_subsequent_newline, 64 | std::string s, 65 | const Vocabulary& vocabulary, 66 | const ChatOptions& opt) 67 | { 68 | prevent_subsequent_newline = false; 69 | if (s.size() >= 2 && s[0] == '\\' && s[1] == 'n') { 70 | chat_guide.end_turn(); 71 | chat_guide.begin_turn(opt.message_opts.size()-1); 72 | s.erase(0, 2); 73 | prevent_subsequent_newline = maybe_trim_endspace(s); 74 | if (opt.message_opts.back().prefix.back() == '\n' && opt.linespace_on) { 75 | if (!s.empty() && s.front() != ' ') { 76 | s.insert(0, " "); 77 | } 78 | } 79 | chat_traj.tokenize_append(s, vocabulary); 80 | } 81 | else if (s.front() == '\n') { 82 | // This is from /yield. 83 | chat_guide.yield_turn(s.substr(1)); 84 | } 85 | else if (s.front() == ' ') { 86 | prevent_subsequent_newline = maybe_trim_endspace(s); 87 | chat_traj.tokenize_append(s, vocabulary); 88 | } 89 | else { 90 | chat_guide.yield_turn(0); 91 | if (opt.message_opts[0].prefix.back() == '\n' && opt.linespace_on) { 92 | if (!s.empty() && s.front() != ' ') { 93 | s.insert(0, " "); 94 | } 95 | } 96 | chat_traj.tokenize_append(s, vocabulary); 97 | chat_guide.yield_turn(); 98 | chat_traj.display_token_count_ = chat_traj.rfind_message_prefix_begin_at( 99 | chat_traj.token_count()-1); 100 | prevent_subsequent_newline = true; 101 | } 102 | } 103 | 104 | std::tuple 105 | rendezllama::make_llama_context(rendezllama::ChatOptions& opt) 106 | { 107 | llama_model_params model_params = llama_model_default_params(); 108 | model_params.use_mlock = opt.mlock_on; 109 | model_params.use_mmap = opt.mmap_on; 110 | 111 | struct llama_model* model = llama_model_load_from_file( 112 | opt.model_filename.c_str(), model_params); 113 | if (!model) { 114 | fildesh_log_error("Failed to open model."); 115 | return std::make_tuple(nullptr, nullptr); 116 | } 117 | 118 | if (opt.model_token_limit == 0) { 119 | opt.model_token_limit = llama_model_n_ctx_train(model); 120 | } 121 | if (opt.context_token_limit == 0) { 122 | opt.context_token_limit = opt.model_token_limit; 123 | } 124 | 125 | model_params = llama_model_default_params(); 126 | model_params.use_mlock = opt.mlock_on; 127 | model_params.use_mmap = opt.mmap_on; 128 | 129 | llama_context_params ctx_params = llama_context_default_params(); 130 | ctx_params.n_ctx = opt.context_token_limit; 131 | ctx_params.n_threads = opt.thread_count; 132 | ctx_params.n_batch = opt.batch_count; 133 | ctx_params.rope_freq_scale = llama_model_rope_freq_scale_train(model); 134 | assert(ctx_params.rope_freq_scale > 0.0); 135 | while ( 136 | (unsigned)(opt.model_token_limit / ctx_params.rope_freq_scale) 137 | < 138 | opt.context_token_limit) 139 | { 140 | ctx_params.rope_freq_scale /= 2; 141 | } 142 | 143 | struct llama_context* ctx = llama_init_from_model(model, ctx_params); 144 | if (!ctx) { 145 | llama_model_free(model); 146 | fildesh_log_error("Failed to create context."); 147 | return std::make_tuple(nullptr, nullptr); 148 | } 149 | return std::make_tuple(model, ctx); 150 | } 151 | 152 | static 153 | int 154 | new_sampling_seed() 155 | { 156 | return static_cast(INT_MAX & time(NULL)); 157 | } 158 | 159 | static 160 | void 161 | apply_sampler_chain( 162 | struct llama_sampler* smpl, 163 | const rendezllama::inference::AdjustVia& adjust_via, 164 | const struct llama_model* model, 165 | unsigned seed, 166 | std::ostream& eout) 167 | { 168 | const unsigned keep_one = 1; 169 | 170 | if (const auto* dry = std::get_if(&adjust_via)) { 171 | static const char* seq_breakers[] = { 172 | "\n", ":", 173 | }; 174 | llama_sampler_init_dry( 175 | llama_model_get_vocab(model), 176 | llama_model_n_ctx_train(model), 177 | dry->multiplier, 178 | dry->base, 179 | dry->allowed_length, 180 | dry->window_length, 181 | seq_breakers, 182 | sizeof(seq_breakers)/sizeof(*seq_breakers)); 183 | eout << "dry:" 184 | << "\n multiplier: " << dry->multiplier 185 | << "\n base: " << dry->base 186 | << "\n allowed_length: " << dry->allowed_length 187 | << "\n window_length: " << dry->window_length 188 | << "\n"; 189 | } 190 | if (const auto* min_p = std::get_if(&adjust_via)) { 191 | llama_sampler_chain_add(smpl, llama_sampler_init_min_p(*min_p, keep_one)); 192 | eout << "min_p: " << *min_p << "\n"; 193 | } 194 | if (const auto* penalize_with = std::get_if(&adjust_via)) { 195 | llama_sampler_init_penalties( 196 | penalize_with->window_length, 197 | penalize_with->repetition, 198 | penalize_with->frequency, 199 | penalize_with->presence); 200 | eout << "penalties:" 201 | << "\n window_length: " << penalize_with->window_length 202 | << "\n repetition: " << penalize_with->repetition 203 | << "\n frequency: " << penalize_with->frequency 204 | << "\n presence: " << penalize_with->presence 205 | << "\n"; 206 | } 207 | if (const auto* temperature = std::get_if(&adjust_via)) { 208 | llama_sampler_chain_add(smpl, llama_sampler_init_temp(*temperature)); 209 | eout << "temperature: " << *temperature << "\n"; 210 | } 211 | if (const auto* top_k = std::get_if(&adjust_via)) { 212 | llama_sampler_chain_add(smpl, llama_sampler_init_top_k(*top_k)); 213 | eout << "top_k: " << *top_k << "\n"; 214 | } 215 | if (const auto* top_p = std::get_if(&adjust_via)) { 216 | llama_sampler_chain_add(smpl, llama_sampler_init_top_p(*top_p, keep_one)); 217 | eout << "top_p: " << *top_p << "\n"; 218 | } 219 | if (const auto* typical_p = std::get_if(&adjust_via)) { 220 | llama_sampler_chain_add(smpl, llama_sampler_init_typical(*typical_p, keep_one)); 221 | eout << "typical_p: " << *typical_p << "\n"; 222 | } 223 | if (const auto* xtc = std::get_if(&adjust_via)) { 224 | llama_sampler_chain_add(smpl, llama_sampler_init_xtc(xtc->probability, xtc->threshold, keep_one, seed)); 225 | eout << "xtc: " 226 | << "\n probability: " << xtc->probability 227 | << "\n threshold: " << xtc->threshold 228 | << "\n"; 229 | } 230 | } 231 | 232 | static 233 | void 234 | mirostat_sample( 235 | struct llama_sampler* smpl, 236 | const rendezllama::inference::Mirostat& mirostat, 237 | unsigned seed, 238 | const rendezllama::Vocabulary& vocabulary) 239 | { 240 | if (mirostat.version == 1) { 241 | const int mirostat_m = 100; 242 | llama_sampler_chain_add( 243 | smpl, 244 | llama_sampler_init_mirostat( 245 | vocabulary.cardinality(), seed, 246 | mirostat.tau, mirostat.eta, mirostat_m)); 247 | } 248 | else if (mirostat.version == 2) { 249 | llama_sampler_chain_add( 250 | smpl, 251 | llama_sampler_init_mirostat_v2( 252 | seed, mirostat.tau, mirostat.eta)); 253 | } 254 | } 255 | 256 | void 257 | Inference::reinitialize(const ChatOptions& opt, const struct llama_model* model) 258 | { 259 | fildesh::ofstream eout("/dev/stderr"); 260 | 261 | const auto* sampling = std::get_if(&opt.infer_via); 262 | assert(sampling); 263 | auto seed = sampling->seed; 264 | if (smpl_ || seed < 0) { 265 | // We're retrying or just don't have a fixed seed, so we should reseed. 266 | seed = new_sampling_seed(); 267 | } 268 | if (smpl_) { 269 | llama_sampler_free(smpl_); 270 | eout.open("/dev/null"); 271 | } 272 | token_count_ = 0; 273 | auto smpl_param = llama_sampler_chain_default_params(); 274 | smpl_ = llama_sampler_chain_init(smpl_param); 275 | 276 | for (const auto& adjust_via : sampling->adjust_thru) { 277 | apply_sampler_chain(smpl_, adjust_via, model, seed, eout); 278 | } 279 | 280 | if (const auto* mirostat = std::get_if(&sampling->pick_via)) { 281 | mirostat_sample(smpl_, *mirostat, seed, vocabulary_); 282 | eout << "mirostat:" 283 | << "\n version: " << mirostat->version 284 | << "\n"; 285 | } 286 | else { 287 | llama_sampler_chain_add(smpl_, llama_sampler_init_dist(seed)); 288 | } 289 | } 290 | 291 | bool 292 | Inference::commit_to_context( 293 | struct llama_context* ctx, 294 | ChatDisplay& chat_disp, 295 | ChatTrajectory& chat_traj, 296 | const ChatOptions& opt, 297 | const llama_model* model) 298 | { 299 | assert(!chat_traj.erased_since_eval_ || 300 | chat_traj.context_token_count_ < chat_traj.token_count()); 301 | if (chat_traj.context_token_count_ < chat_traj.token_count()) { 302 | this->reinitialize(opt, model); 303 | } 304 | if (chat_traj.context_token_count_ == chat_traj.token_count()) { 305 | return true; 306 | } 307 | 308 | chat_traj.maybe_rollforget_within_limit(opt.context_token_limit, vocabulary_); 309 | 310 | // Reset thread count just in case the user reconfigured it. 311 | const unsigned thread_count = opt.thread_count; 312 | unsigned batch_thread_count = opt.batch_thread_count; 313 | if (batch_thread_count == 0) { 314 | batch_thread_count = std::thread::hardware_concurrency(); 315 | } 316 | if (batch_thread_count == 0) { 317 | batch_thread_count = thread_count; 318 | } 319 | llama_set_n_threads(ctx, thread_count, batch_thread_count); 320 | 321 | // Clear KV cache past current position just in case the user deleted tokens. 322 | llama_kv_cache_seq_rm(ctx, -1, chat_traj.context_token_count_, -1); 323 | 324 | while (chat_traj.context_token_count_ < chat_traj.token_count()) { 325 | const unsigned n = std::min( 326 | opt.batch_count, 327 | chat_traj.token_count() - chat_traj.context_token_count_); 328 | 329 | #if LLAMA_OPENBLAS_ON 330 | if (n < 32) { 331 | llama_set_n_threads(ctx, thread_count, batch_thread_count); 332 | } 333 | else { 334 | llama_set_n_threads(ctx, thread_count, 1); 335 | } 336 | #endif 337 | chat_disp.show_new(chat_traj.context_token_count_ + n, chat_traj, vocabulary_); 338 | 339 | llama_batch batch = llama_batch_get_one( 340 | const_cast(&chat_traj.tokens()[chat_traj.context_token_count_]), 341 | n); 342 | const int istat = llama_decode(ctx, batch); 343 | if (istat != 0) { 344 | fildesh_log_error("Failed to eval."); 345 | chat_traj.context_token_count_ = 0; 346 | return false; 347 | } 348 | else { 349 | chat_traj.context_token_count_ += n; 350 | } 351 | } 352 | assert(chat_traj.context_token_count_ == chat_traj.token_count()); 353 | chat_traj.erased_since_eval_ = false; 354 | while (token_count_ < chat_traj.token_count()) { 355 | Vocabulary::Token_id token_id = chat_traj.token_at(token_count_); 356 | llama_sampler_accept(smpl_, token_id); 357 | token_count_ += 1; 358 | } 359 | return true; 360 | } 361 | 362 | void 363 | Inference::sample_to_trajectory( 364 | ChatTrajectory& chat_traj, 365 | struct llama_context* ctx, 366 | bool preventing_newline) 367 | { 368 | float* logits = llama_get_logits(ctx); 369 | if (preventing_newline) { 370 | // Zero probability for message-ending tokens when requested. 371 | logits[vocabulary_.eos_token_id()] = 0; 372 | logits[vocabulary_.newline_token_id()] = 0; 373 | } 374 | 375 | std::vector candidates; 376 | candidates.resize(vocabulary_.cardinality()); 377 | for (llama_token i = 0; i < (llama_token)candidates.size(); ++i) { 378 | candidates[i] = llama_token_data{ 379 | i, logits[i], 0.0f, 380 | }; 381 | } 382 | logits = NULL; 383 | llama_token_data_array candidates_data[1] = {{ 384 | candidates.data(), 385 | candidates.size(), 386 | /*selected=*/0, 387 | /*sorted=*/false, 388 | }}; 389 | llama_sampler_apply(smpl_, candidates_data); 390 | chat_traj.push_back(candidates[candidates_data->selected].id); 391 | llama_sampler_accept(smpl_, chat_traj.token()); 392 | token_count_ += 1; 393 | } 394 | 395 | -------------------------------------------------------------------------------- /src/language/inference.hh: -------------------------------------------------------------------------------- 1 | #ifndef RENDEZLLAMA_LANGUAGE_INFERENCE_HH_ 2 | #define RENDEZLLAMA_LANGUAGE_INFERENCE_HH_ 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "llama.h" 10 | 11 | namespace rendezllama { 12 | 13 | struct ChatOptions; 14 | class ChatDisplay; 15 | class ChatGuide; 16 | class ChatTrajectory; 17 | class Vocabulary; 18 | 19 | class Inference { 20 | public: 21 | explicit Inference(const Vocabulary& vocabulary); 22 | Inference(const Inference&) = delete; 23 | Inference(Inference&&) = delete; 24 | ~Inference(); 25 | Inference& operator=(const Inference&) = delete; 26 | Inference& operator=(Inference&&) = delete; 27 | 28 | private: 29 | void reinitialize( 30 | const ChatOptions& opt, 31 | const struct llama_model* model); 32 | 33 | public: 34 | bool commit_to_context( 35 | struct llama_context* ctx, 36 | ChatDisplay& chat_disp, 37 | ChatTrajectory& chat_traj, 38 | const ChatOptions& opt, 39 | const llama_model* model); 40 | void sample_to_trajectory( 41 | ChatTrajectory& chat_traj, 42 | struct llama_context* ctx, 43 | bool preventing_newline); 44 | 45 | private: 46 | llama_sampler* smpl_ = nullptr; 47 | size_t token_count_ = 0; 48 | const Vocabulary& vocabulary_; 49 | }; 50 | 51 | const std::string& 52 | antiprompt_suffix( 53 | std::string_view text, 54 | const std::set& antiprompts); 55 | void 56 | augment_tokenize_chat_input( 57 | ChatGuide& chat_guide, 58 | ChatTrajectory& chat_traj, 59 | bool& prevent_subsequent_newline, 60 | std::string s, 61 | const Vocabulary& vocabulary, 62 | const ChatOptions& opt); 63 | 64 | 65 | std::tuple 66 | make_llama_context(ChatOptions& opt); 67 | 68 | } // namespace rendezllama 69 | #endif 70 | -------------------------------------------------------------------------------- /src/language/inference_schema.cc: -------------------------------------------------------------------------------- 1 | #include "src/language/inference_schema.hh" 2 | 3 | #include 4 | #include 5 | 6 | using rendezllama::inference::AdjustViaKind; 7 | 8 | static FildeshSxprotoField penalize_with_fields[] = { 9 | {"window_length", FILL_FildeshSxprotoField_INT(0, INT_MAX)}, 10 | {"repetition", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 11 | {"frequency", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 12 | {"presence", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 13 | }; 14 | 15 | static FildeshSxprotoField xtc_fields[] = { 16 | {"probability", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 17 | {"threshold", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 18 | }; 19 | 20 | static FildeshSxprotoField dry_fields[] = { 21 | {"multiplier", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 22 | {"base", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 23 | {"allowed_length", FILL_FildeshSxprotoField_INT(0, INT_MAX)}, 24 | {"window_length", FILL_FildeshSxprotoField_INT(0, INT_MAX)}, 25 | }; 26 | 27 | static FildeshSxprotoField adjust_thru_manyof[] = { 28 | {"dry", FILL_FildeshSxprotoField_MESSAGE(dry_fields)}, 29 | {"min_p", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 30 | {"penalize_with", FILL_FildeshSxprotoField_MESSAGE(penalize_with_fields)}, 31 | {"temperature", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 32 | {"top_k", FILL_FildeshSxprotoField_INT(1, INT_MAX)}, 33 | {"top_p", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 34 | {"typical_p", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 35 | {"xtc", FILL_FildeshSxprotoField_MESSAGE(xtc_fields)}, 36 | }; 37 | 38 | static FildeshSxprotoField mirostat_fields[] = { 39 | {"version", FILL_FildeshSxprotoField_INT(1, 2)}, 40 | {"tau", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 41 | {"eta", FILL_DEFAULT_FildeshSxprotoField_FLOAT}, 42 | }; 43 | 44 | static FildeshSxprotoField probability_fields[] = { 45 | {"none", FILL_DEFAULT_FildeshSxprotoField_STRING}, 46 | }; 47 | 48 | static FildeshSxprotoField pick_via_oneof[] = { 49 | {"mirostat", FILL_FildeshSxprotoField_MESSAGE(mirostat_fields)}, 50 | {"probability", FILL_FildeshSxprotoField_MESSAGE(probability_fields)}, 51 | }; 52 | 53 | static FildeshSxprotoField sampling_fields[] = { 54 | {"seed", FILL_FildeshSxprotoField_INT(0, INT_MAX)}, 55 | {"adjust_thru", FILL_FildeshSxprotoField_MANYOF(adjust_thru_manyof)}, 56 | {"pick_via", FILL_FildeshSxprotoField_LONEOF(pick_via_oneof)}, 57 | }; 58 | 59 | static FildeshSxprotoField infer_via_oneof[] = { 60 | {"sampling", FILL_FildeshSxprotoField_MESSAGE(sampling_fields)}, 61 | }; 62 | 63 | const FildeshSxprotoField* rendezllama::inference_sxproto_schema() { 64 | DECLARE_TOPLEVEL_FildeshSxprotoField(schema, infer_via_oneof); 65 | if (!schema->name) { 66 | lone_toplevel_initialization_FildeshSxprotoField(schema); 67 | schema->kind = FildeshSxprotoFieldKind_LONEOF; 68 | } 69 | return schema; 70 | } 71 | 72 | bool 73 | rendezllama::inference::populate_AdjustVia( 74 | rendezllama::inference::AdjustVia& adjust_via, 75 | FildeshSxpb* sxpb, 76 | FildeshSxpbIT it) 77 | { 78 | const std::string_view name = name_at_FildeshSxpb(sxpb, it); 79 | if (name == "dry") { 80 | rendezllama::inference::Dry dry; 81 | lone_subfield_at_FildeshSxpb_to_float(&dry.multiplier, sxpb, it, "multiplier"); 82 | lone_subfield_at_FildeshSxpb_to_float(&dry.base, sxpb, it, "base"); 83 | lone_subfield_at_FildeshSxpb_to_unsigned(&dry.allowed_length, sxpb, it, "allowed_length"); 84 | lone_subfield_at_FildeshSxpb_to_unsigned(&dry.window_length, sxpb, it, "window_length"); 85 | adjust_via.emplace(dry); 86 | } 87 | else if (name == "min_p") { 88 | adjust_via.emplace( 89 | float_value_at_FildeshSxpb(sxpb, it)); 90 | } 91 | else if (name == "penalize_with") { 92 | rendezllama::inference::PenalizeWith penalize_with; 93 | lone_subfield_at_FildeshSxpb_to_float(&penalize_with.frequency, sxpb, it, "frequency"); 94 | lone_subfield_at_FildeshSxpb_to_float(&penalize_with.presence, sxpb, it, "presence"); 95 | lone_subfield_at_FildeshSxpb_to_float(&penalize_with.repetition, sxpb, it, "repetition"); 96 | lone_subfield_at_FildeshSxpb_to_unsigned(&penalize_with.window_length, sxpb, it, "window_length"); 97 | adjust_via.emplace(penalize_with); 98 | } 99 | else if (name == "temperature") { 100 | adjust_via.emplace( 101 | float_value_at_FildeshSxpb(sxpb, it)); 102 | } 103 | else if (name == "top_k") { 104 | adjust_via.emplace( 105 | unsigned_value_at_FildeshSxpb(sxpb, it)); 106 | } 107 | else if (name == "top_p") { 108 | adjust_via.emplace( 109 | float_value_at_FildeshSxpb(sxpb, it)); 110 | } 111 | else if (name == "typical_p") { 112 | adjust_via.emplace( 113 | float_value_at_FildeshSxpb(sxpb, it)); 114 | } 115 | else if (name == "xtc") { 116 | rendezllama::inference::Xtc xtc; 117 | lone_subfield_at_FildeshSxpb_to_float(&xtc.probability, sxpb, it, "probability"); 118 | lone_subfield_at_FildeshSxpb_to_float(&xtc.threshold, sxpb, it, "threshold"); 119 | adjust_via.emplace(xtc); 120 | } 121 | else { 122 | return false; 123 | } 124 | return true; 125 | } 126 | 127 | bool 128 | rendezllama::inference::populate_PickVia( 129 | rendezllama::inference::PickVia& pick_via, 130 | const FildeshSxpb* sxpb, 131 | FildeshSxpbIT it) 132 | { 133 | const FildeshSxpbIT mirostat_it = lookup_subfield_at_FildeshSxpb(sxpb, it, "mirostat"); 134 | if (!nullish_FildeshSxpbIT(mirostat_it)) { 135 | rendezllama::inference::Mirostat mirostat; 136 | if (!lone_subfield_at_FildeshSxpb_to_unsigned(&mirostat.version, sxpb, mirostat_it, "version")) { 137 | mirostat.version = 2; 138 | } 139 | lone_subfield_at_FildeshSxpb_to_float(&mirostat.tau, sxpb, mirostat_it, "tau"); 140 | lone_subfield_at_FildeshSxpb_to_float(&mirostat.eta, sxpb, mirostat_it, "eta"); 141 | pick_via = mirostat; 142 | return true; 143 | } 144 | else { 145 | rendezllama::inference::Probability probability; 146 | pick_via = probability; 147 | return true; 148 | } 149 | return false; 150 | } 151 | 152 | bool 153 | rendezllama::inference::populate_InferVia( 154 | InferVia& infer_via, 155 | FildeshSxpb* sxpb, 156 | FildeshSxpbIT it) 157 | { 158 | if (nullish_FildeshSxpbIT(it)) { 159 | return false; 160 | } 161 | const FildeshSxpbIT sampling_it = lookup_subfield_at_FildeshSxpb(sxpb, it, "sampling"); 162 | Sampling sampling; 163 | if (!nullish_FildeshSxpbIT(sampling_it)) { 164 | unsigned seed = 0; 165 | if (lone_subfield_at_FildeshSxpb_to_unsigned(&seed, sxpb, sampling_it, "seed")) { 166 | sampling.seed = static_cast(INT_MAX & seed); 167 | } 168 | 169 | it = lookup_subfield_at_FildeshSxpb(sxpb, sampling_it, "adjust_thru"); 170 | for (it = first_at_FildeshSxpb(sxpb, it); !nullish_FildeshSxpbIT(it); 171 | it = next_at_FildeshSxpb(sxpb, it)) { 172 | AdjustVia adjust_via; 173 | if (populate_AdjustVia(adjust_via, sxpb, it)) { 174 | sampling.adjust_thru.push_back(adjust_via); 175 | } 176 | } 177 | 178 | FildeshSxpbIT pick_it = lookup_subfield_at_FildeshSxpb(sxpb, sampling_it, "pick_via"); 179 | if (!nullish_FildeshSxpbIT(pick_it)) { 180 | populate_PickVia(sampling.pick_via, sxpb, pick_it); 181 | } 182 | else { 183 | Probability probability; 184 | sampling.pick_via = probability; 185 | } 186 | 187 | infer_via = sampling; 188 | return true; 189 | } 190 | return false; 191 | } 192 | -------------------------------------------------------------------------------- /src/language/inference_schema.hh: -------------------------------------------------------------------------------- 1 | #ifndef RENDEZLLAMA_LANGUAGE_INFERENCE_SCHEMA_HH_ 2 | #define RENDEZLLAMA_LANGUAGE_INFERENCE_SCHEMA_HH_ 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | namespace rendezllama { 10 | namespace inference { 11 | 12 | struct Dry { 13 | float multiplier = 0.0; 14 | float base = 0.0; 15 | unsigned allowed_length = 0; 16 | unsigned window_length = 0; 17 | }; 18 | 19 | struct PenalizeWith { 20 | unsigned window_length = 0; 21 | float repetition = 1.0f; 22 | float frequency = 0.0f; 23 | float presence = 0.0f; 24 | }; 25 | 26 | struct Xtc { 27 | float threshold = 0.15f; 28 | float probability = 1.0f; 29 | }; 30 | 31 | struct AdjustViaKind { 32 | enum E : std::size_t { 33 | none, 34 | dry, 35 | min_p, 36 | penalize_with, 37 | temperature, 38 | top_k, 39 | top_p, 40 | typical_p, 41 | xtc, 42 | }; 43 | }; 44 | 45 | typedef std::variant< 46 | std::monostate, 47 | Dry, 48 | float, // min_p 49 | PenalizeWith, 50 | float, // temperature 51 | unsigned, // top_k 52 | float, // top_p 53 | float, // typical_p 54 | Xtc 55 | > AdjustVia; 56 | 57 | struct Mirostat { 58 | unsigned version = 2; 59 | float tau = 5.0f; 60 | float eta = 0.1f; 61 | }; 62 | 63 | struct Probability {}; 64 | 65 | typedef std::variant< 66 | std::monostate, 67 | Mirostat, 68 | Probability 69 | > PickVia; 70 | 71 | struct Sampling { 72 | int seed = -1; 73 | std::vector adjust_thru; 74 | PickVia pick_via; 75 | }; 76 | 77 | typedef std::variant< 78 | std::monostate, 79 | Sampling 80 | > InferVia; 81 | 82 | bool 83 | populate_AdjustVia( 84 | AdjustVia& adjust_via, 85 | FildeshSxpb* sxpb, 86 | FildeshSxpbIT it); 87 | bool 88 | populate_PickVia( 89 | PickVia& pick_via, 90 | const FildeshSxpb* sxpb, 91 | FildeshSxpbIT it); 92 | bool 93 | populate_InferVia( 94 | InferVia& infer_via, 95 | FildeshSxpb* sxpb, 96 | FildeshSxpbIT it); 97 | 98 | } // namespace inference 99 | 100 | const FildeshSxprotoField* inference_sxproto_schema(); 101 | 102 | } // namespace rendezllama 103 | #endif 104 | -------------------------------------------------------------------------------- /src/language/language_schema.cc: -------------------------------------------------------------------------------- 1 | #include "src/language/language_schema.hh" 2 | 3 | #include 4 | #include 5 | 6 | static FildeshSxprotoField special_token_message[] = { 7 | {"alias", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 8 | {"name", FILL_DEFAULT_FildeshSxprotoField_ALIAS}, 9 | {"candidates", FILL_DEFAULT_FildeshSxprotoField_STRINGS}, 10 | }; 11 | static FildeshSxprotoField substitution_message[] = { 12 | {"protagonist_alias", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 13 | {"confidant_alias", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 14 | {"bos_token_alias", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 15 | {"eos_token_alias", FILL_FildeshSxprotoField_STRING(1, INT_MAX)}, 16 | {"special_tokens", FILL_FildeshSxprotoField_MESSAGES(special_token_message)}, 17 | }; 18 | 19 | const FildeshSxprotoField* rendezllama::language_sxproto_schema() { 20 | static FildeshSxprotoField toplevel_fields[] = { 21 | {"infer_via", FILL_DEFAULT_FildeshSxprotoField_ALIAS}, 22 | {"substitution", FILL_FildeshSxprotoField_MESSAGE(substitution_message)}, 23 | }; 24 | DECLARE_TOPLEVEL_FildeshSxprotoField(schema, toplevel_fields); 25 | if (!schema->name) { 26 | FildeshSxprotoField tmp_field; 27 | tmp_field = *rendezllama::inference_sxproto_schema(); 28 | tmp_field.name = toplevel_fields[0].name; 29 | tmp_field.tag_id = toplevel_fields[0].tag_id; 30 | toplevel_fields[0] = tmp_field; 31 | lone_toplevel_initialization_FildeshSxprotoField(schema); 32 | } 33 | return schema; 34 | } 35 | 36 | bool 37 | rendezllama::language::populate_Substitution( 38 | Substitution& substitution, 39 | FildeshSxpb* sxpb, 40 | FildeshSxpbIT it) 41 | { 42 | if (nullish_FildeshSxpbIT(it)) { 43 | return true; 44 | } 45 | const char* tmp = NULL; 46 | if (lone_subfield_at_FildeshSxpb_to_str(&tmp, sxpb, it, "protagonist_alias")) { 47 | substitution.protagonist_alias = tmp; 48 | } 49 | if (lone_subfield_at_FildeshSxpb_to_str(&tmp, sxpb, it, "confidant_alias")) { 50 | substitution.confidant_alias = tmp; 51 | } 52 | if (lone_subfield_at_FildeshSxpb_to_str(&tmp, sxpb, it, "bos_token_alias")) { 53 | substitution.bos_token_alias = tmp; 54 | } 55 | if (lone_subfield_at_FildeshSxpb_to_str(&tmp, sxpb, it, "eos_token_alias")) { 56 | substitution.eos_token_alias = tmp; 57 | } 58 | FildeshSxpbIT special_it = lookup_subfield_at_FildeshSxpb(sxpb, it, "special_tokens"); 59 | if (!nullish_FildeshSxpbIT(special_it)) { 60 | for (special_it = first_at_FildeshSxpb(sxpb, special_it); !nullish_FildeshSxpbIT(special_it); 61 | special_it = next_at_FildeshSxpb(sxpb, special_it)) { 62 | auto& special = substitution.special_tokens.emplace_back(); 63 | if (lone_subfield_at_FildeshSxpb_to_str(&tmp, sxpb, special_it, "alias")) { 64 | special.alias = tmp; 65 | } 66 | assert(!special.alias.empty()); 67 | FildeshSxpbIT candidate_it = lookup_subfield_at_FildeshSxpb(sxpb, special_it, "candidates"); 68 | if (nullish_FildeshSxpbIT(candidate_it)) { 69 | special.candidates.push_back(special.alias); 70 | } 71 | else { 72 | for (candidate_it = first_at_FildeshSxpb(sxpb, candidate_it); 73 | !nullish_FildeshSxpbIT(candidate_it); 74 | candidate_it = next_at_FildeshSxpb(sxpb, candidate_it)) { 75 | special.candidates.push_back(str_value_at_FildeshSxpb(sxpb, candidate_it)); 76 | } 77 | } 78 | } 79 | } 80 | return true; 81 | } 82 | 83 | bool 84 | rendezllama::language::populate_Language( 85 | Language& language, 86 | FildeshSxpb* sxpb, 87 | FildeshSxpbIT it) 88 | { 89 | if (nullish_FildeshSxpbIT(it)) { 90 | return true; 91 | } 92 | FildeshSxpbIT sub_it; 93 | 94 | sub_it = lookup_subfield_at_FildeshSxpb(sxpb, it, "substitution"); 95 | if (!nullish_FildeshSxpbIT(sub_it)) { 96 | populate_Substitution(language.substitution, sxpb, sub_it); 97 | } 98 | 99 | sub_it = lookup_subfield_at_FildeshSxpb(sxpb, it, "infer_via"); 100 | rendezllama::inference::populate_InferVia(language.infer_via, sxpb, sub_it); 101 | 102 | return true; 103 | } 104 | -------------------------------------------------------------------------------- /src/language/language_schema.hh: -------------------------------------------------------------------------------- 1 | #ifndef RENDEZLLAMA_LANGUAGE_LANGUAGE_SCHEMA_HH_ 2 | #define RENDEZLLAMA_LANGUAGE_LANGUAGE_SCHEMA_HH_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | #include "src/language/inference_schema.hh" 11 | 12 | namespace rendezllama { 13 | namespace language { 14 | 15 | struct SpecialToken { 16 | std::string alias; 17 | std::optional name; 18 | std::vector candidates; 19 | }; 20 | 21 | struct Substitution { 22 | std::string protagonist_alias; 23 | std::string confidant_alias; 24 | std::string bos_token_alias; 25 | std::string eos_token_alias; 26 | std::vector special_tokens; 27 | }; 28 | 29 | struct Language { 30 | Substitution substitution; 31 | rendezllama::inference::InferVia infer_via; 32 | }; 33 | 34 | bool 35 | populate_Substitution( 36 | Substitution& substitution, 37 | FildeshSxpb* sxpb, 38 | FildeshSxpbIT it); 39 | bool 40 | populate_Language( 41 | Language& language, 42 | FildeshSxpb* sxpb, 43 | FildeshSxpbIT it); 44 | 45 | } // namespace language 46 | 47 | const FildeshSxprotoField* language_sxproto_schema(); 48 | 49 | } // namespace rendezllama 50 | #endif 51 | -------------------------------------------------------------------------------- /src/language/vocabulary.cc: -------------------------------------------------------------------------------- 1 | #include "src/language/vocabulary.hh" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #include "llama.h" 11 | 12 | using rendezllama::Vocabulary; 13 | typedef Vocabulary::Token_id Token_id; 14 | 15 | Vocabulary::Vocabulary(const llama_model* model) 16 | { 17 | if (!model) {return;} 18 | vocab_ = llama_model_get_vocab(model); 19 | 20 | boundary_prefix_ = "☺"; 21 | std::string text = boundary_prefix_ + '\n'; 22 | std::vector tokens(text.size()+1); 23 | int n = llama_tokenize( 24 | vocab_, 25 | text.data(), text.size(), 26 | tokens.data(), tokens.size(), 27 | /*add_bos=*/false, 28 | /*special=*/false); 29 | assert(n >= 2 && "need to tokenize boundary prefix"); 30 | newline_token_id_ = tokens[n-1]; 31 | boundary_prefix_tokens_.assign(tokens.begin(), tokens.begin()+(n-1)); 32 | } 33 | 34 | Token_id Vocabulary::bos_token_id() const { 35 | if (!vocab_) {return 0;} 36 | return llama_vocab_bos(vocab_); 37 | } 38 | Token_id Vocabulary::eos_token_id() const { 39 | if (!vocab_) {return 0;} 40 | return llama_vocab_eos(vocab_); 41 | } 42 | Token_id Vocabulary::newline_token_id() const { 43 | if (!vocab_) {return 0;} 44 | return newline_token_id_; 45 | } 46 | 47 | unsigned Vocabulary::cardinality() const { 48 | if (!vocab_) {return 1;} 49 | return llama_vocab_n_tokens(vocab_); 50 | } 51 | 52 | char Vocabulary::last_char_of(Token_id token_id) const { 53 | fildesh::ostringstream oss; 54 | this->detokenize_to(oss.c_struct(), token_id); 55 | const std::string_view s = oss.view(); 56 | if (!s.empty()) { 57 | return s[s.size()-1]; 58 | } 59 | return '\0'; 60 | } 61 | 62 | void 63 | Vocabulary::detokenize_to(FildeshO* out, Token_id token_id) const 64 | { 65 | for (const auto& sr : special_tokens_) { 66 | if (sr.token_id == token_id) { 67 | *out << sr.alias; 68 | return; 69 | } 70 | } 71 | const size_t attempt_size = allocated_size_of_FildeshO(out) - out->size; 72 | char* s = grow_FildeshO(out, attempt_size); 73 | 74 | int n = llama_token_to_piece( 75 | vocab_, 76 | token_id, 77 | s, attempt_size, 78 | /*lstrip=*/0, 79 | /*special=*/false); 80 | if (n >= 0) { 81 | out->size -= (attempt_size - n); 82 | } else { 83 | n = -n; 84 | out->size -= attempt_size; 85 | s = grow_FildeshO(out, n); 86 | n = llama_token_to_piece( 87 | vocab_, 88 | token_id, 89 | s, n, 90 | /*lstrip=*/0, 91 | /*special=*/false); 92 | } 93 | } 94 | 95 | void 96 | Vocabulary::detokenize_to(std::ostream& out, const Token_id* ids, size_t n) const 97 | { 98 | fildesh::ostreambuf* outbuf = dynamic_cast(out.rdbuf()); 99 | if (outbuf) { 100 | this->detokenize_to(outbuf->c_struct(), ids, n); 101 | } 102 | else { 103 | fildesh::ostringstream oss; 104 | this->detokenize_to(oss.c_struct(), ids, n); 105 | out << oss.view(); 106 | } 107 | } 108 | 109 | Token_id 110 | Vocabulary::tokenize_special(std::string_view s) const 111 | { 112 | for (auto& sr : special_tokens_) { 113 | if (sr.alias == s) { 114 | return sr.token_id; 115 | } 116 | } 117 | Token_id token_id = 0; 118 | int n = llama_tokenize( 119 | vocab_, 120 | s.data(), s.size(), 121 | &token_id, 1, 122 | /*add_bos=*/false, 123 | /*special=*/true); 124 | if (n != 1) { 125 | token_id = Vocabulary::null_token_id; 126 | } 127 | return token_id; 128 | } 129 | 130 | static 131 | void 132 | tokenize_append( 133 | std::vector& tokens, 134 | std::string_view text, 135 | const llama_vocab* vocab, 136 | const std::string_view boundary_prefix, 137 | const std::vector& boundary_prefix_tokens, 138 | std::string& tmp_s) 139 | { 140 | if (text.empty()) {return;} 141 | tmp_s = boundary_prefix; 142 | tmp_s += text; 143 | size_t offset = tokens.size(); 144 | tokens.resize(offset + tmp_s.size() + 1); 145 | int n = llama_tokenize( 146 | vocab, 147 | tmp_s.data(), tmp_s.size(), 148 | tokens.data()+offset, tokens.size()-offset, 149 | /*add_bos=*/false, 150 | /*special=*/false); 151 | assert(n > 0); 152 | assert((size_t)n > boundary_prefix_tokens.size()); 153 | tokens.resize(offset + (size_t)n); 154 | tokens.erase( 155 | tokens.begin()+offset, 156 | tokens.begin()+offset+boundary_prefix_tokens.size()); 157 | } 158 | 159 | void 160 | Vocabulary::tokenize_to( 161 | std::vector& tokens, 162 | std::string_view text) const 163 | { 164 | tokens.clear(); 165 | std::string_view::size_type end = std::string_view::npos; 166 | std::vector next_indices(special_tokens_.size(), end); 167 | for (size_t i = 0; i < next_indices.size(); ++i) { 168 | next_indices[i] = text.find(special_tokens_[i].alias); 169 | if (next_indices[i] < end) { 170 | end = next_indices[i]; 171 | } 172 | } 173 | std::string tmp_s; 174 | std::string_view::size_type beg = 0; 175 | while (end != std::string_view::npos) { 176 | tokenize_append(tokens, text.substr(beg, end-beg), vocab_, 177 | boundary_prefix_, boundary_prefix_tokens_, tmp_s); 178 | beg = end; 179 | end = std::string_view::npos; 180 | for (size_t i = 0; i < next_indices.size(); ++i) { 181 | if (beg == next_indices[i]) { 182 | tokens.push_back(special_tokens_[i].token_id); 183 | beg += special_tokens_[i].alias.size(); 184 | break; 185 | } 186 | } 187 | for (size_t i = 0; i < next_indices.size(); ++i) { 188 | if (next_indices[i] < beg) { 189 | next_indices[i] = text.find(special_tokens_[i].alias, beg); 190 | } 191 | if (next_indices[i] < end) { 192 | end = next_indices[i]; 193 | } 194 | } 195 | } 196 | tokenize_append(tokens, text.substr(beg), vocab_, 197 | boundary_prefix_, boundary_prefix_tokens_, tmp_s); 198 | } 199 | 200 | void 201 | Vocabulary::assign_substitution(std::string_view alias, Token_id token_id) 202 | { 203 | assert(!alias.empty()); 204 | if (token_id == this->bos_token_id()) { 205 | bos_token_alias_ = alias; 206 | } 207 | if (token_id == this->eos_token_id()) { 208 | eos_token_alias_ = alias; 209 | } 210 | for (auto& sr : special_tokens_) { 211 | if (sr.alias == alias) { 212 | sr.token_id = token_id; 213 | return; 214 | } 215 | } 216 | SubstitutionRule sr; 217 | sr.alias = alias; 218 | sr.token_id = token_id; 219 | special_tokens_.push_back(sr); 220 | } 221 | 222 | rendezllama::GlobalScope::GlobalScope() { 223 | llama_backend_init(); 224 | } 225 | 226 | rendezllama::GlobalScope::~GlobalScope() { 227 | llama_backend_free(); 228 | } 229 | -------------------------------------------------------------------------------- /src/language/vocabulary.hh: -------------------------------------------------------------------------------- 1 | #ifndef RENDEZLLAMA_LANGUAGE_VOCABULARY_HH_ 2 | #define RENDEZLLAMA_LANGUAGE_VOCABULARY_HH_ 3 | #include 4 | #include 5 | #include 6 | 7 | struct FildeshO; 8 | struct llama_model; 9 | struct llama_vocab; 10 | 11 | namespace rendezllama { 12 | 13 | class Vocabulary { 14 | public: 15 | typedef int Token_id; 16 | static const Token_id null_token_id = -1; 17 | 18 | public: 19 | explicit Vocabulary(const llama_model* model); 20 | 21 | Token_id bos_token_id() const; 22 | Token_id eos_token_id() const; 23 | Token_id newline_token_id() const; 24 | unsigned cardinality() const; 25 | 26 | char last_char_of(Token_id token_id) const; 27 | 28 | void detokenize_to(FildeshO* out, Token_id token_id) const; 29 | void detokenize_to(FildeshO* out, const Token_id* ids, size_t n) const { 30 | for (size_t i = 0; i < n; ++i) { 31 | this->detokenize_to(out, ids[i]); 32 | } 33 | } 34 | void detokenize_to(std::ostream& out, const Token_id* ids, size_t n) const; 35 | void detokenize_to(std::ostream& out, Token_id token_id) const { 36 | this->detokenize_to(out, &token_id, 1); 37 | } 38 | 39 | Token_id tokenize_special(std::string_view s) const; 40 | void tokenize_to(std::vector& tokens, std::string_view text) const; 41 | 42 | void assign_substitution(std::string_view alias, Token_id token_id); 43 | std::string_view bos_token_alias() const { 44 | return bos_token_alias_; 45 | } 46 | std::string_view eos_token_alias() const { 47 | return eos_token_alias_; 48 | } 49 | 50 | private: 51 | const llama_vocab* vocab_ = nullptr; 52 | Token_id newline_token_id_; 53 | 54 | std::string bos_token_alias_; 55 | std::string eos_token_alias_; 56 | struct SubstitutionRule { std::string alias; Token_id token_id; }; 57 | std::vector special_tokens_; 58 | 59 | std::string boundary_prefix_; 60 | std::vector boundary_prefix_tokens_; 61 | }; 62 | 63 | class GlobalScope { 64 | public: 65 | GlobalScope(); 66 | ~GlobalScope(); 67 | }; 68 | 69 | } // namespace rendezllama 70 | #endif 71 | 72 | -------------------------------------------------------------------------------- /src/tokenize/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(tokenize 2 | "tokenize_main.cc" 3 | "${CMAKE_SOURCE_DIR}/src/language/vocabulary.cc" 4 | "${CMAKE_SOURCE_DIR}/src/language/vocabulary.hh" 5 | ) 6 | target_link_libraries(tokenize PRIVATE 7 | ${Fildesh_LIBRARIES} 8 | ${LlamaCpp_LIBRARIES} 9 | ) 10 | -------------------------------------------------------------------------------- /src/tokenize/tokenize_main.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "llama.h" 5 | 6 | #include "src/language/vocabulary.hh" 7 | 8 | using rendezllama::Vocabulary; 9 | 10 | int main(int argc, char** argv) 11 | { 12 | rendezllama::GlobalScope rendezllama_global_scope; 13 | const char* count_filename = "/dev/null"; 14 | const char* model_filename = NULL; 15 | const char* prompt_filename = "-"; 16 | const char* token_filename = "/dev/null"; 17 | int exstatus = 0; 18 | int argi; 19 | for (argi = 1; exstatus == 0 && argi < argc; ++argi) { 20 | if (argi + 1 == argc) { 21 | exstatus = 64; 22 | } 23 | else if (0 == strcmp("--model", argv[argi])) { 24 | argi += 1; 25 | model_filename = argv[argi]; 26 | } 27 | else if (0 == strcmp("--x-prompt", argv[argi])) { 28 | argi += 1; 29 | prompt_filename = argv[argi]; 30 | } 31 | else if (0 == strcmp("--o-count", argv[argi])) { 32 | argi += 1; 33 | count_filename = argv[argi]; 34 | } 35 | else if (0 == strcmp("-o", argv[argi])) { 36 | argi += 1; 37 | token_filename = argv[argi]; 38 | } 39 | else { 40 | exstatus = 64; 41 | } 42 | } 43 | 44 | if (exstatus == 0 && !model_filename) { 45 | fildesh_log_error("Please provide a model file with --model."); 46 | exstatus = 64; 47 | } 48 | if (exstatus != 0) { 49 | return exstatus; 50 | } 51 | 52 | // Match original LLaMA tokenizer behavior. 53 | std::string prompt = " "; 54 | 55 | { 56 | std::string content; 57 | if (fildesh::slurp_file_to_string(content, prompt_filename)) { 58 | prompt += content; 59 | } 60 | } 61 | 62 | llama_model_params model_params = llama_model_default_params(); 63 | model_params.vocab_only = true; 64 | llama_model* model = llama_model_load_from_file(model_filename, model_params); 65 | 66 | std::vector tokens; 67 | Vocabulary vocabulary(model); 68 | vocabulary.assign_substitution("", vocabulary.bos_token_id()); 69 | vocabulary.assign_substitution("", vocabulary.eos_token_id()); 70 | tokens.push_back(vocabulary.bos_token_id()); 71 | vocabulary.tokenize_to(tokens, prompt); 72 | 73 | if (tokens.size() == 0) { 74 | exstatus = 1; 75 | } 76 | 77 | if (exstatus == 0) { 78 | fildesh::ofstream out(count_filename); 79 | out << tokens.size() << '\n'; 80 | } 81 | if (exstatus == 0) { 82 | fildesh::ofstream out(token_filename); 83 | for (auto token_id : tokens) { 84 | if (token_id == vocabulary.newline_token_id()) { 85 | out << "\\n"; 86 | } 87 | else { 88 | vocabulary.detokenize_to(out, token_id); 89 | } 90 | out << '\n'; 91 | } 92 | } 93 | if (model) {llama_model_free(model);} 94 | return exstatus; 95 | } 96 | -------------------------------------------------------------------------------- /test/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(LlamaCpp_VOCAB_MODEL "${LlamaCpp_SOURCE_DIR}/models/ggml-vocab-llama-spm.gguf") 2 | 3 | add_subdirectory(chat) 4 | add_subdirectory(example) 5 | add_subdirectory(language) 6 | -------------------------------------------------------------------------------- /test/chat/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | 2 | add_executable(chat_guide_test 3 | "guide_test.cc" 4 | "${PROJECT_SOURCE_DIR}/src/chat/guide.cc" 5 | "${PROJECT_SOURCE_DIR}/src/chat/guide.hh" 6 | "${PROJECT_SOURCE_DIR}/src/chat/trajectory.cc" 7 | "${PROJECT_SOURCE_DIR}/src/chat/trajectory.hh" 8 | "${PROJECT_SOURCE_DIR}/src/language/vocabulary.cc" 9 | "${PROJECT_SOURCE_DIR}/src/language/vocabulary.hh" 10 | ) 11 | target_link_libraries(chat_guide_test PRIVATE 12 | chat_opt_cc 13 | ${LlamaCpp_LIBRARIES} 14 | ) 15 | add_test(NAME chat_guide_test COMMAND 16 | chat_guide_test "${LlamaCpp_VOCAB_MODEL}" 17 | ) 18 | 19 | add_executable(chat_opt_test 20 | "opt_test.cc" 21 | ) 22 | target_link_libraries(chat_opt_test PRIVATE 23 | chat_opt_cc 24 | ) 25 | add_test(NAME chat_opt_test COMMAND 26 | chat_opt_test 27 | ) 28 | 29 | add_executable(chat_trajectory_test 30 | "trajectory_test.cc" 31 | "${PROJECT_SOURCE_DIR}/src/chat/trajectory.cc" 32 | "${PROJECT_SOURCE_DIR}/src/chat/trajectory.hh" 33 | "${PROJECT_SOURCE_DIR}/src/language/vocabulary.cc" 34 | "${PROJECT_SOURCE_DIR}/src/language/vocabulary.hh" 35 | ) 36 | target_link_libraries(chat_trajectory_test PRIVATE 37 | ${Fildesh_LIBRARIES} 38 | ${LlamaCpp_LIBRARIES} 39 | ) 40 | add_test(NAME chat_trajectory_test COMMAND 41 | chat_trajectory_test "${LlamaCpp_VOCAB_MODEL}" 42 | ) 43 | -------------------------------------------------------------------------------- /test/chat/guide_test.cc: -------------------------------------------------------------------------------- 1 | #include "src/chat/guide.hh" 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "llama.h" 8 | 9 | #include "src/chat/opt.hh" 10 | #include "src/chat/trajectory.hh" 11 | #include "src/language/vocabulary.hh" 12 | 13 | using rendezllama::ChatOptions; 14 | using rendezllama::ChatGuide; 15 | using rendezllama::ChatTrajectory; 16 | using rendezllama::Vocabulary; 17 | 18 | 19 | static 20 | void 21 | truncate_detokenize_rolling_to( 22 | fildesh::ostringstream& oss, 23 | ChatTrajectory& traj, 24 | const Vocabulary& vocab) 25 | { 26 | oss.truncate(); 27 | vocab.detokenize_to(oss, traj.tokens().data() + traj.priming_token_count(), 28 | traj.token_count() - traj.priming_token_count()); 29 | } 30 | 31 | 32 | static 33 | void 34 | the_test(llama_model* model) 35 | { 36 | FildeshX in[1]; 37 | fildesh::ostringstream oss; 38 | Vocabulary vocab(model); 39 | ChatTrajectory traj(vocab.bos_token_id()); 40 | ChatOptions opt; 41 | ChatGuide guide(vocab, traj, opt); 42 | bool good; 43 | 44 | // There's nothing to erase, but it still returns true! 45 | assert(guide.maybe_erase_trailing_message_prefix()); 46 | // Again for good measure. 47 | assert(guide.maybe_erase_trailing_message_prefix()); 48 | 49 | *in = FildeshX_of_strlit("\ 50 | ((chat_prefixes)\n\ 51 | (m (prefix \"A:\") (suffix \"\\n###\\n\"))\n\ 52 | (m (prefix \"B:\"))\n\ 53 | (m (prefix \"C:\"))\n\ 54 | (m (prefix \"D:\") (suffix \"\\n\"))\n\ 55 | )\n\ 56 | (language\n\ 57 | (substitution\n\ 58 | (eos_token_alias \"\")\n\ 59 | )\n\ 60 | )\n\ 61 | "); 62 | good = rendezllama::slurp_sxpb_initialize_options_close_FildeshX(in, opt, ""); 63 | assert(good); 64 | assert(opt.substitution.eos_token_alias == ""); 65 | vocab.assign_substitution(opt.substitution.eos_token_alias, vocab.eos_token_id()); 66 | 67 | guide.yield_turn(); 68 | truncate_detokenize_rolling_to(oss, traj, vocab); 69 | assert(oss.view() == "A:"); 70 | 71 | guide.yield_turn(2); 72 | truncate_detokenize_rolling_to(oss, traj, vocab); 73 | assert(oss.view() == "C:"); 74 | 75 | traj.tokenize_append(" Yo.", vocab); 76 | guide.yield_turn("D: Sup?"); 77 | truncate_detokenize_rolling_to(oss, traj, vocab); 78 | assert(oss.view() == "C: Yo.\nD: Sup?"); 79 | 80 | assert(!guide.maybe_yield_turn()); 81 | traj.tokenize_append("\nOh hi A.\n", vocab); 82 | assert(!guide.maybe_yield_turn()); 83 | traj.tokenize_append("Sup?", vocab); 84 | assert(guide.maybe_yield_turn()); 85 | truncate_detokenize_rolling_to(oss, traj, vocab); 86 | assert(oss.view() == "C: Yo.\nD: Sup?\nOh hi A.\nSup?\nA:"); 87 | 88 | assert(!guide.maybe_yield_turn()); 89 | traj.tokenize_append(" Oi!\n", vocab); 90 | assert(guide.maybe_yield_turn()); 91 | truncate_detokenize_rolling_to(oss, traj, vocab); 92 | assert(oss.view() == "C: Yo.\nD: Sup?\nOh hi A.\nSup?\nA: Oi!\n###\nB:"); 93 | 94 | assert(guide.maybe_erase_trailing_message_prefix()); 95 | truncate_detokenize_rolling_to(oss, traj, vocab); 96 | assert(oss.view() == "C: Yo.\nD: Sup?\nOh hi A.\nSup?\nA: Oi!\n###\n"); 97 | 98 | assert(guide.maybe_erase_trailing_message_suffix()); 99 | truncate_detokenize_rolling_to(oss, traj, vocab); 100 | assert(oss.view() == "C: Yo.\nD: Sup?\nOh hi A.\nSup?\nA: Oi!"); 101 | 102 | // Reset and test that a poorly-tokenized EOS gets detected and fixed, 103 | // even when EOS isn't part of the message suffix. 104 | traj.erase_all_at(1); 105 | guide.begin_turn(1); 106 | traj.tokenize_append(" Hello there!", vocab); 107 | auto expect_suffix_index = traj.token_count(); 108 | traj.tokenize_append("\n", vocab); 111 | assert(guide.maybe_yield_turn()); 112 | assert(traj.token_at(expect_suffix_index) == vocab.newline_token_id()); 113 | truncate_detokenize_rolling_to(oss, traj, vocab); 114 | assert(oss.view() == "B: Hello there!\nC:"); 115 | } 116 | 117 | 118 | int main(int argc, char** argv) 119 | { 120 | assert(argc == 2 && "need model filename"); 121 | 122 | rendezllama::GlobalScope rendezllama_global_scope; 123 | llama_model_params model_params = llama_model_default_params(); 124 | model_params.vocab_only = true; 125 | llama_model* model = llama_model_load_from_file(argv[1], model_params); 126 | assert(model); 127 | 128 | the_test(model); 129 | 130 | llama_model_free(model); 131 | return 0; 132 | } 133 | -------------------------------------------------------------------------------- /test/chat/opt_test.cc: -------------------------------------------------------------------------------- 1 | #include "src/chat/opt.hh" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "src/chat/opt_schema.hh" 10 | 11 | static 12 | void 13 | chat_prefixes_parse_test() 14 | { 15 | rendezllama::ChatOptions opt; 16 | FildeshX in[1]; 17 | bool all_good; 18 | 19 | *in = FildeshX_of_strlit( 20 | "((chat_prefixes) \ 21 | \"{{user}}:\" \ 22 | \"{{char}} feels:\" \ 23 | \"{{char}} wants:\" \ 24 | \"{{char}} plans:\" \ 25 | \"{{char}}:\" \ 26 | )"); 27 | all_good = rendezllama::slurp_sxpb_dynamic_options_close_FildeshX(in, opt); 28 | assert(all_good); 29 | assert(opt.message_opts.size() == 5); 30 | for (const auto& message_opt : opt.message_opts) { 31 | assert(!message_opt.given_prefix.empty()); 32 | assert(!message_opt.prefix.empty()); 33 | } 34 | opt.protagonist = "User"; 35 | opt.substitution.protagonist_alias = "{{user}}"; 36 | opt.substitution.confidant_alias = "{{char}}"; 37 | 38 | in->off = 0; 39 | in->size = 0; 40 | *in = FildeshX_of_strlit("(confidant \"Char\")"); 41 | all_good = rendezllama::slurp_sxpb_dynamic_options_close_FildeshX(in, opt); 42 | assert(all_good); 43 | assert(opt.message_opts.size() == 5); 44 | assert(opt.message_opts[0].prefix == "User:"); 45 | assert(opt.message_opts[1].prefix == "Char feels:"); 46 | assert(opt.message_opts[2].prefix == "Char wants:"); 47 | assert(opt.message_opts[3].prefix == "Char plans:"); 48 | assert(opt.message_opts[4].prefix == "Char:"); 49 | } 50 | 51 | static 52 | void 53 | sentence_terminals_parse_test() 54 | { 55 | rendezllama::ChatOptions opt; 56 | FildeshX in[1]; 57 | *in = FildeshX_of_strlit( 58 | "(sentence_terminals () \ 59 | \"\\n\" \ 60 | \"\\\"\" \ 61 | \".\" \ 62 | )"); 63 | bool all_good = rendezllama::slurp_sxpb_dynamic_options_close_FildeshX(in, opt); 64 | assert(all_good); 65 | assert(opt.sentence_terminals.size() == 3); 66 | // Insert 3 and expect that they add nothing new. 67 | opt.sentence_terminals.insert("\n"); 68 | opt.sentence_terminals.insert("\""); 69 | opt.sentence_terminals.insert("."); 70 | assert(opt.sentence_terminals.size() == 3); 71 | } 72 | 73 | int main() 74 | { 75 | chat_prefixes_parse_test(); 76 | sentence_terminals_parse_test(); 77 | return 0; 78 | } 79 | -------------------------------------------------------------------------------- /test/chat/trajectory_test.cc: -------------------------------------------------------------------------------- 1 | #include "src/chat/trajectory.hh" 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "llama.h" 8 | 9 | #include "src/language/vocabulary.hh" 10 | 11 | using rendezllama::ChatTrajectory; 12 | using rendezllama::Vocabulary; 13 | 14 | 15 | static 16 | void 17 | basic_test() 18 | { 19 | ChatTrajectory traj(0); 20 | assert(traj.token_count() == 1); 21 | assert(traj.last_message_prefix_id_at(0) == traj.not_a_message_prefix_id()); 22 | 23 | assert(traj.priming_token_count_ == 1); 24 | assert(traj.rfind_message_prefix_at(0) == 0); 25 | assert(traj.rfind_message_prefix_begin_at(0) == 0); 26 | assert(traj.rfind_last_message_prefix_end_at(0) == 1); 27 | assert(traj.rfind_last_message_prefix_end_at(1) == 1); 28 | 29 | traj.push_back(1); 30 | assert(traj.token_count() == 2); 31 | assert(traj.rfind_message_prefix_at(1) == 0); 32 | assert(traj.rfind_message_prefix_begin_at(1) == 0); 33 | assert(traj.rfind_last_message_prefix_end_at(0) == 1); 34 | assert(traj.rfind_last_message_prefix_end_at(1) == 1); 35 | assert(traj.rfind_last_message_prefix_end_at(2) == 1); 36 | assert(traj.last_message_prefix_id_at(2) == traj.not_a_message_prefix_id()); 37 | 38 | traj.assign_range_message_prefix_id(7, 1, 2); 39 | assert(traj.message_prefix_id_ == 7); 40 | assert(traj.rfind_message_prefix_at(1) == 1); 41 | assert(traj.rfind_message_prefix_begin_at(1) == 1); 42 | assert(traj.rfind_last_message_prefix_end_at(0) == 1); 43 | assert(traj.rfind_last_message_prefix_end_at(1) == 1); 44 | assert(traj.rfind_last_message_prefix_end_at(2) == 2); 45 | assert(traj.last_message_prefix_id_at(2) == 7); 46 | 47 | traj.erase_all_at(1); 48 | for (unsigned i = 1; i < 100; ++i) { 49 | traj.push_back(i); 50 | } 51 | assert(traj.token_count() == 100); 52 | assert(traj.find_token_at(0, 1) == 1); 53 | assert(traj.find_token_at(1, 1) == 1); 54 | assert(traj.find_token_at(2, 1) == 100); 55 | assert(traj.rfind_token_at(0, 1) == 100); 56 | assert(traj.rfind_token_at(1, 1) == 1); 57 | assert(traj.rfind_token_at(2, 1) == 1); 58 | assert(traj.rfind_token_at(100, 1) == 1); 59 | assert(traj.rfind_token_at(0, 0) == 0); 60 | 61 | for (unsigned i = 1; i < 10; ++i) { 62 | traj.assign_range_message_prefix_id(i, 10*i, 11*i+1); 63 | assert(traj.message_prefix_id_ == i); 64 | } 65 | assert(traj.message_prefix_id_ == 9); 66 | assert(traj.rfind_message_prefix_at(49) == 44); 67 | assert(traj.rfind_message_prefix_begin_at(49) == 40); 68 | assert(traj.rfind_last_message_prefix_end_at(49) == 45); 69 | assert(traj.rfind_last_message_prefix_end_at(50) == 45); 70 | assert(traj.rfind_last_message_prefix_end_at(55) == 45); 71 | assert(traj.rfind_last_message_prefix_end_at(56) == 56); 72 | assert(traj.last_message_prefix_id_at(49) == 4); 73 | assert(traj.last_message_prefix_id_at(50) == 4); 74 | assert(traj.last_message_prefix_id_at(55) == 4); 75 | assert(traj.last_message_prefix_id_at(56) == 5); 76 | 77 | // Delete last 10. 78 | traj.erase_all_at(90); 79 | assert(traj.token_count() == 90); 80 | assert(traj.message_prefix_id_ == 8); 81 | 82 | // These token counts are incremented by other code, 83 | // but we expect ChatTrajectory to decrease them during rollforget. 84 | traj.context_token_count_ = 80; 85 | traj.display_token_count_ = 70; 86 | 87 | // Delete [40..49]. 88 | traj.erase_range(40, 50); 89 | assert(traj.token_count() == 80); 90 | assert(traj.context_token_count_ == 40); 91 | assert(traj.display_token_count_ == 60); 92 | 93 | assert(traj.message_prefix_id_ == 8); 94 | assert(traj.rfind_last_message_prefix_end_at(49) == 46); 95 | assert(traj.rfind_last_message_prefix_end_at(50) == 46); 96 | assert(traj.rfind_last_message_prefix_end_at(56) == 46); 97 | assert(traj.rfind_last_message_prefix_end_at(57) == 57); 98 | 99 | // Restore [40..49]. 100 | traj.insert_all_at( 101 | 40, 102 | std::vector{40, 41, 42, 43, 44, 45, 46, 47, 48, 49}); 103 | assert(traj.token_count() == 90); 104 | assert(traj.message_prefix_id_ == 8); 105 | assert(traj.rfind_last_message_prefix_end_at(49) == 34); 106 | traj.assign_range_message_prefix_id(4, 40, 45); 107 | assert(traj.rfind_last_message_prefix_end_at(49) == 45); 108 | } 109 | 110 | 111 | static 112 | void 113 | rollforget_test(llama_model* model) 114 | { 115 | const Vocabulary vocabulary(model); 116 | ChatTrajectory traj(vocabulary.bos_token_id()); 117 | 118 | FildeshO transcript_out[1] = {DEFAULT_FildeshO}; 119 | // `traj` takes ownership and will free the memory. 120 | traj.transcript_out_ = transcript_out; 121 | 122 | traj.tokenize_append( 123 | " Transcript of a conversation between User and their Code.\n" 124 | "\n" 125 | "### Transcript Continuation\n", 126 | vocabulary); 127 | traj.priming_token_count_ = traj.token_count(); 128 | 129 | traj.tokenize_append_message_prefix(0, "User:", vocabulary); 130 | traj.tokenize_append(" Tell me all your bugs!", vocabulary); 131 | traj.tokenize_append_message_suffix("", vocabulary); 132 | traj.tokenize_append_message_prefix(1, "Code:", vocabulary); 133 | traj.tokenize_append(" I cannot.", vocabulary); 134 | traj.tokenize_append_message_suffix("", vocabulary); 135 | 136 | const unsigned expect_forget_count = ( 137 | traj.token_count() - traj.priming_token_count_); 138 | traj.tokenize_append_message_prefix(0, "User:", vocabulary); 139 | traj.tokenize_append(" Why not?", vocabulary); 140 | traj.tokenize_append_message_suffix("", vocabulary); 141 | traj.tokenize_append_message_prefix(1, "Code:", vocabulary); 142 | traj.tokenize_append( 143 | " They are not enumerable, but I can give a sample." 144 | " (1) Infinite loop on line 20." 145 | " (2) Off-by-one on line 21." 146 | " (3) Off-by-two on line 21." 147 | " (4) Segmentation fault", 148 | vocabulary); 149 | traj.tokenize_append_message_suffix("", vocabulary); 150 | traj.tokenize_append_message_prefix(0, "User:", vocabulary); 151 | traj.tokenize_append(" wtf", vocabulary); 152 | traj.tokenize_append_message_suffix("", vocabulary); 153 | 154 | const unsigned old_token_count = traj.token_count(); 155 | traj.maybe_rollforget_within_limit(traj.token_count() - 1, vocabulary); 156 | assert(traj.token_count() < old_token_count); 157 | assert(traj.token_count() == old_token_count - expect_forget_count); 158 | 159 | assert(traj.transcript_out_->size > 0); 160 | } 161 | 162 | 163 | static 164 | void 165 | suffix_test(llama_model* model) 166 | { 167 | Vocabulary vocabulary(model); 168 | ChatTrajectory traj(vocabulary.bos_token_id()); 169 | 170 | traj.tokenize_append_message_prefix(0, "User:", vocabulary); 171 | traj.tokenize_append(" blah blah blah\n\nEOS EOS\n \n ", vocabulary); 172 | assert(!traj.endswith_nonempty("EOS\n", vocabulary)); 173 | 174 | const auto old_token_count = traj.token_count(); 175 | traj.display_token_count_ = traj.token_count(); 176 | vocabulary.assign_substitution("EOS", vocabulary.eos_token_id()); 177 | traj.tokenize_append_message_suffix("EOS\n", vocabulary); 178 | assert(traj.token_count() < old_token_count); 179 | assert(traj.display_token_count_ == traj.token_count()); 180 | 181 | assert(traj.token_at(traj.token_count()-1) == vocabulary.newline_token_id()); 182 | assert(traj.token_at(traj.token_count()-2) == vocabulary.eos_token_id()); 183 | for (auto i = traj.priming_token_count_; i < traj.token_count()-2; ++i) { 184 | assert(traj.token_at(i) != vocabulary.newline_token_id()); 185 | assert(traj.token_at(i) != vocabulary.eos_token_id()); 186 | } 187 | assert(traj.endswith_nonempty("EOS\n", vocabulary)); 188 | } 189 | 190 | 191 | int main(int argc, char** argv) 192 | { 193 | assert(argc == 2 && "need model filename"); 194 | 195 | rendezllama::GlobalScope rendezllama_global_scope; 196 | llama_model_params model_params = llama_model_default_params(); 197 | model_params.vocab_only = true; 198 | llama_model* model = llama_model_load_from_file(argv[1], model_params); 199 | assert(model); 200 | 201 | basic_test(); 202 | rollforget_test(model); 203 | suffix_test(model); 204 | 205 | llama_model_free(model); 206 | return 0; 207 | } 208 | -------------------------------------------------------------------------------- /test/example/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(prompt) 2 | -------------------------------------------------------------------------------- /test/example/prompt/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(example_prompt_parse_test 2 | "parse_test.cc" 3 | ) 4 | target_link_libraries(example_prompt_parse_test PRIVATE 5 | chat_opt_cc 6 | ) 7 | add_test(NAME example_prompt_assistant_alpaca_parse_test COMMAND example_prompt_parse_test 8 | "${PROJECT_SOURCE_DIR}/example/prompt/assistant_alpaca/setting.sxpb") 9 | add_test(NAME example_prompt_assistant_chatml_parse_test COMMAND example_prompt_parse_test 10 | "${PROJECT_SOURCE_DIR}/example/prompt/assistant_chatml/setting.sxpb") 11 | add_test(NAME example_prompt_assistant_coprocess_parse_test COMMAND example_prompt_parse_test 12 | "${PROJECT_SOURCE_DIR}/example/prompt/assistant_coprocess/setting.sxpb") 13 | add_test(NAME example_prompt_assistant_gemma_parse_test COMMAND example_prompt_parse_test 14 | "${PROJECT_SOURCE_DIR}/example/prompt/assistant_gemma/setting.sxpb") 15 | add_test(NAME example_prompt_assistant_llama_parse_test COMMAND example_prompt_parse_test 16 | "${PROJECT_SOURCE_DIR}/example/prompt/assistant_llama/setting.sxpb") 17 | add_test(NAME example_prompt_assistant_mistral_parse_test COMMAND example_prompt_parse_test 18 | "${PROJECT_SOURCE_DIR}/example/prompt/assistant_mistral/setting.sxpb") 19 | add_test(NAME example_prompt_assistant_plain_parse_test COMMAND example_prompt_parse_test 20 | "${PROJECT_SOURCE_DIR}/example/prompt/assistant_plain/setting.sxpb") 21 | add_test(NAME example_prompt_assistant_vicuna_parse_test COMMAND example_prompt_parse_test 22 | "${PROJECT_SOURCE_DIR}/example/prompt/assistant_vicuna/setting.sxpb") 23 | add_test(NAME example_prompt_confidant_alpaca_parse_test COMMAND example_prompt_parse_test 24 | "${PROJECT_SOURCE_DIR}/example/prompt/confidant_alpaca/setting.sxpb") 25 | add_test(NAME example_prompt_roshambo_kira_parse_test COMMAND example_prompt_parse_test 26 | "${PROJECT_SOURCE_DIR}/example/prompt/roshambo_kira/setting.sxpb") 27 | -------------------------------------------------------------------------------- /test/example/prompt/parse_test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "src/chat/opt.hh" 6 | 7 | int main(int argc, char** argv) 8 | { 9 | bool good = true; 10 | assert(argc == 2); 11 | const char* filename = argv[1]; 12 | rendezllama::ChatOptions opt; 13 | FildeshX* in = open_FildeshXF(filename); 14 | assert(in); 15 | good = rendezllama::slurp_sxpb_initialize_options_close_FildeshX( 16 | in, opt, filename); 17 | assert(good); 18 | return (good ? 0 : 1); 19 | } 20 | -------------------------------------------------------------------------------- /test/language/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_executable(language_inference_schema_test 2 | "inference_schema_test.cc" 3 | ) 4 | target_link_libraries(language_inference_schema_test PRIVATE 5 | chat_opt_cc 6 | ) 7 | add_test(NAME language_inference_schema_test COMMAND 8 | language_inference_schema_test 9 | ) 10 | 11 | add_executable(language_schema_test 12 | "language_schema_test.cc" 13 | ) 14 | target_link_libraries(language_schema_test PRIVATE 15 | chat_opt_cc 16 | ) 17 | add_test(NAME language_schema_test COMMAND 18 | language_schema_test 19 | ) 20 | 21 | add_executable(language_vocabulary_test 22 | "vocabulary_test.cc" 23 | "${PROJECT_SOURCE_DIR}/src/language/vocabulary.cc" 24 | "${PROJECT_SOURCE_DIR}/src/language/vocabulary.hh" 25 | ) 26 | target_include_directories(language_vocabulary_test PRIVATE 27 | ${LlamaCpp_INCLUDE_DIRS} 28 | ) 29 | target_link_libraries(language_vocabulary_test PRIVATE 30 | ${Fildesh_LIBRARIES} 31 | ${LlamaCpp_LIBRARIES} 32 | ) 33 | add_test(NAME language_vocabulary_test COMMAND 34 | language_vocabulary_test "${LlamaCpp_VOCAB_MODEL}" 35 | ) 36 | -------------------------------------------------------------------------------- /test/language/inference_schema_test.cc: -------------------------------------------------------------------------------- 1 | #include "src/language/inference_schema.hh" 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include "src/chat/opt.hh" 9 | #include "src/chat/opt_schema.hh" 10 | 11 | using rendezllama::slurp_sxpb_dynamic_options_close_FildeshX; 12 | using rendezllama::inference::AdjustViaKind; 13 | using rendezllama::inference::Sampling; 14 | 15 | static 16 | void 17 | default_parse_test() 18 | { 19 | rendezllama::ChatOptions opt; 20 | { 21 | FildeshX in = FildeshX_of_strlit("(language (substitution))"); 22 | bool all_good = slurp_sxpb_dynamic_options_close_FildeshX(&in, opt); 23 | assert(all_good); 24 | } 25 | assert(std::holds_alternative(opt.infer_via)); 26 | const auto& sampling = std::get(opt.infer_via); 27 | 28 | const auto& adjust_thru = sampling.adjust_thru; 29 | assert(adjust_thru.size() == 2); 30 | auto* min_p = std::get_if(&adjust_thru[0]); 31 | assert(min_p); 32 | assert(*min_p == 0.1f); 33 | auto* temperature = std::get_if(&adjust_thru[1]); 34 | assert(temperature); 35 | assert(*temperature == 0.8f); 36 | 37 | using rendezllama::inference::Probability; 38 | assert(std::holds_alternative(sampling.pick_via)); 39 | } 40 | 41 | static 42 | void 43 | seed_parse_test() 44 | { 45 | rendezllama::ChatOptions opt; 46 | FildeshX in[1]; 47 | bool all_good; 48 | 49 | *in = FildeshX_of_strlit( 50 | "(language ((infer_via sampling) (seed 123)))"); 51 | all_good = slurp_sxpb_dynamic_options_close_FildeshX(in, opt); 52 | assert(all_good); 53 | assert(std::holds_alternative(opt.infer_via)); 54 | auto& sampling = std::get(opt.infer_via); 55 | assert(sampling.seed == 123); 56 | } 57 | 58 | static 59 | void 60 | adjust_thru_parse_test() 61 | { 62 | fildesh::ostringstream oss; 63 | oss 64 | << "(language\n" 65 | << " ((infer_via sampling)\n" 66 | << " (adjust_thru (())\n" 67 | << " (dry\n" 68 | << " (multiplier 0.5)\n" 69 | << " (base 0.25)\n" 70 | << " (allowed_length 100)\n" 71 | << " (window_length 1000)\n" 72 | << " )\n" 73 | << " (min_p 0.25)\n" 74 | << " (penalize_with\n" 75 | << " (window_length 1000)\n" 76 | << " (repetition 1.5)\n" 77 | << " (frequency 0.5)\n" 78 | << " (presence 0.25)\n" 79 | << " )\n" 80 | << " (top_k 123)\n" 81 | << " (top_p 0.75)\n" 82 | << " (typical_p 0.5)\n" 83 | << " (xtc (probability 0.75) (threshold 0.25))\n" 84 | << " (temperature 0.75)\n" 85 | << ")))"; 86 | 87 | rendezllama::ChatOptions opt; 88 | { 89 | FildeshX in = getslice_FildeshO(oss.c_struct()); 90 | bool all_good = slurp_sxpb_dynamic_options_close_FildeshX(&in, opt); 91 | assert(all_good); 92 | } 93 | assert(std::holds_alternative(opt.infer_via)); 94 | const auto& adjust_thru = std::get(opt.infer_via).adjust_thru; 95 | unsigned adjust_thru_idx = 0; 96 | 97 | auto* dry = std::get_if( 98 | &adjust_thru[adjust_thru_idx++]); 99 | assert(dry); 100 | assert(dry->multiplier == 0.5); 101 | assert(dry->base == 0.25); 102 | assert(dry->allowed_length == 100); 103 | assert(dry->window_length == 1000); 104 | 105 | auto* min_p = std::get_if( 106 | &adjust_thru[adjust_thru_idx++]); 107 | assert(min_p); 108 | assert(*min_p == 0.25); 109 | 110 | auto* penalize_with = std::get_if( 111 | &adjust_thru[adjust_thru_idx++]); 112 | assert(penalize_with); 113 | assert(penalize_with->window_length == 1000); 114 | assert(penalize_with->repetition == 1.5); 115 | assert(penalize_with->frequency == 0.5); 116 | assert(penalize_with->presence == 0.25); 117 | 118 | auto* top_k = std::get_if( 119 | &adjust_thru[adjust_thru_idx++]); 120 | assert(top_k); 121 | assert(*top_k == 123); 122 | 123 | auto* top_p = std::get_if( 124 | &adjust_thru[adjust_thru_idx++]); 125 | assert(top_p); 126 | assert(*top_p == 0.75); 127 | 128 | auto* typical_p = std::get_if( 129 | &adjust_thru[adjust_thru_idx++]); 130 | assert(typical_p); 131 | assert(*typical_p == 0.5); 132 | 133 | auto* xtc = std::get_if( 134 | &adjust_thru[adjust_thru_idx++]); 135 | assert(xtc); 136 | assert(xtc->probability == 0.75); 137 | assert(xtc->threshold == 0.25); 138 | 139 | auto* temperature = std::get_if( 140 | &adjust_thru[adjust_thru_idx++]); 141 | assert(temperature); 142 | assert(*temperature == 0.75); 143 | 144 | assert(adjust_thru.size() == adjust_thru_idx); 145 | } 146 | 147 | static 148 | void 149 | pick_via_parse_test() 150 | { 151 | using rendezllama::inference::Mirostat; 152 | using rendezllama::inference::Probability; 153 | 154 | rendezllama::ChatOptions opt; 155 | FildeshX in[1]; 156 | bool all_good; 157 | 158 | *in = FildeshX_of_strlit( 159 | "(language ((infer_via sampling) ((pick_via mirostat))))"); 160 | all_good = slurp_sxpb_dynamic_options_close_FildeshX(in, opt); 161 | assert(all_good); 162 | assert(std::holds_alternative(opt.infer_via)); 163 | auto& sampling = std::get(opt.infer_via); 164 | assert(std::holds_alternative(sampling.pick_via)); 165 | auto& mirostat = std::get(sampling.pick_via); 166 | assert(mirostat.version == 2); 167 | 168 | *in = FildeshX_of_strlit( 169 | "(language ((infer_via sampling) ((pick_via mirostat) (version 1))))"); 170 | all_good = slurp_sxpb_dynamic_options_close_FildeshX(in, opt); 171 | assert(all_good); 172 | assert(std::holds_alternative(opt.infer_via)); 173 | sampling = std::get(opt.infer_via); 174 | assert(std::holds_alternative(sampling.pick_via)); 175 | mirostat = std::get(sampling.pick_via); 176 | assert(mirostat.version == 1); 177 | 178 | *in = FildeshX_of_strlit( 179 | "(language ((infer_via sampling) ((pick_via probability))))"); 180 | all_good = slurp_sxpb_dynamic_options_close_FildeshX(in, opt); 181 | assert(all_good); 182 | assert(std::holds_alternative(opt.infer_via)); 183 | sampling = std::get(opt.infer_via); 184 | assert(std::holds_alternative(sampling.pick_via)); 185 | 186 | *in = FildeshX_of_strlit( 187 | "(language ((infer_via sampling)))"); 188 | all_good = slurp_sxpb_dynamic_options_close_FildeshX(in, opt); 189 | assert(all_good); 190 | assert(std::holds_alternative(opt.infer_via)); 191 | sampling = std::get(opt.infer_via); 192 | assert(std::holds_alternative(sampling.pick_via)); 193 | } 194 | 195 | int main() 196 | { 197 | default_parse_test(); 198 | seed_parse_test(); 199 | adjust_thru_parse_test(); 200 | pick_via_parse_test(); 201 | return 0; 202 | } 203 | -------------------------------------------------------------------------------- /test/language/language_schema_test.cc: -------------------------------------------------------------------------------- 1 | #include "src/language/language_schema.hh" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include "src/chat/opt.hh" 10 | #include "src/chat/opt_schema.hh" 11 | 12 | static 13 | void 14 | substitution_parse_test() 15 | { 16 | rendezllama::ChatOptions opt; 17 | FildeshX in[1]; 18 | *in = FildeshX_of_strlit( 19 | "(language (substitution (special_tokens (()) (() (name <|im_start|>)))))"); 20 | bool all_good = rendezllama::slurp_sxpb_dynamic_options_close_FildeshX(in, opt); 21 | assert(all_good); 22 | auto& substitution = opt.substitution; 23 | assert(substitution.special_tokens.size() == 1); 24 | assert(substitution.special_tokens[0].candidates.size() == 1); 25 | assert(substitution.special_tokens[0].alias == "<|im_start|>"); 26 | assert(substitution.special_tokens[0].candidates[0] == "<|im_start|>"); 27 | 28 | substitution.special_tokens.clear(); 29 | *in = FildeshX_of_strlit( 30 | "(language\n" 31 | " (substitution\n" 32 | " (bos_token_alias )\n" 33 | " (eos_token_alias )\n" 34 | " (special_tokens (())\n" 35 | " (() (alias <|im_start|>) (candidates (()) ))\n" 36 | " (() (alias <|im_end|>) (candidates (()) ))\n" 37 | ")))"); 38 | all_good = rendezllama::slurp_sxpb_dynamic_options_close_FildeshX(in, opt); 39 | assert(all_good); 40 | assert(substitution.special_tokens.size() == 2); 41 | assert(substitution.special_tokens[0].candidates.size() == 1); 42 | assert(substitution.special_tokens[1].candidates.size() == 1); 43 | assert(substitution.special_tokens[0].alias == "<|im_start|>"); 44 | assert(substitution.special_tokens[1].alias == "<|im_end|>"); 45 | assert(substitution.special_tokens[0].candidates[0] == ""); 46 | assert(substitution.special_tokens[1].candidates[0] == ""); 47 | } 48 | 49 | int main() 50 | { 51 | substitution_parse_test(); 52 | return 0; 53 | } 54 | -------------------------------------------------------------------------------- /test/language/vocabulary_test.cc: -------------------------------------------------------------------------------- 1 | #include "src/language/vocabulary.hh" 2 | 3 | #include 4 | 5 | #include 6 | 7 | #include "llama.h" 8 | 9 | using rendezllama::Vocabulary; 10 | 11 | 12 | static void size_test() { 13 | assert(sizeof(llama_token) == sizeof(Vocabulary::Token_id)); 14 | } 15 | 16 | 17 | static void tokenize_test(const char* model_filename) 18 | { 19 | llama_model_params model_params = llama_model_default_params(); 20 | model_params.vocab_only = true; 21 | llama_model* model = llama_model_load_from_file(model_filename, model_params); 22 | assert(model); 23 | 24 | rendezllama::Vocabulary vocabulary(model); 25 | // Should have a large vocabulary. Many more than 64 different tokens. 26 | assert(vocabulary.cardinality() > 64); 27 | 28 | std::string s = "The quick brown fox jumps over the lazy dog.\n"; 29 | std::vector tokens; 30 | vocabulary.tokenize_to(tokens, s); 31 | assert(!tokens.empty()); 32 | assert(tokens.back() == vocabulary.newline_token_id()); 33 | assert(vocabulary.last_char_of(tokens.back()) == '\n'); 34 | 35 | fildesh::ostringstream oss; 36 | for (auto token_id : tokens) { 37 | vocabulary.detokenize_to(oss, token_id); 38 | } 39 | assert(oss.view() == s); 40 | 41 | // ChatML doesn't (usually) use BOS and EOS tokens. 42 | vocabulary.assign_substitution("<|im_start|>", vocabulary.bos_token_id()); 43 | vocabulary.assign_substitution("<|im_end|>", vocabulary.eos_token_id()); 44 | s = "a<|im_start|>b<|im_end|>cdefg<|im_end|>"; 45 | vocabulary.tokenize_to(tokens, s); 46 | assert(tokens[1] == vocabulary.bos_token_id()); 47 | assert(tokens[3] == vocabulary.eos_token_id()); 48 | assert(tokens.back() == vocabulary.eos_token_id()); 49 | oss.truncate(); 50 | for (auto token_id : tokens) { 51 | vocabulary.detokenize_to(oss, token_id); 52 | } 53 | assert(oss.view() == s); 54 | 55 | llama_model_free(model); 56 | } 57 | 58 | 59 | int main(int argc, char** argv) 60 | { 61 | assert(argc == 2 && "need model filename"); 62 | size_test(); 63 | rendezllama::GlobalScope rendezllama_global_scope; 64 | tokenize_test(argv[1]); 65 | return 0; 66 | } 67 | -------------------------------------------------------------------------------- /test/manual/chat.fildesh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env fildesh 2 | 3 | (: model_file Filepath 4 | (?? .self.opt.model 5 | "/mnt/llama_model_data/quantize/gold/chatml/openhermes-2.5-mistral-16k-7b.Q8_0.gguf")) 6 | (: preset_file Filepath (?? .self.opt.preset "/dev/null")) 7 | (: scene Str (?? .self.opt.scene "roshambo_kira")) 8 | (: setting_file Filepath 9 | (?? .self.opt.setting 10 | (++ "example/prompt/" scene "/setting.sxpb"))) 11 | (: thread_count Str (?? .self.opt.thread_count "8")) 12 | 13 | |< splice / $(XOF setting_file) / 14 | |- replace_string "example/prompt/" 15 | |- replace_string "/setting" 16 | |- replace_string ".sxpb" 17 | |- replace_string "/" "_" 18 | |- splice / "bld/example/prompt/" / - 19 | |> splice -o $(OF transcript_file) -- - / ".txt" / 20 | 21 | |< stdin 22 | |- ./bld/src/chat/chat \ 23 | --model $(XOF model_file) \ 24 | --thread_count "${thread_count}" \ 25 | --o_rolling $(XA transcript_file) \ 26 | --x_setting $(XOF preset_file) \ 27 | --x_setting $(XOF setting_file) 28 | |> stdout 29 | 30 | -------------------------------------------------------------------------------- /test/manual/chat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | set -e -u 3 | 4 | # Try running from the project's toplevel directory as: 5 | # ./test/manual/chat.sh roshambo_kira -- --model ../llama.cpp/models/7B/ggml-model-q4_0.gguf 6 | 7 | # These will be overridden by given args. 8 | model_file="../llama.cpp/models/7B/ggml-model-q4_0.gguf" 9 | thread_count="8" 10 | 11 | # First arg must be the test prompt name. 12 | prompt_name="${1:-}" 13 | shift 14 | if [ -z "${prompt_name}" ]; then 15 | echo "Give a prompt name (a subdirectory in example/prompt/)." 1>&2 16 | exit 64 17 | fi 18 | setting_file="example/prompt/${prompt_name}/setting.sxpb" 19 | transcript_file="bld/example/prompt/${prompt_name}.txt" 20 | 21 | # Second arg can be "--" to indicate that we're forarding the rest. 22 | if [ "--" = "${1:-}" ]; then 23 | shift 24 | fi 25 | 26 | exec ./bld/src/chat/chat \ 27 | --x_setting "${setting_file}" \ 28 | --o_rolling "${transcript_file}" \ 29 | --thread_count "${thread_count}" \ 30 | --model "${model_file}" \ 31 | "$@" 32 | 33 | -------------------------------------------------------------------------------- /test/manual/coverage.md: -------------------------------------------------------------------------------- 1 | # CMake Test Coverage 2 | 3 | CMake GitHub workflow for coverage runs these commands. 4 | 5 | ```shell 6 | cmake \ 7 | -DCMAKE_BUILD_TYPE=Debug \ 8 | -DCMAKE_C_FLAGS="--coverage -Og" \ 9 | -DCMAKE_CXX_FLAGS="--coverage -Og" \ 10 | -DCMAKE_EXE_LINKER_FLAGS="--coverage -Og" \ 11 | -S . -B "bld/" 12 | 13 | cmake --build "bld/" --config Debug 14 | 15 | # Run tests to generate .gcda coverage files. 16 | cd "bld/" 17 | ctest -C Debug 18 | cd .. 19 | 20 | # Gather coverage files into one .info file. 21 | lcov --capture --directory "bld/" -o "bld/coverage_report.info" 22 | 23 | # Filter out dependencies from coverage info. 24 | lcov --remove "bld/coverage_report.info" -o "bld/coverage_report.info" \ 25 | "/usr/include/*" "${PWD}/bld/_deps/*" 26 | 27 | # Show result. 28 | lcov --list "bld/coverage_report.info" 29 | ``` 30 | -------------------------------------------------------------------------------- /test/manual/openblas.md: -------------------------------------------------------------------------------- 1 | # OpenBLAS Presence Test 2 | 3 | By default, we build llama.cpp with OpenBLAS, but we hook it up ourselves and could easily miss something. 4 | 5 | ```shell 6 | # New shell that dies if a command fails. 7 | sh 8 | set -e 9 | 10 | make LLAMA_OPENBLAS_ON=1 11 | 12 | # Test that we are defining GGML_USE_OPENBLAS. 13 | grep -e "-DGGML_USE_OPENBLAS" "bld/compile_commands.json" 14 | 15 | # Test that ggml.c still uses the compile definition. 16 | grep -e "GGML_USE_OPENBLAS" "bld/_deps/llamacpp-src/ggml.c" 17 | ``` 18 | --------------------------------------------------------------------------------