├── .gitignore ├── LICENSE ├── README.md ├── presets └── place_your_presets_here.txt ├── requirements.txt └── webui.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Presets folder 163 | presets/* 164 | !presets/place_your_presets_here.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 The Royal Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tabbyAPI-gradio-loader 2 | A simple Gradio WebUI for loading/unloading models and loras for tabbyAPI. This provides demo functionality for accessing tabbyAPI's extensive feature base via API, and can be run remotely on a separate system. 3 | 4 | ## Usage 5 | This repo is meant serve as a demo of the API's features and provide an accessible means to change models without editing the config and restarting the instance. Supports speculative decoding and loading of multiple loras with custom scaling. 6 | 7 | This WebUI does not provide an LLM inference frontend - use any OAI-compatible inference frontend of your choosing. 8 | 9 | ## Prerequisites 10 | 11 | To get started, make sure you have the following installed on your system: 12 | 13 | - Python 3.8+ (preferably 3.11) with pip 14 | 15 | ## Installation 16 | 17 | 1. Clone this repository to your machine: `git clone https://github.com/theroyallab/tabbyAPI-gradio-loader` 18 | 2. Navigate to the project directory: `cd tabbyAPI-gradio-loader` 19 | 3. Create a python virtual environment: `python -m venv venv` 20 | 4. Activate the virtual environment: 21 | 1. On Windows (Using powershell or Windows terminal): `.\venv\Scripts\activate.` 22 | 2. On Linux: `source venv/bin/activate` 23 | 5. Install the requirements file: `pip install -r requirements.txt` 24 | 25 | ## Launching the Application 26 | 1. Make sure you are in the project directory and entered into the venv 27 | 2. Run the WebUI application: `python webui.py` 28 | 3. Input your tabbyAPI endpoint URL and admin key and press connect! 29 | 30 | ## Command-line Arguments 31 | | Argument | Description | 32 | | :----------------------- | :----------------------------------------------------------- | 33 | | `-h` or`--help` | Show this help message and exit | 34 | | `-p` or `--port` | Specify port to host the WebUI on (default 7860) | 35 | | `-l` or `--listen` | Share WebUI link via LAN | 36 | | `-s` or `--share` | Share WebUI link remotely via Gradio's built in tunnel | 37 | | `-a` or `--autolaunch` | Launch browser after starting WebUI | 38 | | `-e` or `--endpoint_url` | TabbyAPI endpoint URL (default http://localhost:5000) | 39 | | `-k` or `--admin_key` | TabbyAPI admin key, connect automatically on launch | 40 | -------------------------------------------------------------------------------- /presets/place_your_presets_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theroyallab/tabbyAPI-gradio-loader/1afeb12212f8c91549b16da4a2e1307619dd48c8/presets/place_your_presets_here.txt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp 2 | gradio==5.1.0 3 | -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import json 4 | import pathlib 5 | 6 | import aiohttp 7 | import gradio as gr 8 | import requests 9 | 10 | conn_url = None 11 | conn_key = None 12 | 13 | host_url = "127.0.0.1" 14 | 15 | models = [] 16 | draft_models = [] 17 | loras = [] 18 | templates = [] 19 | overrides = [] 20 | 21 | model_load_task = None 22 | model_load_state = False 23 | download_task = None 24 | 25 | parser = argparse.ArgumentParser(description="TabbyAPI Gradio Loader") 26 | parser.add_argument( 27 | "-p", 28 | "--port", 29 | type=int, 30 | default=7860, 31 | help="Specify port to host the WebUI on (default 7860)", 32 | ) 33 | parser.add_argument( 34 | "-l", "--listen", action="store_true", help="Share WebUI link via LAN" 35 | ) 36 | parser.add_argument( 37 | "-s", 38 | "--share", 39 | action="store_true", 40 | help="Share WebUI link remotely via Gradio's built in tunnel", 41 | ) 42 | parser.add_argument( 43 | "-a", 44 | "--autolaunch", 45 | action="store_true", 46 | help="Launch browser after starting WebUI", 47 | ) 48 | parser.add_argument( 49 | "-e", 50 | "--endpoint_url", 51 | type=str, 52 | default="http://localhost:5000", 53 | help="TabbyAPI endpoint URL (default http://localhost:5000)", 54 | ) 55 | parser.add_argument( 56 | "-k", 57 | "--admin_key", 58 | type=str, 59 | default=None, 60 | help="TabbyAPI admin key, connect automatically on launch", 61 | ) 62 | args = parser.parse_args() 63 | if args.listen: 64 | host_url = "0.0.0.0" 65 | 66 | 67 | def read_preset(name): 68 | if not name: 69 | raise gr.Error("Please select a preset to load.") 70 | path = pathlib.Path(f"./presets/{name}.json").resolve() 71 | with open(path, "r") as openfile: 72 | data = json.load(openfile) 73 | gr.Info(f"Preset {name} loaded.") 74 | return ( 75 | gr.Dropdown(value=data.get("name")), 76 | gr.Number(value=data.get("max_seq_len")), 77 | gr.Number(value=data.get("cache_size")), 78 | gr.Checkbox(value=data.get("gpu_split_auto")), 79 | gr.Textbox(value=data.get("gpu_split")), 80 | gr.Number(value=data.get("rope_scale")), 81 | gr.Number(value=data.get("rope_alpha")), 82 | gr.Checkbox(value=data.get("model_rope_alpha_auto")), 83 | gr.Radio(value=data.get("cache_mode")), 84 | gr.Dropdown(value=data.get("prompt_template")), 85 | gr.Number(value=data.get("num_experts_per_token")), 86 | gr.Dropdown(value=data.get("draft_model_name")), 87 | gr.Number(value=data.get("draft_rope_scale")), 88 | gr.Number(value=data.get("draft_rope_alpha")), 89 | gr.Checkbox(value=data.get("draft_rope_alpha_auto")), 90 | gr.Radio(value=data.get("draft_cache_mode")), 91 | gr.Checkbox(value=data.get("tensor_parallel")), 92 | gr.Checkbox(value=data.get("vision")), 93 | gr.Textbox(value=data.get("autosplit_reserve")), 94 | gr.Number(value=data.get("chunk_size")), 95 | ) 96 | 97 | 98 | def del_preset(name): 99 | if not name: 100 | raise gr.Error("Please select a preset to delete.") 101 | path = pathlib.Path(f"./presets/{name}.json").resolve() 102 | path.unlink() 103 | gr.Info(f"Preset {name} deleted.") 104 | return get_preset_list() 105 | 106 | 107 | def write_preset( 108 | name, 109 | model_name, 110 | max_seq_len, 111 | cache_size, 112 | gpu_split_auto, 113 | gpu_split, 114 | model_rope_scale, 115 | model_rope_alpha, 116 | model_rope_alpha_auto, 117 | cache_mode, 118 | prompt_template, 119 | num_experts_per_token, 120 | draft_model_name, 121 | draft_rope_scale, 122 | draft_rope_alpha, 123 | draft_rope_alpha_auto, 124 | draft_cache_mode, 125 | tensor_parallel, 126 | vision, 127 | autosplit_reserve, 128 | chunk_size, 129 | ): 130 | if not name: 131 | raise gr.Error("Please enter a name for your new preset.") 132 | path = pathlib.Path(f"./presets/{name}.json").resolve() 133 | data = { 134 | "name": model_name, 135 | "max_seq_len": max_seq_len, 136 | "cache_size": cache_size, 137 | "gpu_split_auto": gpu_split_auto, 138 | "gpu_split": gpu_split, 139 | "rope_scale": model_rope_scale, 140 | "rope_alpha": model_rope_alpha, 141 | "model_rope_alpha_auto": model_rope_alpha_auto, 142 | "cache_mode": cache_mode, 143 | "prompt_template": prompt_template, 144 | "num_experts_per_token": num_experts_per_token, 145 | "draft_model_name": draft_model_name, 146 | "draft_rope_scale": draft_rope_scale, 147 | "draft_rope_alpha": draft_rope_alpha, 148 | "draft_rope_alpha_auto": draft_rope_alpha_auto, 149 | "draft_cache_mode": draft_cache_mode, 150 | "tensor_parallel": tensor_parallel, 151 | "vision": vision, 152 | "autosplit_reserve": autosplit_reserve, 153 | "chunk_size": chunk_size, 154 | } 155 | with open(path, "w") as outfile: 156 | json.dump(data, outfile, indent=4) 157 | gr.Info(f"Preset {name} saved.") 158 | return gr.Textbox(value=None), get_preset_list() 159 | 160 | 161 | def get_preset_list(raw=False): 162 | preset_path = pathlib.Path("./presets").resolve() 163 | preset_list = [] 164 | for path in preset_path.iterdir(): 165 | if path.is_file() and path.name.endswith(".json"): 166 | preset_list.append(path.stem) 167 | preset_list.sort(key=str.lower) 168 | if raw: 169 | return preset_list 170 | return gr.Dropdown(choices=[""] + preset_list, value=None) 171 | 172 | 173 | def connect(api_url, admin_key, silent=False): 174 | global conn_url 175 | global conn_key 176 | global models 177 | global draft_models 178 | global loras 179 | global templates 180 | global overrides 181 | 182 | try: 183 | a = requests.get( 184 | url=api_url + "/v1/auth/permission", headers={"X-api-key": admin_key} 185 | ) 186 | a.raise_for_status() 187 | if a.json().get("permission") != "admin": 188 | raise ValueError( 189 | "The provided authentication key must be an admin key to access the loader's functions." 190 | ) 191 | except Exception as e: 192 | raise gr.Error(e) 193 | 194 | try: 195 | m = requests.get( 196 | url=api_url + "/v1/model/list", headers={"X-api-key": admin_key} 197 | ) 198 | m.raise_for_status() 199 | d = requests.get( 200 | url=api_url + "/v1/model/draft/list", headers={"X-api-key": admin_key} 201 | ) 202 | d.raise_for_status() 203 | lo = requests.get( 204 | url=api_url + "/v1/lora/list", headers={"X-api-key": admin_key} 205 | ) 206 | lo.raise_for_status() 207 | t = requests.get( 208 | url=api_url + "/v1/template/list", headers={"X-api-key": admin_key} 209 | ) 210 | t.raise_for_status() 211 | so = requests.get( 212 | url=api_url + "/v1/sampling/override/list", headers={"X-api-key": admin_key} 213 | ) 214 | so.raise_for_status() 215 | except Exception as e: 216 | raise gr.Error(e) 217 | 218 | conn_url = api_url 219 | conn_key = admin_key 220 | 221 | models = [] 222 | for model in m.json().get("data"): 223 | models.append(model.get("id")) 224 | models.sort(key=str.lower) 225 | 226 | draft_models = [] 227 | for draft_model in d.json().get("data"): 228 | draft_models.append(draft_model.get("id")) 229 | draft_models.sort(key=str.lower) 230 | 231 | loras = [] 232 | for lora in lo.json().get("data"): 233 | loras.append(lora.get("id")) 234 | loras.sort(key=str.lower) 235 | 236 | templates = [] 237 | for template in t.json().get("data"): 238 | templates.append(template) 239 | templates.sort(key=str.lower) 240 | 241 | overrides = [] 242 | for override in so.json().get("presets"): 243 | overrides.append(override) 244 | overrides.sort(key=str.lower) 245 | 246 | if not silent: 247 | gr.Info("TabbyAPI connected.") 248 | return ( 249 | gr.Textbox(value=", ".join(models), visible=True), 250 | gr.Textbox(value=", ".join(draft_models), visible=True), 251 | gr.Textbox(value=", ".join(loras), visible=True), 252 | get_model_list(), 253 | get_draft_model_list(), 254 | get_lora_list(), 255 | get_template_list(), 256 | get_override_list(), 257 | get_current_model(), 258 | get_current_loras(), 259 | ) 260 | 261 | 262 | def get_model_list(): 263 | return gr.Dropdown(choices=[""] + models, value=None) 264 | 265 | 266 | def get_draft_model_list(): 267 | return gr.Dropdown(choices=[""] + draft_models, value=None) 268 | 269 | 270 | def get_lora_list(): 271 | return gr.Dropdown(choices=loras, value=[]) 272 | 273 | 274 | def get_template_list(): 275 | return gr.Dropdown(choices=[""] + templates, value=None) 276 | 277 | 278 | def get_override_list(): 279 | return gr.Dropdown(choices=[""] + overrides, value=None) 280 | 281 | 282 | def get_current_model(): 283 | model_card = requests.get( 284 | url=conn_url + "/v1/model", headers={"X-api-key": conn_key} 285 | ).json() 286 | if not model_card.get("id"): 287 | return gr.Textbox(value=None) 288 | params = model_card.get("parameters") 289 | draft_model_card = params.get("draft") 290 | model = f'{model_card.get("id")} (context: {params.get("max_seq_len")}, cache size: {params.get("cache_size")}, rope scale: {params.get("rope_scale")}, rope alpha: {params.get("rope_alpha")})' 291 | 292 | if draft_model_card: 293 | draft_params = draft_model_card.get("parameters") 294 | model += f' | {draft_model_card.get("id")} (rope scale: {draft_params.get("rope_scale")}, rope alpha: {draft_params.get("rope_alpha")})' 295 | return gr.Textbox(value=model) 296 | 297 | 298 | def get_current_loras(): 299 | lo = requests.get(url=conn_url + "/v1/lora", headers={"X-api-key": conn_key}).json() 300 | if not lo.get("data"): 301 | return gr.Textbox(value=None) 302 | lora_list = lo.get("data") 303 | loras = [] 304 | for lora in lora_list: 305 | loras.append(f'{lora.get("id")} (scaling: {lora.get("scaling")})') 306 | return gr.Textbox(value=", ".join(loras)) 307 | 308 | 309 | def update_loras_table(loras): 310 | array = [] 311 | for lora in loras: 312 | array.append(1.0) 313 | if array: 314 | return gr.List( 315 | value=[array], 316 | col_count=(len(array), "fixed"), 317 | row_count=(1, "fixed"), 318 | headers=loras, 319 | visible=True, 320 | ) 321 | else: 322 | return gr.List(value=None, visible=False) 323 | 324 | 325 | async def load_model( 326 | model_name, 327 | max_seq_len, 328 | cache_size, 329 | gpu_split_auto, 330 | gpu_split, 331 | model_rope_scale, 332 | model_rope_alpha, 333 | model_rope_alpha_auto, 334 | cache_mode, 335 | prompt_template, 336 | num_experts_per_token, 337 | draft_model_name, 338 | draft_rope_scale, 339 | draft_rope_alpha, 340 | draft_rope_alpha_auto, 341 | draft_cache_mode, 342 | tensor_parallel, 343 | vision, 344 | autosplit_reserve, 345 | chunk_size, 346 | ): 347 | global model_load_task 348 | global model_load_state 349 | model_load_state = True 350 | if not model_name: 351 | raise gr.Error("Specify a model to load!") 352 | gpu_split_parsed = [] 353 | try: 354 | if gpu_split: 355 | gpu_split_parsed = [float(i) for i in list(gpu_split.split(","))] 356 | except ValueError: 357 | raise gr.Error("Check your GPU split values and ensure they are valid!") 358 | autosplit_reserve_parsed = [] 359 | try: 360 | if autosplit_reserve: 361 | autosplit_reserve_parsed = [ 362 | float(i) for i in list(autosplit_reserve.split(",")) 363 | ] 364 | except ValueError: 365 | raise gr.Error("Check your autosplit reserve values and ensure they are valid!") 366 | if draft_model_name: 367 | draft_request = { 368 | "draft_model_name": draft_model_name, 369 | "draft_rope_scale": draft_rope_scale, 370 | "draft_rope_alpha": "auto" if draft_rope_alpha_auto else draft_rope_alpha, 371 | "draft_cache_mode": draft_cache_mode, 372 | } 373 | else: 374 | draft_request = None 375 | request = { 376 | "name": model_name, 377 | "max_seq_len": max_seq_len, 378 | "cache_size": cache_size, 379 | "gpu_split_auto": gpu_split_auto, 380 | "gpu_split": gpu_split_parsed, 381 | "rope_scale": model_rope_scale, 382 | "rope_alpha": "auto" if model_rope_alpha_auto else model_rope_alpha, 383 | "cache_mode": cache_mode, 384 | "prompt_template": prompt_template, 385 | "num_experts_per_token": num_experts_per_token, 386 | "tensor_parallel": tensor_parallel, 387 | "vision": vision, 388 | "autosplit_reserve": autosplit_reserve_parsed, 389 | "chunk_size": chunk_size, 390 | "draft": draft_request, 391 | } 392 | try: 393 | requests.post( 394 | url=conn_url + "/v1/model/unload", headers={"X-admin-key": conn_key} 395 | ) 396 | async with aiohttp.ClientSession() as session: 397 | gr.Info(f"Loading {model_name}.") 398 | model_load_task = asyncio.create_task( 399 | session.post( 400 | url=conn_url + "/v1/model/load", 401 | headers={"X-admin-key": conn_key}, 402 | json=request, 403 | ) 404 | ) 405 | r = await model_load_task 406 | r.raise_for_status() 407 | async for chunk in r.content: 408 | if not model_load_state: 409 | requests.post( 410 | url=conn_url + "/v1/model/unload", 411 | headers={"X-admin-key": conn_key}, 412 | ) 413 | gr.Info("Model load canceled.") 414 | break 415 | chunk_str = chunk.decode("utf-8") 416 | if chunk_str.startswith("data: "): 417 | data = json.loads(chunk_str.lstrip("data: ")) 418 | if data.get("status") == "finished": 419 | gr.Info("Model successfully loaded.") 420 | return get_current_model(), get_current_loras() 421 | except asyncio.CancelledError: 422 | requests.post( 423 | url=conn_url + "/v1/model/unload", headers={"X-admin-key": conn_key} 424 | ) 425 | gr.Info("Model load canceled.") 426 | except Exception as e: 427 | raise gr.Error(e) 428 | finally: 429 | await session.close() 430 | model_load_task = None 431 | model_load_state = False 432 | 433 | 434 | def load_loras(loras, scalings): 435 | if not loras: 436 | raise gr.Error("Specify at least one lora to load!") 437 | load_list = [] 438 | for index, lora in enumerate(loras): 439 | try: 440 | scaling = float(scalings[0][index]) 441 | load_list.append({"name": lora, "scaling": scaling}) 442 | except ValueError: 443 | raise gr.Error("Check your scaling values and ensure they are valid!") 444 | request = {"loras": load_list} 445 | try: 446 | requests.post( 447 | url=conn_url + "/v1/lora/unload", headers={"X-admin-key": conn_key} 448 | ) 449 | r = requests.post( 450 | url=conn_url + "/v1/lora/load", 451 | headers={"X-admin-key": conn_key}, 452 | json=request, 453 | ) 454 | r.raise_for_status() 455 | gr.Info("Loras successfully loaded.") 456 | return get_current_model(), get_current_loras() 457 | except Exception as e: 458 | raise gr.Error(e) 459 | 460 | 461 | def unload_model(): 462 | global model_load_task 463 | global model_load_state 464 | if model_load_task or model_load_state: 465 | model_load_task.cancel() 466 | model_load_state = False 467 | else: 468 | requests.post( 469 | url=conn_url + "/v1/model/unload", headers={"X-admin-key": conn_key} 470 | ) 471 | gr.Info("Model unloaded.") 472 | return get_current_model(), get_current_loras() 473 | 474 | 475 | def unload_loras(): 476 | try: 477 | r = requests.post( 478 | url=conn_url + "/v1/lora/unload", headers={"X-admin-key": conn_key} 479 | ) 480 | r.raise_for_status() 481 | gr.Info("All loras unloaded.") 482 | return get_current_model(), get_current_loras() 483 | except Exception as e: 484 | raise gr.Error(e) 485 | 486 | 487 | def toggle_model_rope_alpha_auto(model_rope_alpha_auto): 488 | if model_rope_alpha_auto: 489 | return gr.Number(interactive=False) 490 | else: 491 | return gr.Number(interactive=True) 492 | 493 | 494 | def toggle_draft_rope_alpha_auto(draft_rope_alpha_auto): 495 | if draft_rope_alpha_auto: 496 | return gr.Number(interactive=False) 497 | else: 498 | return gr.Number(interactive=True) 499 | 500 | 501 | def toggle_gpu_split(gpu_split_auto): 502 | if gpu_split_auto: 503 | return gr.Textbox(value=None, visible=False), gr.Textbox(visible=True) 504 | else: 505 | return gr.Textbox(visible=True), gr.Textbox(value=None, visible=False) 506 | 507 | 508 | def load_template(prompt_template): 509 | try: 510 | r = requests.post( 511 | url=conn_url + "/v1/template/switch", 512 | headers={"X-admin-key": conn_key}, 513 | json={"name": prompt_template}, 514 | ) 515 | r.raise_for_status() 516 | gr.Info(f"Prompt template switched to {prompt_template}.") 517 | return 518 | except Exception as e: 519 | raise gr.Error(e) 520 | 521 | 522 | def unload_template(): 523 | try: 524 | r = requests.post( 525 | url=conn_url + "/v1/template/unload", headers={"X-admin-key": conn_key} 526 | ) 527 | r.raise_for_status() 528 | gr.Info("Prompt template unloaded.") 529 | return 530 | except Exception as e: 531 | raise gr.Error(e) 532 | 533 | 534 | def load_override(sampler_override): 535 | try: 536 | r = requests.post( 537 | url=conn_url + "/v1/sampling/override/switch", 538 | headers={"X-admin-key": conn_key}, 539 | json={"preset": sampler_override}, 540 | ) 541 | r.raise_for_status() 542 | gr.Info(f"Sampler override switched to {sampler_override}.") 543 | return 544 | except Exception as e: 545 | raise gr.Error(e) 546 | 547 | 548 | def unload_override(): 549 | try: 550 | r = requests.post( 551 | url=conn_url + "/v1/sampling/override/unload", 552 | headers={"X-admin-key": conn_key}, 553 | ) 554 | r.raise_for_status() 555 | gr.Info("Sampler override unloaded.") 556 | return 557 | except Exception as e: 558 | raise gr.Error(e) 559 | 560 | 561 | async def download(repo_id, revision, repo_type, folder_name, token, include, exclude): 562 | global download_task 563 | if not folder_name: 564 | folder_name = repo_id.replace("/", "_") 565 | include_parsed = ["*"] 566 | if include: 567 | include_parsed = [i.strip() for i in list(include.split(","))] 568 | exclude_parsed = [] 569 | if exclude: 570 | exclude_parsed = [i.strip() for i in list(include.split(","))] 571 | request = { 572 | "repo_id": repo_id, 573 | "revision": revision, 574 | "repo_type": repo_type.lower(), 575 | "folder_name": folder_name, 576 | "token": token, 577 | "include": include_parsed, 578 | "exclude": exclude_parsed, 579 | } 580 | try: 581 | async with aiohttp.ClientSession() as session: 582 | gr.Info(f"Beginning download of {repo_id}.") 583 | download_task = asyncio.create_task( 584 | session.post( 585 | url=conn_url + "/v1/download", 586 | headers={"X-admin-key": conn_key}, 587 | json=request, 588 | ) 589 | ) 590 | r = await download_task 591 | r.raise_for_status() 592 | content = await r.json() 593 | gr.Info( 594 | f'{repo_type} {repo_id} downloaded to folder: {content.get("download_path")}.' 595 | ) 596 | except asyncio.CancelledError: 597 | gr.Info("Download canceled.") 598 | except Exception as e: 599 | raise gr.Error(e) 600 | finally: 601 | await session.close() 602 | download_task = None 603 | 604 | 605 | def cancel_download(): 606 | global download_task 607 | if download_task: 608 | download_task.cancel() 609 | 610 | 611 | # Auto-attempt connection if admin key is provided 612 | init_model_text = None 613 | init_lora_text = None 614 | if args.admin_key: 615 | try: 616 | connect(api_url=args.endpoint_url, admin_key=args.admin_key, silent=True) 617 | init_model_text = get_current_model().value 618 | init_lora_text = get_current_loras().value 619 | except Exception: 620 | print("Automatic connection failed, continuing to WebUI.") 621 | 622 | # Setup UI elements 623 | with gr.Blocks(title="TabbyAPI Gradio Loader") as webui: 624 | gr.Markdown( 625 | """ 626 | # TabbyAPI Gradio Loader 627 | """ 628 | ) 629 | current_model = gr.Textbox(value=init_model_text, label="Current Model:") 630 | current_loras = gr.Textbox(value=init_lora_text, label="Current Loras:") 631 | 632 | with gr.Tab("Connect to API"): 633 | connect_btn = gr.Button(value="Connect", variant="primary") 634 | api_url = gr.Textbox( 635 | value=args.endpoint_url, label="TabbyAPI Endpoint URL:", interactive=True 636 | ) 637 | admin_key = gr.Textbox( 638 | value=args.admin_key, label="Admin Key:", type="password", interactive=True 639 | ) 640 | model_list = gr.Textbox( 641 | value=", ".join(models), label="Available Models:", visible=bool(conn_key) 642 | ) 643 | draft_model_list = gr.Textbox( 644 | value=", ".join(draft_models), 645 | label="Available Draft Models:", 646 | visible=bool(conn_key), 647 | ) 648 | lora_list = gr.Textbox( 649 | value=", ".join(loras), label="Available Loras:", visible=bool(conn_key) 650 | ) 651 | 652 | with gr.Tab("Load Model"): 653 | with gr.Row(): 654 | load_model_btn = gr.Button(value="Load Model", variant="primary") 655 | unload_model_btn = gr.Button( 656 | value="Cancel Load/Unload Model", variant="stop" 657 | ) 658 | 659 | with gr.Accordion(open=False, label="Presets"): 660 | with gr.Row(): 661 | load_preset = gr.Dropdown( 662 | choices=[""] + get_preset_list(True), 663 | label="Load Preset:", 664 | interactive=True, 665 | ) 666 | save_preset = gr.Textbox(label="Save Preset:", interactive=True) 667 | 668 | with gr.Row(): 669 | load_preset_btn = gr.Button(value="Load Preset", variant="primary") 670 | del_preset_btn = gr.Button(value="Delete Preset", variant="stop") 671 | save_preset_btn = gr.Button(value="Save Preset", variant="primary") 672 | refresh_preset_btn = gr.Button(value="Refresh Presets") 673 | 674 | with gr.Group(): 675 | models_drop = gr.Dropdown( 676 | choices=[""] + models, label="Select Model:", interactive=True 677 | ) 678 | with gr.Row(): 679 | max_seq_len = gr.Number( 680 | value=lambda: None, 681 | label="Max Sequence Length:", 682 | precision=0, 683 | minimum=1, 684 | interactive=True, 685 | info="Configured context length to load the model with. If left blank, automatically reads from model config.", 686 | ) 687 | cache_size = gr.Number( 688 | value=lambda: None, 689 | label="Cache Size:", 690 | precision=0, 691 | minimum=1, 692 | interactive=True, 693 | info="Size of the prompt cache to allocate (in number of tokens, multiple of 256). Defaults to max sequence length if left blank.", 694 | ) 695 | 696 | with gr.Row(): 697 | model_rope_scale = gr.Number( 698 | value=lambda: None, 699 | label="Rope Scale:", 700 | minimum=1, 701 | interactive=True, 702 | info="Used for models trained with modified linear positional embeddings. If left blank, automatically reads from model config.", 703 | ) 704 | model_rope_alpha = gr.Number( 705 | value=lambda: None, 706 | label="Rope Alpha:", 707 | minimum=1, 708 | interactive=False, 709 | info="Factor used for NTK-aware rope scaling. Ignored if automatic calculation is selected.", 710 | ) 711 | model_rope_alpha_auto = gr.Checkbox( 712 | value=True, 713 | label="Automatic Rope Alpha", 714 | interactive=True, 715 | info="Enable automatic calculation based on your configured max_seq_len and the model's base context length.", 716 | ) 717 | 718 | with gr.Accordion(open=False, label="Speculative Decoding"): 719 | draft_models_drop = gr.Dropdown( 720 | choices=[""] + draft_models, 721 | label="Select Draft Model:", 722 | interactive=True, 723 | info="Must share the same tokenizer and vocabulary as the primary model.", 724 | ) 725 | with gr.Row(): 726 | draft_rope_scale = gr.Number( 727 | value=lambda: None, 728 | label="Draft Rope Scale:", 729 | minimum=1, 730 | interactive=True, 731 | info="Used for models trained with modified linear positional embeddings. If left blank, automatically reads from model config.", 732 | ) 733 | draft_rope_alpha = gr.Number( 734 | value=lambda: None, 735 | label="Draft Rope Alpha:", 736 | minimum=1, 737 | interactive=False, 738 | info="Factor used for NTK-aware rope scaling. Leave blank for automatic scaling calculated based on your configured max_seq_len and the model's base context length.", 739 | ) 740 | draft_rope_alpha_auto = gr.Checkbox( 741 | value=True, 742 | label="Automatic Rope Alpha", 743 | interactive=True, 744 | info="Enable automatic calculation based on your configured max_seq_len and the model's base context length.", 745 | ) 746 | draft_cache_mode = gr.Radio( 747 | value="FP16", 748 | label="Draft Cache Mode:", 749 | choices=["Q4", "Q6", "Q8", "FP16"], 750 | interactive=True, 751 | info="Q4/Q6/Q8 cache sacrifice some precision to save VRAM compared to full FP16 precision.", 752 | ) 753 | 754 | with gr.Group(): 755 | with gr.Row(): 756 | cache_mode = gr.Radio( 757 | value="FP16", 758 | label="Cache Mode:", 759 | choices=["Q4", "Q6", "Q8", "FP16"], 760 | interactive=True, 761 | info="Q4/Q6/Q8 cache sacrifice some precision to save VRAM compared to full FP16 precision.", 762 | ) 763 | gpu_split_auto = gr.Checkbox( 764 | value=True, 765 | label="GPU Split Auto", 766 | interactive=True, 767 | info="Automatically determine how to split model layers between multiple GPUs.", 768 | ) 769 | tensor_parallel = gr.Checkbox( 770 | value=False, 771 | label="Tensor Parallel", 772 | interactive=True, 773 | info="Enable to enable tensor parallelism on multi-GPU setups, which will improve generation speed in most settings.", 774 | ) 775 | vision = gr.Checkbox( 776 | value=False, 777 | label="Vision", 778 | interactive=True, 779 | info="Enables vision support if the model supports it.", 780 | ) 781 | 782 | gpu_split = gr.Textbox( 783 | label="GPU Split:", 784 | placeholder="20.6,24", 785 | visible=False, 786 | interactive=True, 787 | info="Amount of VRAM TabbyAPI will be allowed to use on each GPU. List of numbers separated by commas, in gigabytes.", 788 | ) 789 | autosplit_reserve = gr.Textbox( 790 | label="Auto-split Reserve:", 791 | placeholder="96", 792 | interactive=True, 793 | info="Amount of VRAM to keep reserved on each GPU when using auto split. List of numbers separated by commas, in megabytes.", 794 | ) 795 | with gr.Row(): 796 | num_experts_per_token = gr.Number( 797 | value=lambda: None, 798 | label="Number of experts per token (MoE only):", 799 | precision=0, 800 | minimum=1, 801 | interactive=True, 802 | info="Number of experts to use for simultaneous inference in mixture of experts. If left blank, automatically reads from model config.", 803 | ) 804 | chunk_size = gr.Number( 805 | value=lambda: None, 806 | label="Chunk Size:", 807 | precision=0, 808 | minimum=1, 809 | interactive=True, 810 | info="The number of prompt tokens to ingest at a time. A lower value reduces VRAM usage at the cost of ingestion speed.", 811 | ) 812 | 813 | with gr.Accordion(open=True, label="Prompt Templates"): 814 | prompt_template = gr.Dropdown( 815 | choices=[""] + templates, 816 | value="", 817 | label="Prompt Template:", 818 | allow_custom_value=True, 819 | interactive=True, 820 | info="Jinja2 prompt template to be used for the chat completions endpoint.", 821 | ) 822 | with gr.Row(): 823 | load_template_btn = gr.Button(value="Load Template", variant="primary") 824 | unload_template_btn = gr.Button(value="Unload Template", variant="stop") 825 | 826 | with gr.Accordion(open=False, label="Sampler Overrides"): 827 | sampler_override = gr.Dropdown( 828 | choices=[""] + overrides, 829 | value="", 830 | label="Select Sampler Overrides:", 831 | interactive=True, 832 | info="Select a sampler override preset to load.", 833 | ) 834 | with gr.Row(): 835 | load_override_btn = gr.Button(value="Load Override", variant="primary") 836 | unload_override_btn = gr.Button(value="Unload Override", variant="stop") 837 | 838 | with gr.Tab("Load Loras"): 839 | with gr.Row(): 840 | load_loras_btn = gr.Button(value="Load Loras", variant="primary") 841 | unload_loras_btn = gr.Button(value="Unload All Loras", variant="stop") 842 | 843 | loras_drop = gr.Dropdown( 844 | label="Select Loras:", 845 | choices=loras, 846 | multiselect=True, 847 | interactive=True, 848 | info="Select one or more loras to load, specify individual lora weights in the box that appears below (default 1.0).", 849 | ) 850 | loras_table = gr.List( 851 | label="Lora Scaling:", 852 | visible=False, 853 | datatype="number", 854 | type="array", 855 | interactive=True, 856 | ) 857 | 858 | with gr.Tab("HF Downloader"): 859 | with gr.Row(): 860 | download_btn = gr.Button(value="Download", variant="primary") 861 | cancel_download_btn = gr.Button(value="Cancel", variant="stop") 862 | 863 | with gr.Group(): 864 | with gr.Row(): 865 | repo_id = gr.Textbox( 866 | label="Repo ID:", 867 | interactive=True, 868 | info="Provided in the format /.", 869 | ) 870 | revision = gr.Textbox( 871 | label="Revision/Branch:", 872 | interactive=True, 873 | info="Name of the revision/branch of the repository to download.", 874 | ) 875 | 876 | with gr.Row(): 877 | repo_type = gr.Dropdown( 878 | choices=["Model", "Lora"], 879 | value="Model", 880 | label="Repo Type:", 881 | interactive=True, 882 | info="Specify whether the repository contains a model or lora.", 883 | ) 884 | folder_name = gr.Textbox( 885 | label="Folder Name:", 886 | interactive=True, 887 | info="Name to use for the local downloaded copy of the repository.", 888 | ) 889 | 890 | with gr.Row(): 891 | include = gr.Textbox( 892 | placeholder="adapter_config.json, adapter_model.bin", 893 | label="Include Patterns:", 894 | interactive=True, 895 | info="Comma-separated list of file patterns to download from repository (default all).", 896 | ) 897 | exclude = gr.Textbox( 898 | placeholder="*.bin, *.pth", 899 | label="Exclude Patterns:", 900 | interactive=True, 901 | info="Comma-separated list of file patterns to exclude from download.", 902 | ) 903 | 904 | with gr.Row(): 905 | token = gr.Textbox( 906 | label="HF Access Token:", 907 | type="password", 908 | info="Provide HF access token to download from private/gated repositories.", 909 | ) 910 | 911 | # Define event listeners 912 | # Connection tab 913 | connect_btn.click( 914 | fn=connect, 915 | inputs=[api_url, admin_key], 916 | outputs=[ 917 | model_list, 918 | draft_model_list, 919 | lora_list, 920 | models_drop, 921 | draft_models_drop, 922 | loras_drop, 923 | prompt_template, 924 | sampler_override, 925 | current_model, 926 | current_loras, 927 | ], 928 | ) 929 | 930 | # Model tab 931 | load_preset_btn.click( 932 | fn=read_preset, 933 | inputs=load_preset, 934 | outputs=[ 935 | models_drop, 936 | max_seq_len, 937 | cache_size, 938 | gpu_split_auto, 939 | gpu_split, 940 | model_rope_scale, 941 | model_rope_alpha, 942 | model_rope_alpha_auto, 943 | cache_mode, 944 | prompt_template, 945 | num_experts_per_token, 946 | draft_models_drop, 947 | draft_rope_scale, 948 | draft_rope_alpha, 949 | draft_rope_alpha_auto, 950 | draft_cache_mode, 951 | tensor_parallel, 952 | vision, 953 | autosplit_reserve, 954 | chunk_size, 955 | ], 956 | ) 957 | del_preset_btn.click(fn=del_preset, inputs=load_preset, outputs=load_preset) 958 | save_preset_btn.click( 959 | fn=write_preset, 960 | inputs=[ 961 | save_preset, 962 | models_drop, 963 | max_seq_len, 964 | cache_size, 965 | gpu_split_auto, 966 | gpu_split, 967 | model_rope_scale, 968 | model_rope_alpha, 969 | model_rope_alpha_auto, 970 | cache_mode, 971 | prompt_template, 972 | num_experts_per_token, 973 | draft_models_drop, 974 | draft_rope_scale, 975 | draft_rope_alpha, 976 | draft_rope_alpha_auto, 977 | draft_cache_mode, 978 | tensor_parallel, 979 | vision, 980 | autosplit_reserve, 981 | chunk_size, 982 | ], 983 | outputs=[save_preset, load_preset], 984 | ) 985 | refresh_preset_btn.click(fn=get_preset_list, outputs=load_preset) 986 | 987 | model_rope_alpha_auto.change( 988 | fn=toggle_model_rope_alpha_auto, 989 | inputs=model_rope_alpha_auto, 990 | outputs=model_rope_alpha, 991 | ) 992 | draft_rope_alpha_auto.change( 993 | fn=toggle_draft_rope_alpha_auto, 994 | inputs=draft_rope_alpha_auto, 995 | outputs=draft_rope_alpha, 996 | ) 997 | gpu_split_auto.change( 998 | fn=toggle_gpu_split, 999 | inputs=gpu_split_auto, 1000 | outputs=[gpu_split, autosplit_reserve], 1001 | ) 1002 | unload_model_btn.click(fn=unload_model, outputs=[current_model, current_loras]) 1003 | load_model_btn.click( 1004 | fn=load_model, 1005 | inputs=[ 1006 | models_drop, 1007 | max_seq_len, 1008 | cache_size, 1009 | gpu_split_auto, 1010 | gpu_split, 1011 | model_rope_scale, 1012 | model_rope_alpha, 1013 | model_rope_alpha_auto, 1014 | cache_mode, 1015 | prompt_template, 1016 | num_experts_per_token, 1017 | draft_models_drop, 1018 | draft_rope_scale, 1019 | draft_rope_alpha, 1020 | draft_rope_alpha_auto, 1021 | draft_cache_mode, 1022 | tensor_parallel, 1023 | vision, 1024 | autosplit_reserve, 1025 | chunk_size, 1026 | ], 1027 | outputs=[current_model, current_loras], 1028 | concurrency_limit=1, 1029 | ) 1030 | load_template_btn.click(fn=load_template, inputs=prompt_template) 1031 | unload_template_btn.click(fn=unload_template) 1032 | load_override_btn.click(fn=load_override, inputs=sampler_override) 1033 | unload_override_btn.click(fn=unload_override) 1034 | 1035 | # Loras tab 1036 | loras_drop.change(update_loras_table, inputs=loras_drop, outputs=loras_table) 1037 | unload_loras_btn.click(fn=unload_loras, outputs=[current_model, current_loras]) 1038 | load_loras_btn.click( 1039 | fn=load_loras, 1040 | inputs=[loras_drop, loras_table], 1041 | outputs=[current_model, current_loras], 1042 | ) 1043 | 1044 | # HF Downloader tab 1045 | download_btn.click( 1046 | fn=download, 1047 | inputs=[repo_id, revision, repo_type, folder_name, token, include, exclude], 1048 | concurrency_limit=1, 1049 | ) 1050 | cancel_download_btn.click(fn=cancel_download) 1051 | 1052 | webui.launch( 1053 | inbrowser=args.autolaunch, 1054 | show_api=False, 1055 | server_name=host_url, 1056 | server_port=args.port, 1057 | share=args.share, 1058 | ) 1059 | --------------------------------------------------------------------------------