├── .gitignore ├── LICENSE ├── README.md ├── batchplot.png ├── create_model.py ├── fastapi_helpers.py ├── hf_download.py ├── model_settings.py ├── ollama_registry.py ├── ollama_template.py ├── openai_types.py ├── requirements.txt ├── server.py ├── static └── status.html └── test_with_openai_module.py /.gitignore: -------------------------------------------------------------------------------- 1 | # test files for continue.dev 2 | junk.* 3 | # local db 4 | models.json 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ExLlamaV2-OpenAI-Server 2 | 3 | An implementation of the OpenAI API using the ExLlamaV2 backend. 4 | This project is not affiliated with ExLlamaV2 or OpenAI. 5 | 6 | ## Features 7 | 8 | * Continuous batching. 9 | * Streamed responses. 10 | * OpenAI compatibility for `/v1/models` and `/v1/chat/completions` endpoints 11 | * Uses Ollama model metadata information to set default prompting and parameters. 12 | * Remembers your settings per model. 13 | * Loads models on demand. 14 | * Status endpoint with graphs! (and one button!) 15 | 16 | I've been testing against the python openai module, [Ollama Web UI](https://github.com/ollama-webui/ollama-webui) and [continue.dev](https://continue.dev/). 17 | 18 | ## Origin Story 19 | 20 | This wouldn't be possible without [ExLlamaV2](https://github.com/turboderp/exllamav2) or EricLLM. I saw [EricLLM](https://github.com/epolewski/EricLLM) and thought it was close to 21 | doing what I wanted, and by the time I realized what I was doing, I had pretty much completely rewritten it. 22 | 23 | My goals are to be able to figure out how to set up a model once (preferably by leveraging work by the Ollama team) and then easily use it in a variety of frontends without thinking about it again. However, I also like to be able to quantize things to meet specific memory goals, and I like the performance of ExLlamaV2. Hence this project. 24 | 25 | ## Issues 26 | 27 | * I have ~~no~~ _some_ idea what I'm doing. 28 | * It's currently streaming everything internally, which is almost certainly slowing down non-streaming requests. 29 | * The ExLlamaV2 class `ExLlamaV2StreamingGenerator` has too much important stuff in it to avoid using it, but it also wasn't meant to be used this way. 30 | * Prompt parsing is synchronous, token decode is serialized with model inference, ... 31 | 32 | ## Installation 33 | 34 | ``` 35 | git clone https://github.com/bjj/exllamav2-openai-server 36 | cd exllamav2-openai-server 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | Notes: 41 | * Tested on python 3.11.7. 3.12+ seems to have version conflicts. 42 | * First start will take a long time to compile `exllamav2_ext`. 43 | 44 | ## Adding Models 45 | 46 | To add a new model: 47 | 48 | 1. Download a ExLlamaV2 compatible model (EXL2 or GPTQ) 49 | 2. Browse the [Ollama Library](https://ollama.ai/library) to find the matching repository. This is where we'll get prompt information and other default settings. 50 | 3. Run `python create_model.py --model-dir ` 51 | 4. Repeat for as many models as you want to use 52 | 53 | Note that tags are optional and often have the same metadata in the Ollama library. You can use them for yourself to give models unique names, for example `deepseek-coder:6.7b` and `deepseek-coder:33b`. This never downloads the GGUF files used by Ollama, so it doesn't matter what their default quantization is or what quantization tag you choose. **The quantization is determined by what EXL2 or GPTQ you download for yourself.** 54 | 55 | You can also pass options to `create_model.py` to override options provided by Ollama. For example, to add Mixtral-8x7B-Instruct, with a model in `E:\...`, prompting from Ollama, but with batching limited to 1 and context limited to 16k (for memory): 56 | 57 | ``` 58 | python .\create_model.py --model-dir E:\exl2-llm-models\turboderp\Mixtral-8x7B-instruct-3.5bpw-exl2\ --max-batch-size 1 --max-seq-len 16384 mixtral 59 | ``` 60 | 61 | You can add models while the server is running. It reads the `models.json` file again whenever it needs to know about the models. 62 | 63 | If there is no corresponding Ollama library config for your model, you can use `create_model.py --no-ollama` and specify everything (template, system prompt, etc) on the command line. 64 | 65 | ## Running the Server 66 | 67 | You can run the server with no arguments. It will listen on `0.0.0.0:8000` by default: 68 | 69 | ``` 70 | python server.py 71 | ``` 72 | 73 | The server takes several optional arguments. The options used are selected with the following priority: 74 | 75 | 1. The options provided in the API request (if they don't exceed limits) 76 | 2. The `server.py` command line arguments 77 | 3. The `create_model.py` command line arguments 78 | 4. The Ollama repository configuration data 79 | 5. The model's `config.json` 80 | 81 | For example, you pass `--max-batch-size 8` to the server. You get a batch size of 8 even though the model (see example above) was limited to `--max-batch-size 1`. 82 | 83 | You can do a quick test with `curl http://localhost:8000/v1/models` 84 | 85 | ## If you get "Out of Memory" 86 | 87 | When loading the model, the automatic GPU splitting in ExLlamaV2 allocates all the memory it could possibly need to satisfy the batching and context requirements. If you run out of memory (or your model loads onto two GPUs when you are sure it should fit on one), you can try these ideas. 88 | 89 | Try one of these options when creating the model with `create_model.py` or when launching `server.py`. Remember, command line arguments override model configuration: 90 | * Use `--cache-8bit` to reduce the memory footprint of the cache. This has a significant effect on memory without sacrificing much accuracy or speed. 91 | * Use `--max-batch-size` to reduce the batching. Maximum throughput can be achieved with fairly low batch sizes. Going higher than that just lets more users see progress happening on streaming requests at once. 92 | * Use `--max-seq-len` to reduce the maximum context length. If the model supports especially large context size this can catch you out, for example `dolphin-mistral:7b` or `mixtral`. 93 | * Use `--max-input-len` to change an internal batching size, which has a very small effect. 94 | 95 | If you use manual `--gpu_split` you will load the model without accounting for the memory needed to actually handle requests. This will work fine if you don't get many concurrent requests and/or don't use much context, but you risk running out of memory unexpectedly later. 96 | 97 | ## Monitoring 98 | 99 | There is a simple webpage at `http://localhost:8000/` 100 | 101 | ![screenshot](batchplot.png) 102 | 103 | ## Windows 104 | 105 | You can run on native Windows if you can set up a working python+CUDA environment. I recommend using the nvidia control panel to change "CUDA sysmem fallback policy" to "prefer no sysmem fallback" because I would rather OOM than do inference at the terrible speeds you get if you overflow into "shared" GPU memory. 106 | 107 | You can run on WSL2 by setting up CUDA using [nvidia's guide](https://docs.nvidia.com/cuda/wsl-user-guide/index.html). The server will translate Windows paths in your models.json so you can experiment with it. WSL2 access to Windows filesystems is unbelieveably slow, so models will take forever to load. If WSL2 is your long-term strategy you'll want your models in a native filesystem. On WSL2, fast models are a little faster (maybe 10%) and slow models are about the same. 108 | 109 | ## Multi GPU 110 | 111 | You can manually specify the memory split with `--gpu_split`, but it's very finicky to get right. Otherwise it will use ExLlamaV2's automatic splitting. Note that the auto splitting works by allocating as much memory as it will ever need for maximum context length and batch size. See "If you get 'Out of Memory'" above. 112 | -------------------------------------------------------------------------------- /batchplot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bjj/exllamav2-openai-server/d33a2b6ffaa92513f7c52a1f3ba2418e84633840/batchplot.png -------------------------------------------------------------------------------- /create_model.py: -------------------------------------------------------------------------------- 1 | 2 | # Given a model directory and the ollama repository path for the same model, 3 | # generate a configuration for that model. 4 | 5 | import json 6 | import sys 7 | import os 8 | import tempfile 9 | import argparse 10 | import asyncio 11 | import time 12 | from ollama_registry import get_ollama_model_descriptor 13 | from model_settings import ModelSettings 14 | 15 | registry_path = "models.json" 16 | 17 | # Run exllamav2 from a git checkout in a sibling dir 18 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/exllamav2") 19 | from exllamav2 import ExLlamaV2Config 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description="Add a model description.") 24 | parser.add_argument("--model-dir", metavar="MODEL_DIRECTORY", type=str, help="Sets model_directory", required=True) 25 | parser.add_argument("--no-ollama", action='store_true', help="Make a model without ollama data") 26 | ModelSettings.add_arguments(parser) 27 | parser.add_argument("repository") 28 | args = parser.parse_args() 29 | return args, ModelSettings.from_args(args) 30 | 31 | def read_registry(): 32 | global registry_path 33 | try: 34 | with open(registry_path, 'r') as file: 35 | registry = json.load(file) 36 | except FileNotFoundError: 37 | print(f"Creating new registry {registry_path}") 38 | registry = {} 39 | return registry 40 | 41 | def write_registry(registry): 42 | global registry_path 43 | 44 | # Try to atomically update the JSON 45 | temp_fd, temp_path = tempfile.mkstemp(dir='.') 46 | with os.fdopen(temp_fd, 'w') as temp_file: 47 | json.dump(registry, temp_file, indent=4) 48 | os.replace(temp_path, registry_path) 49 | 50 | async def main(): 51 | args, settings = parse_args() 52 | 53 | registry = read_registry() 54 | 55 | # Sanity check the model 56 | config = ExLlamaV2Config() 57 | config.model_dir = args.model_dir 58 | config.prepare() 59 | 60 | # Get ollama's description of the model 61 | if not args.no_ollama: 62 | ollama_descr = await get_ollama_model_descriptor(args.repository, debug=True) 63 | else: 64 | ollama_descr = {} 65 | 66 | record = { 67 | "model_dir": args.model_dir, 68 | "settings": settings.dict(), 69 | "ollama": ollama_descr, 70 | "created": int(time.time()), 71 | } 72 | 73 | if args.repository in registry: 74 | print(f"Replacing model {args.repository}, was:\n{json.dumps(registry[args.repository], indent=4)}") 75 | else: 76 | print(f"Adding new model {args.repository}") 77 | registry[args.repository] = record 78 | 79 | write_registry(registry) 80 | 81 | 82 | if __name__ == "__main__": 83 | asyncio.run(main()) 84 | -------------------------------------------------------------------------------- /fastapi_helpers.py: -------------------------------------------------------------------------------- 1 | from fastapi.responses import StreamingResponse 2 | from fastapi.encoders import jsonable_encoder 3 | import typing, json 4 | 5 | # Helper that takes a stream of objects and streams using "Server-sent events" 6 | SyncJsonStream = typing.Iterator[typing.Any] 7 | AsyncJsonStream = typing.AsyncIterable[typing.Any] 8 | JsonStream = typing.Union[AsyncJsonStream, SyncJsonStream] 9 | class StreamingJSONResponse(StreamingResponse): 10 | def __init__( 11 | self, 12 | content: JsonStream, 13 | **kw 14 | ) -> None: 15 | async def json_iterator(): 16 | if isinstance(content, typing.AsyncIterable): 17 | iter = content 18 | else: 19 | from starlette.concurrency import iterate_in_threadpool 20 | iter = iterate_in_threadpool(content) 21 | 22 | async for chunk in iter: 23 | text = json.dumps(jsonable_encoder(chunk), 24 | ensure_ascii=False, 25 | allow_nan=False, 26 | indent=None, 27 | separators=(",", ":"), 28 | ) 29 | text = "data: " + text + "\n\n" 30 | yield text.encode("utf-8") 31 | yield "data: [DONE]".encode("utf-8") 32 | 33 | super().__init__(content=json_iterator(), media_type="text/event-stream", **kw) -------------------------------------------------------------------------------- /hf_download.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import HfApi 2 | import asyncio, sys 3 | import httpx 4 | from tqdm import tqdm 5 | from pathlib import Path 6 | 7 | 8 | async def download_file(client, url, filepath: Path): 9 | async with client.stream("GET", url, follow_redirects=True) as response: 10 | response.raise_for_status() 11 | total_size = int(response.headers.get("Content-Length", 0)) 12 | with filepath.open("wb") as f: 13 | with tqdm(total=total_size, unit='B', unit_scale=True, desc=filepath.name) as progress_bar: 14 | async for chunk in response.aiter_bytes(): 15 | f.write(chunk) 16 | progress_bar.update(len(chunk)) 17 | 18 | async def download_hf_repo(repo_id, root_dir): 19 | root_path = Path(root_dir) 20 | repo_owner, repo_name = repo_id.split('/') 21 | path = root_path / repo_owner / repo_name 22 | path.mkdir(parents=True, exist_ok=True) 23 | 24 | api = HfApi() 25 | async with httpx.AsyncClient() as client: 26 | files = api.list_repo_files(repo_id=repo_id, repo_type="model") 27 | for file in files: 28 | url = f"https://huggingface.co/{repo_id}/resolve/main/{file}?download=true" 29 | await download_file(client, url, path / file) 30 | 31 | 32 | async def main(): 33 | await download_hf_repo(sys.argv[2], sys.argv[1]) 34 | 35 | if __name__ == "__main__": 36 | asyncio.run(main()) -------------------------------------------------------------------------------- /model_settings.py: -------------------------------------------------------------------------------- 1 | import argparse, re 2 | from pydantic import BaseModel, Field 3 | import typing 4 | from typing import Optional 5 | 6 | # transforms -> middle-out from openrouter.ai 7 | 8 | class ModelSettings(BaseModel): 9 | system_prompt: Optional[str] = None 10 | template: Optional[str] = None 11 | stop: Optional[list[str]] = None 12 | lora: Optional[str] = None 13 | max_seq_len: Optional[int] = None 14 | max_input_len: Optional[int] = None 15 | max_batch_size: Optional[int] = None 16 | rope_alpha: Optional[float] = None 17 | rope_scale: Optional[float] = None 18 | cache_8bit: Optional[bool] = None 19 | temperature: Optional[float] = None 20 | top_k: Optional[int] = None 21 | top_p: Optional[float] = None 22 | presence_penalty: Optional[float] = None 23 | frequency_penalty: Optional[float] = None 24 | repetition_penalty: Optional[float] = None 25 | min_p: Optional[float] = None 26 | top_a: Optional[float] = None 27 | logit_bias: Optional[dict[str, float]] = None 28 | 29 | def apply_to_exllamav2_settings(self, settings): 30 | """ 31 | Apply to ExLlamaV2Sampler.Settings 32 | """ 33 | settings.temperature = self.temperature 34 | settings.top_k = self.top_k 35 | settings.top_p = self.top_p 36 | settings.min_p = self.min_p 37 | settings.top_a = self.top_a 38 | settings.token_presence_penalty = self.presence_penalty 39 | settings.token_frequency_penalty = self.frequency_penalty 40 | settings.token_repetition_penalty = self.repetition_penalty 41 | settings.token_bias = self.logit_bias 42 | return settings 43 | 44 | def apply_to_config(self, config): 45 | """ 46 | Apply to ExLlamaV2Config 47 | """ 48 | config.max_batch_size = self.max_batch_size 49 | 50 | if self.max_seq_len: 51 | config.max_seq_len = self.max_seq_len 52 | 53 | if self.max_input_len: 54 | config.max_input_len = self.max_input_len 55 | 56 | if self.rope_scale is not None: 57 | config.scale_pos_emb = self.rope_scale 58 | 59 | if self.rope_alpha is not None: 60 | config.scale_rope_alpha = self.rope_alpha 61 | 62 | return config 63 | 64 | def inherit_from(self, *sets): 65 | """ 66 | Merges multiple sets together on top of this one, first is highest priority 67 | """ 68 | for s in sets: 69 | for name, field in ModelSettings.__fields__.items(): 70 | if getattr(self, name, None) is None: 71 | setattr(self, name, getattr(s, name, None)) 72 | 73 | def dict(self): 74 | result = super().dict() 75 | return {k: v for k, v in result.items() if v is not None} 76 | 77 | @staticmethod 78 | def add_arguments(parser): 79 | """ 80 | Adds command line arguments to a given parser based on the fields in the ModelSettings class. 81 | """ 82 | for name, field in ModelSettings.model_fields.items(): 83 | arg_name = (field.alias or name).replace('_', '-') 84 | required = True 85 | key_type = None 86 | is_complex = False 87 | type_ = typing.get_origin(field.annotation) 88 | if type_ is typing.Union: 89 | required = False # union w/None 90 | type_ = typing.get_args(field.annotation)[0] 91 | if typing.get_args(type_): 92 | keyvalue = (None,) + typing.get_args(type_) 93 | key_type, type_ = keyvalue[-2:] 94 | is_complex = True 95 | 96 | add = { 97 | "required": required 98 | } 99 | if type_ == bool: 100 | # store_true will default False rather than None and override 101 | add["action"] = 'store_const' 102 | add["const"] = True 103 | elif key_type is not None: # logit_bias 104 | add["action"] = _StoreDictStrFloat 105 | else: 106 | add["type"] = type_ 107 | if is_complex: 108 | add["action"] = 'append' 109 | parser.add_argument(f"--{arg_name}", **add) 110 | 111 | @staticmethod 112 | def from_args(args): 113 | """ 114 | Constructs a ModelSettings from the result of argparse.parse_args() 115 | """ 116 | kv = {k: v for k, v in vars(args).items() if v is not None and k in ModelSettings.__fields__} 117 | return ModelSettings(**kv) 118 | 119 | @staticmethod 120 | def defaults(): 121 | """ 122 | Defaults suitable to merge with. Anything not defaulted here either gets 123 | defaults from exllamav2 or can't be defaulted without more context. 124 | """ 125 | return ModelSettings( 126 | max_batch_size=4, 127 | system_prompt="", 128 | template="{{ .System }}{{ .Prompt }}", # response is implicit. this is "raw" 129 | stop=[], 130 | cache_8bit=False, 131 | temperature=0.8, 132 | top_k=0, # 0 means "none" in exllamav2 133 | top_p=1.0, 134 | presence_penalty=0.0, 135 | frequency_penalty=0.0, 136 | repetition_penalty= 1.0, 137 | min_p=0.0, 138 | top_a=0.0, 139 | ) 140 | 141 | class _StoreDictStrFloat(argparse.Action): 142 | def __call__(self, parser, namespace, values, option_string=None): 143 | my_dict = getattr(namespace, self.dest) or {} 144 | for kv in values.split(","): 145 | try: 146 | k, v = re.split('[:=]', kv) 147 | except ValueError: 148 | parser.error('Expected k:v or k=v') 149 | try: 150 | v = float(v) 151 | except ValueError: 152 | parser.error('Expected key=') 153 | my_dict[k] = v 154 | 155 | setattr(namespace, self.dest, my_dict) 156 | 157 | 158 | def _main(): 159 | parser = argparse.ArgumentParser() 160 | ModelSettings.add_arguments(parser) 161 | args = parser.parse_args() 162 | settings = ModelSettings.from_args(args) 163 | print(repr(settings)) 164 | print(repr(settings.dict())) 165 | 166 | if __name__ == "__main__": 167 | _main() 168 | -------------------------------------------------------------------------------- /ollama_registry.py: -------------------------------------------------------------------------------- 1 | # Fetch metadata about models known to ollama 2 | 3 | import sys 4 | import httpx 5 | import json 6 | import copy 7 | from urllib.parse import urljoin 8 | import asyncio 9 | 10 | async def get_url_of_type(url, mimetype): 11 | async with httpx.AsyncClient() as client: 12 | headers = { 13 | 'Accept': mimetype 14 | } 15 | 16 | try: 17 | response = await client.get(url, headers=headers, follow_redirects=True) 18 | 19 | # Check if the request was successful (HTTP status code 200) 20 | if response.status_code == 200: 21 | return response.text # You can also use response.content for binary data 22 | else: 23 | print(f"Request failed with status code {response.status_code}") 24 | return None 25 | 26 | except httpx.RequestError as e: 27 | print(f"Request error: {e}") 28 | return None 29 | 30 | 31 | async def get_ollama_model_descriptor(repository, debug=False): 32 | baseUrl = 'https://registry.ollama.ai/' 33 | namespace = 'library' 34 | tag = 'latest' 35 | 36 | if '/' in repository: 37 | namespace, repository = repository.split('/') 38 | if ':' in repository: 39 | repository, tag = repository.split(':') 40 | 41 | url = urljoin(baseUrl, f"v2/{namespace}/{repository}/manifests/{tag}") 42 | 43 | data = await get_url_of_type(url, 'application/vnd.docker.distribution.manifest.v2+json') 44 | 45 | if data: 46 | manifest = json.loads(data) 47 | else: 48 | print("Failed to retrieve data from the URL.") 49 | 50 | layerNames = { 51 | "application/vnd.ollama.image.model": "model", 52 | "application/vnd.ollama.image.license": "license", 53 | "application/vnd.ollama.image.template": "template", 54 | "application/vnd.ollama.image.params": "params", 55 | "application/vnd.ollama.image.system": "system", 56 | "application/vnd.docker.container.image.v1+json": "config", 57 | } 58 | 59 | blobs = copy.deepcopy(manifest['layers']) 60 | blobs.append(copy.deepcopy(manifest['config'])) 61 | 62 | pending = [] 63 | for blob in blobs: 64 | if blob['size'] > 2048: 65 | continue 66 | url = urljoin(baseUrl, f"v2/{namespace}/{repository}/blobs/{blob['digest']}") 67 | if debug: 68 | print(url) 69 | async def fetch(b, u): 70 | b['body'] = await get_url_of_type(u, b['mediaType']) 71 | pending.append(fetch(blob, url)) 72 | await asyncio.gather(*pending) 73 | 74 | descr = {} 75 | for blob in blobs: 76 | try: 77 | body = blob['body'] 78 | except KeyError: 79 | continue 80 | name = layerNames[blob['mediaType']] 81 | if name in ['params', 'config']: 82 | body = json.loads(body) 83 | descr[name] = body 84 | return descr 85 | 86 | async def main(): 87 | repository = sys.argv[1] 88 | 89 | descr = await get_ollama_model_descriptor(repository, debug=True) 90 | print(json.dumps(descr, indent=4)) 91 | 92 | 93 | if __name__ == "__main__": 94 | asyncio.run(main()) 95 | -------------------------------------------------------------------------------- /ollama_template.py: -------------------------------------------------------------------------------- 1 | # Use ollama style templates to generate chat prompts. 2 | # NOTE: ollama is dependent on Go's text/template and it's not easy to emulate or wrap 3 | # this is just dealing with basic substitutions, not {{if}} shenanigans and so on 4 | 5 | import re, platform 6 | import typing 7 | from pydantic import BaseModel 8 | from openai_types import ChatCompletions 9 | from create_model import read_registry 10 | from model_settings import ModelSettings 11 | 12 | def _windows_to_wsl2_path(windows_path): 13 | # Convert backslashes to forward slashes 14 | wsl_path = windows_path.replace('\\', '/') 15 | 16 | # Replace the drive letter and colon (e.g., "C:") with "/mnt/c" 17 | if wsl_path[1:3] == ':/': 18 | wsl_path = '/mnt/' + wsl_path[0].lower() + wsl_path[2:] 19 | 20 | return wsl_path 21 | 22 | class ModelFile: 23 | repository: str 24 | model_dir: str 25 | created: int 26 | settings: ModelSettings 27 | our_settings: ModelSettings 28 | ollama_settings: ModelSettings 29 | 30 | def __init__(self, repository): 31 | self.repository = repository 32 | registry = read_registry() 33 | try: 34 | record = registry[repository] 35 | except KeyError: 36 | raise FileNotFoundError() 37 | 38 | self.model_dir = record["model_dir"] 39 | if platform.system() != "Windows": 40 | self.model_dir = _windows_to_wsl2_path(self.model_dir) 41 | self.created = record["created"] 42 | 43 | # defaults from ollama 44 | ollama = record.get("ollama", {}) 45 | ollama_params = ollama.get("params", {}) 46 | self.ollama_settings = ModelSettings( 47 | template=ollama.get("template"), 48 | system_prompt=ollama.get("system"), 49 | max_seq_len=ollama_params.get("num_ctx"), 50 | stop=ollama_params.get("stop", []), 51 | ) 52 | 53 | self.our_settings = ModelSettings(**record.get("settings", {})) 54 | 55 | self.settings = self.our_settings.copy(deep=True) 56 | self.settings.inherit_from(self.ollama_settings) 57 | 58 | class Prompt: 59 | first: bool = True 60 | system_prompt: str = "" 61 | prompt: str = "" 62 | response: str = "" 63 | template: str 64 | 65 | result: str = "" 66 | 67 | def __init__(self, settings: ModelSettings): 68 | if settings.system_prompt is not None: 69 | self.system_prompt = settings.system_prompt 70 | self.template = settings.template.strip(" ") 71 | self.template = re.sub(r"{{\s+", "{{", self.template) 72 | self.template = re.sub(r"\s+}}", "}}", self.template) 73 | 74 | def flush(self, template=None): 75 | if template is None: 76 | template = self.template 77 | subbed = template.replace("{{.System}}", self.system_prompt) 78 | subbed = subbed.replace("{{.Prompt}}", self.prompt) 79 | if "{{.Response}}" in subbed: 80 | subbed = subbed.replace("{{.Response}}", self.response) 81 | else: 82 | subbed = subbed + self.response 83 | 84 | # we're not fully text/template compatible by a long shot 85 | if '{{' in subbed: 86 | raise(f'Incomplete template substitution {template}') 87 | 88 | self.first = False 89 | self.system_prompt = "" 90 | self.prompt = "" 91 | self.response = "" 92 | 93 | self.result += subbed 94 | 95 | def chatString(self, messages: list[ChatCompletions.Message]): 96 | if self.result: 97 | raise ("Do not re-use this object") 98 | for m in messages: 99 | if m.role == "system": 100 | # and not self.First => does not match ollama. It 101 | # would add a whole empty exchange with the model system prompt 102 | # and then the user system prompt. 103 | # This does replacement. Concatenation also makes sense 104 | if self.system_prompt and not self.first: 105 | self.flush() 106 | self.system_prompt = m.content 107 | elif m.role == "user": 108 | if self.prompt: 109 | self.flush() 110 | self.prompt = m.content 111 | elif m.role == "assistant": 112 | self.response = m.content 113 | self.flush() 114 | else: 115 | pass 116 | 117 | if self.prompt or self.system_prompt: 118 | pre = self.template.split("{{.Response}}") 119 | self.flush(pre[0]) 120 | 121 | return self.result 122 | 123 | 124 | def main(): 125 | model = ModelFile(repository="hello") 126 | p = Prompt(model) 127 | messages = [ 128 | ChatCompletions.Message(content="MySystemMessage", role="system"), 129 | ChatCompletions.Message(content="What is my name?", role="user"), 130 | ChatCompletions.Message(content="King Arthur", role="assistant"), 131 | ChatCompletions.Message(content="What is my quest?", role="user"), 132 | ] 133 | print(p.chatString(messages)) 134 | 135 | p = Prompt(model) 136 | messages = [ 137 | ChatCompletions.Message(content="What is my name?", role="user"), 138 | ChatCompletions.Message(content="King Arthur", role="assistant"), 139 | ChatCompletions.Message(content="What is my quest?", role="user"), 140 | ] 141 | print(p.chatString(messages)) 142 | 143 | 144 | if __name__ == "__main__": 145 | main() 146 | -------------------------------------------------------------------------------- /openai_types.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from pydantic import BaseModel 3 | import typing 4 | 5 | # Don't put defaults here or it overrides everything else 6 | class ChatCompletions(BaseModel): 7 | class Message(BaseModel): 8 | content: str 9 | role: str 10 | name: str | None = None 11 | 12 | messages: list[ChatCompletions.Message] 13 | model: str 14 | frequency_penalty: float | None = None 15 | logit_bias: dict | None = None 16 | logprobs: bool = False 17 | top_logprobs: int | None = None 18 | max_tokens: int | None = None 19 | n: int = 1 20 | repetition_penalty: float | None = None # openrouter 21 | presence_penalty: float | None = None 22 | response_format: dict | None = None 23 | seed: int | None = None 24 | stop: str | list[str] | None = None 25 | stream: bool = False 26 | temperature: float | None = None 27 | top_k: int | None = None # openrouter 28 | top_p: float | None = None 29 | min_p: float | None = None # openrouter 30 | top_a: float | None = None # openrouter 31 | tools: list = [] 32 | tool_choice: str | dict | None = None 33 | user: str | None = None 34 | 35 | class ChatCompletionsResponse(BaseModel): 36 | class Choice(BaseModel): 37 | class Message(BaseModel): 38 | content: str | None = None 39 | tool_calls: list | None = None 40 | role: str 41 | 42 | finish_reason: str # stop / length / content_filter / tool_calls / function_call 43 | index: int 44 | message: ChatCompletionsResponse.Choice.Message 45 | logprobs: None = None 46 | 47 | class Usage(BaseModel): 48 | completion_tokens: int = 0 49 | prompt_tokens: int = 0 50 | total_tokens: int = 0 51 | 52 | id: str 53 | choices: list[ChatCompletionsResponse.Choice] = [] 54 | created: int 55 | model: str 56 | system_fingerprint: str = "exllamav2" 57 | object: str = "chat.completion" 58 | usage: ChatCompletionsResponse.Usage 59 | 60 | ChatCompletionsResponse.update_forward_refs() 61 | ChatCompletionsResponse.Choice.update_forward_refs() 62 | 63 | 64 | class ChatCompletionsChunkResponse(BaseModel): 65 | class Choice(BaseModel): 66 | class Delta(BaseModel): 67 | content: str | None = None 68 | tool_calls: list | None = None 69 | role: str 70 | 71 | delta: ChatCompletionsChunkResponse.Choice.Delta 72 | finish_reason: str | None = None # stop / length / content_filter / tool_calls / function_call 73 | index: int 74 | logprobs: None = None 75 | 76 | id: str 77 | choices: list[ChatCompletionsChunkResponse.Choice] = [] 78 | created: int 79 | model: str 80 | system_fingerprint: str = "exllamav2" 81 | object: str = "chat.completion.chunk" 82 | 83 | ChatCompletionsChunkResponse.update_forward_refs() 84 | ChatCompletionsChunkResponse.Choice.update_forward_refs() 85 | 86 | class ModelsResponse(BaseModel): 87 | class Model(BaseModel): 88 | id: str 89 | created: int 90 | object: str = "model" 91 | onwed_by: str = "system" 92 | 93 | object: str = "list" 94 | data: list[ModelsResponse.Model] 95 | 96 | ModelsResponse.update_forward_refs() 97 | ModelsResponse.Model.update_forward_refs() 98 | 99 | class ErrorResponse(BaseModel): 100 | class Error(BaseModel): 101 | message: str 102 | type: str 103 | param: typing.Any = None 104 | code: typing.Any = None 105 | 106 | error: ErrorResponse.Error 107 | 108 | ErrorResponse.update_forward_refs() 109 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | pydantic>=2 3 | starlette 4 | torch>=2.1.0 5 | exllamav2>=0.0.11 6 | uvicorn 7 | httpx 8 | tokenizers -------------------------------------------------------------------------------- /server.py: -------------------------------------------------------------------------------- 1 | import sys, os, time, torch, random, asyncio, json, argparse, pathlib, gc 2 | import uuid 3 | import typing 4 | from fastapi import FastAPI, HTTPException, Request 5 | from pydantic import BaseModel, Field 6 | from starlette.websockets import WebSocket 7 | from fastapi.responses import JSONResponse, FileResponse, HTMLResponse 8 | from fastapi.middleware.cors import CORSMiddleware 9 | from fastapi.encoders import jsonable_encoder 10 | from openai_types import * 11 | from fastapi_helpers import StreamingJSONResponse 12 | import ollama_template 13 | from create_model import read_registry 14 | from model_settings import ModelSettings 15 | 16 | # Run exllamav2 from a git checkout in a sibling dir 17 | #sys.path.append( 18 | # os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/exllamav2" 19 | #) 20 | 21 | from exllamav2 import ( 22 | ExLlamaV2, 23 | ExLlamaV2Config, 24 | ExLlamaV2Cache, 25 | ExLlamaV2Tokenizer, 26 | ExLlamaV2Cache_8bit, 27 | ExLlamaV2Lora, 28 | ) 29 | from exllamav2.generator import ExLlamaV2Sampler, ExLlamaV2StreamingGenerator 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser(description="OpenAI compatible server for exllamav2.") 33 | parser.add_argument("--verbose", action="store_true", default=False, help="Sets verbose") 34 | parser.add_argument("--model", metavar="REPOSITORY", type=str, help="Initial model to load") 35 | parser.add_argument("--host", metavar="HOST", type=str, default="0.0.0.0", help="Sets host") 36 | parser.add_argument("--port", metavar="PORT", type=int, default=8000, help="Sets port") 37 | parser.add_argument("--timeout", metavar="TIMEOUT", type=float, default=600.0, help="Sets HTTP timeout") 38 | parser.add_argument("--cors", action="store_true", help="Wide open CORS settings") 39 | parser.add_argument("--gpu_split", metavar="GPU_SPLIT", type=str, default="", 40 | help="Sets array gpu_split and accepts input like 16,24. Default is automatic") 41 | ModelSettings.add_arguments(parser) 42 | 43 | return parser.parse_args() 44 | 45 | def load_modelfile(repository): 46 | return ollama_template.ModelFile(repository) 47 | 48 | def activate_modelfile(_modelfile): 49 | global modelfile, active_settings 50 | modelfile = _modelfile 51 | if modelfile is None: 52 | active_settings = None 53 | else: 54 | active_settings = args_settings.copy(deep=True) 55 | active_settings.inherit_from(modelfile.settings, ModelSettings.defaults()) 56 | 57 | request_unload = False 58 | request_cancel_all = False 59 | loaded_pct = None 60 | 61 | 62 | app = FastAPI() 63 | 64 | class ServerStatus: 65 | work_item_times: list[float] = [time.time()] 66 | work_items: list[int] = [0] 67 | 68 | queue_depth_times: list[float] = [time.time()] 69 | queue_depths: list[int] = [0] 70 | 71 | token_rate_times: list[float] = [time.time()] 72 | token_rates: list[float] = [0.0] 73 | 74 | mem_times: list[float] = [time.time()] 75 | mems: list[list[float]] = [] 76 | max_mems: list[float] = [] 77 | 78 | def __init__(self): 79 | for i in range(torch.cuda.device_count()): 80 | self.mems.append([0.0]) 81 | self.max_mems.append(torch.cuda.get_device_properties(i).total_memory / 1024**3) 82 | 83 | def update_work_items(self, n): 84 | if n != self.work_items[-1]: 85 | self.work_item_times.append(time.time()) 86 | self.work_items.append(n) 87 | 88 | def update_queue_depths(self, n): 89 | if n != self.queue_depths[-1]: 90 | self.queue_depth_times.append(time.time()) 91 | self.queue_depths.append(n) 92 | def increment_queue_depths(self): 93 | self.update_queue_depths(self.queue_depths[-1] + 1) 94 | 95 | def update_token_rates(self, n, offset=0, force=False): 96 | if n != self.token_rates[-1] or force: 97 | self.token_rate_times.append(time.time() + offset) 98 | self.token_rates.append(n) 99 | 100 | def update_memory(self, force=False): 101 | now = time.time() 102 | if now - self.mem_times[-1] > 5.0 or force: 103 | m = [] 104 | d = 0 105 | for i in range(torch.cuda.device_count()): 106 | m.append(torch.cuda.memory_reserved(i) / 1024**3) 107 | d = max(d, abs(m[i] - self.mems[i][-1])) 108 | if force or d > 0.1: 109 | self.mem_times.append(now) 110 | for i in range(torch.cuda.device_count()): 111 | self.mems[i].append(m[i]) 112 | 113 | status = ServerStatus() 114 | 115 | # XXX this inheritance isn't ideal because there are config settings in here too which won't work, 116 | # but the client can only actually specify things in ChatCompletions 117 | class QueueRequest(ModelSettings): 118 | request_id: str = Field(default_factory=lambda: str(uuid.uuid4())) 119 | modelfile: typing.Any 120 | messages: list[ChatCompletions.Message] 121 | completion_queue: typing.Any # asyncio.Queue 122 | max_tokens: int | None = None 123 | stream: bool = False 124 | finish_reason: str | None = None 125 | 126 | def __init__(__pydantic_self__, **kv): 127 | # Let caller pass None meaning "unspecified" rather than literal None 128 | kv = {k: v for k, v in kv.items() if v is not None} 129 | super().__init__(**kv) 130 | 131 | class QueueResponse(BaseModel): 132 | content: str 133 | finish_reason: str | None = None 134 | status_code: int = 200 135 | completion_tokens: int = 0 136 | prompt_tokens: int = 0 137 | 138 | 139 | prompts_queue = asyncio.Queue() 140 | 141 | processing_started = False 142 | model = None 143 | tokenizer = None 144 | loras = [] 145 | 146 | 147 | # We need the power of ExLlamaV2StreamingGenerator but we want to 148 | # batch, so we replace the actual inference in this inner function. 149 | # Ideally refactor exllamav2 so this is not necessary 150 | def patch_gen_single_token(sampler): 151 | def _gen_single_token(self, gen_settings, prefix_token = None): 152 | if self.draft_model is not None: 153 | raise NotImplementedError 154 | 155 | logits = self.logits_queue.pop(0) 156 | token, prob, eos = ExLlamaV2Sampler.sample(logits, gen_settings, self.sequence_ids[:1, :], random.random(), self.tokenizer, prefix_token) 157 | 158 | if self.sequence_ids.shape[0] > 1 and token.shape[0] == 1: 159 | self.sequence_ids = torch.cat([self.sequence_ids, token.repeat(self.sequence_ids.shape[0], 1)], dim = 1) 160 | else: 161 | self.sequence_ids = torch.cat([self.sequence_ids, token], dim = 1) 162 | 163 | gen_settings.feed_filters(token) 164 | if hasattr(self, "no_probs"): # hacks upon hacks: try to support 0.12 165 | return token, prob, eos 166 | else: 167 | return token, eos 168 | 169 | sampler.logits_queue = [] 170 | sampler._gen_single_token = _gen_single_token.__get__(sampler, ExLlamaV2StreamingGenerator) 171 | 172 | 173 | class WorkItem: 174 | generator: any 175 | output_str: str = "" 176 | cache: any 177 | exllamav2_settings: any 178 | completion_queue: any 179 | request: QueueRequest 180 | completion_tokens: int = 0 181 | prompt_tokens: int = 0 182 | first_content: bool = True 183 | 184 | async def inference_loop(): 185 | global prompts_queue, processing_started, status, modelfile, tokenizer, model, config 186 | processing_started = True 187 | 188 | # throttle streaming to 10/s instead of making big JSON HTTP responses for every token 189 | chunk_interval = 0.1 190 | next_stream_time = asyncio.get_event_loop().time() + chunk_interval 191 | 192 | pending_model_request = None 193 | 194 | work: list[WorkItem] = [] 195 | 196 | async def add_workitem(request): 197 | if request.finish_reason: 198 | return False 199 | 200 | request.inherit_from(active_settings) 201 | chat = ollama_template.Prompt(request).chatString(request.messages) 202 | #print(chat) 203 | 204 | input_ids = tokenizer.encode(chat) 205 | n_input_tokens = input_ids.shape[-1] 206 | print(f"num input tokens {n_input_tokens}") 207 | if n_input_tokens >= config.max_seq_len: 208 | response = QueueResponse(status_code=403, finish_reason="tokens_exceeded_error", 209 | content=f"Input tokens exceeded. Model limit: {config.max_seq_len}.") 210 | await request.completion_queue.put(response) 211 | return False 212 | 213 | item = WorkItem() 214 | item.prompt_tokens = n_input_tokens 215 | batch_size = 1 216 | max_tokens = request.max_tokens or config.max_seq_len 217 | CacheClass = ExLlamaV2Cache_8bit if active_settings.cache_8bit else ExLlamaV2Cache 218 | item.cache = CacheClass(model, max_seq_len=min(config.max_seq_len, (n_input_tokens + max_tokens)), batch_size=batch_size) 219 | item.generator = ExLlamaV2StreamingGenerator(model, item.cache, tokenizer) 220 | patch_gen_single_token(item.generator) 221 | item.exllamav2_settings = ExLlamaV2Sampler.Settings() 222 | request.apply_to_exllamav2_settings(item.exllamav2_settings) 223 | item.generator.set_stop_conditions([tokenizer.eos_token_id, *request.stop]) 224 | item.completion_queue = request.completion_queue 225 | item.request = request 226 | token_healing_must_be_false = False # see below 227 | item.generator.begin_stream(input_ids, item.exllamav2_settings, loras=loras, token_healing=token_healing_must_be_false) 228 | work.append(item) 229 | status.update_memory() 230 | return True 231 | 232 | token_rate_start_time = asyncio.get_event_loop().time() 233 | token_rate_count = 0 234 | def update_token_rates(force=False): 235 | nonlocal token_rate_start_time, token_rate_count 236 | now = asyncio.get_event_loop().time() 237 | duration = now - token_rate_start_time 238 | if duration < 1.0 and not force: 239 | return 240 | if duration > 0: 241 | status.update_token_rates(token_rate_count / duration, -duration / 2, force) 242 | token_rate_start_time = now 243 | token_rate_count = 0 244 | 245 | def update_queue_depths(): 246 | global prompts_queue 247 | nonlocal pending_model_request 248 | status.update_queue_depths(prompts_queue.qsize() + (1 if pending_model_request is not None else 0)) 249 | 250 | while processing_started: 251 | added = False 252 | 253 | global request_cancel_all 254 | async def send_cancel(req): 255 | response = QueueResponse(status_code=429, finish_reason="rate_limit_error", 256 | content=f"Request queue was flushed while request was pending.") 257 | await req.completion_queue.put(response) 258 | 259 | if request_cancel_all: 260 | if pending_model_request is not None: 261 | await send_cancel(pending_model_request) 262 | pending_model_request = None 263 | for w in work: 264 | w.request.finish_reason = "canceled" 265 | while not prompts_queue.empty(): 266 | await send_cancel(prompts_queue.get_nowait()) 267 | if not work and prompts_queue.empty(): 268 | request_cancel_all = False 269 | 270 | global request_unload 271 | if request_unload and not work: 272 | pending_model_request = None 273 | activate_modelfile(None) 274 | unload_model() 275 | request_unload = False 276 | 277 | # If we need a new model, handle that when the work queue drains 278 | if pending_model_request and not work: 279 | update_token_rates(True) 280 | try: 281 | activate_modelfile(pending_model_request.modelfile) 282 | await load_model() 283 | if await add_workitem(pending_model_request): 284 | added = True 285 | except Exception as e: 286 | response = QueueResponse(status_code=500, finish_reason="server_error", 287 | content=f"Unable to load requested model {modelfile.repository}.") 288 | await pending_model_request.completion_queue.put(response) 289 | activate_modelfile(None) 290 | pending_model_request = None 291 | update_queue_depths() 292 | update_token_rates(True) 293 | status.update_token_rates(0, force=True) # need a better way to handle this boundary 294 | 295 | # If pending model request, do not add more work items. 296 | # Else enter this (possibly blocking) loop if there's nothing to do (ok to block) 297 | # or if we could accept more work and the queue isn't empty (no blocking) 298 | while pending_model_request is None and (len(work) == 0 or (len(work) < config.max_batch_size and prompts_queue.qsize() != 0)): 299 | try: 300 | request: QueueRequest = await asyncio.wait_for(prompts_queue.get(), 0.5) 301 | except asyncio.TimeoutError: 302 | update_token_rates() 303 | break 304 | 305 | if request.finish_reason: 306 | update_queue_depths() 307 | continue 308 | 309 | if modelfile is None or request.modelfile.repository != modelfile.repository: 310 | pending_model_request = request 311 | break 312 | 313 | if await add_workitem(request): 314 | added = True 315 | 316 | update_queue_depths() 317 | if added: 318 | status.update_work_items(len(work)) 319 | update_token_rates(len(work) == 1) 320 | 321 | # process as long as there are incomplete requests 322 | if work: 323 | send_chunk = False 324 | now = asyncio.get_event_loop().time() 325 | if now >= next_stream_time: 326 | next_stream_time = now + chunk_interval 327 | send_chunk = True 328 | update_token_rates() 329 | 330 | inputs = torch.cat([w.generator.sequence_ids[:, -1:] for w in work], dim=0) 331 | caches = [w.cache for w in work] 332 | # NOTE: can run out of memory here. Need to handle that. torch.cuda.OutOfMemoryError 333 | logits = model.forward(inputs, caches, input_mask=None, loras=loras).float() 334 | inputs = None 335 | caches = None 336 | event = torch.cuda.Event() 337 | event.record(torch.cuda.default_stream()) 338 | token_rate_count += len(work) 339 | 340 | # yield to HTTP threads or we can't stream (and batched responses are all as slow as the last one). 341 | # Without the loop, other things will not get enough time to run (if you have a stack of functions 342 | # yielding values, only one will run each sleep(0) while making the next runnable). 343 | # Sleeping for nonzero time here is almost guaranteed to wait too long due to sleep granularity. 344 | while not event.query(): 345 | await asyncio.sleep(0) 346 | 347 | # sync with GPU 348 | logits = logits.cpu() 349 | 350 | eos = [] 351 | for i in range(len(work)): 352 | item = work[i] 353 | 354 | item.completion_tokens += 1 355 | item.generator.logits_queue.append(logits[i: i + 1, :, :]) 356 | # with token_healing off, this queue only needs depth 1. 357 | # Continuing here can't work because we must update item.generator.sequence_ids before 358 | # generating the next batch. So more invasive changes to ExLlamaV2StreamingGenerator 359 | # would be required. 360 | #if len(item.generator.logits_queue) < 2: # from inspection, most .stream() will consume 361 | # continue 362 | 363 | chunk, stopped, tokens = item.generator.stream() 364 | item.output_str += chunk 365 | 366 | limited = item.cache.current_seq_len >= item.cache.max_seq_len 367 | final = stopped or limited 368 | finish_reason = None 369 | if final: 370 | finish_reason = "stop" if stopped and not limited else "length" 371 | 372 | if item.request.finish_reason: 373 | final = True 374 | finish_reason = item.request.finish_reason 375 | 376 | if final or (item.request.stream and send_chunk and item.output_str): 377 | try: 378 | content = item.output_str 379 | if item.first_content: 380 | content = content.lstrip() 381 | item.first_content = False 382 | if final: 383 | content = content.rstrip() 384 | response = QueueResponse(content=content, finish_reason=finish_reason, 385 | prompt_tokens=item.prompt_tokens, completion_tokens=item.completion_tokens) 386 | await item.completion_queue.put(response) 387 | except Exception as e: 388 | print(f"Error processing completed prompt: {e}") 389 | final = True 390 | if final: 391 | eos.insert(0, i) # Indices of completed prompts 392 | else: 393 | # reset after sending stream delta 394 | item.output_str = "" 395 | 396 | item = None 397 | 398 | logits = None 399 | # Remove completed prompts from the list 400 | for i in eos: 401 | work.pop(i) 402 | if eos: 403 | gc.collect() 404 | if not work and prompts_queue.qsize() == 0: 405 | update_token_rates(True) 406 | if eos and (prompts_queue.qsize() == 0 and not pending_model_request): 407 | status.update_work_items(len(work)) 408 | status.update_memory() 409 | 410 | 411 | @app.get("/", response_class=typing.Union[HTMLResponse, FileResponse]) 412 | def status_page(): 413 | file_path = pathlib.Path('static/status.html') 414 | 415 | if file_path.exists(): 416 | return FileResponse(file_path.resolve(), media_type='text/html') 417 | else: 418 | return HTMLResponse(f"Server is running model {modelfile.repository} but status.html is missing") 419 | 420 | @app.websocket("/ws/status") 421 | async def websocket_status(websocket: WebSocket): 422 | global status, loaded_pct 423 | 424 | await websocket.accept() 425 | while True: 426 | await asyncio.sleep(1) 427 | model_name = None 428 | if modelfile: 429 | model_name = modelfile.repository 430 | if loaded_pct is not None and loaded_pct < 100: 431 | model_name = f"{model_name} {loaded_pct:.1f}%" 432 | data = { 433 | "model": model_name, 434 | "queues": [ 435 | { "x": status.work_item_times, "y": status.work_items, "name": "run" }, 436 | { "x": status.queue_depth_times, "y": status.queue_depths, "name": "wait" }, 437 | ], 438 | "rates": [ 439 | { "x": status.token_rate_times, "y": status.token_rates, "name": "tok/s" }, 440 | ], 441 | "mems": [ 442 | { "x": status.mem_times, "y": status.mems[i], "name": f"gpu{i}" } for i in range(torch.cuda.device_count()) 443 | ] 444 | } 445 | try: 446 | await websocket.send_json(data) 447 | except Exception as e: 448 | break 449 | 450 | @app.post("/unload", response_class=JSONResponse) 451 | def handle_unload(): 452 | global request_unload, request_cancel_all 453 | request_cancel_all = True 454 | request_unload = True 455 | return {"status": "ok"} 456 | 457 | 458 | # This is meant to be returned. If raised, catch yourself and return 459 | # Maybe there's a cleaner way to do this by inheriting from HTTPException 460 | class ApiErrorResponse(JSONResponse, Exception): 461 | def __init__( 462 | self, *, 463 | status_code: int, 464 | message: str, 465 | type: str, 466 | param: typing.Any = None, 467 | code: typing.Any = None 468 | ): 469 | error = ErrorResponse.Error(message=message, type=type, param=param, code=code) 470 | response = ErrorResponse(error=error) 471 | super().__init__(status_code=status_code, content=jsonable_encoder(response)) 472 | 473 | last_queued_modelfile = None 474 | 475 | @app.post( 476 | "/v1/chat/completions", 477 | response_class=typing.Union[StreamingJSONResponse, JSONResponse, ApiErrorResponse], 478 | ) 479 | async def chat_completions(fastapi_request: Request, prompt: ChatCompletions): 480 | global modelfile, prompts_queue, config, status, last_queued_modelfile 481 | 482 | # if idle, initialize 483 | if prompts_queue.qsize() == 0: 484 | last_queued_modelfile = modelfile 485 | 486 | if not last_queued_modelfile or prompt.model != last_queued_modelfile.repository: 487 | try: 488 | newmodelfile = load_modelfile(prompt.model) 489 | except FileNotFoundError: 490 | return ApiErrorResponse(status_code=400, type="invalid_request_error", 491 | message=f"Model \"{prompt.model}\" is not available. Try adding it with create_model.py") 492 | last_queued_modelfile = newmodelfile 493 | 494 | # Listify stop 495 | stop = prompt.stop 496 | if prompt.stop is None: 497 | stop = [] 498 | else: 499 | stop = [stop] if isinstance(stop, str) else stop 500 | 501 | # what a terrible interface Request.is_disconnected() is 502 | async def poll_is_disconnected(fastapi_request, request): 503 | try: 504 | while not request.finish_reason and not await fastapi_request.is_disconnected(): 505 | await asyncio.sleep(0.5) 506 | if not request.finish_reason: 507 | print(">> Client disconnected!") 508 | request.finish_reason = "disconnected" 509 | except asyncio.CancelledError: 510 | pass 511 | 512 | request = QueueRequest( 513 | modelfile=last_queued_modelfile, 514 | completion_queue=asyncio.Queue(0), 515 | **prompt.dict() 516 | ) 517 | 518 | await prompts_queue.put(request) 519 | status.increment_queue_depths() 520 | 521 | created = int(time.time()) # constant for all chunks according to api docs 522 | asyncio.create_task(poll_is_disconnected(fastapi_request, request)) 523 | 524 | async def gen(): 525 | while request.finish_reason is None: 526 | try: 527 | qresponse: QueueResponse = await asyncio.wait_for(request.completion_queue.get(), timeout=timeout) 528 | request.finish_reason = qresponse.finish_reason 529 | except asyncio.TimeoutError: 530 | request.finish_reason = "timeout" # abort inference 531 | raise ApiErrorResponse(status_code=408, type="timeout", message=f"Processing did not complete within {timeout} seconds.") 532 | if qresponse.status_code >= 300: 533 | raise ApiErrorResponse(status_code=qresponse.status_code, type=qresponse.finish_reason, 534 | message=qresponse.content) 535 | 536 | if request.stream: 537 | delta = ChatCompletionsChunkResponse.Choice.Delta(content=qresponse.content, role="assistant") 538 | choice = ChatCompletionsChunkResponse.Choice(finish_reason=request.finish_reason, index=1, delta=delta) 539 | response = ChatCompletionsChunkResponse( 540 | id=request.request_id, 541 | choices=[choice], 542 | created=created, 543 | model=prompt.model, 544 | ) 545 | # print(".", end="\n" if finish_reason is not None else "") 546 | #print(qresponse.content, end="\n" if request.finish_reason is not None else "") 547 | #sys.stdout.flush() 548 | # print(repr(response)) 549 | yield response 550 | else: 551 | if request.finish_reason is None: 552 | raise HTTPException(status_code=505, detail="Tried to stream non-streaming request") 553 | message = ChatCompletionsResponse.Choice.Message(content=qresponse.content, role="assistant") 554 | choice = ChatCompletionsResponse.Choice(finish_reason=request.finish_reason, index=1, message=message) 555 | usage = ChatCompletionsResponse.Usage( 556 | prompt_tokens=qresponse.prompt_tokens, 557 | completion_tokens=qresponse.completion_tokens, 558 | total_tokens=qresponse.prompt_tokens + qresponse.completion_tokens 559 | ) 560 | response = ChatCompletionsResponse( 561 | id=request.request_id, 562 | choices=[choice], 563 | created=created, 564 | model=prompt.model, 565 | usage=usage, 566 | ) 567 | #print(repr(response)) 568 | yield response 569 | 570 | try: 571 | # catch error from first gen. I don't love this 572 | response = await gen().__anext__() 573 | if request.stream: 574 | async def concat(first, rest): 575 | yield first 576 | async for next in rest: 577 | yield next 578 | return StreamingJSONResponse(concat(response, gen())) 579 | else: 580 | return JSONResponse(jsonable_encoder(response)) 581 | except ApiErrorResponse as e: 582 | return e 583 | 584 | @app.get("/v1/models", response_class=JSONResponse) 585 | async def api_models(): 586 | registry = read_registry() 587 | models = [] 588 | # make active model first 589 | if modelfile: 590 | models.append(ModelsResponse.Model(id=modelfile.repository, created=modelfile.created)) 591 | for k, v in registry.items(): 592 | if modelfile and modelfile.repository == k: 593 | continue 594 | models.append(ModelsResponse.Model(id=k, created=v["created"])) 595 | response = ModelsResponse(data=models) 596 | return response 597 | 598 | 599 | def setup_gpu_split(args): 600 | global gpu_split 601 | gpu_split = None 602 | if args.gpu_split: 603 | gpu_split = list(map(float, args.gpu_split.split(","))) 604 | 605 | async def load_model(): 606 | global args, model, modelfile, active_settings, tokenizer, loras, config, gpu_split, loaded_pct 607 | 608 | unload_model() 609 | 610 | print("Loading model: " + modelfile.repository) 611 | print("From: " + modelfile.model_dir) 612 | print("Settings: " + repr(active_settings)) 613 | 614 | try: 615 | config = ExLlamaV2Config() 616 | config.model_dir = modelfile.model_dir 617 | config.prepare() 618 | active_settings.apply_to_config(config) 619 | 620 | # use loading status callback to yield to web thread 621 | status.update_memory(force=True) 622 | def callback_gen(idx, total): 623 | global request_unload 624 | if request_unload: 625 | raise ValueError("force unloaded") 626 | yield 100.0 * idx / total 627 | status.update_memory(force=True) 628 | 629 | model = ExLlamaV2(config) 630 | if gpu_split is not None: 631 | loader = model.load_gen(gpu_split=gpu_split, callback_gen=callback_gen) 632 | else: 633 | CacheClass = ExLlamaV2Cache_8bit if active_settings.cache_8bit else ExLlamaV2Cache 634 | scratch_cache = CacheClass(model, max_seq_len=config.max_seq_len, batch_size=config.max_batch_size, lazy = True) 635 | loader = model.load_autosplit_gen(scratch_cache, callback_gen=callback_gen) 636 | for pct in loader: 637 | loaded_pct = pct 638 | await asyncio.sleep(0) 639 | 640 | tokenizer = ExLlamaV2Tokenizer(config) 641 | if active_settings.lora: 642 | lora = ExLlamaV2Lora.from_directory(model, active_settings.lora) 643 | loras.append(lora) 644 | except Exception as e: 645 | import traceback 646 | traceback.print_exception(e); 647 | print(f"Exception loading {modelfile.repository}: {str(e)}") 648 | unload_model() 649 | raise 650 | 651 | loaded_pct = None 652 | print(f"Model is loaded.") 653 | 654 | 655 | def unload_model(): 656 | global model, config, tokenizer, loras 657 | status.update_memory(force=True) 658 | if model: 659 | model.unload() 660 | model = None 661 | config = None 662 | tokenizer = None 663 | loaded_pct = None 664 | for lora in loras: 665 | lora.unload() 666 | loras = [] 667 | gc.collect() 668 | torch.cuda.empty_cache() 669 | torch.cuda.synchronize() 670 | status.update_memory(force=True) 671 | gc.collect() 672 | 673 | @app.on_event("startup") 674 | async def startup_event(): 675 | global args, modelfile 676 | 677 | print("Starting up...") 678 | if modelfile: 679 | await load_model() 680 | asyncio.create_task(inference_loop()) 681 | 682 | 683 | @app.on_event("shutdown") 684 | async def shutdown_event(): 685 | global processing_started 686 | print("Shutting down...") 687 | processing_started = False 688 | unload_model() 689 | 690 | 691 | if __name__ == "__main__": 692 | import uvicorn 693 | 694 | args = parse_args() 695 | 696 | if args.cors: 697 | app.add_middleware( 698 | CORSMiddleware, 699 | allow_origins=["*"], 700 | allow_credentials=True, 701 | allow_methods=["*"], 702 | allow_headers=["*"], 703 | ) 704 | 705 | global args_settings 706 | args_settings = ModelSettings.from_args(args) 707 | if args.model: 708 | try: 709 | activate_modelfile(load_modelfile(args.model)) 710 | # xxx oops not loading model 711 | except FileNotFoundError: 712 | print(f"Could not load model {args.model}. Try python create_model.py...") 713 | sys.exit(1) 714 | else: 715 | activate_modelfile(None) 716 | global timeout 717 | timeout = args.timeout 718 | setup_gpu_split(args) 719 | 720 | print(f"Starting a server at {args.host} on port {args.port}...") 721 | uvicorn.run( 722 | "__main__:app", 723 | host=args.host, 724 | port=args.port, 725 | http="h11", 726 | ) 727 | -------------------------------------------------------------------------------- /static/status.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 26 | 27 | 28 | 29 |
30 | waiting for status... 31 | 32 |
33 |
34 | 180 | 181 | 182 | -------------------------------------------------------------------------------- /test_with_openai_module.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import asyncio 3 | import random 4 | 5 | # Set your OpenAI API key here 6 | api_key = "YOUR_API_KEY" 7 | 8 | # Define a conversation prompt 9 | conversation_prompt = ( 10 | "Write an essay on the subject of history being written by the victors." 11 | ) 12 | 13 | client = openai.AsyncOpenAI(base_url="http://localhost:8000/v1", api_key=api_key) 14 | 15 | async def request(): 16 | global client, models 17 | 18 | start_time = asyncio.get_event_loop().time() 19 | try: 20 | response = await client.chat.completions.create( 21 | model=random.choice(models).id, 22 | #model=models[0].id, # we always return active model first 23 | messages=[ 24 | {"role": "system", "content": "You are a helpful assistant."}, 25 | {"role": "user", "content": conversation_prompt}, 26 | ], 27 | max_tokens=2000, 28 | n=1, 29 | stop=None, 30 | ) 31 | content = response.choices[0].message.content 32 | tokens = response.usage.completion_tokens 33 | except Exception as e: 34 | content = repr(e) 35 | tokens = 0 36 | duration = asyncio.get_event_loop().time() - start_time 37 | print(f"duration {duration} {tokens / duration:.2f} t/s") 38 | return content 39 | 40 | async def main(): 41 | global models 42 | 43 | # wow this is an annoying interface 44 | models = [] 45 | async for model_page in client.models.list(): 46 | models.append(model_page) 47 | 48 | requests = [] 49 | for i in range(20): 50 | requests.append(request()) 51 | all = await asyncio.gather(*requests) 52 | print(all[int(len(all)/2)]) 53 | 54 | if __name__ == "__main__": 55 | asyncio.run(main()) 56 | --------------------------------------------------------------------------------