├── .github ├── FUNDING.yml └── workflows │ └── docker-image.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md └── app ├── __init__.py ├── lib ├── __init__.py ├── endpoints.py ├── models.py └── utils.py ├── main.py └── models └── .gitkeep /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [grctest] 4 | -------------------------------------------------------------------------------- /.github/workflows/docker-image.yml: -------------------------------------------------------------------------------- 1 | name: Build and Push Docker Image 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v*.*.*" 7 | 8 | jobs: 9 | build-and-push: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout repository 13 | uses: actions/checkout@v4 14 | 15 | # Set up Python (no need for full Anaconda in CI) 16 | - name: Set up Python 3.11 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: '3.11' 20 | 21 | # Install huggingface_hub CLI 22 | - name: Install huggingface_hub CLI 23 | run: pip install -U "huggingface_hub[cli]" 24 | 25 | # Cache the downloaded model directory 26 | - name: Cache HuggingFace model 27 | id: cache-model 28 | uses: actions/cache@v4 29 | with: 30 | path: app/models/BitNet-b1.58-2B-4T 31 | key: bitnet-model-${{ hashFiles('**/docker-image.yml') }} 32 | 33 | # Download the model if not cached 34 | - name: Download BitNet model 35 | if: steps.cache-model.outputs.cache-hit != 'true' 36 | run: | 37 | huggingface-cli download microsoft/BitNet-b1.58-2B-4T-gguf --local-dir app/models/BitNet-b1.58-2B-4T 38 | 39 | - name: Set up Docker Buildx 40 | uses: docker/setup-buildx-action@v3 41 | 42 | - name: Log in to Docker Hub 43 | uses: docker/login-action@v3 44 | with: 45 | username: ${{ secrets.DOCKERHUB_USERNAME }} 46 | password: ${{ secrets.DOCKERHUB_TOKEN }} 47 | 48 | - name: Build and push Docker image 49 | uses: docker/build-push-action@v5 50 | with: 51 | context: . 52 | push: true 53 | tags: grctest/fastapi_bitnet:latest 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | app/models/*/* 164 | 165 | # Allow all files in app/lib/ 166 | !app/lib -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | WORKDIR /code 4 | 5 | COPY ./app /code 6 | 7 | # Clone BitNet with submodules directly into /code (ensures all files and submodules are present) 8 | RUN git clone --recursive https://github.com/microsoft/BitNet.git /tmp/BitNet && \ 9 | cp -r /tmp/BitNet/* /code && \ 10 | rm -rf /tmp/BitNet 11 | 12 | # Install dependencies 13 | RUN apt-get update && apt-get install -y \ 14 | wget \ 15 | lsb-release \ 16 | software-properties-common \ 17 | gnupg \ 18 | cmake \ 19 | clang \ 20 | && bash -c "$(wget -O - https://apt.llvm.org/llvm.sh)" \ 21 | && apt-get clean \ 22 | && rm -rf /var/lib/apt/lists/* 23 | 24 | # Install Python dependencies 25 | RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt && \ 26 | pip install "fastapi[standard]" "uvicorn[standard]" httpx fastapi-mcp 27 | 28 | # (Optional) Run your setup_env.py if needed 29 | RUN python /code/setup_env.py -md /code/models/BitNet-b1.58-2B-4T -q i2_s 30 | 31 | EXPOSE 8080 32 | 33 | CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 R 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastAPI-BitNet 2 | 3 | This project uses a combination of [Uvicorn](https://www.uvicorn.org/), [FastAPI](https://fastapi.tiangolo.com/) (Python) and [Docker](https://www.docker.com/) to provide a reliable REST API for testing [Microsoft's BitNet](https://github.com/microsoft/BitNet) out locally! 4 | 5 | It supports running the inference framework, running BitNet model benchmarks and calculating BitNet model perplexity values. 6 | 7 | It's offers the same functionality as the [Electron-BitNet](https://github.com/grctest/Electron-BitNet) project, however it does so through a REST API which devs/researchers can use to automate testing/benchmarking of 1-bit BitNet models! 8 | 9 | ## Setup instructions 10 | 11 | If running in dev mode, run Docker Desktop on windows to initialize docker in WSL2. 12 | 13 | Launch WSL: `wsl` 14 | 15 | Install Conda: https://anaconda.org/anaconda/conda 16 | 17 | Initialize the python environment: 18 | ``` 19 | conda init 20 | conda create -n bitnet python=3.11 21 | conda activate bitnet 22 | ``` 23 | 24 | Install the Huggingface-CLI tool to download the models: 25 | ``` 26 | pip install -U "huggingface_hub[cli]" 27 | ``` 28 | 29 | Download Microsoft's official BitNet model: 30 | ``` 31 | huggingface-cli download microsoft/BitNet-b1.58-2B-4T-gguf --local-dir app/models/BitNet-b1.58-2B-4T 32 | ``` 33 | 34 | Build the docker image: 35 | ``` 36 | docker build -t fastapi_bitnet . 37 | ``` 38 | 39 | Run the docker image: 40 | ``` 41 | docker run -d --name ai_container -p 8080:8080 fastapi_bitnet 42 | ``` 43 | 44 | Once it's running navigate to http://127.0.0.1:8080/docs 45 | 46 | ## Docker hub repository 47 | 48 | You can fetch the dockerfile at: https://hub.docker.com/repository/docker/grctest/fastapi_bitnet/general 49 | 50 | ## How to add to VSCode! 51 | 52 | Run the dockerfile locally using the command above, then navigate to the VSCode Copilot chat window and find the wrench icon "Configure Tools...". 53 | 54 | In the tool configuration overview scroll to the bottom and select 'Add more tools...' then '+ Add MCP Server' then 'HTTP'. 55 | 56 | Enter into the URL field `http://127.0.0.1:8080/mcp` then your copilot will be able to launch new bitnet server instances and chat with them. -------------------------------------------------------------------------------- /app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grctest/FastAPI-BitNet/754414662f4c86b3ce1da9fd9449348246c3eee5/app/__init__.py -------------------------------------------------------------------------------- /app/lib/__init__.py: -------------------------------------------------------------------------------- 1 | from .endpoints import ChatRequest 2 | from typing import List 3 | from pydantic import BaseModel 4 | 5 | __all__ = ["ChatRequest", "MultiChatRequest"] 6 | 7 | # Re-export for import convenience 8 | class MultiChatRequest(BaseModel): 9 | requests: List[ChatRequest] 10 | -------------------------------------------------------------------------------- /app/lib/endpoints.py: -------------------------------------------------------------------------------- 1 | # --- bitnet Orchestrator (Middleman Proxy) --- 2 | from pydantic import BaseModel 3 | 4 | from fastapi import FastAPI, HTTPException, Query, Depends 5 | from .models import ModelEnum, BenchmarkRequest, PerplexityRequest, InferenceRequest 6 | from .utils import run_command, parse_benchmark_data, parse_perplexity_data 7 | import os 8 | import subprocess 9 | import atexit 10 | import time 11 | import httpx 12 | 13 | from typing import List 14 | from pydantic import BaseModel, Field 15 | from fastapi import HTTPException 16 | import asyncio 17 | 18 | # --- Server Process Management --- 19 | # Each server instance is tracked by a unique (host, port) key 20 | server_processes = {} 21 | server_configs = {} 22 | 23 | def _terminate_server_process(key): 24 | proc = server_processes.get(key) 25 | if proc and proc.poll() is None: 26 | try: 27 | proc.terminate() 28 | proc.wait(timeout=5) 29 | except Exception: 30 | proc.kill() 31 | proc.wait() 32 | server_processes.pop(key, None) 33 | server_configs.pop(key, None) 34 | 35 | def _terminate_all_servers(): 36 | for key in list(server_processes.keys()): 37 | _terminate_server_process(key) 38 | 39 | atexit.register(_terminate_all_servers) 40 | 41 | def _total_threads_in_use(): 42 | return sum(cfg['threads'] for cfg in server_configs.values() if 'threads' in cfg) 43 | 44 | def _max_threads(): 45 | return os.cpu_count() or 1 46 | 47 | async def initialize_server_endpoint( 48 | threads: int = Query(1, gt=0, le=os.cpu_count()), 49 | ctx_size: int = Query(2048, gt=0), 50 | port: int = Query(8081, gt=8081, le=65535), 51 | system_prompt: str = Query("You are a helpful assistant.", description="Unique system prompt for this server instance"), 52 | n_predict: int = Query(256, gt=0, description="Number of tokens to predict for the server instance."), 53 | temperature: float = Query(0.8, gt=0.0, le=2.0, description="Temperature for sampling") 54 | ): 55 | """ 56 | Initializes a llama-server process in the background if not already running on the given port. 57 | Will not oversubscribe threads beyond system capacity. 58 | Allows a unique system prompt per server instance. 59 | """ 60 | host = "127.0.0.1" 61 | key = (host, port) 62 | build_dir = os.getenv("BUILD_DIR", "build") 63 | server_path = os.path.join(build_dir, "bin", "llama-server") 64 | if not os.path.exists(server_path): 65 | raise HTTPException(status_code=500, detail=f"Server binary not found at '{server_path}'") 66 | # Check if already running 67 | if key in server_processes and server_processes[key].poll() is None: 68 | return {"message": f"Server already running on {host}:{port}", "pid": server_processes[key].pid, "config": server_configs[key]} 69 | # Check thread oversubscription and per-server thread limit 70 | max_threads = _max_threads() 71 | if threads > max_threads: 72 | raise HTTPException(status_code=400, detail=f"Requested threads ({threads}) exceed available CPU threads ({max_threads}).") 73 | threads_in_use = _total_threads_in_use() 74 | if threads_in_use + threads > max_threads: 75 | raise HTTPException(status_code=429, detail=f"Cannot start server: would oversubscribe CPU threads (in use: {threads_in_use}, requested: {threads}, max: {max_threads})") 76 | command = [ 77 | server_path, 78 | '-m', "models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf", 79 | '-c', str(ctx_size), 80 | '-t', str(threads), 81 | '-n', str(n_predict), 82 | '-ngl', '0', 83 | '--temp', str(temperature), 84 | '--host', host, 85 | '--port', str(port), 86 | '-cb', # Enable continuous batching 87 | ] 88 | if system_prompt: 89 | command += ['-p', system_prompt] 90 | try: 91 | proc = subprocess.Popen( 92 | command, 93 | stdout=subprocess.DEVNULL, 94 | stderr=subprocess.PIPE 95 | ) 96 | time.sleep(2) 97 | if proc.poll() is not None: 98 | stderr_output = proc.stderr.read().decode(errors='ignore') if proc.stderr else '' 99 | proc = None 100 | raise HTTPException(status_code=500, detail=f"Server failed to start. Stderr: {stderr_output}") 101 | server_processes[key] = proc 102 | server_configs[key] = { 103 | "model": "models/BitNet-b1.58-2B-4T/ggml-model-i2_s.gguf", 104 | "threads": threads, 105 | "ctx_size": ctx_size, 106 | "host": host, 107 | "port": port, 108 | "system_prompt": system_prompt, 109 | "n_predict": n_predict, 110 | "temperature": temperature, 111 | "pid": proc.pid 112 | } 113 | return {"message": f"Server started on {host}:{port}", "pid": proc.pid, "config": server_configs[key]} 114 | except Exception as e: 115 | if proc and proc.poll() is None: 116 | proc.kill() 117 | raise HTTPException(status_code=500, detail=f"Failed to start server: {str(e)}") 118 | 119 | async def shutdown_server_endpoint(port: int = Query(8081, gt=1023)): 120 | host = "127.0.0.1" 121 | key = (host, port) 122 | if key in server_processes and server_processes[key].poll() is None: 123 | pid = server_processes[key].pid 124 | _terminate_server_process(key) 125 | return {"message": f"Shutdown initiated for server (PID: {pid}) on {host}:{port}."} 126 | else: 127 | _terminate_server_process(key) 128 | return {"message": f"No running server found on {host}:{port}."} 129 | 130 | async def get_server_status(port: int = Query(8081, gt=1023)): 131 | host = "127.0.0.1" 132 | key = (host, port) 133 | proc = server_processes.get(key) 134 | cfg = server_configs.get(key) 135 | if proc and proc.poll() is None: 136 | return {"status": "running", "pid": proc.pid, "config": cfg} 137 | else: 138 | _terminate_server_process(key) 139 | return {"status": "stopped", "config": cfg} 140 | 141 | # Benchmark endpoint 142 | async def run_benchmark( 143 | model: ModelEnum, 144 | n_token: int = Query(128, gt=0), 145 | threads: int = Query(2, gt=0, le=os.cpu_count()), 146 | n_prompt: int = Query(32, gt=0) 147 | ): 148 | """Run benchmark on specified model""" 149 | request = BenchmarkRequest(model=model, n_token=n_token, threads=threads, n_prompt=n_prompt) 150 | build_dir = os.getenv("BUILD_DIR", "build") 151 | bench_path = os.path.join(build_dir, "bin", "llama-bench") 152 | if not os.path.exists(bench_path): 153 | raise HTTPException(status_code=500, detail="Benchmark binary not found") 154 | command = [ 155 | bench_path, 156 | '-m', request.model.value, 157 | '-n', str(request.n_token), 158 | '-ngl', '0', 159 | '-b', '1', 160 | '-t', str(request.threads), 161 | '-p', str(request.n_prompt), 162 | '-r', '5' 163 | ] 164 | try: 165 | result = subprocess.run(command, capture_output=True, text=True, check=True) 166 | parsed_data = parse_benchmark_data(result.stdout) 167 | return parsed_data 168 | except subprocess.CalledProcessError as e: 169 | raise HTTPException(status_code=500, detail=f"Benchmark failed: {str(e)}") 170 | 171 | # Validate prompt length for perplexity 172 | 173 | def validate_prompt_length(prompt: str = Query(..., description="Input text for perplexity calculation"), ctx_size: int = Query(10, gt=3)) -> str: 174 | token_count = len(prompt.split()) 175 | min_tokens = 2 * ctx_size 176 | if token_count < min_tokens: 177 | raise HTTPException( 178 | status_code=400, 179 | detail=f"Prompt too short. Needs at least {min_tokens} tokens, got {token_count}" 180 | ) 181 | return prompt 182 | 183 | # Perplexity endpoint 184 | async def run_perplexity( 185 | model: ModelEnum, 186 | prompt: str = Depends(validate_prompt_length), 187 | threads: int = Query(2, gt=0, le=os.cpu_count()), 188 | ctx_size: int = Query(10, gt=3), 189 | ppl_stride: int = Query(0, ge=0) 190 | ): 191 | """Calculate perplexity for given text and model""" 192 | try: 193 | request = PerplexityRequest( 194 | model=model, 195 | prompt=prompt, 196 | threads=threads, 197 | ctx_size=ctx_size, 198 | ppl_stride=ppl_stride 199 | ) 200 | except ValueError as e: 201 | raise HTTPException(status_code=400, detail=str(e)) 202 | 203 | build_dir = os.getenv("BUILD_DIR", "build") 204 | ppl_path = os.path.join(build_dir, "bin", "llama-perplexity") 205 | if not os.path.exists(ppl_path): 206 | raise HTTPException(status_code=500, detail="Perplexity binary not found") 207 | 208 | command = [ 209 | ppl_path, 210 | '--model', request.model.value, 211 | '--prompt', request.prompt, 212 | '--threads', str(request.threads), 213 | '--ctx-size', str(request.ctx_size), 214 | '--perplexity', 215 | '--ppl-stride', str(request.ppl_stride) 216 | ] 217 | 218 | try: 219 | result = subprocess.run(command, capture_output=True, text=True, check=True) 220 | parsed_data = parse_perplexity_data(result.stderr) 221 | return parsed_data 222 | except subprocess.CalledProcessError as e: 223 | raise HTTPException(status_code=500, detail=str(e)) 224 | 225 | # Model sizes endpoint 226 | def get_model_sizes(): 227 | """Endpoint to get the file sizes of supported .gguf models.""" 228 | model_sizes = {} 229 | models_dir = "models" 230 | for subdir in os.listdir(models_dir): 231 | subdir_path = os.path.join(models_dir, subdir) 232 | if os.path.isdir(subdir_path): 233 | for file in os.listdir(subdir_path): 234 | if file.endswith(".gguf"): 235 | file_path = os.path.join(subdir_path, file) 236 | file_size_bytes = os.path.getsize(file_path) 237 | file_size_mb = round(file_size_bytes / (1024 * 1024), 3) 238 | file_size_gb = round(file_size_bytes / (1024 * 1024 * 1024), 3) 239 | model_sizes[file] = { 240 | "bytes": file_size_bytes, 241 | "MB": file_size_mb, 242 | "GB": file_size_gb 243 | } 244 | return model_sizes 245 | 246 | class ChatRequest(BaseModel): 247 | message: str 248 | port: int = 8081 249 | threads: int = 1 250 | ctx_size: int = 2048 251 | n_predict: int = 256 252 | temperature: float = 0.8 253 | 254 | async def chat_with_bitnet( 255 | chat: ChatRequest 256 | ): 257 | """ 258 | Middleman endpoint: receives a chat message and forwards it to the specified bitnet (llama server instance) by port. 259 | Returns the response from the bitnet. 260 | """ 261 | host = "127.0.0.1" 262 | key = (host, chat.port) 263 | proc = server_processes.get(key) 264 | cfg = server_configs.get(key) 265 | if not (proc and proc.poll() is None and cfg): 266 | raise HTTPException(status_code=404, detail=f"Server on port {chat.port} not running or not configured.") 267 | server_url = f"http://{host}:{chat.port}/completion" 268 | payload = { 269 | "prompt": chat.message, 270 | "threads": chat.threads, 271 | "ctx_size": chat.ctx_size, 272 | "n_predict": chat.n_predict, 273 | "temperature": chat.temperature 274 | } 275 | # Use httpx for async requests 276 | async def _chat(): 277 | async with httpx.AsyncClient() as client: 278 | try: 279 | response = await client.post(server_url, json=payload, timeout=60.0) # Increased timeout 280 | response.raise_for_status() # Raise an exception for bad status codes 281 | return response.json() 282 | except httpx.ReadTimeout: 283 | raise HTTPException(status_code=504, detail=f"Request to BitNet server on port {chat.port} timed out.") 284 | except httpx.ConnectError: 285 | raise HTTPException(status_code=503, detail=f"Could not connect to BitNet server on port {chat.port}.") 286 | except httpx.HTTPStatusError as e: 287 | raise HTTPException(status_code=e.response.status_code, detail=f"BitNet server error: {e.response.text}") 288 | except Exception as e: 289 | # Catch any other unexpected errors during the chat process 290 | error_detail = f"An unexpected error occurred while communicating with BitNet server on port {chat.port}: {str(e)}" 291 | raise HTTPException(status_code=500, detail=error_detail) 292 | return await _chat() 293 | 294 | class MultiChatRequest(BaseModel): 295 | requests: List[ChatRequest] 296 | 297 | async def multichat_with_bitnet(multichat: MultiChatRequest): 298 | async def run_chat(chat_req: ChatRequest): 299 | chat_fn = chat_with_bitnet(chat_req) 300 | return await chat_fn 301 | results = await asyncio.gather(*(run_chat(req) for req in multichat.requests), return_exceptions=True) 302 | # Format results: if exception, return error message 303 | formatted = [] 304 | for res in results: 305 | if isinstance(res, Exception): 306 | if isinstance(res, HTTPException): 307 | formatted.append({"error": res.detail, "status_code": res.status_code}) 308 | else: 309 | formatted.append({"error": str(res)}) 310 | else: 311 | formatted.append(res) 312 | return {"results": formatted} -------------------------------------------------------------------------------- /app/lib/models.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from pydantic import BaseModel, validator, root_validator 3 | from enum import Enum 4 | import os 5 | 6 | def create_model_enum(directory: str): 7 | """Dynamically create an Enum for models based on files in the directory.""" 8 | models = {} 9 | for subdir in os.listdir(directory): 10 | subdir_path = os.path.join(directory, subdir) 11 | if os.path.isdir(subdir_path): 12 | for file in os.listdir(subdir_path): 13 | if file.endswith(".gguf"): 14 | model_name = f"{subdir}_{file.replace('-', '_').replace('.', '_')}" 15 | models[model_name] = os.path.join(subdir_path, file) 16 | return Enum("ModelEnum", models) 17 | 18 | # Create the ModelEnum based on the models directory 19 | ModelEnum = create_model_enum("models") 20 | 21 | max_n_predict = 100000 22 | 23 | class BenchmarkRequest(BaseModel): 24 | model: ModelEnum 25 | n_token: int = 128 26 | threads: int = 2 27 | n_prompt: int = 32 28 | 29 | @validator('threads') 30 | def validate_threads(cls, v): 31 | max_threads = os.cpu_count() 32 | if v > max_threads: 33 | raise ValueError(f"Number of threads cannot exceed {max_threads}") 34 | return v 35 | 36 | @validator('n_token', 'n_prompt', 'threads') 37 | def validate_positive(cls, v): 38 | if v <= 0: 39 | raise ValueError("Value must be positive") 40 | return v 41 | 42 | class PerplexityRequest(BaseModel): 43 | model: ModelEnum 44 | prompt: str 45 | threads: int = 2 46 | ctx_size: int = 3 47 | ppl_stride: int = 0 48 | 49 | @validator('threads') 50 | def validate_threads(cls, v): 51 | max_threads = os.cpu_count() 52 | if v > max_threads: 53 | raise ValueError(f"Number of threads cannot exceed {max_threads}") 54 | elif v <= 0: 55 | raise ValueError("Value must be positive") 56 | return v 57 | 58 | @validator('ctx_size') 59 | def validate_positive(cls, v): 60 | if v < 3: 61 | raise ValueError("Value must be greater than 3") 62 | return v 63 | 64 | @root_validator(pre=True) 65 | def validate_prompt_length(cls, values: Dict[str, Any]) -> Dict[str, Any]: 66 | prompt = values.get('prompt') 67 | ctx_size = values.get('ctx_size') 68 | 69 | if prompt and ctx_size: 70 | token_count = len(prompt.split()) 71 | min_tokens = 2 * ctx_size 72 | 73 | if token_count < min_tokens: 74 | raise ValueError(f"Prompt too short. Needs at least {min_tokens} tokens, got {token_count}") 75 | 76 | return values 77 | 78 | class InferenceRequest(BaseModel): 79 | model: ModelEnum 80 | n_predict: int = 128 81 | prompt: str 82 | threads: int = 2 83 | ctx_size: int = 2048 84 | temperature: float = 0.8 85 | 86 | @validator('threads') 87 | def validate_threads(cls, v): 88 | max_threads = os.cpu_count() 89 | if v > max_threads: 90 | raise ValueError(f"Number of threads cannot exceed {max_threads}") 91 | return v 92 | 93 | @validator('n_predict') 94 | def validate_n_predict(cls, v): 95 | if v > max_n_predict: 96 | raise ValueError(f"Number of predictions cannot exceed {max_n_predict}") 97 | return v 98 | 99 | @validator('threads', 'ctx_size', 'temperature', 'n_predict') 100 | def validate_positive(cls, v): 101 | if v <= 0: 102 | raise ValueError("Value must be positive") 103 | return v -------------------------------------------------------------------------------- /app/lib/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, List 2 | import re 3 | from fastapi import HTTPException 4 | import subprocess 5 | 6 | def run_command(command: List[str]) -> str: 7 | """Run a system command and capture its output.""" 8 | try: 9 | result = subprocess.run(command, check=True, capture_output=True, text=True) 10 | return result.stdout 11 | except subprocess.CalledProcessError as e: 12 | raise HTTPException(status_code=500, detail=f"Error occurred while running command: {e}") 13 | except Exception as e: 14 | raise HTTPException(status_code=500, detail=f"Unexpected error: {e}") 15 | 16 | def parse_benchmark_data(log: str) -> List[Dict[str, str]]: 17 | lines = log.strip().split('\n') 18 | headers = lines[0].split('|') 19 | headers = [header.strip() for header in headers if header.strip()] 20 | data_lines = lines[2:] # Skip the header and dashes rows 21 | 22 | parsed_data = [] 23 | for line in data_lines: 24 | values = line.split('|') 25 | values = [value.strip() for value in values if value.strip()] 26 | if len(values) == len(headers): # Ensure the number of values matches the number of headers 27 | obj = {headers[i]: values[i] for i in range(len(headers))} 28 | if all(obj.values()) and 'build:' not in obj.get('model', ''): 29 | parsed_data.append(obj) 30 | return parsed_data 31 | 32 | def clean_key(key: str) -> str: 33 | """Remove type information from key names by taking first part before space""" 34 | return key.split()[0] if ' ' in key else key 35 | 36 | def parse_model_loader_section(lines: List[str]) -> Dict[str, Any]: 37 | data = { 38 | "general": {}, 39 | "llama": { 40 | "attention": {}, 41 | "rope": {} 42 | }, 43 | "tokenizer": { 44 | "ggml": {} 45 | } 46 | } 47 | 48 | for line in lines: 49 | if not line.startswith("llama_model_loader:"): 50 | continue 51 | 52 | # Split after prefix and number 53 | try: 54 | # Example: "llama_model_loader: - kv 0: general.architecture str = llama" 55 | prefix, content = line.split(": ", 1) # Split on first ": " 56 | _, content = content.split(":", 1) # Remove "- kv NUMBER:" 57 | key, value = content.split("=", 1) # Split on "=" 58 | 59 | # Clean and split key parts 60 | key_parts = key.strip().split(".") # Split path components 61 | final_key = clean_key(key_parts[-1]) # Clean type from last part 62 | value = value.strip() # Clean value 63 | 64 | # Map to correct section 65 | if "general" in key_parts: 66 | data["general"][final_key] = value 67 | elif "llama" in key_parts: 68 | if "attention" in key_parts: 69 | data["llama"]["attention"][final_key] = value 70 | elif "rope" in key_parts: 71 | data["llama"]["rope"][final_key] = value 72 | else: 73 | data["llama"][final_key] = value 74 | elif "tokenizer" in key_parts: 75 | data["tokenizer"]["ggml"][final_key] = value 76 | 77 | except ValueError: 78 | continue 79 | 80 | return data 81 | 82 | def parse_meta_section(lines: List[str]) -> Dict[str, str]: 83 | data = {} 84 | for line in lines: 85 | if not line.startswith("llm_load_print_meta:"): 86 | continue 87 | parts = line.split('=', 1) 88 | if len(parts) != 2: 89 | continue 90 | key = parts[0].replace('llm_load_print_meta:', '').strip() 91 | value = parts[1].strip() 92 | data[re.sub(r'\s+', '', key)] = value 93 | return data 94 | 95 | def parse_system_info(line: str) -> Dict[str, Any]: 96 | if not line.startswith("system_info:"): 97 | return {} 98 | info = {} 99 | parts = line.replace("system_info:", "").split("|") 100 | for part in parts: 101 | part = part.strip() 102 | if "=" in part: 103 | key, value = part.split("=", 1) 104 | info[key.strip()] = int(value) if value.strip().isdigit() else value.strip() 105 | return info 106 | 107 | def parse_perf_context(lines: List[str]) -> Dict[str, str]: 108 | data = {} 109 | for line in lines: 110 | if not line.startswith("llama_perf_context_print:"): 111 | continue 112 | parts = line.split('=', 1) 113 | if len(parts) != 2: 114 | continue 115 | key = parts[0].replace('llama_perf_context_print:', '').strip() 116 | value = parts[1].strip() 117 | data[re.sub(r'\s+', '', key)] = value 118 | return data 119 | 120 | def parse_perplexity_data(log: str) -> Dict[str, Any]: 121 | """Parse llama.cpp log output into structured data""" 122 | lines = log.strip().split('\n') 123 | 124 | data = { 125 | "llama_model_loader": parse_model_loader_section(lines), 126 | "llm_load_print_meta": parse_meta_section(lines), 127 | "system_info": {}, 128 | "llama_perf_context_print": parse_perf_context(lines), 129 | "final_estimate": {} 130 | } 131 | 132 | for line in lines: 133 | if 'Final estimate:' in line: # More lenient matching 134 | try: 135 | # Extract numeric values 136 | parts = line.split('Final estimate: PPL = ')[1].split(' +/- ') 137 | if len(parts) == 2: 138 | data['final_estimate'] = { 139 | 'ppl': float(parts[0].strip()), 140 | 'uncertainty': float(parts[1].strip()) 141 | } 142 | print(f"Found final estimate: {data['final_estimate']}") # Debug print 143 | except (IndexError, ValueError) as e: 144 | print(f"Failed to parse final estimate from: {line}") 145 | print(f"Error: {e}") 146 | elif line.startswith('system_info:'): 147 | data['system_info'] = parse_system_info(line) 148 | 149 | return data -------------------------------------------------------------------------------- /app/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Query, HTTPException 2 | import os 3 | from fastapi_mcp import FastApiMCP 4 | from lib.models import ModelEnum 5 | import lib.endpoints as endpoints 6 | from lib.endpoints import chat_with_bitnet, ChatRequest, multichat_with_bitnet, MultiChatRequest 7 | import traceback 8 | 9 | app = FastAPI() 10 | 11 | @app.post("/initialize-server") 12 | async def initialize_server( 13 | threads: int = Query(os.cpu_count() // 2, gt=0, le=os.cpu_count()), 14 | ctx_size: int = Query(2048, gt=0), 15 | port: int = Query(8081, gt=1023), 16 | system_prompt: str = Query("You are a helpful assistant.", description="Unique system prompt for this server instance"), 17 | n_predict: int = Query(4096, gt=0, description="Number of tokens to predict for the server instance"), 18 | temperature: float = Query(0.8, gt=0.0, le=2.0, description="Temperature for sampling") 19 | ): 20 | try: 21 | return await endpoints.initialize_server_endpoint( 22 | threads=threads, 23 | ctx_size=ctx_size, 24 | port=port, 25 | system_prompt=system_prompt, 26 | n_predict=n_predict, 27 | temperature=temperature 28 | ) 29 | except Exception as e: 30 | print(traceback.format_exc()) 31 | raise HTTPException(status_code=500, detail=str(e)) 32 | 33 | def _max_threads(): 34 | return os.cpu_count() or 1 35 | 36 | # --- Server Initialization and Shutdown Endpoints --- 37 | def validate_thread_allocation(requests): 38 | max_threads = _max_threads() 39 | total_requested = sum(req["threads"] for req in requests) 40 | for req in requests: 41 | if req["threads"] > max_threads: 42 | raise HTTPException( 43 | status_code=400, 44 | detail=f"Requested {req['threads']} threads for a server, but only {max_threads} are available." 45 | ) 46 | if total_requested > max_threads: 47 | raise HTTPException( 48 | status_code=400, 49 | detail=f"Total requested threads ({total_requested}) exceed available threads ({max_threads})." 50 | ) 51 | 52 | @app.post("/shutdown-server") 53 | async def shutdown_server(port: int = Query(8081, gt=1023)): 54 | try: 55 | return await endpoints.shutdown_server_endpoint(port=port) 56 | except Exception as e: 57 | print(traceback.format_exc()) 58 | raise HTTPException(status_code=500, detail=str(e)) 59 | 60 | @app.get("/server-status") 61 | async def server_status_endpoint(port: int = Query(8081, gt=1023)): # Renamed for clarity 62 | return await endpoints.get_server_status(port=port) # Call the function from endpoints.py 63 | 64 | @app.get("/benchmark") 65 | async def benchmark( 66 | model: ModelEnum, 67 | n_token: int = Query(128, gt=0), 68 | threads: int = Query(2, gt=0, le=os.cpu_count()), 69 | n_prompt: int = Query(32, gt=0) 70 | ): 71 | try: 72 | return await endpoints.run_benchmark(model, n_token, threads, n_prompt) 73 | except Exception as e: 74 | print(traceback.format_exc()) 75 | raise HTTPException(status_code=500, detail=str(e)) 76 | 77 | @app.get("/perplexity") 78 | async def perplexity( 79 | model: ModelEnum, 80 | prompt: str, 81 | threads: int = Query(2, gt=0, le=os.cpu_count()), 82 | ctx_size: int = Query(4, gt=0), 83 | ppl_stride: int = Query(0, ge=0) 84 | ): 85 | try: 86 | return await endpoints.run_perplexity(model, prompt, threads, ctx_size, ppl_stride) 87 | except Exception as e: 88 | print(traceback.format_exc()) 89 | raise HTTPException(status_code=500, detail=str(e)) 90 | 91 | @app.get("/model-sizes") 92 | def model_sizes(): 93 | return endpoints.get_model_sizes() 94 | 95 | @app.post("/chat") 96 | async def chat(chat: ChatRequest): 97 | try: 98 | return await chat_with_bitnet(chat) 99 | except Exception as e: 100 | print(traceback.format_exc()) 101 | raise HTTPException(status_code=500, detail=str(e)) 102 | 103 | # Parallel multi-chat endpoint 104 | @app.post("/multichat") 105 | async def multichat(multichat: MultiChatRequest): 106 | try: 107 | return await multichat_with_bitnet(multichat) 108 | except Exception as e: 109 | print(traceback.format_exc()) 110 | raise HTTPException(status_code=500, detail=str(e)) 111 | 112 | # Wrap with MCP for Model Context Protocol support 113 | mcp = FastApiMCP(app) 114 | 115 | # Mount the MCP server directly to your FastAPI app 116 | mcp.mount() -------------------------------------------------------------------------------- /app/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/grctest/FastAPI-BitNet/754414662f4c86b3ce1da9fd9449348246c3eee5/app/models/.gitkeep --------------------------------------------------------------------------------