├── .github └── workflows │ └── wheels.yml ├── .gitignore ├── .gitmodules ├── .vscode ├── c_cpp_properties.json ├── launch.json └── settings.json ├── CMakeLists.txt ├── README.md ├── examples ├── context_swapping.py ├── retrieval │ ├── test.db │ └── test.pdf ├── simple.py └── simple_low_level.py ├── models └── .gitkeep ├── pyproject.toml ├── src ├── llama2.cpp ├── llama_wrapper.cpp ├── llama_wrapper.h └── llamacpp │ ├── __init__.py │ ├── chat.py │ ├── cli.py │ ├── convert.py │ └── quantize.py └── tests ├── test_llama_context.py └── test_llama_inference.py /.github/workflows/wheels.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | branches: 7 | - master 8 | - main 9 | push: 10 | branches: 11 | - master 12 | - main 13 | release: 14 | types: 15 | - published 16 | 17 | jobs: 18 | build_wheels: 19 | name: Build wheels on ${{ matrix.os }} for ${{ matrix.arch }} - ${{ matrix.p_ver }} 20 | runs-on: ${{ matrix.os }} 21 | env: 22 | CIBW_BUILD: ${{ matrix.cibw_build }} 23 | CIBW_ARCHS_LINUX: ${{ matrix.arch }} 24 | CIBW_ARCHS_MACOS: ${{ matrix.arch }} 25 | strategy: 26 | matrix: 27 | os: [ubuntu-latest, windows-latest, macos-latest] 28 | arch: [auto64] 29 | cibw_build: ["cp3{7,8,9,10,11}-*"] 30 | p_ver: ["3.7-3.11"] 31 | include: 32 | - arch: aarch64 33 | os: ubuntu-latest 34 | cibw_build: "cp37*" 35 | p_ver: "3.7" 36 | - arch: aarch64 37 | os: ubuntu-latest 38 | cibw_build: "cp38*" 39 | p_ver: "3.8" 40 | - arch: aarch64 41 | os: ubuntu-latest 42 | cibw_build: "cp39*" 43 | p_ver: "3.9" 44 | - arch: aarch64 45 | os: ubuntu-latest 46 | cibw_build: "cp310*" 47 | p_ver: "3.10" 48 | - arch: aarch64 49 | os: ubuntu-latest 50 | cibw_build: "cp311*" 51 | p_ver: "3.11" 52 | 53 | steps: 54 | - uses: actions/checkout@v3 55 | with: 56 | fetch-depth: 0 57 | submodules: true 58 | 59 | - name: Set up QEMU 60 | if: matrix.os == 'ubuntu-latest' && matrix.arch == 'aarch64' 61 | uses: docker/setup-qemu-action@v2 62 | with: 63 | platforms: arm64 64 | 65 | - name: Build wheels 66 | uses: pypa/cibuildwheel@v2.12.1 67 | env: 68 | CIBW_ARCHS_MACOS: "x86_64 universal2 arm64" 69 | 70 | - uses: actions/upload-artifact@v3 71 | with: 72 | path: ./wheelhouse/*.whl 73 | 74 | make_sdist: 75 | name: Make SDist 76 | runs-on: ubuntu-latest 77 | steps: 78 | - uses: actions/checkout@v3 79 | with: 80 | fetch-depth: 0 # Optional, use if you use setuptools_scm 81 | submodules: true # Optional, use if you have submodules 82 | 83 | - name: Install setup dependencies 84 | run: python -m pip install build 85 | 86 | - name: Build source distribution 87 | run: python -m build --sdist 88 | 89 | - uses: actions/upload-artifact@v3 90 | with: 91 | path: dist/*.tar.gz 92 | 93 | upload_all: 94 | needs: [build_wheels, make_sdist] 95 | runs-on: ubuntu-latest 96 | if: github.event_name == 'release' && github.event.action == 'published' 97 | steps: 98 | - uses: actions/download-artifact@v3 99 | with: 100 | name: artifact 101 | path: dist 102 | 103 | - name: Publish to PyPI 104 | uses: pypa/gh-action-pypi-publish@v1.5.0 105 | with: 106 | user: __token__ 107 | password: ${{ secrets.PYPI_API_TOKEN }} 108 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # setup.py is auto-generated by poetry 2 | _skbuild 3 | setup.py 4 | 5 | models/* 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "llama.cpp"] 2 | path = vendor/llama.cpp 3 | url = https://github.com/thomasantony/llama.cpp 4 | -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Mac", 5 | "includePath": [ 6 | "${workspaceFolder}/", 7 | "${workspaceFolder}/**", 8 | "/usr/local/include", 9 | "/Library/Developer/CommandLineTools/usr/include/c++/v1/", 10 | "/Library/Developer/CommandLineTools//SDKs/MacOSX13.1.sdk/usr/include/" 11 | ], 12 | "defines": [], 13 | "macFrameworkPath": ["/System/Library/Frameworks", "/Library/Frameworks"], 14 | "compilerPath": "/usr/local/bin/clang", 15 | "cStandard": "c17", 16 | "cppStandard": "c++11", 17 | "intelliSenseMode": "macos-clang-arm64" 18 | } 19 | ], 20 | "version": 4 21 | } 22 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "examples/simple.py", 12 | "console": "integratedTerminal", 13 | "args": [], 14 | "justMyCode": true 15 | }, 16 | { 17 | "type": "lldb", 18 | "request": "launch", 19 | "name": "LLDB Python", 20 | "program": "/Users/thomas/miniconda3/envs/llamacpp/bin/python", 21 | "args": ["${workspaceFolder}/examples/simple.py"], 22 | "cwd": "${workspaceFolder}", 23 | "stopOnEntry": false 24 | // "env": { 25 | // "PYTHONPATH": "${workspaceFolder}/build" 26 | // } 27 | } 28 | ] 29 | } 30 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.associations": { 3 | "regex": "cpp" 4 | } 5 | } 6 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # CmakeLists for building python bindings 2 | cmake_minimum_required(VERSION 3.0) 3 | 4 | project(llamacpp) 5 | 6 | set(CMAKE_CXX_STANDARD 14) 7 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 8 | set(CMAKE_CXX_EXTENSIONS OFF) 9 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 10 | 11 | set(GGML_USE_ACCELERATE 1) 12 | find_package(pybind11 CONFIG REQUIRED) 13 | 14 | add_subdirectory(vendor/llama.cpp) 15 | pybind11_add_module(llamacpp MODULE src/llama2.cpp src/llama_wrapper.cpp src/llama_wrapper.h) 16 | target_include_directories(llamacpp PRIVATE vendor/llama.cpp) 17 | target_link_libraries(llamacpp PRIVATE pybind11::module pybind11::lto pybind11::windows_extras llama) 18 | add_link_options(-no_fixup_chains) 19 | 20 | if(NOT MSVC AND NOT ${CMAKE_BUILD_TYPE} MATCHES Debug|RelWithDebInfo) 21 | # Strip unnecessary sections of the binary on Linux/macOS 22 | pybind11_strip(llamacpp) 23 | endif() 24 | 25 | set_target_properties(llamacpp PROPERTIES CXX_VISIBILITY_PRESET "hidden" 26 | CUDA_VISIBILITY_PRESET "hidden") 27 | 28 | install(TARGETS llamacpp DESTINATION llamacpp) 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Python bindings for llama.cpp 2 | 3 | **Important** 4 | 5 | - The Python API has changed significantly in the recent weeks and as a result, I have not had a chance to update `cli.py` or `chat.py` to reflect the new changes. The scripts under `examples/simple.py` and `examples/simple_low_level.py` should give you an idea of how to use the library. 6 | 7 | 8 | ## Install 9 | ### From PyPI 10 | 11 | ``` 12 | pip install llamacpp 13 | ``` 14 | 15 | ### Build from Source 16 | 17 | ``` 18 | pip install . 19 | ``` 20 | 21 | ## Get the model weights 22 | 23 | You will need to obtain the weights for LLaMA yourself. There are a few torrents floating around as well as some huggingface repositories (e.g https://huggingface.co/nyanko7/LLaMA-7B/). Once you have them, copy them into the models folder. 24 | 25 | ``` 26 | ls ./models 27 | 65B 30B 13B 7B tokenizer_checklist.chk tokenizer.model 28 | ``` 29 | 30 | Convert the weights to GGML format using `llamacpp-convert`. Then use `llamacpp-quantize` to quantize them into INT4. For example, for the 7B parameter model, run 31 | 32 | ``` 33 | llamacpp-convert ./models/7B/ 1 34 | llamacpp-quantize ./models/7B/ 35 | llamacpp-cli 36 | ``` 37 | 38 | **Note that running `llamacpp-convert` requires `torch`, `sentencepiece` and `numpy` to be installed. These packages are not installed by default when your install `llamacpp`.** 39 | 40 | ## Command line interface 41 | 42 | The package installs the command line entry point `llamacpp-cli` that points to `llamacpp/cli.py` and should provide about the same functionality as the `main` program in the original C++ repository. There is also an experimental `llamacpp-chat` that is supposed to bring up a chat interface but this is not working correctly yet. 43 | 44 | ## API 45 | 46 | Documentation is TBD. But the long and short of it is that there are two interfaces 47 | * `LlamaInference` - this one is a high level interface that tries to take care of most things for you. The demo script below uses this. 48 | * `LlamaContext` - this is a low level interface to the underlying llama.cpp API. You can use this similar to how the [main](https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp) example in `llama.cpp` does uses the C API. This is a rough implementation and currently untested except for compiling successfully. 49 | 50 | ## Demo script 51 | 52 | See `llamacpp/cli.py` for a detailed example. The simplest demo would be something like the following: 53 | 54 | ```python 55 | import sys 56 | import llamacpp 57 | 58 | 59 | def progress_callback(progress): 60 | print("Progress: {:.2f}%".format(progress * 100)) 61 | sys.stdout.flush() 62 | 63 | 64 | params = llamacpp.InferenceParams.default_with_callback(progress_callback) 65 | params.path_model = './models/7B/ggml-model-q4_0.bin' 66 | model = llamacpp.LlamaInference(params) 67 | 68 | prompt = "A llama is a" 69 | prompt_tokens = model.tokenize(prompt, True) 70 | model.update_input(prompt_tokens) 71 | 72 | model.ingest_all_pending_input() 73 | 74 | model.print_system_info() 75 | for i in range(20): 76 | model.eval() 77 | token = model.sample() 78 | text = model.token_to_str(token) 79 | print(text, end="") 80 | 81 | # Flush stdout 82 | sys.stdout.flush() 83 | 84 | model.print_timings() 85 | ``` 86 | 87 | ## ToDo 88 | 89 | - [ ] Investigate using dynamic versions using setuptools-scm (Example: https://github.com/pypa/setuptools_scm/blob/main/scm_hack_build_backend.py) 90 | -------------------------------------------------------------------------------- /examples/context_swapping.py: -------------------------------------------------------------------------------- 1 | """Demonstrates that the library now supports going over the context size limit 2 | (but loses "memory" of earlier text in the process)""" 3 | import sys 4 | import llamacpp 5 | 6 | 7 | params = llamacpp.InferenceParams() 8 | params.path_model = './models/7B/ggml-model-f16.bin' 9 | params.seed = 69420 10 | params.repeat_penalty = 1.0 11 | params.n_ctx = 128 12 | model = llamacpp.LlamaInference(params) 13 | 14 | prompt = " Llama is" 15 | prompt_tokens = model.tokenize(prompt, True) 16 | model.update_input(prompt_tokens) 17 | 18 | model.ingest_all_pending_input() 19 | print(model.system_info()) 20 | 21 | print(prompt, end='') 22 | for i in range(256): 23 | model.eval() 24 | token = model.sample() 25 | text = model.token_to_str(token) 26 | print(text, end='') 27 | sys.stdout.flush() 28 | 29 | print() 30 | model.print_timings() 31 | -------------------------------------------------------------------------------- /examples/retrieval/test.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasantony/llamacpp-python/41dbd7c1fa6498387b08ad99296f8dc79ebb3ee2/examples/retrieval/test.db -------------------------------------------------------------------------------- /examples/retrieval/test.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasantony/llamacpp-python/41dbd7c1fa6498387b08ad99296f8dc79ebb3ee2/examples/retrieval/test.pdf -------------------------------------------------------------------------------- /examples/simple.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import llamacpp 3 | 4 | 5 | def progress_callback(progress): 6 | # print("Progress: {:.2f}%".format(progress * 100)) 7 | # sys.stdout.flush() 8 | pass 9 | 10 | 11 | params = llamacpp.InferenceParams.default_with_callback(progress_callback) 12 | params.path_model = './models/7B/ggml-model-f16.bin' 13 | params.seed = 19472 14 | params.repeat_penalty = 1.0 15 | model = llamacpp.LlamaInference(params) 16 | 17 | prompt = " Llama is" 18 | prompt_tokens = model.tokenize(prompt, True) 19 | model.update_input(prompt_tokens) 20 | 21 | model.ingest_all_pending_input() 22 | print(model.system_info()) 23 | 24 | print(prompt, end='') 25 | for i in range(20): 26 | model.eval() 27 | token = model.sample() 28 | text = model.token_to_str(token) 29 | print(text, end="") 30 | 31 | # Flush stdout 32 | sys.stdout.flush() 33 | 34 | # model.print_timings() 35 | -------------------------------------------------------------------------------- /examples/simple_low_level.py: -------------------------------------------------------------------------------- 1 | import array 2 | import llamacpp 3 | 4 | params = llamacpp.LlamaContextParams() 5 | params.seed = 19472 6 | model = llamacpp.LlamaContext("./models/7B/ggml-model-f16.bin", params) 7 | 8 | prompt = "Llama is" 9 | # add a space in front of the first character to match OG llama tokenizer behavior 10 | prompt = f" {prompt}" 11 | 12 | # tokenize the prompt 13 | embd_inp = model.str_to_token(prompt, True) 14 | 15 | n_ctx = model.get_n_ctx() 16 | 17 | if len(embd_inp) > n_ctx - 4: 18 | raise Exception("Prompt is too long") 19 | 20 | n_past = 0 21 | n_remain = 9 22 | n_consumed = 0 23 | embd = [] 24 | 25 | while n_remain: 26 | if len(embd): 27 | if model.eval(array.array('i', embd), len(embd), n_past, 1): 28 | raise Exception("Failed to predict\n") 29 | n_past += len(embd) 30 | embd.clear() 31 | 32 | if len(embd_inp) <= n_consumed: 33 | # sample 34 | top_k = 40 35 | top_p = 0.95 36 | temp = 0.8 37 | repeat_penalty = 0.0 38 | 39 | # sending an empty array for the last n tokens 40 | id = model.sample_top_p_top_k(array.array('i', []), top_k, top_p, temp, repeat_penalty) 41 | # add it to the context 42 | embd.append(id) 43 | # decrement remaining sampling budget 44 | n_remain -= 1 45 | else: 46 | # has unconsumed input 47 | while len(embd_inp) > n_consumed: 48 | # update_input 49 | embd.append(embd_inp[n_consumed]) 50 | n_consumed += 1 51 | 52 | for id in embd: 53 | print(model.token_to_str(id), end="") 54 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasantony/llamacpp-python/41dbd7c1fa6498387b08ad99296f8dc79ebb3ee2/models/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["scikit-build-core>=0.2.1", "pybind11>2.10"] 3 | build-backend = "scikit_build_core.build" 4 | 5 | [project] 6 | name = "llamacpp" 7 | version = "0.1.15" 8 | description = "Python bindings for @ggerganov's llama.cpp" 9 | authors = [ 10 | {name = "Thomas Antony", email= "mail@thomasantony.com"} 11 | ] 12 | license = {text = "MIT"} 13 | readme = "README.md" 14 | requires-python = ">=3.7" 15 | 16 | classifiers = [ 17 | "Development Status :: 4 - Beta", 18 | "License :: OSI Approved :: MIT License", 19 | "Programming Language :: Python :: 3 :: Only", 20 | "Programming Language :: Python :: 3.7", 21 | "Programming Language :: Python :: 3.8", 22 | "Programming Language :: Python :: 3.9", 23 | "Programming Language :: Python :: 3.10", 24 | "Programming Language :: Python :: 3.11", 25 | ] 26 | 27 | [project.urls] 28 | homepage = "https://github.com/thomasantony/llamacpp-python" 29 | repository = "https://github.com/thomasantony/llamacpp-python" 30 | 31 | [tool.scikit-build] 32 | wheel.expand-macos-universal-tags = true 33 | cmake.build-type = "Release" 34 | 35 | [project.scripts] 36 | llamacpp-convert = 'llamacpp.convert:main' 37 | llamacpp-quantize = 'llamacpp.quantize:main' 38 | llamacpp-cli = 'llamacpp.cli:run' 39 | llamacpp-chat = 'llamacpp.chat:run' 40 | 41 | [tool.cibuildwheel] 42 | test-command = "python -c \"import llamacpp\"" 43 | 44 | # Skip Python 3.6, PyPy and 32-bit builds 45 | skip = ["cp36-*", "pp*", "*-win32", "*-manylinux_i686", "*-musllinux_i686"] 46 | 47 | build-verbosity = 3 48 | test-skip = ["*_arm64", "*_universal2:arm64"] 49 | -------------------------------------------------------------------------------- /src/llama2.cpp: -------------------------------------------------------------------------------- 1 | #include "ggml.h" 2 | #include "llama.h" 3 | #include "llama_wrapper.h" 4 | #include 5 | #include 6 | #include "pybind11/functional.h" 7 | #include "pybind11/numpy.h" 8 | #include 9 | namespace py = pybind11; 10 | using Callback = std::function; 11 | 12 | 13 | class LlamaInference; 14 | /* Tokenizer for use with text-ui project */ 15 | class Tokenizer { 16 | const LlamaInference& llama; 17 | public: 18 | Tokenizer(const LlamaInference& llama): llama(llama) {} 19 | std::vector tokenize(const std::string & text, bool bos); 20 | std::string detokenize(const std::vector& ids); 21 | std::string detokenize(const llama_token& id); 22 | }; 23 | 24 | // Lower level API that gives more direct access to llama_context 25 | class LlamaContext 26 | { 27 | llama_context* ctx; 28 | // Flag that indicates whether the logits for all tokens should be returned 29 | // as opposed to just the last one (helpful for building kNN datastores or 30 | // computing perplexity) 31 | bool logits_all = false; 32 | int last_eval_n_tokens = 0; 33 | public: 34 | LlamaContext(std::string path_model, const llama_context_params& params): logits_all(params.logits_all) { 35 | ctx = llama_init_from_file(path_model.c_str(), params); 36 | } 37 | LlamaContext(std::string path_model, const llama_context_params& params, Callback progress_cb) 38 | { 39 | llama_context_params params_with_cb = params; 40 | params_with_cb.progress_callback = [](float progress, void* user_data) { 41 | auto cb = static_cast(user_data); 42 | (*cb)(progress); 43 | }; 44 | params_with_cb.progress_callback_user_data = &progress_cb; 45 | 46 | ctx = llama_init_from_file(path_model.c_str(), params_with_cb); 47 | } 48 | ~LlamaContext() 49 | { 50 | llama_free(ctx); 51 | } 52 | 53 | // Run the llama inference to obtain the logits and probabilities for the next token. 54 | // tokens + n_tokens is the provided batch of new tokens to process 55 | // n_past is the number of tokens to use from previous eval calls 56 | // Returns 0 on success 57 | int eval(py::buffer tokens, 58 | const int n_tokens, 59 | const int n_past, 60 | const int n_threads) { 61 | py::buffer_info tokens_info = tokens.request(); 62 | // Check that tokens are integers and one-dimensional 63 | if (tokens_info.format != py::format_descriptor::format() || 64 | tokens_info.ndim != 1) { 65 | throw std::runtime_error("Invalid tokens buffer format"); 66 | } 67 | // Check that the number of tokens is correct 68 | if (tokens_info.size < n_tokens) { 69 | throw std::runtime_error("Invalid number of tokens"); 70 | } 71 | llama_token* tokens_ptr = (llama_token*)tokens.request().ptr; 72 | last_eval_n_tokens = n_tokens; 73 | return llama_eval(ctx, tokens_ptr, n_tokens, n_past, n_threads); 74 | } 75 | 76 | // Sample a token from the logits 77 | llama_token sample_top_p_top_k(py::buffer last_n_tokens_data, 78 | int top_k, 79 | float top_p, 80 | float temp, 81 | float repeat_penalty) 82 | { 83 | py::buffer_info last_n_tokens_info = last_n_tokens_data.request(); 84 | // Check that tokens are integers and one-dimensional 85 | if (last_n_tokens_info.format != py::format_descriptor::format() || 86 | last_n_tokens_info.ndim != 1) { 87 | throw std::runtime_error("Invalid tokens buffer format"); 88 | } 89 | llama_token* last_n_tokens_ptr = (llama_token*)last_n_tokens_info.ptr; 90 | size_t last_n_tokens_size = last_n_tokens_info.size; 91 | return llama_sample_top_p_top_k(ctx, last_n_tokens_ptr, last_n_tokens_size, top_k, top_p, temp, repeat_penalty); 92 | } 93 | 94 | // Token logits obtained from the last call to eval() 95 | // The logits for the last token are stored in the last row 96 | // Size: n_tokens x n_vocab (n_tokens == 1 if params.logits_all == 0) 97 | py::memoryview get_logits() 98 | { 99 | if(last_eval_n_tokens == 0) 100 | { 101 | throw std::runtime_error("No logits available. Call eval() first."); 102 | } 103 | // Returns matrix if logits_all is true 104 | // and vector otherwise 105 | if (logits_all) 106 | { 107 | float* logits_ptr = llama_get_logits(ctx); 108 | const size_t n_vocab = llama_n_vocab(ctx); 109 | const size_t n_tokens = last_eval_n_tokens; 110 | 111 | return py::memoryview::from_buffer( 112 | static_cast(logits_ptr), /* Pointer to buffer */ 113 | sizeof(float), /* Size of one scalar */ 114 | "f", /* Python struct-style format descriptor */ 115 | { n_tokens, n_vocab }, /* Buffer dimensions */ 116 | { sizeof(float) * n_vocab, /* Assumes Row-Major */ 117 | sizeof(float) * 1 }, /* Strides (in bytes) for each index */ 118 | true /* Read only */ 119 | ); 120 | }else{ 121 | float* logits_ptr = llama_get_logits(ctx); 122 | const size_t n_vocab = llama_n_vocab(ctx); 123 | const size_t n_tokens = 1; 124 | return py::memoryview::from_buffer( 125 | static_cast(logits_ptr), /* Pointer to buffer */ 126 | sizeof(float), /* Size of one scalar */ 127 | "f", /* Python struct-style format descriptor */ 128 | { n_tokens, n_vocab }, /* Buffer dimensions */ 129 | { sizeof(float) * n_vocab, /* Assumes Row-Major */ 130 | sizeof(float) * 1 }, /* Strides (in bytes) for each index */ 131 | true /* Read only */ 132 | ); 133 | } 134 | } 135 | 136 | // Get the embeddings for the input 137 | // shape: [n_embd] (1-dimensional) 138 | py::array_t get_embeddings() const 139 | { 140 | const float* embd_ptr = llama_get_embeddings(ctx); 141 | const size_t n_embd = llama_n_embd(ctx); 142 | return py::array(n_embd, embd_ptr); 143 | } 144 | 145 | // Get the number of tokens in the vocabulary 146 | size_t get_n_vocab() const 147 | { 148 | return llama_n_vocab(ctx); 149 | } 150 | 151 | // Get the number of dimensions in the embedding 152 | size_t get_n_embd() const 153 | { 154 | return llama_n_embd(ctx); 155 | } 156 | 157 | // Get the context size 158 | size_t get_n_ctx() const 159 | { 160 | return llama_n_ctx(ctx); 161 | } 162 | 163 | 164 | // Token Id -> String. Uses the vocabulary in the provided context 165 | std::string token_to_str(llama_token token) const 166 | { 167 | return llama_token_to_str(ctx, token); 168 | } 169 | 170 | // String -> Token Id. Uses the vocabulary in the provided context 171 | py::array str_to_token(const std::string& text, bool add_bos) const 172 | { 173 | std::vector res(text.size() + (int)add_bos); 174 | int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); 175 | assert(n >= 0); 176 | res.resize(n); 177 | return py::array(res.size(), res.data()); 178 | } 179 | 180 | // Performance information 181 | void print_timings() const 182 | { 183 | llama_print_timings(ctx); 184 | } 185 | void reset_timings() 186 | { 187 | llama_reset_timings(ctx); 188 | } 189 | }; 190 | 191 | 192 | // High level API that includes input management and other convenience functions 193 | class LlamaInference { 194 | public: 195 | LlamaWrapper llama{}; 196 | InferenceParams params{}; 197 | LlamaInference(InferenceParams params): params(params), llama(params) { 198 | llama.init(); 199 | } 200 | 201 | // Get tokenizer for the provided context 202 | // Returns a Tokenizer object 203 | Tokenizer get_tokenizer() const 204 | { 205 | return Tokenizer(*this); 206 | } 207 | // Run the llama inference to obtain the logits and probabilities for the next token. 208 | // tokens + n_tokens is the provided batch of new tokens to process 209 | // n_past is the number of tokens to use from previous eval calls 210 | // Returns 0 on success 211 | int eval() 212 | { 213 | return llama.eval(); 214 | } 215 | 216 | // Convert the provided text into tokens. 217 | // Duplicate of the version in examples/common.h in llama.cpp 218 | std::vector tokenize(const std::string& text, bool add_bos) const 219 | { 220 | // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars 221 | return llama.tokenize_text(text, add_bos); 222 | } 223 | 224 | // Token logits obtained from the last call to eval() 225 | // The logits for the last token are stored in the last row 226 | // Size: n_vocab 227 | py::array_t get_logits() const 228 | { 229 | const float* logit_ptr = llama.get_logits(); 230 | const size_t n_vocab = llama.get_n_vocab(); 231 | return py::array(n_vocab, logit_ptr); 232 | } 233 | 234 | // Get the embeddings for the input 235 | // shape: [n_embd] (1-dimensional) 236 | py::array_t get_embeddings() const 237 | { 238 | const float* embd_ptr = llama.get_embeddings(); 239 | const size_t n_embd = llama.get_n_embd(); 240 | return py::array(n_embd, embd_ptr); 241 | } 242 | 243 | // Token Id -> String. Uses the vocabulary in the provided context 244 | std::string token_to_str(llama_token token) const 245 | { 246 | return llama.token_to_str(token); 247 | } 248 | 249 | // String -> Token Id. Uses the vocabulary in the provided context 250 | llama_token sample() 251 | { 252 | return llama.sample(); 253 | } 254 | // Add BOS token to the input 255 | void add_bos() 256 | { 257 | llama.add_bos(); 258 | } 259 | // set input using tokens 260 | void set_input(const std::vector& tokens) 261 | { 262 | llama.set_input(tokens); 263 | } 264 | // set input using string 265 | void set_input(const std::string& text) 266 | { 267 | llama.set_input(text); 268 | } 269 | 270 | // set input using tokens 271 | void update_input(const std::vector& tokens) 272 | { 273 | llama.update_input(tokens); 274 | } 275 | // update input using string 276 | void update_input(const std::string& text) 277 | { 278 | llama.update_input(text); 279 | } 280 | 281 | bool has_unconsumed_input() const { 282 | return llama.has_unconsumed_input(); 283 | } 284 | 285 | void ingest_all_pending_input() 286 | { 287 | llama.ingest_all_pending_input(); 288 | } 289 | 290 | // Performance information 291 | void print_timings() 292 | { 293 | llama.print_timings(); 294 | } 295 | void reset_timings() 296 | { 297 | llama.reset_timings(); 298 | } 299 | }; 300 | 301 | std::vector Tokenizer::tokenize(const std::string & text, bool bos) { 302 | return llama.tokenize(text, bos); 303 | } 304 | std::string Tokenizer::detokenize(const std::vector& ids) { 305 | std::string output = ""; 306 | for (auto id: ids) { 307 | output += detokenize(id); 308 | } 309 | return output; 310 | } 311 | std::string Tokenizer::detokenize(const llama_token& id) { 312 | return llama.token_to_str(id); 313 | } 314 | 315 | 316 | PYBIND11_MODULE(llamacpp, m) { 317 | m.doc() = "Python bindings for C++ implementation of the LLaMA language model"; 318 | /* Wrapper for llama_context_params */ 319 | py::class_(m, "LlamaContextParams") 320 | .def(py::init<>(&llama_context_default_params)) 321 | .def_readwrite("n_ctx", &llama_context_params::n_ctx) 322 | .def_readwrite("n_parts", &llama_context_params::n_parts) 323 | .def_readwrite("seed", &llama_context_params::seed) 324 | .def_readwrite("f16_kv", &llama_context_params::f16_kv) 325 | .def_readwrite("logits_all", &llama_context_params::logits_all) 326 | .def_readwrite("vocab_only", &llama_context_params::vocab_only) 327 | .def_readwrite("use_mlock", &llama_context_params::use_mlock) 328 | .def_readwrite("embedding", &llama_context_params::embedding); 329 | 330 | /* Wrapper for InferenceParams */ 331 | py::class_(m, "InferenceParams") 332 | .def(py::init<>()) 333 | .def_static("default_with_callback", [](Callback cb){ 334 | InferenceParams params; 335 | params.callback = cb; 336 | return params; 337 | }, py::arg("callback")) 338 | .def_readwrite("path_model", &InferenceParams::path_model) 339 | .def_readwrite("seed", &InferenceParams::seed) 340 | .def_readwrite("n_threads", &InferenceParams::n_threads) 341 | .def_readwrite("n_predict", &InferenceParams::n_predict) 342 | .def_readwrite("repeat_last_n", &InferenceParams::repeat_last_n) 343 | .def_readwrite("n_batch", &InferenceParams::n_batch) 344 | .def_readwrite("top_k", &InferenceParams::top_k) 345 | .def_readwrite("top_p", &InferenceParams::top_p) 346 | .def_readwrite("temp", &InferenceParams::temp) 347 | .def_readwrite("repeat_penalty", &InferenceParams::repeat_penalty) 348 | .def_readwrite("use_mlock", &InferenceParams::use_mlock) 349 | .def_readwrite("memory_f16", &InferenceParams::memory_f16) 350 | .def_readwrite("n_ctx", &InferenceParams::n_ctx) 351 | .def_readwrite("callback", &InferenceParams::callback) 352 | .def_readwrite("n_keep", &InferenceParams::n_keep); 353 | 354 | /* Wrapper for LlamaContext */ 355 | py::class_(m, "LlamaContext") 356 | .def(py::init(), py::arg("path_model"), py::arg("params")) 357 | .def(py::init(), py::arg("path_model"), py::arg("params"), py::arg("progress_callback")) 358 | .def("get_n_vocab", &LlamaContext::get_n_vocab, "Get the number of tokens in the vocabulary") 359 | .def("get_n_embd", &LlamaContext::get_n_embd, "Get the number of dimensions in the embedding") 360 | .def("get_n_ctx", &LlamaContext::get_n_ctx, "Get the number of tokens in the context") 361 | .def("get_embeddings", &LlamaContext::get_embeddings, "Get the embeddings as a numpy array") 362 | .def("get_logits", &LlamaContext::get_logits, "Get the logits as a numpy array") 363 | .def("token_to_str", &LlamaContext::token_to_str, "Convert a token id to a string") 364 | .def("str_to_token", &LlamaContext::str_to_token, "Convert a string to a token id") 365 | .def("print_timings", &LlamaContext::print_timings, "Print the timings for the last call to eval()") 366 | .def("reset_timings", &LlamaContext::reset_timings, "Reset the timings for the last call to eval()") 367 | .def("eval", &LlamaContext::eval, "Run the llama inference to obtain the logits and probabilities for the next token", 368 | py::call_guard()) 369 | .def("sample_top_p_top_k", &LlamaContext::sample_top_p_top_k, "Sample a token from the logits using top-p and top-k"); 370 | 371 | /* Wrapper for LlamaInference methods */ 372 | py::class_(m, "LlamaInference") 373 | .def(py::init(), py::arg("params")) 374 | .def("set_input", py::overload_cast&>(&LlamaInference::set_input), "Set the input to the provided tokens") 375 | .def("set_input", py::overload_cast(&LlamaInference::set_input), "Set the input to the provided tokens") 376 | .def("update_input", py::overload_cast&>(&LlamaInference::update_input), "Update the input with the provided tokens") 377 | .def("update_input", py::overload_cast(&LlamaInference::update_input), "Update the input with the provided text") 378 | .def("eval", &LlamaInference::eval, "Run the llama inference to obtain the logits and probabilities for the next token", 379 | py::call_guard()) 380 | .def("add_bos", &LlamaInference::add_bos) 381 | .def("tokenize", &LlamaInference::tokenize, "Convert the provided text into tokens", 382 | py::arg("text"), py::arg("add_bos")) 383 | .def("has_unconsumed_input", &LlamaInference::has_unconsumed_input, "Check if there is unconsumed input") 384 | .def("ingest_all_pending_input", &LlamaInference::ingest_all_pending_input, "Ingest all pending input") 385 | .def("get_logits", &LlamaInference::get_logits, "Get the logits for the last token", py::call_guard()) 386 | .def("get_embeddings", &LlamaInference::get_embeddings, "Get the embeddings for the last token") 387 | .def("token_to_str", &LlamaInference::token_to_str, "Convert a token to a string", 388 | py::arg("token")) 389 | .def_static("token_bos", &llama_token_bos, "Get the token for the beginning of a sentence") 390 | .def_static("token_eos", &llama_token_eos, "Get the token for the end of a sentence") 391 | .def("print_timings", &LlamaInference::print_timings, "Print the timings for the last call to eval()") 392 | .def("reset_timings", &LlamaInference::reset_timings, "Reset the timings for the last call to eval()") 393 | .def_static("system_info", &llama_print_system_info, "Print system information") 394 | .def("sample", &LlamaInference::sample, "Sample a token from the logits") 395 | .def("get_tokenizer", &LlamaInference::get_tokenizer, "Get the tokenizer"); 396 | 397 | 398 | // /* Wrapper for Tokenizer */ 399 | py::class_(m, "Tokenizer") 400 | .def("tokenize", &Tokenizer::tokenize, "Tokenize text", py::arg("text"), py::arg("add_bos") = false) 401 | .def("detokenize", py::overload_cast&>(&Tokenizer::detokenize), "Detokenize text") 402 | .def("detokenize", py::overload_cast(&Tokenizer::detokenize), "Detokenize single token"); 403 | 404 | /* Wrapper for llama_model_quantize */ 405 | m.def("llama_model_quantize", &llama_model_quantize, "Quantize the LLaMA model"); 406 | } 407 | -------------------------------------------------------------------------------- /src/llama_wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include "llama_wrapper.h" 2 | #include 3 | 4 | static void trigger_cb(float progress, void * user_data) { 5 | if (user_data == nullptr) { 6 | return; 7 | } 8 | auto cb = static_cast(user_data); 9 | (*cb)(progress); 10 | } 11 | 12 | // Initialize the model 13 | bool LlamaWrapper::init() 14 | { 15 | if (is_initialized) 16 | { 17 | return true; 18 | } 19 | // update pointer to callback if needed 20 | if(inference_params.callback) 21 | { 22 | using raw_cb = void (*)(float, void*); 23 | inference_params.ctx_params.progress_callback = (raw_cb)trigger_cb; 24 | inference_params.ctx_params.progress_callback_user_data = &inference_params.callback; 25 | }else{ 26 | inference_params.ctx_params.progress_callback = nullptr; 27 | } 28 | 29 | // update default ctx params with our user-selected overrides 30 | inference_params.ctx_params.n_ctx = inference_params.n_ctx; 31 | inference_params.ctx_params.seed = inference_params.seed; 32 | inference_params.ctx_params.use_mlock = inference_params.use_mlock; 33 | inference_params.ctx_params.f16_kv = inference_params.memory_f16; 34 | 35 | ctx = llama_init_from_file(inference_params.path_model.c_str(), inference_params.ctx_params); 36 | 37 | n_ctx = llama_n_ctx(ctx); 38 | last_n_tokens = std::vector(n_ctx); 39 | std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); 40 | is_initialized = true; 41 | return true; 42 | } 43 | // Tokenize text 44 | const vector LlamaWrapper::tokenize_text(const std::string& text, bool add_bos) const 45 | { 46 | // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars 47 | std::vector res(text.size() + (int)add_bos); 48 | int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); 49 | assert(n >= 0); 50 | res.resize(n); 51 | return res; 52 | } 53 | 54 | // Add BOS token to input 55 | void LlamaWrapper::add_bos() { 56 | embd_inp.push_back(llama_token_bos()); 57 | } 58 | 59 | // Clear the model input buffer 60 | void LlamaWrapper::clear_input() 61 | { 62 | embd_inp.clear(); 63 | n_consumed = 0; 64 | std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); 65 | } 66 | 67 | // Set the model input buffer 68 | void LlamaWrapper::set_input(const std::string& text) 69 | { 70 | set_input(tokenize_text(text)); 71 | } 72 | 73 | // Set the model input buffer from tokens 74 | void LlamaWrapper::set_input(const vector& tokens) 75 | { 76 | embd_inp.clear(); 77 | update_input(tokens); 78 | std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); 79 | n_consumed = 0; 80 | n_past = 0; 81 | } 82 | 83 | // Update input with text 84 | void LlamaWrapper::update_input(const std::string& text) 85 | { 86 | update_input(tokenize_text(text)); 87 | } 88 | 89 | // Update input with tokens 90 | void LlamaWrapper::update_input(const vector& tokens) 91 | { 92 | embd_inp.insert(embd_inp.end(), tokens.begin(), tokens.end()); 93 | } 94 | 95 | // Ingest one batch of input 96 | void LlamaWrapper::ingest_input_batch() 97 | { 98 | // Copy at most n_batch elements from embd_inp to embd 99 | size_t num_copied = std::min((size_t) inference_params.n_batch+1, embd_inp.size() - n_consumed); 100 | std::copy(embd_inp.begin() + n_consumed, 101 | embd_inp.begin() + n_consumed + num_copied, 102 | std::back_inserter(embd)); 103 | n_consumed += num_copied; 104 | 105 | // Copy the last `repeat_last_n` elements copied into embd to last_n_tokens 106 | size_t num_copied_last_n = std::min(num_copied, (size_t) inference_params.repeat_last_n); 107 | last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin()+num_copied_last_n); 108 | last_n_tokens.insert(last_n_tokens.end(), embd.end() - num_copied_last_n, embd.end()); 109 | } 110 | 111 | // Ingest all input 112 | bool LlamaWrapper::ingest_all_pending_input() 113 | { 114 | while (has_unconsumed_input()) 115 | { 116 | ingest_input_batch(); 117 | eval(); 118 | } 119 | return true; 120 | } 121 | 122 | // Check if there is unconsumed input 123 | bool LlamaWrapper::has_unconsumed_input() const 124 | { 125 | return n_consumed < embd_inp.size(); 126 | } 127 | 128 | // Eval model and clear input 129 | bool LlamaWrapper::eval() 130 | { 131 | if (embd.size() > 0) { 132 | // infinite text generation via context swapping 133 | // if we run out of context: 134 | // - take the n_keep first tokens from the original prompt (via n_past) 135 | // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in a batch 136 | if (n_past + (int) embd.size() > n_ctx) { 137 | const int n_left = n_past - inference_params.n_keep; 138 | 139 | n_past = inference_params.n_keep; 140 | 141 | // insert n_left/2 tokens at the start of embd from last_n_tokens 142 | embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size()); 143 | } 144 | if (llama_eval(ctx, embd.data(), embd.size(), n_past, inference_params.n_threads) != 0) { 145 | fprintf(stderr, "Failed to predict\n"); 146 | return false; 147 | } 148 | } 149 | n_past += embd.size(); 150 | embd.clear(); 151 | return true; 152 | } 153 | 154 | // Sample from logits 155 | llama_token LlamaWrapper::sample() 156 | { 157 | llama_token id = 0; 158 | 159 | { 160 | id = llama_sample_top_p_top_k( 161 | ctx, 162 | last_n_tokens.data() + n_ctx - inference_params.repeat_last_n, 163 | inference_params.repeat_last_n, 164 | inference_params.top_k, 165 | inference_params.top_p, 166 | inference_params.temp, 167 | inference_params.repeat_penalty 168 | ); 169 | 170 | last_n_tokens.erase(last_n_tokens.begin()); 171 | last_n_tokens.push_back(id); 172 | embd.push_back(id); 173 | } 174 | return id; 175 | } 176 | 177 | // Get the logits for the last token 178 | const float* LlamaWrapper::get_logits() const 179 | { 180 | return llama_get_logits(ctx); 181 | } 182 | -------------------------------------------------------------------------------- /src/llama_wrapper.h: -------------------------------------------------------------------------------- 1 | #ifndef LLAMA_WRAPPER_H 2 | #define LLAMA_WRAPPER_H 3 | 4 | #include "llama.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | /* High level wrapper for the C-style LLAMA API */ 11 | using std::vector; 12 | using Callback = std::function; 13 | 14 | struct InferenceParams { 15 | // model parameters 16 | std::string path_model = ""; 17 | int32_t seed = -1; // RNG seed 18 | int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); 19 | int32_t n_predict = 128; // new tokens to predict 20 | int32_t repeat_last_n = 64; // last n tokens to penalize 21 | int32_t n_batch = 8; // batch size for prompt processing 22 | int32_t n_keep = 0; // number of tokens to keep from initial prompt 23 | 24 | // sampling parameters 25 | int32_t top_k = 40; 26 | float top_p = 0.95f; 27 | float temp = 0.80f; 28 | float repeat_penalty = 1.10f; 29 | 30 | bool use_mlock = false; 31 | bool memory_f16 = false; 32 | 33 | int n_ctx = 512; // context size 34 | 35 | llama_context_params ctx_params = llama_context_default_params(); 36 | 37 | Callback callback{}; 38 | }; 39 | 40 | class LlamaWrapper { 41 | public: 42 | // LLAMA API 43 | LlamaWrapper() = default; 44 | LlamaWrapper(InferenceParams inference_params) 45 | : is_initialized(false), inference_params(inference_params) 46 | {} 47 | ~LlamaWrapper() { 48 | if (ctx) 49 | { 50 | llama_free(ctx); 51 | } 52 | }; 53 | 54 | // Initialize the model 55 | bool init(); 56 | // Check if the model is initialized 57 | bool is_init() const { return is_initialized; } 58 | 59 | // Input processing and inference 60 | // Tokenize text 61 | const vector tokenize_text(const std::string& text, bool add_bos = false) const; 62 | // Queues up a BOS token to the model input 63 | void add_bos(); 64 | // Clears the model input buffer 65 | void clear_input(); 66 | // Set the model input buffer 67 | void set_input(const std::string& text); 68 | // Set the model input buffer from tokens 69 | void set_input(const vector& tokens); 70 | // Queues up input text to the model input 71 | void update_input(const std::string& text); 72 | // Queues up input tokens to the model input 73 | void update_input(const vector& tokens); 74 | // Ingests input previously added using update_input() 75 | void ingest_input_batch(); 76 | // Ingests all input previously added using update_input() in multiple batches 77 | // Batch size is determined by n_batch in InferenceParams 78 | bool ingest_all_pending_input(); 79 | // Checks if the model has unconsumed input to be ingested using ingest_input_batch() 80 | bool has_unconsumed_input() const; 81 | 82 | // Evaluate the model on a batch of input. Must call llama_ingest_input_batch() first. 83 | bool eval(); 84 | // Sample token from the model and add it to the model input 85 | llama_token sample(); 86 | 87 | // Output processing 88 | // Get logits 89 | const float* get_logits() const; 90 | 91 | // Get embeddings 92 | const float* get_embeddings() const { 93 | return llama_get_embeddings(ctx); 94 | } 95 | 96 | int get_n_vocab() const { return llama_n_vocab(ctx); } 97 | int get_n_embd() const { return llama_n_embd(ctx); } 98 | 99 | // Convert token to str 100 | std::string token_to_str(llama_token token) const { return llama_token_to_str(ctx, token); } 101 | 102 | // Print timings 103 | void print_timings() const { llama_print_timings(ctx); } 104 | // Reset timings 105 | void reset_timings() const { llama_reset_timings(ctx); } 106 | 107 | private: 108 | std::string path_model = ""; 109 | llama_context* ctx = nullptr; 110 | InferenceParams inference_params{}; 111 | 112 | // Random number generator 113 | std::mt19937 rng{}; 114 | 115 | // Tokens 116 | vector embd{}; 117 | vector embd_inp{}; 118 | vector last_n_tokens{}; 119 | 120 | int n_consumed = 0; 121 | int remaining_tokens = 0; 122 | int n_past = 0; 123 | int n_ctx = 0; 124 | size_t mem_per_token = 0; 125 | 126 | bool is_initialized = false; 127 | }; 128 | 129 | #endif /* LLAMA_WRAPPER_H */ 130 | -------------------------------------------------------------------------------- /src/llamacpp/__init__.py: -------------------------------------------------------------------------------- 1 | import llamacpp 2 | 3 | # Expose the bindings in module 4 | from .llamacpp import (InferenceParams, 5 | LlamaInference, 6 | LlamaContext, 7 | LlamaContextParams, 8 | llama_model_quantize 9 | ) 10 | -------------------------------------------------------------------------------- /src/llamacpp/chat.py: -------------------------------------------------------------------------------- 1 | """A modified version of llamacpp-cli and includes a good prompt for the chatbot""" 2 | import sys 3 | import llamacpp 4 | import argparse 5 | from typing import Dict 6 | 7 | from llamacpp.cli import main as llamacpp_main 8 | 9 | # Default prompt 10 | prompt = """Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. 11 | 12 | User: Hello, Bob. 13 | Bob: Hello. How may I help you today? 14 | User:""" 15 | 16 | 17 | def parse_chat_params(argv) -> Dict[str, str]: 18 | """Parse chat parameters""" 19 | 20 | parser = argparse.ArgumentParser(description="LLaMa") 21 | parser.add_argument("-i", "--interactive", action="store_true", help="run in interactive mode", default=True) 22 | parser.add_argument( 23 | "--interactive-start", 24 | action="store_true", 25 | help="run in interactive mode and poll user input at startup", 26 | default=False, 27 | ) 28 | parser.add_argument( 29 | "-r", 30 | "--reverse-prompt", 31 | type=str, 32 | help="in interactive mode, poll user input upon seeing PROMPT", 33 | default="User:", 34 | ) 35 | parser.add_argument( 36 | "--color", 37 | action="store_true", 38 | help="colorise output to distinguish prompt and user input from generations", 39 | default=True, 40 | ) 41 | parser.add_argument("-s", "--seed", type=int, default=-1, help="RNG seed (default: -1)") 42 | parser.add_argument( 43 | "-t", 44 | "--threads", 45 | type=int, 46 | default=8, 47 | help="number of threads to use during computation (default: 4)", 48 | ) 49 | parser.add_argument( 50 | "-p", 51 | "--prompt", 52 | type=str, 53 | help="prompt to start generation with (default: random)", 54 | default=prompt, 55 | ) 56 | # parser.add_argument( 57 | # "-f", "--file", type=str, default="", help="prompt file to start generation." 58 | # ) 59 | parser.add_argument( 60 | "-n", "--n_predict", type=int, default=256, help="number of tokens to predict (default: 128)" 61 | ) 62 | parser.add_argument("--top_k", type=int, default=40, help="top-k sampling (default: 40)") 63 | parser.add_argument("--top_p", type=float, default=0.95, help="top-p sampling (default: 0.1)") 64 | parser.add_argument( 65 | "--repeat_last_n", 66 | type=int, 67 | default=64, 68 | help="last n tokens to consider for penalize (default: 0)", 69 | ) 70 | parser.add_argument( 71 | "--repeat_penalty", 72 | type=float, 73 | default=1.30, 74 | help="penalize repeat sequence of tokens (default: 0.0)", 75 | ) 76 | parser.add_argument( 77 | "-c", 78 | "--ctx_size", 79 | type=int, 80 | default=4096, 81 | help="size of the prompt context (default: 4096)", 82 | ) 83 | parser.add_argument("--temp", type=float, default=0.8, help="temperature (default: 0.7)") 84 | parser.add_argument( 85 | "-b", 86 | "--batch_size", 87 | type=int, 88 | default=8, 89 | help="batch size for prompt processing (default: 2)", 90 | ) 91 | parser.add_argument("-m", "--model", type=str, default="./models/7B/ggml-model-q4_0.bin", help="model path (default: )") 92 | parser.add_argument("--mlock", action="store_true", help="use mlock to lock memory") 93 | parser.add_argument("--memory_f16", action="store_true", help="use half-precision memory") 94 | 95 | args = parser.parse_args(argv[1:]) 96 | 97 | return args 98 | 99 | 100 | def run(): 101 | args = parse_chat_params(sys.argv) 102 | 103 | args.instruct = False 104 | 105 | return llamacpp_main(args) 106 | 107 | 108 | if __name__ == "__main__": 109 | sys.exit(run()) 110 | -------------------------------------------------------------------------------- /src/llamacpp/cli.py: -------------------------------------------------------------------------------- 1 | """Python version of main.cpp""" 2 | import sys 3 | import argparse 4 | import llamacpp 5 | from typing import Dict 6 | 7 | 8 | def parse_args_into_params(argv) -> Dict[str, str]: 9 | """Parses arguments using argparse based on usage information above""" 10 | parser = argparse.ArgumentParser(description="llama.cpp CLI") 11 | parser.add_argument("-i", "--interactive", action="store_true", help="run in interactive mode") 12 | parser.add_argument( 13 | "-ins", "--instruct", 14 | action="store_true", 15 | help="run in 'instruct mode' where the user is prompted to enter a command", 16 | default=False, 17 | ) 18 | parser.add_argument( 19 | "-r", 20 | "--reverse-prompt", 21 | type=str, 22 | help="in interactive mode, poll user input upon seeing PROMPT", 23 | default="", 24 | ) 25 | parser.add_argument( 26 | "--color", 27 | action="store_true", 28 | help="colorise output to distinguish prompt and user input from generations", 29 | ) 30 | parser.add_argument("-s", "--seed", type=int, default=-1, help="RNG seed (default: -1)") 31 | parser.add_argument( 32 | "-t", 33 | "--threads", 34 | type=int, 35 | default=4, 36 | help="number of threads to use during computation (default: 4)", 37 | ) 38 | parser.add_argument( 39 | "-p", 40 | "--prompt", 41 | type=str, 42 | help="prompt to start generation with (default: random)", 43 | ) 44 | parser.add_argument( 45 | "-f", "--file", type=str, default="", help="prompt file to start generation." 46 | ) 47 | parser.add_argument( 48 | "-n", "--n_predict", type=int, default=128, help="number of tokens to predict (default: 128)" 49 | ) 50 | parser.add_argument("--top_k", type=int, default=40, help="top-k sampling (default: 40)") 51 | parser.add_argument("--top_p", type=float, default=0.95, help="top-p sampling (default: 0.1)") 52 | parser.add_argument( 53 | "--repeat_last_n", 54 | type=int, 55 | default=64, 56 | help="last n tokens to consider for penalize (default: 64)", 57 | ) 58 | parser.add_argument( 59 | "--repeat_penalty", 60 | type=float, 61 | default=1.30, 62 | help="penalize repeat sequence of tokens (default: 1.30)", 63 | ) 64 | parser.add_argument( 65 | "-c", 66 | "--ctx_size", 67 | type=int, 68 | default=512, 69 | help="size of the prompt context (default: 512)", 70 | ) 71 | parser.add_argument("--temp", type=float, default=0.8, help="temperature (default: 0.7)") 72 | parser.add_argument( 73 | "-b", 74 | "--batch_size", 75 | type=int, 76 | default=8, 77 | help="batch size for prompt processing (default: 8)", 78 | ) 79 | parser.add_argument("-m", "--model", type=str, default="./models/7B/ggml-model-q4_0.bin", help="model path (default: )") 80 | parser.add_argument("--mlock", action="store_true", help="use mlock to lock memory") 81 | parser.add_argument("--memory_f16", action="store_true", help="use half-precision memory") 82 | 83 | args = parser.parse_args(argv[1:]) 84 | 85 | if args.interactive or args.instruct: 86 | print("WARNING: interactive mode and instruct mode are currently broken") 87 | return args 88 | 89 | 90 | def process_interactive_input(model: llamacpp.LlamaInference): 91 | """Process interactive input similar to the C++ version""" 92 | 93 | # Read lines as long as user is entering "\" at the end of the line 94 | # Pass each line to the model 95 | while True: 96 | line = input() 97 | if line.endswith("\\"): 98 | line = line[:-1] 99 | model.update_input(line) 100 | else: 101 | model.update_input(line) 102 | break 103 | 104 | 105 | def main(args): 106 | """Main function""" 107 | 108 | # Add a space in front of the first character to match OG llama tokenizer behavior 109 | args.prompt = " " + args.prompt 110 | 111 | params = llamacpp.InferenceParams() 112 | params.path_model = args.model 113 | params.seed = args.seed 114 | params.n_threads = args.threads 115 | 116 | params.repeat_last_n = args.repeat_last_n 117 | params.n_batch = args.batch_size 118 | params.top_k = args.top_k 119 | params.top_p = args.top_p 120 | params.temp = args.temp 121 | params.repeat_penalty = args.repeat_penalty 122 | params.use_mlock = args.mlock 123 | params.memory_f16 = args.memory_f16 124 | params.n_ctx = args.ctx_size 125 | 126 | model = llamacpp.LlamaInference(params) 127 | model.update_input([model.token_bos()]) 128 | model.update_input(args.prompt) 129 | print(model.system_info()) 130 | 131 | inp_pfx = model.tokenize("\n\n### Instruction:\n\n", True) 132 | inp_sfx = model.tokenize("\n\n### Response:\n\n", False) 133 | 134 | if args.instruct: 135 | args.interactive = True 136 | args.reverse_prompt = "### Instruction:\n\n" 137 | 138 | # Set antiprompt if we are in interactive mode 139 | if args.reverse_prompt: 140 | args.interactive = True 141 | 142 | if args.interactive: 143 | print("== Running in interactive mode. ==") 144 | print(" - Press Ctrl+C to interject at any time.") 145 | print(" - Press Return to return control to LLaMa.") 146 | print(" - If you want to submit another line, end your input in '\\'.") 147 | print() 148 | is_interacting = True 149 | 150 | input_noecho = False 151 | is_finished = False 152 | 153 | print(args.prompt, end="") 154 | 155 | n_output = 0 156 | while n_output < args.n_predict: 157 | if model.has_unconsumed_input(): 158 | model.ingest_all_pending_input() 159 | # # reset color to default if we there is no pending user input 160 | # if (!input_noecho && args.use_color) { 161 | # printf(ANSI_COLOR_RESET); 162 | # } 163 | else: 164 | token = model.sample() 165 | text = model.token_to_str(token) 166 | print(text, end="") 167 | n_output += 1 168 | is_finished = token == model.token_eos() 169 | input_noecho = False 170 | 171 | if args.interactive: 172 | if model.is_antiprompt_present(): 173 | # reverse prompt found 174 | is_interacting = True 175 | if is_interacting: 176 | if args.instruct: 177 | model.update_input_tokens(inp_pfx) 178 | print("\n> ", end="") 179 | 180 | process_interactive_input(model) 181 | 182 | if args.instruct: 183 | model.update_input_tokens(inp_sfx) 184 | 185 | input_noecho = True 186 | is_interacting = False 187 | 188 | # end of text token was found 189 | if is_finished: 190 | if args.interactive: 191 | is_interacting = True 192 | else: 193 | print(" [end of text]") 194 | break 195 | 196 | if args.interactive and model.is_finished(): 197 | model.reset_remaining_tokens() 198 | is_interacting = True 199 | 200 | return 0 201 | 202 | 203 | def run(): 204 | # Parse params into a gpt_params object 205 | args = parse_args_into_params(sys.argv) 206 | 207 | # if args.file is specified, read the file and set the prompt to the contents 208 | if args.file: 209 | with open(args.file, "r") as f: 210 | args.prompt = f.read().strip() 211 | 212 | return main(args) 213 | 214 | 215 | if __name__ == "__main__": 216 | sys.exit(run()) 217 | -------------------------------------------------------------------------------- /src/llamacpp/convert.py: -------------------------------------------------------------------------------- 1 | # Convert a LLaMA model checkpoint to a ggjt compatible file 2 | # 3 | # Load the model using Torch 4 | # Iterate over all variables and write them to a binary file. 5 | # 6 | # For each variable, write the following: 7 | # - Number of dimensions (int) 8 | # - Name length (int) 9 | # - Dimensions (int[n_dims]) 10 | # - Name (char[name_length]) 11 | # - Data (float[n_dims]) 12 | # 13 | # At the start of the ggml file we write the model parameters 14 | # and vocabulary. 15 | # 16 | 17 | import argparse 18 | import os 19 | import sys 20 | import json 21 | import struct 22 | import numpy as np 23 | import torch 24 | 25 | from sentencepiece import SentencePieceProcessor 26 | 27 | QK = 32 28 | 29 | GGML_TYPE_Q4_0 = 0 30 | GGML_TYPE_Q4_1 = 1 31 | GGML_TYPE_I8 = 2 32 | GGML_TYPE_I16 = 3 33 | GGML_TYPE_I32 = 4 34 | GGML_TYPE_F16 = 5 35 | GGML_TYPE_F32 = 6 36 | 37 | WTYPES = { 38 | 0: GGML_TYPE_F32, 39 | 1: GGML_TYPE_F16, 40 | 2: GGML_TYPE_Q4_0, 41 | 3: GGML_TYPE_Q4_1, 42 | } 43 | 44 | GGML_BLCK_SIZE = { 45 | GGML_TYPE_Q4_0: QK, 46 | GGML_TYPE_Q4_1: QK, 47 | GGML_TYPE_I8: 1, 48 | GGML_TYPE_I16: 1, 49 | GGML_TYPE_I32: 1, 50 | GGML_TYPE_F16: 1, 51 | GGML_TYPE_F32: 1, 52 | } 53 | 54 | GGML_TYPE_SIZE = { 55 | GGML_TYPE_Q4_0: 4 + QK//2, 56 | GGML_TYPE_Q4_1: 4*2 + QK//2, 57 | GGML_TYPE_I8: 1, 58 | GGML_TYPE_I16: 2, 59 | GGML_TYPE_I32: 4, 60 | GGML_TYPE_F16: 2, 61 | GGML_TYPE_F32: 4, 62 | } 63 | 64 | def ggml_nelements(shape): 65 | r = 1 66 | for i in shape: 67 | r *= i 68 | return r 69 | 70 | def ggml_nbytes(shape, ftype): 71 | x = ggml_nelements(shape) 72 | t = WTYPES[ftype] 73 | x *= GGML_TYPE_SIZE[t] 74 | x //= GGML_BLCK_SIZE[t] 75 | return x 76 | 77 | def parse_args(): 78 | parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') 79 | parser.add_argument('dir_model', help='directory containing the model checkpoint') 80 | parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1) 81 | parser.add_argument('vocab_only', help='only write vocab to file', type=int, default=0, nargs='?') 82 | return parser.parse_args() 83 | 84 | def get_n_parts(dim): 85 | mappings = {4096: 1, 5120: 2, 6656: 4, 8192: 8} 86 | n_parts = mappings.get(dim) 87 | if n_parts is None: 88 | print(f"Invalid dim: {dim}") 89 | sys.exit(1) 90 | 91 | print(f"n_parts = {n_parts}\n") 92 | return n_parts 93 | 94 | def load_hparams_and_tokenizer(dir_model): 95 | # `dir_model` is something like `models/7B` or `models/7B/`. 96 | # "tokenizer.model" is expected under model's parent dir. 97 | # When `dir_model` is a symlink, f"{dir_model}/../tokenizer.model" would not be found. 98 | # Let's use the model's parent dir directly. 99 | model_parent_dir = os.path.dirname(os.path.normpath(dir_model)) 100 | fname_hparams = f"{dir_model}/params.json" 101 | fname_tokenizer = f"{model_parent_dir}/tokenizer.model" 102 | with open(fname_hparams, "r") as f: 103 | hparams = json.load(f) 104 | print(hparams) 105 | tokenizer = SentencePieceProcessor(fname_tokenizer) 106 | hparams.update({"vocab_size": tokenizer.vocab_size()}) 107 | return hparams, tokenizer 108 | 109 | def write_header(fout, hparams, ftype): 110 | keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"] 111 | values = [ 112 | 0x67676a74, # magic: ggjt in hex 113 | 1, # file version 114 | *[hparams[key] for key in keys], 115 | hparams["dim"] // hparams["n_heads"], # rot (obsolete) 116 | ftype 117 | ] 118 | fout.write(struct.pack("i" * len(values), *values)) 119 | 120 | def write_tokens(fout, tokenizer): 121 | for i in range(tokenizer.vocab_size()): 122 | if tokenizer.is_unknown(i): 123 | text = " \u2047 ".encode("utf-8") 124 | elif tokenizer.is_control(i): 125 | text = b"" 126 | elif tokenizer.is_byte(i): 127 | piece = tokenizer.id_to_piece(i) 128 | if len(piece) != 6: 129 | print(f"Invalid token: {piece}") 130 | sys.exit(1) 131 | byte_value = int(piece[3:-1], 16) 132 | text = struct.pack("B", byte_value) 133 | else: 134 | text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") 135 | fout.write(struct.pack("i", len(text))) 136 | fout.write(text) 137 | fout.write(struct.pack("f", tokenizer.get_score(i))) 138 | 139 | def process_and_write_variables(fout, model, ftype, part_id, n_parts): 140 | for name, datao in model.items(): 141 | if name.endswith("freqs"): 142 | continue 143 | 144 | # remove dimensions with a single element 145 | data = datao.numpy().squeeze() 146 | partshape = data.shape 147 | n_dims = len(data.shape) 148 | assert n_dims in (1, 2) 149 | 150 | print(f"Processing variable: {name} with shape: {partshape} and type: {datao.dtype}") 151 | 152 | # coerce single-dimensional tensors from float16 to float32 153 | ftype_cur = 1 154 | if ftype == 0 or n_dims == 1: 155 | print(" Converting to float32") 156 | data = data.astype(np.float32) 157 | ftype_cur = 0 158 | blck_size = GGML_BLCK_SIZE[WTYPES[ftype_cur]] 159 | type_size = GGML_TYPE_SIZE[WTYPES[ftype_cur]] 160 | 161 | # determine dimension along which multipart tensor is sharded 162 | # 163 | # split_dim 0 regex: 164 | # - output.* 165 | # - layers.*.attention.wq.weight 166 | # - layers.*.attention.wk.weight 167 | # - layers.*.attention.wv.weight 168 | # - layers.*.feed_forward.w1.weight 169 | # - layers.*.feed_forward.w3.weight 170 | # 171 | # split_dim 1 regex: 172 | # - tok_embeddings.* 173 | # - layers.*.attention.wo.weight 174 | # - layers.*.feed_forward.w2.weight 175 | # 176 | if n_dims > 1: 177 | split_dim = 1 178 | if "tok_embeddings" in name: 179 | split_dim = 1 180 | elif "layers" in name: 181 | if "attention.wo.weight" in name: 182 | split_dim = 1 183 | elif "feed_forward.w2.weight" in name: 184 | split_dim = 1 185 | else: 186 | split_dim = 0 187 | elif "output" in name: 188 | split_dim = 0 189 | 190 | # output tensor header 191 | fullshape = list(partshape) 192 | if n_dims > 1: 193 | fullshape[split_dim] *= n_parts 194 | sname = name.encode('utf-8') 195 | fout.write(struct.pack("iii", n_dims, len(sname), ftype_cur)) 196 | for dim in reversed(fullshape): 197 | fout.write(struct.pack("i", dim)) 198 | fout.write(sname) 199 | 200 | # ensure tensor data is aligned 201 | tensor_data_offset = fout.tell() 202 | while tensor_data_offset % QK != 0: 203 | fout.write(struct.pack("B", 0)) 204 | tensor_data_offset += 1 205 | 206 | # output unified mappable tensor data 207 | if n_dims == 1 or n_parts == 1: 208 | # copy tensor which we thankfully received in one piece 209 | if part_id == 0: 210 | data.tofile(fout) 211 | elif split_dim == 0: 212 | # reassemble multifile tensor containing some of the rows 213 | rows_per_chunk = partshape[0] 214 | current_row = part_id * rows_per_chunk 215 | bytes_per_row = fullshape[1] // blck_size * type_size 216 | offset = current_row * bytes_per_row 217 | fout.seek(tensor_data_offset + offset) 218 | data.tofile(fout) 219 | elif split_dim == 1: 220 | # reassemble multifile tensor containing some of the cols 221 | cols_per_chunk = partshape[1] 222 | current_col = part_id * cols_per_chunk 223 | bytes_per_row = fullshape[1] // blck_size * type_size 224 | offset_current_col = current_col // blck_size * type_size 225 | for row in range(partshape[0]): 226 | offset_row = row * bytes_per_row 227 | offset = offset_row + offset_current_col 228 | fout.seek(tensor_data_offset + offset) 229 | data[row].tofile(fout) 230 | 231 | # advance file position to next tensor 232 | fout.seek(tensor_data_offset + ggml_nbytes(fullshape, ftype_cur)) 233 | 234 | def main(): 235 | args = parse_args() 236 | dir_model = args.dir_model 237 | ftype = args.ftype 238 | ftype_str = ["f32", "f16"] 239 | hparams, tokenizer = load_hparams_and_tokenizer(dir_model) 240 | 241 | print(args) 242 | 243 | # if only writing vocab to file 244 | if args.vocab_only: 245 | fname_model = f"{dir_model}/consolidated.00.pth" 246 | fname_out = f"{dir_model}/ggml-vocab.bin" 247 | print(f"Extracting only the vocab from '{fname_model}'\n") 248 | with open(fname_out, "wb") as fout: 249 | write_header(fout, hparams, ftype) 250 | write_tokens(fout, tokenizer) 251 | print(f"Done. Output file: {fname_out}\n") 252 | return 253 | 254 | n_parts = get_n_parts(hparams["dim"]) 255 | fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin" 256 | 257 | # we output a single file for ggml 258 | with open(fname_out, "wb") as fout: 259 | write_header(fout, hparams, ftype) 260 | write_tokens(fout, tokenizer) 261 | offset_of_tensors = fout.tell() 262 | # the tensors we load could be split across multiple files 263 | for part_id in range(n_parts): 264 | fout.seek(offset_of_tensors) 265 | print(f"Processing part {part_id+1} of {n_parts}\n") 266 | fname_model = f"{dir_model}/consolidated.0{part_id}.pth" 267 | model = torch.load(fname_model, map_location="cpu") 268 | process_and_write_variables(fout, model, ftype, part_id, n_parts) 269 | del model 270 | 271 | print(f"Done. Output file: {fname_out}\n") 272 | 273 | if __name__ == "__main__": 274 | main() 275 | -------------------------------------------------------------------------------- /src/llamacpp/quantize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | def main(): 6 | """Pass command line arguments to llama_model_quantize""" 7 | import llamacpp 8 | 9 | # Print usage if not enough arguments are provided 10 | if len(sys.argv) < 2: 11 | print(f"Usage: llamacpp-quantize [=0]") 12 | print("bits: 0 = q4_0, 1 = q4_1\n") 13 | print("This script assumes that you have already used convert-pth-to-ggml.py to convert") 14 | print("the pytorch model to a ggml model. It will then quantize the ggml model to INT4") 15 | print("for use with the llamacpp library.\n") 16 | print("llamacpp-quantize will walk through the model_path directory and quantize all") 17 | print("ggml-model-f16.bin.* files it finds. The output files will be named") 18 | print("ggml-model-q4_0.bin.* or ggml-model-q4_1.bin.* depending on the value of .\n") 19 | sys.exit(1) 20 | 21 | model_path = sys.argv[1] 22 | if len(sys.argv) < 3: 23 | bits = 0 24 | else: 25 | bits = int(sys.argv[2]) 26 | 27 | # Convert "bits" to input for llama_model_quantize() 28 | if bits == 0: 29 | q_type = 2 30 | q_type_str = 'q4_0' 31 | elif bits == 1: 32 | q_type = 3 33 | q_type_str = 'q4_1' 34 | 35 | # Print the model path 36 | print(f"Quantizing model in {model_path} to {q_type_str}") 37 | 38 | # Walk through files in model_path matching ggml-model-q*.bin 39 | # and pass them to llama_model_quantize() 40 | for root, dirs, files in os.walk(model_path): 41 | for file in files: 42 | if file.startswith("ggml-model-f16") and file.endswith(".bin"): 43 | output_file = file.replace("-f16.bin", f"-{q_type_str}.bin") 44 | print(f"Quantizing file: {file} to {output_file}") 45 | llamacpp.llama_model_quantize(os.path.join(root, file), os.path.join(root, output_file), q_type) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /tests/test_llama_context.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import array 4 | import llamacpp 5 | import pytest 6 | 7 | 8 | @pytest.fixture(scope="session") 9 | def llama_context(): 10 | params = llamacpp.LlamaContextParams() 11 | params.seed = 19472 12 | params.logits_all = True 13 | # Get path to current file 14 | current_file_path = os.path.dirname(os.path.realpath(__file__)) 15 | # Get path to the model 16 | model_path = os.path.join(current_file_path, "../models/7B/ggml-model-f16.bin") 17 | return llamacpp.LlamaContext(model_path, params) 18 | 19 | 20 | def test_str_to_token(llama_context): 21 | prompt = "Hello World" 22 | prompt_tokens = llama_context.str_to_token(prompt, True) 23 | assert all(prompt_tokens == [1, 10994, 2787]) 24 | 25 | 26 | def test_token_to_str(llama_context): 27 | tokens = [1, 10994, 2787] 28 | text = ''.join([llama_context.token_to_str(token) for token in tokens]) 29 | assert text == "Hello World" 30 | 31 | 32 | def test_eval(llama_context): 33 | embd_inp = llama_context.str_to_token(" Llama is", True) 34 | n_past, n_remain, n_consumed = 0, 8, 0 35 | embd = [] 36 | 37 | output = '' 38 | while n_remain: 39 | if len(embd): 40 | llama_context.eval(array.array('i', embd), len(embd), n_past, 1) 41 | n_past += len(embd) 42 | embd.clear() 43 | 44 | if len(embd_inp) <= n_consumed: 45 | # sample 46 | top_k = 40 47 | top_p = 0.95 48 | temp = 0.8 49 | repeat_penalty = 1.0 50 | 51 | # sending an empty array for the last n tokens 52 | id = llama_context.sample_top_p_top_k(array.array('i', []), top_k, top_p, temp, repeat_penalty) 53 | # add it to the context 54 | embd.append(id) 55 | # decrement remaining sampling budget 56 | n_remain -= 1 57 | else: 58 | # has unconsumed input 59 | while len(embd_inp) > n_consumed: 60 | # update_input 61 | embd.append(embd_inp[n_consumed]) 62 | n_consumed += 1 63 | 64 | output += ''.join([llama_context.token_to_str(id) for id in embd]) 65 | assert output == " Llama is the newest member of our growing family" 66 | 67 | 68 | def test_get_logits_all(llama_context): 69 | """Verify that get_logits() returns a numpy array of the correct shape when 70 | logits_all is set to True.""" 71 | embd_inp = llama_context.str_to_token(" Llama is", True) 72 | llama_context.eval(array.array('i', embd_inp), len(embd_inp), 0, 1) 73 | logits = llama_context.get_logits() 74 | assert logits.shape == (len(embd_inp), 32000) 75 | 76 | 77 | if __name__ == '__main__': 78 | sys.exit(pytest.main(['-s', '-v', __file__])) 79 | -------------------------------------------------------------------------------- /tests/test_llama_inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pytest 4 | import llamacpp 5 | 6 | 7 | @pytest.fixture 8 | def nominal_params(): 9 | """Nominal InfereceParams""" 10 | params = llamacpp.InferenceParams() 11 | # Get path to current file 12 | current_file_path = os.path.dirname(os.path.realpath(__file__)) 13 | # Get path to the model 14 | model_path = os.path.join(current_file_path, "../models/7B/ggml-model-f16.bin") 15 | params.path_model = model_path 16 | params.seed = 19472 17 | params.repeat_penalty = 1.0 18 | return params 19 | 20 | 21 | @pytest.fixture 22 | def limited_context_params(nominal_params): 23 | """InferenceParams with limited context""" 24 | nominal_params.n_ctx = 32 25 | return nominal_params 26 | 27 | 28 | @pytest.fixture 29 | def llama_model(nominal_params): 30 | return llamacpp.LlamaInference(nominal_params) 31 | 32 | 33 | @pytest.fixture 34 | def limited_context_llama_model(limited_context_params): 35 | return llamacpp.LlamaInference(limited_context_params) 36 | 37 | 38 | def test_update_input(llama_model): 39 | prompt_tokens = [1, 2, 3] 40 | llama_model.update_input(prompt_tokens) 41 | assert llama_model.has_unconsumed_input() 42 | llama_model.ingest_all_pending_input() 43 | assert not llama_model.has_unconsumed_input() 44 | 45 | 46 | def test_tokenize(llama_model): 47 | prompt = "Hello World" 48 | prompt_tokens = llama_model.tokenize(prompt, True) 49 | assert prompt_tokens == [1, 10994, 2787] 50 | 51 | 52 | def test_token_to_str(llama_model): 53 | tokens = [1, 10994, 2787] 54 | text = ''.join([llama_model.token_to_str(token) for token in tokens]) 55 | assert text == "Hello World" 56 | 57 | 58 | def test_eval(llama_model): 59 | prompt = " Llama is" 60 | prompt_tokens = llama_model.tokenize(prompt, True) 61 | llama_model.update_input(prompt_tokens) 62 | llama_model.ingest_all_pending_input() 63 | output = prompt 64 | for i in range(8): 65 | llama_model.eval() 66 | token = llama_model.sample() 67 | output += llama_model.token_to_str(token) 68 | 69 | assert output == " Llama is the newest member of our growing family" 70 | 71 | 72 | def test_eval_exceed_n_ctx(limited_context_llama_model): 73 | # Tests context swapping feature 74 | llama_model = limited_context_llama_model 75 | prompt = " Llama is" 76 | prompt_tokens = llama_model.tokenize(prompt, True) 77 | llama_model.update_input(prompt_tokens) 78 | llama_model.ingest_all_pending_input() 79 | output = prompt 80 | # Generate 35 tokens with n_ctx of 32 81 | for i in range(35): 82 | llama_model.eval() 83 | token = llama_model.sample() 84 | output += llama_model.token_to_str(token) 85 | 86 | assert output == " Llama is the newest member of our growing family. We’re excited to welcome him to the pack!\nLlama is a male, born in 2017" 87 | 88 | 89 | if __name__=='__main__': 90 | sys.exit(pytest.main(['-s', '-v', __file__])) 91 | --------------------------------------------------------------------------------