├── .gitignore ├── .gitmodules ├── CHANGELOG.md ├── LICENSE ├── README.md ├── assets ├── demo_llama3_8B_fpft.gif └── demo_yi_34B_peft.gif ├── db ├── api_key_management.py └── create_apikey_table.sql ├── docker ├── Dockerfile └── README.md ├── environment.yml ├── examples ├── __init__.py └── local_rag │ ├── README.md │ ├── __init__.py │ ├── requirements.txt │ └── run.py ├── green_bit_llm ├── __init__.py ├── args_parser.py ├── common │ ├── __init__.py │ ├── enum.py │ ├── model.py │ └── utils.py ├── evaluation │ ├── README.md │ ├── __init__.py │ ├── datautils.py │ ├── evaluate.py │ ├── lmclass.py │ └── utils.py ├── inference │ ├── README.md │ ├── __init__.py │ ├── chat_base.py │ ├── chat_cli.py │ ├── conversation.py │ ├── sim_gen.py │ └── utils.py ├── langchain │ ├── README.md │ ├── __init__.py │ ├── chat_model.py │ ├── embedding.py │ └── pipeline.py ├── patches │ ├── __init__.py │ ├── deepseek_v3_moe_patch.py │ └── qwen3_moe_patch.py ├── routing │ ├── __init__.py │ └── confidence_scorer.py ├── serve │ ├── __init__.py │ ├── api │ │ ├── __init__.py │ │ └── v1 │ │ │ ├── __init__.py │ │ │ └── fastapi_server.py │ └── auth │ │ ├── __init__.py │ │ ├── api_key_auth.py │ │ └── rate_limiter.py ├── sft │ ├── README.md │ ├── __init__.py │ ├── finetune.py │ ├── optim │ │ ├── __init__.py │ │ ├── adamw8bit.py │ │ └── bnb_optimizer.py │ ├── peft_lora.py │ ├── peft_utils │ │ ├── __init__.py │ │ ├── gba_lora.py │ │ └── model.py │ ├── trainer.py │ └── utils.py └── version.py ├── requirements.txt ├── scripts └── curl_script ├── setup.py └── tests ├── __init__.py ├── test_langchain_chatmodel.py ├── test_langchain_embedding.py └── test_langchain_pipeline.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | .coverage 141 | .idea/ 142 | .vscode/ 143 | .mypy_cache 144 | 145 | # downloaded dataset files 146 | train/ 147 | test/ 148 | 149 | # Logs 150 | logs/* 151 | !logs/.gitkeep 152 | 153 | # databasegit status 154 | db/* 155 | !db/.gitkeep -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "green_bit_llm/routing/libra_router"] 2 | path = green_bit_llm/routing/libra_router 3 | url = https://github.com/GreenBitAI/Libra-Router.git 4 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | The format is based on [Keep a Changelog](http://keepachangelog.com/) 5 | and this project adheres to [Semantic Versioning](http://semver.org/). 6 | 7 | ## [0.2.6] - 2025/6/4 8 | 9 | ### Updated 10 | 11 | - Fixed RoPE type missing problem in deepseek-r1-qwen3-8B model 12 | - Update README with Qwen3 model notes and Transformers compatibility details 13 | 14 | ## [0.2.5] - 2025/5/30 15 | 16 | ### Added 17 | 18 | - Model server support 19 | - Deepseek model support 20 | - Qwen-3 model support 21 | - Langchain integration 22 | - local RAG example 23 | 24 | ### Updated 25 | 26 | - Various refactoring and improvements 27 | 28 | ## [0.2.4] - 2024/06/04 29 | 30 | ### Fixed 31 | 32 | - Source distribution (was missing `requirements.txt`) 33 | 34 | ## [0.2.3] - 2024/05/26 35 | 36 | ### Added 37 | 38 | - Evaluation results 39 | 40 | ### Fixed 41 | 42 | - Changelog order and date format 43 | - URL in README for PyPI 44 | 45 | ## [0.2.2] - 2024/05/24 46 | 47 | ### Added 48 | 49 | - Evaluation results 50 | 51 | ### Fixed 52 | 53 | - Version numbering 54 | 55 | ## [0.2.1] - 2024/05/22 56 | 57 | ### Added 58 | 59 | - Missing changelog entries 60 | 61 | ### Fixed 62 | 63 | - Version numbering 64 | 65 | ## [0.2.0] - 2024/05/20 66 | 67 | ### Added 68 | 69 | - Initial support for a classical GPTQ model using the MPQLinear layer 70 | - AutoGPTQ information and commands in the repository 71 | - Support for LoRA and GPTQ evaluation 72 | - SFT comparison updates 73 | - Missing comment to the customized trainer class 74 | 75 | ### Fixed 76 | 77 | - Issue in GbaSFTTrainer for saving non-GBA models 78 | - Mismatch issue between GPTQ and LoRA 79 | - Bug preventing quant_strategy.json from being saved during SFT 80 | 81 | ### Updated 82 | 83 | - README with new AutoGPTQ and GPTQ support information 84 | 85 | ## [0.1.0] - 2024/01/05 86 | 87 | ### Added 88 | 89 | - Integration with Bitorch Engine 90 | - Full-parameter fine-tuning and PEFT support 91 | - Fast inference capabilities 92 | - Comprehensive evaluation tools and detailed model evaluation results 93 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Green-Bit-LLM 2 | 3 | This Python package uses the [Bitorch Engine](https://github.com/GreenBitAI/bitorch-engine) for efficient operations on [GreenBitAI's Low-bit Language Models (LLMs)](https://huggingface.co/GreenBitAI). 4 | It enables **high-performance inference** on both cloud-based and consumer-level GPUs, and supports **full-parameter fine-tuning** directly **using quantized LLMs**. 5 | Additionally, you can use our provided **evaluation tools** to validate the model's performance on mainstream benchmark datasets. 6 | 7 | ## News 8 | - [2025/5] 9 | - Qwen-3 and Deepseek support. 10 | - [2024/10] 11 | - Langchain integration, model server support. 12 | - [2024/04] 13 | - We have launched over **200 low-bit LLMs** in [GreenBitAI's Hugging Face Model Repo](https://huggingface.co/GreenBitAI). Our release includes highly precise 2.2/2.5/3-bit models across the LLM family, featuring LLaMA 2/3, 01-Yi, Qwen, Mistral, Phi-3 and more. 14 | - We released [Bitorch Engine](https://github.com/GreenBitAI/bitorch-engine) for **low-bit** quantized neural network operations. Our release support full parameter fine-tuning and parameter efficiency fine-tuning (PEFT), even under extremely constrained GPU resource conditions. 15 | 16 | ## LLMs 17 | 18 | We have released over 260 highly efficient 2-4 bit models across the modern LLM family, featuring Deepseek, LLaMA, Qwen, Mistral, Phi, and more. 19 | Explore all available models in our [Hugging Face repository](https://huggingface.co/GreenBitAI). 20 | green-bit-llm is also fully compatible with all 4-bit models in the AutoGPTQ series. 21 | 22 | ## Installation 23 | 24 | This package depends on [Bitorch Engine](https://github.com/GreenBitAI/bitorch-engine) and 25 | a first experimental **binary release for Linux with CUDA 12.1 is ready.** 26 | We recommend to create a conda environment to manage the installed CUDA version and other packages. 27 | 28 | ### Conda 29 | 30 | We recommend using [Miniconda](https://docs.conda.io/en/latest/miniconda.html) for a lightweight installation. 31 | Please download the installer from the official Miniconda website and follow the setup instructions. 32 | 33 | After Conda successfully installed, do the following steps: 34 | 35 | 1. Create Environment for Python 3.10 and activate it: 36 | ```bash 37 | conda create -y --name bitorch-engine python=3.10 38 | conda activate bitorch-engine 39 | ``` 40 | 2. Install target CUDA version: 41 | ```bash 42 | conda install -y -c "nvidia/label/cuda-12.1.0" cuda-toolkit 43 | ``` 44 | 3. Install bitorch engine: 45 | 46 | *Inference ONLY* 47 | ```bash 48 | pip install \ 49 | "https://packages.greenbit.ai/whl/cu121/bitorch-engine/bitorch_engine-0.2.6-cp310-cp310-linux_x86_64.whl" 50 | ``` 51 | 52 | *Training REQUIRED* 53 | 54 | Install our customized torch that allows gradients on INT tensors and install it with pip (this URL is for CUDA 12.1 55 | and Python 3.10 - you can find other versions [here](https://packages.greenbit.ai/whl/)) together with bitorch engine: 56 | ```bash 57 | pip install \ 58 | "https://packages.greenbit.ai/whl/cu121/torch/torch-2.5.1-cp310-cp310-linux_x86_64.whl" \ 59 | "https://packages.greenbit.ai/whl/cu121/bitorch-engine/bitorch_engine-0.2.6-cp310-cp310-linux_x86_64.whl" 60 | ``` 61 | 62 | 4. Install green-bit-llm: 63 | 64 | via pypi 65 | ```bash 66 | pip install green-bit-llm 67 | ``` 68 | or from source 69 | ```bash 70 | git clone https://github.com/GreenBitAI/green-bit-llm.git 71 | cd green-bit-llm 72 | pip install -r requirements.txt 73 | ``` 74 | 75 | **Note: For Qwen3 model support, you need to install the development version of transformers:** 76 | ```bash 77 | pip install git+https://github.com/huggingface/transformers.git 78 | ``` 79 | This installs transformers version 4.53.0.dev0 which includes the necessary Qwen3 model support. 80 | 81 | 5. Install [Flash Attention](https://github.com/Dao-AILab/flash-attention) (`flash-attn`) according to their [official instructions](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features). 82 | ```bash 83 | pip install flash-attn --no-build-isolation 84 | ``` 85 | 86 | ## Examples 87 | 88 | ### Simple Generation 89 | 90 | Run the simple generation script as follows: 91 | 92 | ```bash 93 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.inference.sim_gen --model GreenBitAI/Qwen-3-1.7B-layer-mix-bpw-4.0 --max-tokens 1024 --use-flash-attention-2 94 | ``` 95 | 96 | ### FastAPI Model Server 97 | 98 | A high-performance HTTP API for text generation with GreenBitAI's low-bit models. 99 | 100 | #### Quick Start 101 | 1. Run: 102 | ```shell 103 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.serve.api.v1.fastapi_server --model GreenBitAI/Qwen-3-1.7B-layer-mix-bpw-4.0 --host 127.0.0.1 --port 11668 104 | ``` 105 | 2. Use: 106 | ```shell 107 | # Chat 108 | curl http://localhost:11668/v1/GreenBitAI-Qwen-3-17B-layer-mix-bpw-40/chat/completions -H "Content-Type: application/json" \ 109 | -d '{"model": "default_model", "messages": [{"role": "user", "content": "Hello!"}]}' 110 | 111 | # Chat stream 112 | curl http://localhost:11668/v1/GreenBitAI-Qwen-3-17B-layer-mix-bpw-40/chat/completions -H "Content-Type: application/json" \ 113 | -d '{"model": "default_model", "messages": [{"role": "user", "content": "Hello!"}], "stream": "True"}' 114 | ``` 115 | 116 | ### Full-parameter fine-tuning 117 | 118 | Full parameter fine-tuning of the LLaMA-3 8B model using a single GTX 3090 GPU with 24GB of graphics memory: 119 | 120 | 121 | 122 | Run the script as follows to fine-tune the quantized weights of the model on the target dataset. 123 | The '--tune-qweight-only' parameter determines whether to fine-tune only the quantized weights or all weights, including non-quantized ones. 124 | 125 | ```bash 126 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.sft.finetune --model GreenBitAI/Qwen-1.5-1.8B-layer-mix-bpw-3.0 --dataset tatsu-lab/alpaca --optimizer DiodeMix --tune-qweight-only 127 | 128 | # AutoGPTQ model Q-SFT 129 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.sft.finetune --model astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit --dataset tatsu-lab/alpaca --tune-qweight-only --batch-size 1 130 | ``` 131 | 132 | ### Parameter efficient fine-tuning 133 | 134 | PEFT of the 01-Yi 34B model using a single GTX 3090 GPU with 24GB of graphics memory: 135 | 136 | 137 | 138 | ```bash 139 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.sft.peft_lora --model GreenBitAI/Qwen-1.5-1.8B-layer-mix-bpw-3.0 --dataset tatsu-lab/alpaca --lr-fp 1e-6 140 | 141 | # AutoGPTQ model with Lora 142 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.sft.peft_lora --model astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit --dataset tatsu-lab/alpaca --lr-fp 1e-6 143 | ``` 144 | 145 | ## Further Usage 146 | 147 | Please see the description of the [Inference](green_bit_llm/inference/README.md), [sft](green_bit_llm/sft/README.md) and [evaluation](green_bit_llm/evaluation/README.md) package for details. 148 | 149 | ## License 150 | We release our codes under the [Apache 2.0 License](LICENSE). 151 | Additionally, three packages are also partly based on third-party open-source codes. For detailed information, please refer to the description pages of the sub-projects. 152 | -------------------------------------------------------------------------------- /assets/demo_llama3_8B_fpft.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/assets/demo_llama3_8B_fpft.gif -------------------------------------------------------------------------------- /assets/demo_yi_34B_peft.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/assets/demo_yi_34B_peft.gif -------------------------------------------------------------------------------- /db/create_apikey_table.sql: -------------------------------------------------------------------------------- 1 | -- z.B. sqlite3 greenbit.db < create_tables.sql 2 | 3 | -- drop tables if exists 4 | --DROP TABLE IF EXISTS api_keys; 5 | --DROP TABLE IF EXISTS users; 6 | 7 | -- Create users table 8 | CREATE TABLE IF NOT EXISTS users ( 9 | id INTEGER PRIMARY KEY AUTOINCREMENT, 10 | name TEXT NOT NULL, 11 | email TEXT NOT NULL UNIQUE, 12 | organization TEXT, 13 | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP 14 | ); 15 | 16 | -- Create API keys table 17 | CREATE TABLE IF NOT EXISTS api_keys ( 18 | id INTEGER PRIMARY KEY AUTOINCREMENT, 19 | user_id INTEGER NOT NULL, 20 | api_key_hash TEXT NOT NULL UNIQUE, 21 | name TEXT NOT NULL, 22 | email TEXT NOT NULL, 23 | organization TEXT, 24 | tier VARCHAR(20) DEFAULT 'basic', 25 | rpm_limit INTEGER DEFAULT 60, 26 | tpm_limit INTEGER DEFAULT 40000, 27 | concurrent_requests INTEGER DEFAULT 5, 28 | max_tokens INTEGER DEFAULT 32768, 29 | permissions TEXT DEFAULT 'completion,chat', 30 | is_active BOOLEAN DEFAULT TRUE, 31 | created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 32 | last_used_at TIMESTAMP, 33 | FOREIGN KEY (user_id) REFERENCES users(id), 34 | CHECK (is_active IN (0, 1)) 35 | ); 36 | 37 | -- Create indexes 38 | CREATE INDEX IF NOT EXISTS idx_api_key_hash ON api_keys(api_key_hash); 39 | CREATE INDEX IF NOT EXISTS idx_user_email ON users(email); 40 | CREATE INDEX IF NOT EXISTS idx_last_used ON api_keys(last_used_at); -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG FROM_IMAGE="bitorch/engine" 2 | FROM ${FROM_IMAGE} as bitorch-engine-base 3 | 4 | FROM bitorch-engine-base as requirements-installed 5 | COPY "../requirements.txt" "/green-bit-llm-req.txt" 6 | RUN pip install packaging -r "/green-bit-llm-req.txt" && \ 7 | rm "/green-bit-llm-req.txt" && \ 8 | pip install flash-attn --no-build-isolation && \ 9 | pip cache purge 10 | 11 | # clone instead of mounting makes the code in the image independent from local changes 12 | # to mount your code before building, use the target above and mount your local code 13 | FROM requirements-installed as code-cloned 14 | ARG GIT_URL="https://github.com/GreenBitAI/green-bit-llm.git" 15 | ARG GIT_BRANCH="main" 16 | ARG BUILD_TARGET="." 17 | RUN git clone \ 18 | --depth 1 \ 19 | --branch "${GIT_BRANCH}" \ 20 | "${GIT_URL}" \ 21 | /green-bit-llm && \ 22 | cd /green-bit-llm && \ 23 | pip install -e . 24 | -------------------------------------------------------------------------------- /docker/README.md: -------------------------------------------------------------------------------- 1 | # Project Setup with Docker 2 | 3 | In the following, we show how to build a [Docker](https://www.docker.com/) image for this project and explain additional options. 4 | 5 | ## Build Docker Image 6 | 7 | 1. Build the bitorch engine image according to [these instructions](https://github.com/GreenBitAI/bitorch-engine/blob/HEAD/docker/README.md). 8 | 2. Now you should be able to build the image by running the following commands 9 | (if you used a custom image name or tag, you can adjust with `--build-arg FROM_IMAGE="bitorch/engine:custom-tag"`): 10 | ```bash 11 | # cd docker 12 | # you should be in this `docker` directory 13 | cp -f ../requirements.txt . 14 | docker build -t gbai/green-bit-llm . 15 | ``` 16 | 3. You can now run the container, for example with this: 17 | ```bash 18 | docker run -it --rm --gpus all gbai/green-bit-llm 19 | ``` 20 | 4. Alternatively, you can mount the directory `/root/.cache/huggingface/hub` which will save the downloaded model cache locally, 21 | e.g. you could use your users cache directory: 22 | ```bash 23 | docker run -it --rm --gpus all -v "${HOME}/.cache/huggingface/hub":"/root/.cache/huggingface/hub" gbai/green-bit-llm 24 | ``` 25 | 26 | ## Build Options 27 | 28 | Depending on your setup, you may want to adjust some options through build arguments: 29 | - base docker image, e.g. add `--build-arg FROM_IMAGE="bitorch/engine:custom-tag"` 30 | - repository URL, e.g. add `--build-arg GIT_URL="https://github.com/MyFork/green-bit-llm.git"` 31 | - green-bit-llm branch or tag, e.g. add `--build-arg GIT_BRANCH="v1.2.3"` 32 | - if there is a problem, set the environment variable `BUILDKIT_PROGRESS=plain` to see all output 33 | 34 | ## For Development 35 | 36 | A docker image without the code cloned, e.g. for mounting a local copy of the code, can be made easily with the target `requirements-installed`: 37 | ```bash 38 | # cd docker 39 | # you should be in this `docker` directory 40 | cp -f ../requirements.txt . 41 | docker build -t gbai/green-bit-llm:no-code --target requirements-installed . 42 | docker run -it --rm --gpus all --volume "$(pwd)/..":/green-bit-llm gbai/green-bit-llm:no-code 43 | # in the docker container: 44 | cd /green-bit-llm 45 | pip install -e . 46 | ``` 47 | However, this means the build results will not be persisted in the image, so you probably want to mount the same directory every time. 48 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gbai_cuda_lm 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - python=3.9 7 | - pip=24.0 8 | - pytorch::pytorch>=2.0 9 | - pip: 10 | - sentencepiece 11 | - huggingface-hub 12 | - transformers>=4.52.4 13 | - accelerate 14 | - colorama 15 | - datasets 16 | - lm-eval==0.3.0 17 | - termcolor 18 | - pillow 19 | - requests 20 | - prompt-toolkit 21 | - rich 22 | - optimum 23 | - auto-gptq 24 | - langchain-core 25 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/examples/__init__.py -------------------------------------------------------------------------------- /examples/local_rag/README.md: -------------------------------------------------------------------------------- 1 | # Local RAG Demo 2 | 3 | ## Overview 4 | 5 | This project demonstrates a local implementation of Retrieval-Augmented Generation (RAG) using GreenBit models, including GreenBitPipeline, ChatGreenBit, and GreenBitEmbeddings. It showcases features such as document loading, text splitting, vector store creation, and various natural language processing tasks in a CUDA environment. 6 | 7 | ## Features 8 | 9 | - Document loading from web sources 10 | - Text splitting for efficient processing 11 | - Vector store creation using BERT embeddings 12 | - Rap battle simulation 13 | - Document summarization 14 | - Question answering 15 | - Question answering with retrieval 16 | 17 | ## Installation 18 | 19 | 1. Install the required packages: 20 | ``` 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | 2. Ensure you have a CUDA-compatible environment set up on your system. 25 | 26 | ## Usage 27 | 28 | Run the main script to execute all tasks: 29 | 30 | ``` 31 | CUDA_VISIBLE_DEVICES=0 \ 32 | python -m examples.local_rag.run --model "GreenBitAI/Llama-3-8B-instruct-layer-mix-bpw-4.0" \ 33 | --embedding_model "sentence-transformers/all-MiniLM-L12-v2" \ 34 | --query "What are the core components of GraphRAG?" \ 35 | --max_tokens 300 \ 36 | --web_source "https://www.microsoft.com/en-us/research/blog/graphrag-unlocking-llm-discovery-on-narrative-private-data/" 37 | ``` 38 | 39 | This will perform the following tasks: 40 | 1. Initialize the model and prepare data 41 | 2. Simulate a rap battle 42 | 3. Summarize documents based on a question 43 | 4. Perform question answering 44 | 5. Perform question answering with retrieval 45 | 46 | ## Note 47 | 48 | This implementation uses GreenBit models, which are compatible with Hugging Face's transformers library and optimized for CUDA environments. Make sure you have the appropriate CUDA setup and GreenBit model files before running the demo. -------------------------------------------------------------------------------- /examples/local_rag/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/examples/local_rag/__init__.py -------------------------------------------------------------------------------- /examples/local_rag/requirements.txt: -------------------------------------------------------------------------------- 1 | langchain-core 2 | langchain-community 3 | langchain-chroma 4 | langchain-text-splitters 5 | langchain-experimental 6 | beautifulsoup4 7 | pydantic 8 | sentence-transformers -------------------------------------------------------------------------------- /examples/local_rag/run.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | import torch 4 | 5 | from langchain_community.document_loaders import WebBaseLoader 6 | from langchain_text_splitters import RecursiveCharacterTextSplitter 7 | from langchain_chroma import Chroma 8 | from langchain_core.output_parsers import StrOutputParser 9 | from langchain_core.prompts import ChatPromptTemplate 10 | from langchain_core.runnables import RunnablePassthrough 11 | 12 | from green_bit_llm.langchain import GreenBitPipeline, ChatGreenBit, GreenBitEmbeddings 13 | 14 | import warnings 15 | warnings.filterwarnings("ignore", category=UserWarning) 16 | 17 | 18 | # Helper function to format documents 19 | def format_docs(docs): 20 | return "\n\n".join(doc.page_content for doc in docs) 21 | 22 | # Helper function to print task separators 23 | def print_task_separator(task_name): 24 | print("\n" + "="*50) 25 | print(f"Task: {task_name}") 26 | print("="*50 + "\n") 27 | 28 | def clean_output(text): 29 | # Remove all non-alphanumeric characters except periods and spaces 30 | cleaned = re.sub(r'[^a-zA-Z0-9\.\s]', ' ', text) 31 | # Replace multiple spaces with a single space 32 | cleaned = re.sub(r'\s+', ' ', cleaned).strip() 33 | # Remove any mentions of "assistant" or other unwanted words 34 | cleaned = re.sub(r'\b(assistant|correct|I apologize|mistake)\b', '', cleaned, flags=re.IGNORECASE) 35 | # Remove any remaining leading/trailing whitespace 36 | cleaned = cleaned.strip() 37 | # Ensure the first letter is capitalized 38 | cleaned = cleaned.capitalize() 39 | # Ensure the answer ends with a period 40 | if cleaned and not cleaned.endswith('.'): 41 | cleaned += '.' 42 | return cleaned 43 | 44 | def extract_answer(text): 45 | # Try to extract a single sentence answer 46 | match = re.search(r'([A-Z][^\.!?]*[\.!?])', text) 47 | if match: 48 | return match.group(1) 49 | # If no clear sentence is found, return the first 100 characters 50 | return text[:100] + '...' if len(text) > 100 else text 51 | 52 | # Load and prepare data 53 | def prepare_data(url): 54 | loader = WebBaseLoader(url) 55 | data = loader.load() 56 | text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0) 57 | all_splits = text_splitter.split_documents(data) 58 | return all_splits 59 | 60 | # Create vector store 61 | def create_vectorstore(documents, embedding_model): 62 | model_kwargs = {'trust_remote_code': True} 63 | encode_kwargs = {'normalize_embeddings': False} 64 | 65 | greenbit_embeddings = GreenBitEmbeddings.from_model_id( 66 | model_name=embedding_model, 67 | cache_dir="cache", 68 | multi_process=False, 69 | show_progress=False, 70 | model_kwargs=model_kwargs, 71 | encode_kwargs=encode_kwargs 72 | ) 73 | return Chroma.from_documents(documents=documents, embedding=greenbit_embeddings) 74 | 75 | # Initialize GreenBit model 76 | def init_greenbit_model(model_id, max_tokens): 77 | pipeline = GreenBitPipeline.from_model_id( 78 | model_id=model_id, 79 | model_kwargs={"dtype": torch.half, "seqlen": 2048, "device_map": "auto"}, 80 | pipeline_kwargs={"max_new_tokens": max_tokens, "temperature": 0.7, "do_sample": True}, 81 | ) 82 | 83 | return ChatGreenBit(llm=pipeline) 84 | 85 | 86 | # Task 1: Rap Battle Simulation 87 | def simulate_rap_battle(model): 88 | print_task_separator("Rap Battle Simulation") 89 | prompt = "Simulate a rap battle between rag and graphRag." 90 | response = model.invoke(prompt) 91 | print(response.content) 92 | 93 | 94 | # Task 2: Summarization 95 | def summarize_docs(model, vectorstore, question): 96 | print_task_separator("Summarization") 97 | prompt_template = "Summarize the main themes in these retrieved docs in a single, complete sentence of no more than 200 words: {docs}" 98 | prompt = ChatPromptTemplate.from_template(prompt_template) 99 | 100 | chain = ( 101 | {"docs": format_docs} 102 | | prompt 103 | | model 104 | | StrOutputParser() 105 | | clean_output 106 | | extract_answer 107 | ) 108 | docs = vectorstore.similarity_search(question) 109 | response = chain.invoke(docs) 110 | print(response) 111 | 112 | 113 | # Task 3: Q&A 114 | def question_answering(model, vectorstore, question): 115 | print_task_separator("Q&A") 116 | RAG_TEMPLATE = """ 117 | Answer the following question based on the context provided. Give a direct and concise answer in a single, complete sentence of no more than 100 words. Do not include any additional dialogue or explanation. 118 | 119 | Context: 120 | {context} 121 | 122 | Question: {question} 123 | 124 | Answer:""" 125 | 126 | rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE) 127 | chain = ( 128 | RunnablePassthrough.assign(context=lambda input: format_docs(input["context"])) 129 | | rag_prompt 130 | | model 131 | | StrOutputParser() 132 | | clean_output 133 | | extract_answer 134 | ) 135 | docs = vectorstore.similarity_search(question) 136 | response = chain.invoke({"context": docs, "question": question}) 137 | print(response) 138 | 139 | 140 | # Task 4: Q&A with Retrieval 141 | def qa_with_retrieval(model, vectorstore, question): 142 | print_task_separator("Q&A with Retrieval") 143 | RAG_TEMPLATE = """ 144 | Answer the following question based on the retrieved information. Provide a direct and concise answer in a single, complete sentence of no more than 100 words. Do not include any additional dialogue or explanation. 145 | 146 | Question: {question} 147 | 148 | Answer:""" 149 | 150 | rag_prompt = ChatPromptTemplate.from_template(RAG_TEMPLATE) 151 | retriever = vectorstore.as_retriever() 152 | qa_chain = ( 153 | {"context": retriever | format_docs, "question": RunnablePassthrough()} 154 | | rag_prompt 155 | | model 156 | | StrOutputParser() 157 | | clean_output 158 | | extract_answer 159 | ) 160 | response = qa_chain.invoke(question) 161 | print(response) 162 | 163 | 164 | 165 | def main(model_id, embedding_model, query, max_tokens, web_source): 166 | print_task_separator("Initialization") 167 | print("Preparing data and initializing model...") 168 | # Prepare data and initialize model 169 | all_splits = prepare_data(web_source) 170 | vectorstore = create_vectorstore(all_splits, embedding_model) 171 | model = init_greenbit_model(model_id, max_tokens) 172 | print("Initialization complete.") 173 | 174 | # Execute tasks 175 | simulate_rap_battle(model) 176 | summarize_docs(model, vectorstore, query) 177 | question_answering(model, vectorstore, query) 178 | qa_with_retrieval(model, vectorstore, query) 179 | 180 | print("\nAll tasks completed.") 181 | 182 | if __name__ == "__main__": 183 | parser = argparse.ArgumentParser(description="Run NLP tasks with specified model, query, max tokens, and web source.") 184 | parser.add_argument("--model", type=str, default="GreenBitAI/Llama-3-8B-instruct-layer-mix-bpw-4.0", help="Model ID to use for the tasks") 185 | parser.add_argument("--embedding_model", type=str, default="sentence-transformers/all-MiniLM-L12-v2", help="Embedding model to use for vector store creation") 186 | parser.add_argument("--query", type=str, required=True, help="Query to use for the tasks") 187 | parser.add_argument("--max_tokens", type=int, default=200, help="Maximum number of tokens for model output (default: 200)") 188 | parser.add_argument("--web_source", type=str, default="https://www.microsoft.com/en-us/research/blog/graphrag-unlocking-llm-discovery-on-narrative-private-data/", 189 | help="URL of the web source to load data from (default: Microsoft Research blog post on GraphRAG)") 190 | args = parser.parse_args() 191 | 192 | main(args.model, args.embedding_model, args.query, args.max_tokens, args.web_source) -------------------------------------------------------------------------------- /green_bit_llm/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | -------------------------------------------------------------------------------- /green_bit_llm/args_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | DEFAULT_MODEL_PATH = "GreenBitAI/Qwen-1.5-0.5B-layer-mix-bpw-2.2" 4 | DEFAULT_SEQLEN = 2048 5 | 6 | 7 | def setup_shared_arg_parser(parser_name="Shared argument parser for green-bit-llm scripts"): 8 | """Set up and return the argument parser with shared arguments.""" 9 | parser = argparse.ArgumentParser(description=parser_name) 10 | parser.add_argument( 11 | "--model", 12 | type=str, 13 | default=DEFAULT_MODEL_PATH, 14 | help="The path to the local model directory or Hugging Face repo.", 15 | ) 16 | parser.add_argument( 17 | "--trust-remote-code", 18 | action="store_true", 19 | help="Enable trusting remote code for tokenizer", 20 | ) 21 | parser.add_argument( 22 | "--use-flash-attention-2", 23 | action="store_true", 24 | help="Enable using flash attention v2", 25 | ) 26 | parser.add_argument( 27 | "--seqlen", 28 | type=int, 29 | default=DEFAULT_SEQLEN, 30 | help="Sequence length" 31 | ) 32 | parser.add_argument( 33 | "--save-dir", 34 | type=str, 35 | default="output/", 36 | help="Specify save dir for eval results.", 37 | ) 38 | parser.add_argument( 39 | "--save-step", 40 | type=int, 41 | default=500, 42 | help="Specify how many steps to save a checkpoint.", 43 | ) 44 | parser.add_argument( 45 | "--dtype", 46 | type=str, 47 | choices=["float", "half"], 48 | default="half", 49 | help="Dtype used in optimizer.", 50 | ) 51 | parser.add_argument( 52 | "--dataset", 53 | type=str, 54 | default="tatsu-lab/alpaca", 55 | help="Dataset name for finetuning", 56 | ) 57 | parser.add_argument( 58 | "--optimizer", 59 | type=str, 60 | default="DiodeMix", 61 | help="Optimizer to use: 1. DiodeMix, 2. AdamW8bit" 62 | ) 63 | parser.add_argument( 64 | "--batch-size", 65 | type=int, 66 | default=4, 67 | help="Batch size" 68 | ) 69 | parser.add_argument("--weight_decay", type=float, default=0.0) 70 | return parser 71 | -------------------------------------------------------------------------------- /green_bit_llm/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import load, generate 2 | -------------------------------------------------------------------------------- /green_bit_llm/common/enum.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class LayerMode(Enum): 4 | LAYER_MIX = 1 5 | CHANNEL_MIX = 2 6 | LEGENCY = 3 7 | 8 | class TextGenMode(Enum): 9 | SEQUENCE = 1 10 | TOKEN = 2 -------------------------------------------------------------------------------- /green_bit_llm/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/green_bit_llm/evaluation/__init__.py -------------------------------------------------------------------------------- /green_bit_llm/evaluation/datautils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | from datasets import load_dataset 3 | import torch 4 | import random 5 | 6 | from colorama import init, Fore, Style 7 | init(autoreset=True) 8 | 9 | def get_wikitext2(nsamples, seed, seqlen, model): 10 | """ 11 | Prepares data loaders for the Wikitext-2 dataset for training and testing. 12 | 13 | Args: 14 | nsamples (int): Number of random samples to generate from the training data. 15 | seed (int): Seed for random number generator to ensure reproducibility. 16 | seqlen (int): Sequence length for each input sample. 17 | model (str): Pretrained model identifier used for tokenization. 18 | 19 | Returns: 20 | tuple: A tuple containing the training loader and tokenized test data. 21 | """ 22 | print(Style.BRIGHT + Fore.CYAN + "Info: get_wikitext2") 23 | traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train', cache_dir="./cache/") 24 | testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test', cache_dir="./cache/") 25 | 26 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True) 27 | trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt') 28 | testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') 29 | 30 | random.seed(seed) 31 | trainloader = [] 32 | for _ in range(nsamples): 33 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 34 | j = i + seqlen 35 | inp = trainenc.input_ids[:, i:j] 36 | tar = inp.clone() 37 | tar[:, :-1] = -100 38 | trainloader.append((inp, tar)) 39 | 40 | return trainloader, testenc 41 | 42 | 43 | def get_ptb(nsamples, seed, seqlen, model): 44 | """ 45 | Load and prepare the Penn Treebank (PTB) dataset for training and validation. 46 | 47 | Args: 48 | nsamples (int): The number of samples to generate for the training loader. 49 | seed (int): The seed value for random number generation, ensuring reproducibility. 50 | seqlen (int): The sequence length of each sample. 51 | model (str): The model identifier used to load a pre-trained tokenizer. 52 | 53 | Returns: 54 | tuple: A tuple containing the training loader and tokenized validation data. 55 | """ 56 | print(Style.BRIGHT + Fore.CYAN + "Info: get_ptb") 57 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train', cache_dir="./cache/") 58 | valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation', cache_dir="./cache/") 59 | 60 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True) 61 | 62 | trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt') 63 | testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt') 64 | 65 | random.seed(seed) 66 | trainloader = [] 67 | for _ in range(nsamples): 68 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 69 | j = i + seqlen 70 | inp = trainenc.input_ids[:, i:j] 71 | tar = inp.clone() 72 | tar[:, :-1] = -100 73 | trainloader.append((inp, tar)) 74 | 75 | return trainloader, testenc 76 | 77 | 78 | def get_c4(nsamples, seed, seqlen, model): 79 | """ 80 | Loads and preprocesses the C4 dataset for training and validation. 81 | Args: 82 | nsamples (int): Number of samples to generate for training. 83 | seed (int): Random seed for reproducibility. 84 | seqlen (int): The sequence length for each training sample. 85 | model (str): Model identifier for tokenizer initialization. 86 | 87 | Returns: 88 | tuple: A tuple containing training data loader and validation data tensor. 89 | """ 90 | print(Style.BRIGHT + Fore.CYAN + "Info: get_c4") 91 | traindata = load_dataset( 92 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, 93 | split='train', 94 | cache_dir="./cache/" 95 | ) 96 | valdata = load_dataset( 97 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, 98 | split='validation', 99 | cache_dir="./cache/" 100 | ) 101 | 102 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True) 103 | 104 | random.seed(seed) 105 | trainloader = [] 106 | for _ in range(nsamples): 107 | 108 | while True: 109 | i = random.randint(0, len(traindata) - 1) 110 | trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') 111 | if trainenc.input_ids.shape[1] > seqlen + 2: 112 | break 113 | 114 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 115 | j = i + seqlen 116 | inp = trainenc.input_ids[:, i:j] 117 | tar = inp.clone() 118 | tar[:, :-1] = -100 119 | trainloader.append((inp, tar)) 120 | 121 | random.seed(0) 122 | valenc = [] 123 | for _ in range(256): 124 | while True: 125 | i = random.randint(0, len(valdata) - 1) 126 | tmp = tokenizer(valdata[i]['text'], return_tensors='pt') 127 | if tmp.input_ids.shape[1] >= seqlen: 128 | break 129 | i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) 130 | j = i + seqlen 131 | valenc.append(tmp.input_ids[:, i:j]) 132 | valenc = torch.hstack(valenc) 133 | 134 | return trainloader, valenc 135 | 136 | 137 | def get_ptb_new(nsamples, seed, seqlen, model): 138 | """ 139 | Generates training and testing data loaders for the Penn Treebank dataset using a specified model tokenizer. 140 | 141 | Args: 142 | nsamples (int): Number of samples to generate in the training loader. 143 | seed (int): Random seed for reproducibility of sample selection. 144 | seqlen (int): Sequence length of each sample in the training data. 145 | model (str): Model identifier for the tokenizer (e.g., a Hugging Face model name). 146 | 147 | Returns: 148 | tuple: A tuple containing the training loader (list of tuples with input IDs and target IDs) and 149 | the tokenized test data. 150 | """ 151 | print(Style.BRIGHT + Fore.CYAN + "Info: get_ptb_new") 152 | traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train') 153 | testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test') 154 | 155 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True) 156 | 157 | trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt") 158 | testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt") 159 | 160 | random.seed(seed) 161 | trainloader = [] 162 | for _ in range(nsamples): 163 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 164 | j = i + seqlen 165 | inp = trainenc.input_ids[:, i:j] 166 | tar = inp.clone() 167 | tar[:, :-1] = -100 168 | trainloader.append((inp, tar)) 169 | 170 | return trainloader, testenc 171 | 172 | 173 | def get_c4_new(nsamples, seed, seqlen, model): 174 | """ 175 | Load and preprocess training and validation datasets from C4 dataset, and tokenize the data. 176 | 177 | Args: 178 | nsamples (int): Number of samples to process for the training data. 179 | seed (int): Random seed for reproducibility of sample selection. 180 | seqlen (int): Length of the sequence for each input/output example. 181 | model (str): Model identifier for the tokenizer, specifying which pretrained model to use. 182 | 183 | Returns: 184 | tuple: A tuple containing two elements: 185 | - trainloader (list of tuples): A list where each tuple contains input ids and target tensors for training. 186 | - valenc (torch.Tensor): A tensor containing the tokenized validation data. 187 | """ 188 | print(Style.BRIGHT + Fore.CYAN + "Info: get_c4_new") 189 | traindata = load_dataset( 190 | 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train' 191 | ) 192 | valdata = load_dataset( 193 | 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation' 194 | ) 195 | 196 | tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, trust_remote_code=True) 197 | 198 | random.seed(seed) 199 | trainloader = [] 200 | for _ in range(nsamples): 201 | while True: 202 | i = random.randint(0, len(traindata) - 1) 203 | trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") 204 | if trainenc.input_ids.shape[1] >= seqlen: 205 | break 206 | i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) 207 | j = i + seqlen 208 | inp = trainenc.input_ids[:, i:j] 209 | tar = inp.clone() 210 | tar[:, :-1] = -100 211 | trainloader.append((inp, tar)) 212 | 213 | valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt") 214 | valenc = valenc.input_ids[:, : (256 * seqlen)] 215 | 216 | return trainloader, valenc 217 | 218 | 219 | def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=''): 220 | """ 221 | Retrieves data loaders for different datasets based on the dataset name. 222 | 223 | Args: 224 | name (str): The name of the dataset to load, which can include specific versions like 'new'. 225 | nsamples (int): The number of samples to retrieve from the dataset. 226 | seed (int): The random seed to ensure reproducibility. 227 | seqlen (int): The sequence length of the samples. 228 | model (str): The model specification that might influence data preprocessing. 229 | 230 | Returns: 231 | DataLoader: A configured data loader for the specified dataset. 232 | 233 | Raises: 234 | ValueError: If the dataset name is not recognized or supported. 235 | """ 236 | if 'wikitext2' in name: 237 | return get_wikitext2(nsamples, seed, seqlen, model) 238 | elif 'ptb' in name: 239 | if 'new' in name: 240 | return get_ptb_new(nsamples, seed, seqlen, model) 241 | else: 242 | return get_ptb(nsamples, seed, seqlen, model) 243 | elif 'c4' in name: 244 | if 'new' in name: 245 | return get_c4_new(nsamples, seed, seqlen, model) 246 | else: 247 | return get_c4(nsamples, seed, seqlen, model) 248 | else: 249 | raise ValueError(f"Only support wikitext2, c4, c4_new, ptb, ptb_new currently, but get {name}") 250 | -------------------------------------------------------------------------------- /green_bit_llm/evaluation/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import sys 4 | import warnings 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | from lm_eval import evaluator 9 | from peft import PeftModel, LoraConfig, get_peft_model 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | from green_bit_llm.evaluation.lmclass import LMClass 15 | from green_bit_llm.evaluation.utils import create_logger, add_dict_to_json_file 16 | from green_bit_llm.evaluation.datautils import get_loaders 17 | from green_bit_llm.common import load 18 | from green_bit_llm.sft.peft_utils.model import * 19 | 20 | warnings.filterwarnings('ignore') 21 | 22 | # default value for arguments 23 | DEFAULT_MODEL_PATH = "GreenBitAI/Qwen-1.5-1.8B-layer-mix-bpw-2.2" 24 | DEFAULT_SEQLEN = 2048 25 | DEFAULT_RANDOM_SEED = 0 26 | DTYPE = torch.half 27 | 28 | replace_peft_lora_model_with_gba_lora_model() 29 | 30 | 31 | @torch.no_grad() 32 | def lm_evaluate(lm, args, logger): 33 | """ 34 | Evaluates the language model (lm) according to the specified evaluation parameters in args. 35 | This function handles both perplexity evaluation on various datasets and few-shot learning evaluation. 36 | 37 | Parameters: 38 | lm: The language model object configured for evaluation. 39 | args: An object containing all necessary parameters and flags for evaluation, such as which datasets to use, 40 | whether to evaluate perplexity or few-shot performance, sequence length, etc. 41 | logger: A logging object used to record evaluation results and other important messages. 42 | 43 | Returns: 44 | A dictionary containing evaluation results, including perplexity values and few-shot evaluation metrics, 45 | keyed by dataset or task name. 46 | """ 47 | results = {} 48 | lm.model = lm.model.to(lm.device) 49 | 50 | if args.eval_ppl: 51 | for dataset in args.ppl_tasks.split(","): 52 | dataloader, testloader = get_loaders( 53 | dataset, 54 | seed=args.seed, 55 | model=args.model, 56 | seqlen=args.seqlen, 57 | ) 58 | 59 | if "c4" in dataset: 60 | testenc = testloader 61 | else: 62 | testenc = testloader.input_ids 63 | 64 | nsamples = testenc.numel() // lm.seqlen 65 | use_cache = lm.model.config.use_cache 66 | lm.model.config.use_cache = False 67 | lm.model.eval() 68 | nlls = [] 69 | 70 | for i in tqdm(range(nsamples)): 71 | batch = testenc[:, (i * lm.seqlen): ((i + 1) * lm.seqlen)].to(lm.device) 72 | with torch.no_grad(): 73 | outputs = lm.model.model(batch) 74 | hidden_states = outputs[0] 75 | logits = lm.model.lm_head(hidden_states) 76 | shift_logits = logits[:, :-1, :] 77 | shift_labels = testenc[:, (i * lm.seqlen): ((i + 1) * lm.seqlen)][ 78 | :, 1: 79 | ].to(lm.model.lm_head.weight.device) 80 | loss_fct = nn.CrossEntropyLoss() 81 | loss = loss_fct( 82 | shift_logits.view(-1, shift_logits.size(-1)), 83 | shift_labels.view(-1), 84 | ) 85 | neg_log_likelihood = loss.float() * lm.seqlen 86 | nlls.append(neg_log_likelihood) 87 | 88 | ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * lm.seqlen)) 89 | logger.info(f'{dataset} : {ppl.item()}') 90 | lm.model.config.use_cache = use_cache 91 | results[dataset] = ppl.item() 92 | 93 | if args.eval_few_shot and args.eval_few_shot != "": 94 | few_shot_tasks = args.few_shot_tasks.split(",") 95 | 96 | eval_results = evaluator.simple_evaluate( 97 | lm, 98 | tasks=few_shot_tasks, 99 | batch_size=args.batch_size, 100 | num_fewshot=args.num_fewshot, 101 | no_cache=True 102 | ) 103 | 104 | results.update({"results": eval_results["results"]}) 105 | results.update({"versions": eval_results["versions"]}) 106 | logger.info(evaluator.make_table(results)) 107 | 108 | return results 109 | 110 | def setup_arg_parser(): 111 | """Set up and return the argument parser.""" 112 | parser = argparse.ArgumentParser(description="green-bit-llm evaluate script") 113 | parser.add_argument( 114 | "--seed", 115 | type=int, 116 | default=DEFAULT_RANDOM_SEED, 117 | help="The random seed for data loader.", 118 | ) 119 | parser.add_argument( 120 | "--model", 121 | type=str, 122 | default=DEFAULT_MODEL_PATH, 123 | help="The path to the local model directory or Hugging Face repo.", 124 | ) 125 | parser.add_argument( 126 | "--cuda-device-id", 127 | type=str, 128 | default="0", 129 | help="CUDA device IDs.", 130 | ) 131 | parser.add_argument( 132 | "--trust-remote-code", 133 | action="store_true", 134 | help="Enable trusting remote code for tokenizer.", 135 | ) 136 | parser.add_argument( 137 | "--use-flash-attention-2", 138 | action="store_true", 139 | help="Enable using flash attention v2.", 140 | ) 141 | parser.add_argument( 142 | "--eos-token", 143 | type=str, 144 | default="<|im_end|>", 145 | help="End of sequence token for tokenizer.", 146 | ) 147 | parser.add_argument( 148 | "--seqlen", 149 | type=int, 150 | default=DEFAULT_SEQLEN, 151 | help="Sequence length." 152 | ) 153 | parser.add_argument( 154 | "--eval-ppl", 155 | action="store_true", 156 | help="Evaluate LLM prediction perplexity.", 157 | ) 158 | parser.add_argument( 159 | "--ppl-tasks", 160 | type=str, 161 | default="wikitext2, c4_new, ptb", 162 | help="Specify ppl evaluation task.", 163 | ) 164 | parser.add_argument( 165 | "--eval-few-shot", 166 | action="store_true", 167 | help="Evaluate LLM few-shot learning ability.", 168 | ) 169 | parser.add_argument( 170 | "--num-fewshot", 171 | type=int, 172 | default=0, 173 | help="Specify num of few shot examples for evaluation.", 174 | ) 175 | parser.add_argument( 176 | "--few-shot-tasks", 177 | type=str, 178 | default="openbookqa,arc_easy,winogrande,hellaswag,arc_challenge,piqa,boolq,race,anli_r1,anli_r2,anli_r3,wic", 179 | help="Few-shot learning ability evaluation tasks.", 180 | ) 181 | parser.add_argument( 182 | "--batch-size", 183 | type=int, 184 | default=16, 185 | help="Batch size for few-shot evaluation.", 186 | ) 187 | parser.add_argument( 188 | "--save-dir", 189 | type=str, 190 | default="log/", 191 | help="Specify save dir for eval results.", 192 | ) 193 | parser.add_argument( 194 | "--lora-dir", 195 | type=str, 196 | default=None, 197 | help="Specify lora dir for lora merge" 198 | 199 | ) 200 | return parser 201 | 202 | 203 | def create_device_map(cuda_device_id): 204 | ids = cuda_device_id.split(',') 205 | # Create strings in the format "cuda:x" for each ID and put them into the collection 206 | device_map = {f"cuda:{id}" for id in ids} 207 | return device_map 208 | 209 | def main(args): 210 | if not os.path.exists(Path(args.save_dir)): 211 | os.mkdir(Path(args.save_dir)) 212 | 213 | # Building configs 214 | tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} 215 | 216 | pretrain_model_config = { 217 | "trust_remote_code": True if args.trust_remote_code else None, 218 | "attn_implementation": "flash_attention_2" if args.use_flash_attention_2 else None 219 | } 220 | 221 | if args.eos_token is not None: 222 | tokenizer_config["eos_token"] = args.eos_token 223 | 224 | model, tokenizer, config = load( 225 | args.model, 226 | tokenizer_config=tokenizer_config, 227 | dtype=DTYPE, 228 | device_map='auto', 229 | seqlen=args.seqlen, 230 | model_config=pretrain_model_config, 231 | requires_grad=False 232 | ) 233 | 234 | if args.lora_dir is not None: 235 | config = LoraConfig( 236 | r=64, 237 | lora_alpha=32, 238 | target_modules=["q_proj", "v_proj", "out_proj", "up_proj"], 239 | lora_dropout=0.01, 240 | bias="none", 241 | task_type="CAUSAL_LM", 242 | ) 243 | model = get_peft_model(model,config) 244 | model.load_adapter(args.lora_dir, adapter_name="default") 245 | 246 | lm = LMClass(args.model, batch_size=args.batch_size, config=config, tokenizer=tokenizer, model=model) 247 | lm.seqlen = args.seqlen 248 | 249 | logger = create_logger(Path(args.save_dir)) 250 | 251 | with torch.no_grad(): 252 | eval_results = lm_evaluate(lm=lm, args=args, logger=logger) 253 | 254 | eval_results = {"{}".format(args.model): eval_results} 255 | 256 | add_dict_to_json_file(file_path="{}".format(os.path.join(args.save_dir, "eval_results.json")), new_data=eval_results) 257 | 258 | if __name__ == "__main__": 259 | if not torch.cuda.is_available(): 260 | print("Warning: CUDA is needed to run the model.") 261 | sys.exit(0) 262 | 263 | parser = setup_arg_parser() 264 | args = parser.parse_args() 265 | 266 | main(args) 267 | -------------------------------------------------------------------------------- /green_bit_llm/evaluation/lmclass.py: -------------------------------------------------------------------------------- 1 | from lm_eval.base import BaseLM 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | from colorama import init, Fore, Style 6 | init(autoreset=True) 7 | 8 | class LMClass(BaseLM): 9 | """ 10 | Wraps a pretrained language model into a format suitable for evaluation tasks. 11 | This class adapts a given model to be used with specific language modeling evaluation tools. 12 | 13 | Args: 14 | model_name (str): Name of the language model. 15 | batch_size (int): Batch size per GPU for processing. 16 | config (dict): Configuration settings for the model. 17 | tokenizer: Tokenizer associated with the model for text processing. 18 | model (nn.Module): The pretrained neural network model. 19 | """ 20 | def __init__(self, model_name, batch_size, config, tokenizer, model): 21 | # Initializes the model wrapper class with specified model and configuration 22 | super().__init__() 23 | 24 | self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 25 | self.model_name = model_name 26 | self.batch_size_per_gpu = batch_size 27 | 28 | self.model_config = config 29 | self.tokenizer = tokenizer 30 | self.model = model 31 | self.initial() 32 | 33 | def initial(self): 34 | # Initializes the model for inference, setting up necessary parameters such as sequence length and vocab size 35 | self.seqlen = self.model.config.max_position_embeddings 36 | self.model.eval() 37 | self.vocab_size = self.tokenizer.vocab_size 38 | print(Style.BRIGHT + Fore.CYAN + "Info: vocab size: ", self.vocab_size) 39 | 40 | @property 41 | def eot_token(self) -> str: 42 | # Returns the end-of-text token as a string 43 | return self.tokenizer.eos_token 44 | 45 | @property 46 | def eot_token_id(self): 47 | # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* 48 | return self.tokenizer.eos_token_id 49 | 50 | @property 51 | def max_length(self): 52 | # Returns the maximum length of sequences the model can handle, based on the model's configuration 53 | try: 54 | return self.gpt2.config.n_ctx 55 | except AttributeError: 56 | # gptneoconfig doesn't have n_ctx apparently 57 | return self.model.config.max_position_embeddings 58 | 59 | @property 60 | def max_gen_toks(self): 61 | # Returns the maximum number of tokens that can be generated in one go 62 | print(Style.BRIGHT + Fore.CYAN + "Info: max_gen_toks fn") 63 | return 256 64 | 65 | @property 66 | def batch_size(self): 67 | # Returns the configured batch size per GPU 68 | # TODO: fix multi-gpu 69 | return self.batch_size_per_gpu # * gpus 70 | 71 | @property 72 | def device(self): 73 | # Returns the computing device (CPU or GPU) that the model is using 74 | # TODO: fix multi-gpu 75 | return self._device 76 | 77 | def tok_encode(self, string: str): 78 | # Encodes a string into its corresponding IDs using the tokenizer 79 | return self.tokenizer.encode(string, add_special_tokens=False) 80 | 81 | def tok_encode_batch(self, strings): 82 | # Encodes a batch of strings into model inputs, handling padding and special tokens 83 | return self.tokenizer( 84 | strings, 85 | padding=True, 86 | add_special_tokens=False, 87 | return_tensors="pt", 88 | ) 89 | 90 | def tok_decode(self, tokens): 91 | # Decodes a batch of token IDs back into strings 92 | return self.tokenizer.batch_decode(tokens, skip_special_tokens=True) 93 | 94 | def _model_call(self, inps): 95 | """ 96 | Performs a forward pass through the model with the provided inputs and returns logits 97 | 98 | Args: 99 | inps: a torch tensor of shape [batch, sequence] 100 | the size of sequence may vary from call to call 101 | returns: a torch tensor of shape [batch, sequence, vocab] with the 102 | logits returned from the model 103 | """ 104 | with torch.no_grad(): 105 | return self.model(inps)["logits"] 106 | 107 | def model_batched_set(self, inps): 108 | # Processes a set of inputs in batches and returns a list of logit tensors 109 | dataset_logits = [] 110 | for batch in inps: 111 | multi_logits = F.log_softmax( 112 | self._model_call(batch), dim=-1 113 | ).cpu() # [batch, padding_length, vocab] 114 | dataset_logits.append(multi_logits) 115 | return dataset_logits 116 | 117 | def _model_generate(self, context, max_length, eos_token_id): 118 | # Generates text based on a given context up to a specified maximum length 119 | return self.model.generate( 120 | context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False 121 | ) 122 | -------------------------------------------------------------------------------- /green_bit_llm/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import time 5 | import logging 6 | from termcolor import colored 7 | 8 | 9 | def pattern_match(patterns, source_list): 10 | """ 11 | Function to find unique matching patterns from a source list. 12 | Args: 13 | patterns: list of pattern strings to match. 14 | Returns: 15 | list: list of task name 16 | """ 17 | task_names = set() 18 | for pattern in patterns: 19 | task_names.add(pattern) 20 | return list(task_names) 21 | 22 | def add_dict_to_json_file(file_path, new_data): 23 | """ 24 | Update a JSON file based on the top-level keys in new_data. If the key exists, it replaces the existing content 25 | under that key. If it doesn't exist, it adds the new key with its value. 26 | 27 | Args: 28 | file_path: Path to the JSON file. 29 | new_data: Dictionary to add or update in the file. 30 | """ 31 | # Initialize or load existing data 32 | if os.path.exists(file_path) and os.stat(file_path).st_size > 0: 33 | with open(file_path, 'r') as file: 34 | existing_data = json.load(file) 35 | else: 36 | existing_data = {} 37 | 38 | # Merge new data into existing data based on the top-level keys 39 | existing_data.update(new_data) 40 | 41 | # Write the updated data back to the file 42 | with open(file_path, 'w') as file: 43 | json.dump(existing_data, file, indent=4) 44 | 45 | def create_logger(output_dir, dist_rank=0, name=''): 46 | """ 47 | Creates and configures a logger with console and file handlers. 48 | 49 | Args: 50 | output_dir (str): Directory where the log files will be stored. 51 | dist_rank (int): Rank of the process in distributed training. Only the master process (rank 0) will output to the console. 52 | name (str): Name of the logger, used to differentiate output if multiple loggers are used. 53 | 54 | Returns: 55 | logger: Configured logger object. 56 | """ 57 | # create logger 58 | logger = logging.getLogger(name) 59 | logger.setLevel(logging.INFO) 60 | logger.propagate = False 61 | 62 | # create formatter 63 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 64 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 65 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 66 | 67 | # create console handlers for master process 68 | if dist_rank == 0: 69 | console_handler = logging.StreamHandler(sys.stdout) 70 | console_handler.setLevel(logging.DEBUG) 71 | console_handler.setFormatter( 72 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 73 | logger.addHandler(console_handler) 74 | 75 | # create file handlers 76 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}_{int(time.time())}.txt'), mode='a') 77 | file_handler.setLevel(logging.DEBUG) 78 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 79 | logger.addHandler(file_handler) 80 | 81 | return logger 82 | 83 | -------------------------------------------------------------------------------- /green_bit_llm/inference/README.md: -------------------------------------------------------------------------------- 1 | # Inference Package for GreenBitAI's Low-bit LLMs 2 | 3 | ## Overview 4 | 5 | This package demonstrates the capabilities of [GreenBitAI's low-bit large language models (LLMs)](https://huggingface.co/GreenBitAI) through two main features: 6 | 1. Simple generation with `sim_gen.py` script. 7 | 2. A command-line interface (CLI) based chat demo using the `chat_cli.py` script. 8 | 9 | Both tools are designed for efficient natural language processing, enabling quick setups and responses using local models. 10 | 11 | ## Installation 12 | 13 | Please follow the [main installation instructions](../../README.md#installation) for how to install the packages required to run this inference package. 14 | Further packages should not be required. 15 | 16 | ## Usage 17 | 18 | ### LLMs 19 | 20 | We have released over 200 highly precise 2.2/2.5/3/4-bit models across the modern LLM family, featuring LLaMA 2/3, 01-Yi, Qwen, Mistral, Phi-3, and more. 21 | 22 | | Family | Bpw | Size | HF collection id | 23 | |:----------------:|:------------------:|:------------------------------:|:-----------------------------------------------------------------------------------------------------------------:| 24 | | Llama-3 | `4.0/3.0/2.5/2.2` | `8B/70B` | [`GreenBitAI Llama-3`](https://huggingface.co/collections/GreenBitAI/greenbitai-llama-3-6627bc1ec6538e3922c5d81c) | 25 | | Llama-2 | `3.0/2.5/2.2` | `7B/13B/70B` | [`GreenBitAI Llama-2`](https://huggingface.co/collections/GreenBitAI/greenbitai-llama-2-661f87e3b073ff8e48a12834) | 26 | | Qwen-1.5 | `4.0/3.0/2.5/2.2` | `0.5B/1.8B/4B/7B/14B/32B/110B` | [`GreenBitAI Qwen 1.5`](https://huggingface.co/collections/GreenBitAI/greenbitai-qwen15-661f86ea69433f3d3062c920) | 27 | | Phi-3 | `3.0/2.5/2.2` | `mini` | [`GreenBitAI Phi-3`](https://huggingface.co/collections/GreenBitAI/greenbitai-phi-3-6628d008cdf168398a296c92) | 28 | | Mistral | `3.0/2.5/2.2` | `7B` | [`GreenBitAI Mistral`](https://huggingface.co/collections/GreenBitAI/greenbitai-mistral-661f896c45da9d8b28a193a8) | 29 | | 01-Yi | `3.0/2.5/2.2` | `6B/34B` | [`GreenBitAI 01-Yi`](https://huggingface.co/collections/GreenBitAI/greenbitai-01-yi-661f88af0648daa766d5102f) | 30 | | Llama-3-instruct | `4.0/3.0/2.5/2.2` | `8B/70B` | [`GreenBitAI Llama-3`](https://huggingface.co/collections/GreenBitAI/greenbitai-llama-3-6627bc1ec6538e3922c5d81c) | 31 | | Mistral-instruct | `3.0/2.5/2.2` | `7B` | [`GreenBitAI Mistral`](https://huggingface.co/collections/GreenBitAI/greenbitai-mistral-661f896c45da9d8b28a193a8) | 32 | | Phi-3-instruct | `3.0/2.5/2.2` | `mini` | [`GreenBitAI Phi-3`](https://huggingface.co/collections/GreenBitAI/greenbitai-phi-3-6628d008cdf168398a296c92) | 33 | | Qwen-1.5-Chat | `4.0/3.0/2.5/2.2` | `0.5B/1.8B/4B/7B/14B/32B/110B` | [`GreenBitAI Qwen 1.5`](https://huggingface.co/collections/GreenBitAI/greenbitai-qwen15-661f86ea69433f3d3062c920) | 34 | | 01-Yi-Chat | `3.0/2.5/2.2` | `6B/34B` | [`GreenBitAI 01-Yi`](https://huggingface.co/collections/GreenBitAI/greenbitai-01-yi-661f88af0648daa766d5102f) | 35 | 36 | 37 | ### Simple Generation 38 | 39 | Run the simple generation script as follows: 40 | 41 | ```bash 42 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.inference.sim_gen --model GreenBitAI/Qwen-1.5-1.8B-layer-mix-bpw-3.0 --max-tokens 100 --use-flash-attention-2 --ignore-chat-template --prompt "The meaning of life is" 43 | ``` 44 | 45 | This command generates text based on the provided prompt using the specified GreenBitAI model. 46 | 47 | ### CLI-Based Chat Demo 48 | 49 | To start the chat interface: 50 | 51 | ```bash 52 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.inference.chat_cli --model GreenBitAI/Qwen-1.5-1.8B-Chat-layer-mix-bpw-2.2 --use-flash-attention-2 --multiline --mouse 53 | ``` 54 | This launches a rich command-line interface for interactive chatting. 55 | 56 | ## License 57 | - The scripts `conversation.py`, `chat_base.py`, and `chat_cli.py` have been modified from their original versions found in [FastChat-serve](https://github.com/lm-sys/FastChat/tree/main/fastchat/serve), which are released under the [Apache 2.0 License](https://github.com/lm-sys/FastChat/tree/main/LICENSE). 58 | - We release our changes and additions to these files under the [Apache 2.0 License](../../LICENSE). 59 | -------------------------------------------------------------------------------- /green_bit_llm/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/green_bit_llm/inference/__init__.py -------------------------------------------------------------------------------- /green_bit_llm/inference/chat_cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | cli chat demo. 3 | Code based on: https://github.com/yanghaojin/FastChat/blob/greenbit/fastchat/serve/cli.py 4 | """ 5 | import argparse 6 | import os 7 | import warnings 8 | warnings.filterwarnings("ignore", category=UserWarning, module='torch.nn.modules.module') 9 | 10 | import torch 11 | 12 | # Add the parent directory to sys.path 13 | from green_bit_llm.inference.utils import str_to_torch_dtype 14 | 15 | try: 16 | from green_bit_llm.inference.chat_base import chat_loop, SimpleChatIO, RichChatIO 17 | except Exception: 18 | raise Exception("Error occurred when import chat loop, ChatIO classes.") 19 | 20 | 21 | def main(args): 22 | # args setup 23 | if args.gpus: 24 | if len(args.gpus.split(",")) < args.num_gpus: 25 | raise ValueError( 26 | f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" 27 | ) 28 | # NOTE that we need to set this before any other cuda operations. Otherwise will not work. 29 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 30 | 31 | if not torch.cuda.is_available(): 32 | raise Exception("Warning: CUDA is needed to run the model.") 33 | 34 | if args.style == "simple": 35 | chatio = SimpleChatIO(args.multiline) 36 | elif args.style == "rich": 37 | chatio = RichChatIO(args.multiline, args.mouse) 38 | else: 39 | raise ValueError(f"Invalid style for console: {args.style}") 40 | 41 | # Building configs 42 | tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} 43 | pretrain_model_config = { 44 | "trust_remote_code": True if args.trust_remote_code else None, 45 | "attn_implementation": "flash_attention_2" if args.use_flash_attention_2 else None 46 | } 47 | 48 | if args.eos_token is not None: 49 | tokenizer_config["eos_token"] = args.eos_token 50 | 51 | # chat 52 | try: 53 | chat_loop( 54 | args.model, 55 | tokenizer_config, 56 | pretrain_model_config, 57 | args.seqlen, 58 | args.device, 59 | str_to_torch_dtype(args.dtype), 60 | args.conv_template, 61 | args.conv_system_msg, 62 | args.temperature, 63 | args.repetition_penalty, 64 | args.max_new_tokens, 65 | chatio, 66 | judge_sent_end=args.judge_sent_end, 67 | debug=args.debug, 68 | history=not args.no_history, 69 | ) 70 | except KeyboardInterrupt: 71 | print("exit...") 72 | 73 | 74 | if __name__ == "__main__": 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument( 77 | "--model", 78 | type=str, 79 | default="GreenBitAI/Mistral-Instruct-7B-v0.2-layer-mix-bpw-2.2", 80 | help="The path to the weights. This can be a local folder or a Hugging Face repo ID.", 81 | ) 82 | parser.add_argument( 83 | "--device", 84 | type=str, 85 | choices=["cpu", "cuda"], 86 | default="cuda", 87 | help="The device type", 88 | ) 89 | parser.add_argument( 90 | "--gpus", 91 | type=str, 92 | default='0', 93 | help="A single GPU like 1 or multiple GPUs like 0,2", 94 | ) 95 | parser.add_argument("--num-gpus", type=int, default=1) 96 | parser.add_argument( 97 | "--dtype", 98 | type=str, 99 | choices=["float32", "float16"], 100 | help="Override the default dtype. If not set, it will use float16 on GPU and float32 on CPU.", 101 | default="float16", 102 | ) 103 | parser.add_argument( 104 | "--conv-template", type=str, default=None, help="Conversation prompt template." 105 | ) 106 | parser.add_argument( 107 | "--conv-system-msg", type=str, default=None, help="Conversation system message." 108 | ) 109 | parser.add_argument("--temperature", type=float, default=0.7) 110 | parser.add_argument("--repetition_penalty", type=float, default=1.0) 111 | parser.add_argument("--max-new-tokens", type=int, default=256) 112 | parser.add_argument("--no-history", action="store_true", help="Disables chat history.") 113 | parser.add_argument( 114 | "--style", 115 | type=str, 116 | default="rich", 117 | choices=["simple", "rich"], 118 | help="Display style.", 119 | ) 120 | parser.add_argument( 121 | "--multiline", 122 | action="store_true", 123 | help="Enable multiline input. Use ESC+Enter for newline.", 124 | ) 125 | parser.add_argument( 126 | "--mouse", 127 | action="store_true", 128 | help="[Rich Style]: Enable mouse support for cursor positioning.", 129 | ) 130 | parser.add_argument( 131 | "--judge-sent-end", 132 | action="store_true", 133 | help="Whether enable the correction logic that interrupts the output of sentences due to EOS.", 134 | ) 135 | parser.add_argument( 136 | "--debug", 137 | action="store_true", 138 | help="Print useful debug information (e.g., prompts)", 139 | ) 140 | parser.add_argument( 141 | "--eos-token", 142 | type=str, 143 | default=None, 144 | help="End of sequence token for tokenizer", 145 | ) 146 | parser.add_argument( 147 | "--trust-remote-code", 148 | action="store_true", 149 | help="Enable trusting remote code", 150 | ) 151 | parser.add_argument( 152 | "--seqlen", type=int, default=2048, help="Sequence length" 153 | ) 154 | parser.add_argument( 155 | "--use-flash-attention-2", 156 | action="store_true", 157 | help="Enable using flash attention v2", 158 | ) 159 | args = parser.parse_args() 160 | main(args) -------------------------------------------------------------------------------- /green_bit_llm/inference/sim_gen.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import warnings 7 | warnings.filterwarnings("ignore", category=UserWarning, module='torch.nn.modules.module') 8 | 9 | from transformers import PreTrainedTokenizer 10 | 11 | from green_bit_llm.common import generate, load 12 | from green_bit_llm.args_parser import setup_shared_arg_parser 13 | 14 | # default value for arguments 15 | DEFAULT_PROMPT = None 16 | DEFAULT_MAX_TOKENS = 100 17 | DEFAULT_TEMP = 0.8 18 | DEFAULT_TOP_P = 0.95 19 | DTYPE = torch.half 20 | 21 | 22 | def setup_arg_parser(): 23 | """Set up and return the argument parser.""" 24 | parser = setup_shared_arg_parser("green-bit-llm inference script") 25 | 26 | parser.add_argument("--num-gpus", type=int, default=1) 27 | parser.add_argument( 28 | "--gpus", 29 | type=str, 30 | default='0', 31 | help="A single GPU like 1 or multiple GPUs like 0,2", 32 | ) 33 | parser.add_argument( 34 | "--eos-token", 35 | type=str, 36 | default=None, 37 | help="End of sequence token for tokenizer", 38 | ) 39 | parser.add_argument( 40 | "--max-tokens", 41 | type=int, 42 | default=DEFAULT_MAX_TOKENS, 43 | help="Maximum number of tokens to generate", 44 | ) 45 | parser.add_argument( 46 | "--prompt", default=DEFAULT_PROMPT, help="Message to be processed by the model" 47 | ) 48 | parser.add_argument( 49 | "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" 50 | ) 51 | parser.add_argument( 52 | "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" 53 | ) 54 | parser.add_argument( 55 | "--ignore-chat-template", 56 | action="store_true", 57 | help="Use the raw prompt without the tokenizer's chat template.", 58 | ) 59 | parser.add_argument( 60 | "--use-default-chat-template", 61 | action="store_true", 62 | help="Use the default chat template", 63 | ) 64 | parser.add_argument( 65 | "--enable-thinking", 66 | action="store_true", 67 | help="Enable thinking mode for Qwen-3 models.", 68 | ) 69 | return parser 70 | 71 | 72 | def do_generate(args, model: nn.Module, tokenizer: PreTrainedTokenizer, prompt: str, enable_thinking: bool): 73 | """ 74 | This function generates text based on a given prompt using a model and tokenizer. 75 | It handles optional pre-processing with chat templates if specified in the arguments. 76 | """ 77 | if not args.ignore_chat_template and ( 78 | hasattr(tokenizer, "apply_chat_template") 79 | and tokenizer.chat_template is not None 80 | ): 81 | messages = [{"role": "user", "content": prompt}] 82 | prompt = tokenizer.apply_chat_template( 83 | messages, 84 | tokenize=False, 85 | add_generation_prompt=True, 86 | enable_thinking=enable_thinking 87 | ) 88 | else: 89 | prompt = prompt 90 | 91 | generate( 92 | model, 93 | tokenizer, 94 | prompt, 95 | args.temp, 96 | args.max_tokens, 97 | True, 98 | top_p=args.top_p 99 | ) 100 | 101 | 102 | def main(args): 103 | 104 | if args.gpus: 105 | if len(args.gpus.split(",")) < args.num_gpus: 106 | raise ValueError( 107 | f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" 108 | ) 109 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 110 | 111 | if not torch.cuda.is_available(): 112 | raise Exception("Warning: CUDA is needed to run the model.") 113 | 114 | # Building configs 115 | tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} 116 | pretrain_model_config = { 117 | "trust_remote_code": True if args.trust_remote_code else None, 118 | "attn_implementation": "flash_attention_2" if args.use_flash_attention_2 else None 119 | } 120 | 121 | if args.eos_token is not None: 122 | tokenizer_config["eos_token"] = args.eos_token 123 | 124 | model, tokenizer, config = load( 125 | args.model, 126 | tokenizer_config=tokenizer_config, 127 | dtype=DTYPE, 128 | device_map='auto', 129 | seqlen=args.seqlen, 130 | model_config=pretrain_model_config, 131 | requires_grad=False 132 | ) 133 | 134 | if args.use_default_chat_template: 135 | if tokenizer.chat_template is None: 136 | tokenizer.chat_template = tokenizer.default_chat_template 137 | 138 | if args.prompt is None: 139 | while True: 140 | user_input = input("Input prompt or type 'exit' to quit): ") 141 | if user_input.lower() in ['exit', 'quit']: 142 | break 143 | do_generate(args, model, tokenizer, user_input, args.enable_thinking) 144 | else: 145 | do_generate(args, model, tokenizer, args.prompt, args.enable_thinking) 146 | 147 | 148 | if __name__ == "__main__": 149 | parser = setup_arg_parser() 150 | args = parser.parse_args() 151 | 152 | main(args) -------------------------------------------------------------------------------- /green_bit_llm/inference/utils.py: -------------------------------------------------------------------------------- 1 | from transformers.generation.logits_process import ( 2 | LogitsProcessorList, 3 | RepetitionPenaltyLogitsProcessor, 4 | TemperatureLogitsWarper, 5 | TopKLogitsWarper, 6 | TopPLogitsWarper, 7 | ) 8 | 9 | from .conversation import Conversation, get_conv_template 10 | 11 | # value is the search term in model name, 12 | # key is the name of conversation template 13 | CONV_TEMP_DICT = { 14 | "llama-2": "llama-2", 15 | "qwen-chat": "qwen", 16 | "yi-chat": "yi-", 17 | "mistral": "mistral", 18 | "gemma": "gemma", 19 | "llama-3": "llama-3", 20 | "phi-3": "phi-3", 21 | } 22 | 23 | # Models don't use the same configuration key for determining the maximum 24 | # sequence length. Store them here so we can sanely check them. 25 | # NOTE: The ordering here is important. Some models have two of these and we 26 | # have a preference for which value gets used. 27 | SEQUENCE_LENGTH_KEYS = [ 28 | "max_position_embeddings", 29 | "max_sequence_length", 30 | "seq_length", 31 | "max_seq_len", 32 | "model_max_length", 33 | ] 34 | 35 | 36 | def is_partial_stop(output: str, stop_str: str): 37 | """Check whether the output contains a partial stop str.""" 38 | for i in range(0, min(len(output), len(stop_str))): 39 | if stop_str.startswith(output[-i:]): 40 | return True 41 | return False 42 | 43 | def is_sentence_complete(output: str): 44 | """Check whether the output is a complete sentence.""" 45 | end_symbols = (".", "?", "!", "...", "。", "?", "!", "…", '"', "'", "”") 46 | return output.endswith(end_symbols) 47 | 48 | def get_context_length(config): 49 | """Get the context length of a model from a huggingface model config.""" 50 | rope_scaling = getattr(config, "rope_scaling", None) 51 | if rope_scaling: 52 | try: 53 | rope_scaling_factor = config.rope_scaling["factor"] 54 | except KeyError: 55 | rope_scaling_factor = 1 56 | else: 57 | rope_scaling_factor = 1 58 | 59 | for key in SEQUENCE_LENGTH_KEYS: 60 | val = getattr(config, key, None) 61 | if val is not None: 62 | return int(rope_scaling_factor * val) 63 | return 2048 64 | 65 | def get_conversation_template(model_path: str) -> Conversation: 66 | """Get and return a specific conversation template via checking its model path/model name.""" 67 | for key, value in CONV_TEMP_DICT.items(): 68 | # check if model path contains the value 69 | if value in model_path.lower(): 70 | return get_conv_template(key) 71 | raise Exception("Invalid model path: The provided model is not supported yet.") 72 | 73 | def prepare_logits_processor( 74 | temperature: float, repetition_penalty: float, top_p: float, top_k: int 75 | ) -> LogitsProcessorList: 76 | """ 77 | Creates and initializes a list of logits processors based on the specified parameters. 78 | Each processor applies a different modification to the logits during text generation, 79 | such as adjusting the sampling temperature, applying repetition penalties, 80 | or enforcing top-p and top-k constraints. 81 | 82 | Args: 83 | temperature (float): Scaling factor for logits; a value of 1.0 means no scaling. 84 | repetition_penalty (float): Penalty for repeated tokens to discourage repetition. 85 | top_p (float): The cumulative probability threshold for nucleus sampling, filters out the smallest probabilities. 86 | top_k (int): The number of highest probability logits to keep for top-k sampling. 87 | 88 | Returns: 89 | LogitsProcessorList: A configured list of logits processors. 90 | """ 91 | processor_list = LogitsProcessorList() 92 | # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. 93 | if temperature >= 1e-5 and temperature != 1.0: 94 | processor_list.append(TemperatureLogitsWarper(temperature)) 95 | if repetition_penalty > 1.0: 96 | processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) 97 | if 1e-8 <= top_p < 1.0: 98 | processor_list.append(TopPLogitsWarper(top_p)) 99 | if top_k > 0: 100 | processor_list.append(TopKLogitsWarper(top_k)) 101 | return processor_list 102 | 103 | def str_to_torch_dtype(dtype: str): 104 | """Get torch dtype via parsing the dtype string.""" 105 | import torch 106 | 107 | if dtype is None: 108 | return None 109 | elif dtype == "float32": 110 | return torch.float32 111 | elif dtype == "float16": 112 | return torch.half 113 | elif dtype == "bfloat16": 114 | return torch.bfloat16 115 | else: 116 | raise ValueError(f"Unrecognized dtype: {dtype}") 117 | 118 | def is_chat_model(path): 119 | """Distinguish if the input model name contains keywords like '-chat-' or '-instrct-'""" 120 | substrings = ["-chat-", "-instruct-"] 121 | return any(substring in path for substring in substrings) -------------------------------------------------------------------------------- /green_bit_llm/langchain/README.md: -------------------------------------------------------------------------------- 1 | # GreenBit Langchain Demos 2 | 3 | ## Overview 4 | 5 | GreenBit Langchain Demos showcase the integration of GreenBit language models with the Langchain framework, enabling powerful and flexible natural language processing capabilities. 6 | 7 | ## Installation 8 | 9 | ### Step 1: Install the green-bit-llm Package 10 | 11 | ```bash 12 | pip install green-bit-llm 13 | ``` 14 | 15 | ### Step 2: Install Langchain Package 16 | 17 | ```bash 18 | pip install langchain-core 19 | ``` 20 | 21 | If you want to use RAG demo, please make sure that the `sentence_transformers` python package has been installed. 22 | 23 | ```bash 24 | pip install sentence-transformers 25 | ``` 26 | 27 | Ensure your system has Python3 and pip installed before proceeding. 28 | 29 | ## Usage 30 | 31 | ### Basic Example 32 | 33 | Here's a basic example of how to use the GreenBit Langchain integration: 34 | 35 | ```python 36 | from langchain_core.messages import HumanMessage 37 | from green_bit_llm.langchain import GreenBitPipeline, ChatGreenBit 38 | import torch 39 | 40 | pipeline = GreenBitPipeline.from_model_id( 41 | model_id="GreenBitAI/Llama-3-8B-instruct-layer-mix-bpw-4.0", 42 | device="cuda:0", 43 | model_kwargs={"dtype": torch.half, "device_map": 'auto', "seqlen": 2048, "requires_grad": False}, 44 | pipeline_kwargs={"max_new_tokens": 100, "temperature": 0.7}, 45 | ) 46 | 47 | chat = ChatGreenBit(llm=pipeline) 48 | 49 | # normal generation 50 | response = chat.invoke("What is the capital of France?") 51 | print(response.content) 52 | 53 | # stream generation 54 | for chunk in chat.stream([HumanMessage(content="Tell me a story about a brave knight.")]): 55 | print(chunk.message.content, end="", flush=True) 56 | 57 | ``` 58 | -------------------------------------------------------------------------------- /green_bit_llm/langchain/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline import GreenBitPipeline 2 | from .embedding import GreenBitEmbeddings 3 | from .chat_model import ChatGreenBit -------------------------------------------------------------------------------- /green_bit_llm/langchain/chat_model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Union, Sequence, Literal, Callable, Type 2 | from pydantic import Field 3 | from langchain_core.callbacks.manager import ( 4 | AsyncCallbackManagerForLLMRun, 5 | CallbackManagerForLLMRun, 6 | ) 7 | from langchain_core.language_models.chat_models import BaseChatModel 8 | from langchain_core.messages import ( 9 | AIMessage, 10 | BaseMessage, 11 | ChatMessage, 12 | HumanMessage, 13 | SystemMessage, 14 | ) 15 | from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult 16 | from langchain_core.language_models import LanguageModelInput 17 | from langchain_core.runnables import Runnable, RunnableConfig 18 | from langchain_core.tools import BaseTool 19 | from langchain_core.utils.function_calling import convert_to_openai_tool 20 | 21 | from green_bit_llm.langchain import GreenBitPipeline 22 | 23 | DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful, and honest assistant.""" 24 | 25 | 26 | class ChatGreenBit(BaseChatModel): 27 | """GreenBit Chat model. 28 | 29 | Example: 30 | .. code-block:: python 31 | 32 | from green_bit_llm.langchain import GreenBitPipeline 33 | 34 | model_config = { 35 | "trust_remote_code": True, 36 | "attn_implementation": "flash_attention_2" 37 | } 38 | 39 | tokenizer_config = {"trust_remote_code": True} 40 | 41 | gb = GreenBitPipeline.from_model_id( 42 | model_id="GreenBitAI/Llama-3-8B-instruct-layer-mix-bpw-4.0", 43 | task="text-generation", 44 | pipeline_kwargs={"max_tokens": 100, "temp": 0.7}, 45 | model_kwargs={ 46 | "dtype": torch.half, 47 | "seqlen": 2048, 48 | "requires_grad": False, 49 | "model_config": model_config, 50 | "tokenizer_config": tokenizer_config 51 | } 52 | ) 53 | """ 54 | 55 | llm: GreenBitPipeline = Field(..., description="GreenBit Pipeline instance") 56 | 57 | class Config: 58 | """Configuration for this pydantic object.""" 59 | arbitrary_types_allowed = True 60 | 61 | def __init__( 62 | self, 63 | llm: GreenBitPipeline, 64 | **kwargs: Any, 65 | ) -> None: 66 | """Initialize the chat model. 67 | 68 | Args: 69 | llm: GreenBit Pipeline instance 70 | **kwargs: Additional keyword arguments 71 | """ 72 | # First initialize with mandatory llm field 73 | super().__init__(llm=llm, **kwargs) 74 | 75 | @property 76 | def _llm_type(self) -> str: 77 | return "greenbit-chat" 78 | 79 | def _create_chat_result(self, llm_result: LLMResult) -> ChatResult: 80 | """Convert LLM result to chat messages""" 81 | generations = [] 82 | for gen in llm_result.generations: 83 | for g in gen: 84 | message = AIMessage(content=g.text.strip()) 85 | chat_generation = ChatGeneration( 86 | message=message, 87 | generation_info=g.generation_info 88 | ) 89 | generations.append(chat_generation) 90 | return ChatResult(generations=generations) 91 | 92 | def _messages_to_dict(self, messages: List[BaseMessage]) -> List[Dict[str, str]]: 93 | """Convert LangChain messages to dictionary format for apply_chat_template""" 94 | message_dicts = [] 95 | 96 | for message in messages: 97 | if isinstance(message, SystemMessage): 98 | message_dicts.append({"role": "system", "content": message.content}) 99 | elif isinstance(message, HumanMessage): 100 | message_dicts.append({"role": "user", "content": message.content}) 101 | elif isinstance(message, AIMessage): 102 | message_dicts.append({"role": "assistant", "content": message.content}) 103 | 104 | return message_dicts 105 | 106 | def _prepare_prompt(self, messages: List[BaseMessage], **kwargs) -> str: 107 | """Prepare prompt using apply_chat_template""" 108 | if not hasattr(self.llm.pipeline, 'tokenizer'): 109 | raise ValueError("Tokenizer not available in pipeline") 110 | 111 | tokenizer = self.llm.pipeline.tokenizer 112 | if not hasattr(tokenizer, 'apply_chat_template'): 113 | raise ValueError("Tokenizer does not support apply_chat_template") 114 | 115 | # Convert messages to dict format 116 | message_dicts = self._messages_to_dict(messages) 117 | 118 | # Prepare apply_chat_template arguments 119 | template_kwargs = { 120 | "add_generation_prompt": True, 121 | "tokenize": False, # 关键:设置为 False 确保返回字符串 122 | } 123 | 124 | # Add enable_thinking for Qwen3 models if provided 125 | enable_thinking = kwargs.get('enable_thinking') 126 | if enable_thinking is not None: 127 | template_kwargs["enable_thinking"] = enable_thinking 128 | 129 | try: 130 | result = tokenizer.apply_chat_template(message_dicts, **template_kwargs) 131 | 132 | # 确保返回字符串 133 | if isinstance(result, list): 134 | # 如果仍然返回 token IDs,解码为字符串 135 | return tokenizer.decode(result, skip_special_tokens=False) 136 | elif isinstance(result, str): 137 | return result 138 | else: 139 | raise ValueError(f"Unexpected return type from apply_chat_template: {type(result)}") 140 | 141 | except Exception as e: 142 | raise ValueError(f"Failed to apply chat template: {str(e)}") 143 | 144 | def generate( 145 | self, 146 | messages: List[BaseMessage], 147 | stop: Optional[List[str]] = None, 148 | run_manager: Optional[CallbackManagerForLLMRun] = None, 149 | **kwargs: Any, 150 | ) -> ChatResult: 151 | """Generate chat completion using the underlying pipeline""" 152 | # Prepare prompt using apply_chat_template 153 | prompt = self._prepare_prompt(messages, **kwargs) 154 | 155 | # Handle generation parameters 156 | generation_kwargs = {} 157 | if "temperature" in kwargs: 158 | generation_kwargs["temperature"] = kwargs["temperature"] 159 | if "max_new_tokens" in kwargs: 160 | generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"] 161 | elif "max_tokens" in kwargs: 162 | generation_kwargs["max_new_tokens"] = kwargs["max_tokens"] 163 | 164 | wrapped_kwargs = { 165 | "pipeline_kwargs": generation_kwargs, 166 | "stop": stop, 167 | **kwargs 168 | } 169 | 170 | # Generate using pipeline 171 | llm_result = self.llm.generate( 172 | prompts=[prompt], 173 | run_manager=run_manager, 174 | **wrapped_kwargs 175 | ) 176 | 177 | return self._create_chat_result(llm_result) 178 | 179 | def _generate( 180 | self, 181 | messages: List[BaseMessage], 182 | stop: Optional[List[str]] = None, 183 | run_manager: Optional[CallbackManagerForLLMRun] = None, 184 | **kwargs: Any, 185 | ) -> ChatResult: 186 | """Generate method required by BaseChatModel""" 187 | return self.generate(messages, stop, run_manager, **kwargs) 188 | 189 | def stream( 190 | self, 191 | messages: List[BaseMessage], 192 | stop: Optional[List[str]] = None, 193 | run_manager: Optional[CallbackManagerForLLMRun] = None, 194 | **kwargs: Any, 195 | ): 196 | """Stream chat completion""" 197 | # Prepare prompt using apply_chat_template 198 | prompt = self._prepare_prompt(messages, **kwargs) 199 | 200 | # Handle generation parameters 201 | generation_kwargs = {} 202 | if "temperature" in kwargs: 203 | generation_kwargs["temperature"] = kwargs["temperature"] 204 | if "max_new_tokens" in kwargs: 205 | generation_kwargs["max_new_tokens"] = kwargs["max_new_tokens"] 206 | elif "max_tokens" in kwargs: 207 | generation_kwargs["max_new_tokens"] = kwargs["max_tokens"] 208 | 209 | wrapped_kwargs = { 210 | "pipeline_kwargs": generation_kwargs, 211 | "stop": stop, 212 | "skip_prompt": kwargs.get("skip_prompt", True), 213 | **kwargs 214 | } 215 | 216 | # Stream using pipeline 217 | for chunk in self.llm.stream( 218 | prompt, 219 | run_manager=run_manager, 220 | **wrapped_kwargs 221 | ): 222 | yield ChatGeneration(message=AIMessage(content=chunk.text)) 223 | 224 | async def agenerate( 225 | self, 226 | messages: List[BaseMessage], 227 | stop: Optional[List[str]] = None, 228 | run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, 229 | **kwargs: Any, 230 | ) -> ChatResult: 231 | """Async generation (not implemented)""" 232 | raise NotImplementedError("Async generation not implemented yet") 233 | 234 | async def astream( 235 | self, 236 | messages: List[BaseMessage], 237 | stop: Optional[List[str]] = None, 238 | run_manager: Optional[CallbackManagerForLLMRun] = None, 239 | **kwargs: Any, 240 | ): 241 | """Async streaming (not implemented)""" 242 | raise NotImplementedError("Async stream generation not implemented yet") 243 | 244 | def bind_tools( 245 | self, 246 | tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]], 247 | *, 248 | tool_choice: Optional[Union[dict, str, Literal["auto", "none"], bool]] = None, 249 | **kwargs: Any, 250 | ) -> Runnable[LanguageModelInput, BaseMessage]: 251 | """Bind tools to the chat model""" 252 | formatted_tools = [convert_to_openai_tool(tool) for tool in tools] 253 | 254 | if tool_choice is not None and tool_choice: 255 | if len(formatted_tools) != 1: 256 | raise ValueError( 257 | "When specifying `tool_choice`, you must provide exactly one " 258 | f"tool. Received {len(formatted_tools)} tools." 259 | ) 260 | 261 | if isinstance(tool_choice, str): 262 | if tool_choice not in ("auto", "none"): 263 | tool_choice = { 264 | "type": "function", 265 | "function": {"name": tool_choice}, 266 | } 267 | elif isinstance(tool_choice, bool): 268 | tool_choice = formatted_tools[0] 269 | elif isinstance(tool_choice, dict): 270 | if ( 271 | formatted_tools[0]["function"]["name"] 272 | != tool_choice["function"]["name"] 273 | ): 274 | raise ValueError( 275 | f"Tool choice {tool_choice} was specified, but the only " 276 | f"provided tool was {formatted_tools[0]['function']['name']}." 277 | ) 278 | else: 279 | raise ValueError( 280 | f"Unrecognized tool_choice type. Expected str, bool or dict. " 281 | f"Received: {tool_choice}" 282 | ) 283 | 284 | kwargs["tool_choice"] = tool_choice 285 | 286 | return super().bind(tools=formatted_tools, **kwargs) -------------------------------------------------------------------------------- /green_bit_llm/langchain/embedding.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any, Optional 2 | from langchain_core.embeddings import Embeddings 3 | from pydantic import BaseModel, Field 4 | 5 | # good balance between performance and efficiency, params: 22M 6 | DEFAULT_MODEL_NAME1 = "sentence-transformers/all-MiniLM-L12-v2" 7 | # params: 110M, better NLU ability 8 | DEFAULT_MODEL_NAME2 = "sentence-transformers/all-mpnet-base-v2" 9 | # Optimized for multi-round question answering and suitable for applications 10 | # that require more complex context understanding. 11 | DEFAULT_MODEL_NAME3 = "sentence-transformers/multi-qa-MiniLM-L6-cos-v" 12 | 13 | 14 | class GreenBitEmbeddings(BaseModel, Embeddings): 15 | """GreenBit embedding model using sentence-transformers. 16 | 17 | This class provides an interface to generate embeddings using GreenBit's models, 18 | which are based on the sentence-transformers package. 19 | 20 | Attributes: 21 | model (Any): Embedding model. 22 | encode_kwargs (Dict[str, Any]): Additional keyword arguments for the encoding process. 23 | device (str): The device to use for computations (e.g., 'cuda' for GPU). 24 | 25 | Example: 26 | .. code-block:: python 27 | from reen_bit_llm.langchain import GreenBitEmbeddings 28 | 29 | embedder = GreenBitEmbeddings.from_model_id( 30 | "sentence-transformers/all-mpnet-base-v2", 31 | device="cuda", 32 | multi_process=True, 33 | model_kwargs={"cache_folder": "/path/to/cache"}, 34 | encode_kwargs={"normalize_embeddings": True} 35 | ) 36 | 37 | texts = ["Hello, world!", "This is a test."] 38 | document_embeddings = embedder.embed_documents(texts) 39 | query_embedding = embedder.embed_query("What is the meaning of life?") 40 | """ 41 | cache_dir: Optional[str] = "~/.cache/huggingface/hub" 42 | """Path to store models. 43 | Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" 44 | encode_kwargs: Dict[str, Any] = Field(default_factory=dict) 45 | """Keyword arguments to pass when calling the `encode` method of the Sentence 46 | Transformer model, such as `prompt_name`, `prompt`, `batch_size`, `precision`, 47 | `normalize_embeddings`, and more. 48 | See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode""" 49 | multi_process: bool = False 50 | """Run encode() on multiple GPUs.""" 51 | show_progress: bool = False 52 | """Whether to show a progress bar.""" 53 | device: str = "cuda" 54 | model: Any = None 55 | 56 | def __init__(self, **data): 57 | super().__init__(**data) 58 | 59 | class Config: 60 | """Configuration for this pydantic object.""" 61 | 62 | extra = "forbid" 63 | 64 | def embed_documents(self, texts: List[str]) -> List[List[float]]: 65 | """ 66 | Generate embeddings for a list of documents. 67 | 68 | Args: 69 | texts (List[str]): The list of texts to embed. 70 | 71 | Returns: 72 | List[List[float]]: The list of embeddings, one for each input text. 73 | """ 74 | import sentence_transformers 75 | texts = list(map(lambda x: x.replace("\n", " "), texts)) 76 | if self.multi_process: 77 | pool = self.model.start_multi_process_pool() 78 | embeddings = self.model.encode_multi_process(texts, pool) 79 | sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool) 80 | else: 81 | embeddings = self.model.encode( 82 | texts, show_progress_bar=self.show_progress, **self.encode_kwargs 83 | ) 84 | return embeddings.tolist() 85 | 86 | def embed_query(self, text: str) -> List[float]: 87 | """ 88 | Generate an embedding for a single query text. 89 | 90 | Args: 91 | text (str): The query text to embed. 92 | 93 | Returns: 94 | List[float]: The embedding for the input text. 95 | """ 96 | return self.embed_documents([text])[0] 97 | 98 | @classmethod 99 | def from_model_id( 100 | cls, 101 | model_name: str = DEFAULT_MODEL_NAME1, 102 | device: str = "cuda", 103 | cache_dir: Optional[str] = "", 104 | multi_process: bool = False, 105 | show_progress: bool = False, 106 | model_kwargs: Dict[str, Any] = Field(default_factory=dict), 107 | encode_kwargs: Optional[Dict[str, Any]] = Field(default_factory=dict), 108 | **kwargs 109 | ) -> "GreenBitEmbeddings": 110 | """ 111 | Create a GreenBitEmbeddings instance from a model name. 112 | 113 | Args: 114 | model_name (str): The name of the model to use. 115 | device (str): The device to use for computations (default is "cuda" for GPU). 116 | cache_dir (Optional[str]): Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable. 117 | multi_process (bool): Run encode() on multiple GPUs. 118 | show_progress (bool): Whether to show a progress bar. 119 | model_kwargs (Optional[Dict[str, Any]]): Keyword arguments to pass to the Sentence Transformer model, such as `device`, 120 | `prompts`, `default_prompt_name`, `revision`, `trust_remote_code`, or `token`. 121 | See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer 122 | encode_kwargs (Optional[Dict[str, Any]]): Keyword arguments to pass when calling the `encode` method of the SentenceTransformer model, such as `prompt_name`, `prompt`, `batch_size`, `precision`, 123 | `normalize_embeddings`, and more. See also the Sentence Transformer documentation: https://sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer.encode 124 | **kwargs: Additional keyword arguments for the GreenBitEmbeddings constructor. 125 | 126 | Returns: 127 | GreenBitEmbeddings: An instance of GreenBitEmbeddings. 128 | """ 129 | try: 130 | from sentence_transformers import SentenceTransformer 131 | except ImportError: 132 | raise ImportError( 133 | "Could not import sentence_transformers. " 134 | "Please install it with `pip install sentence-transformers`." 135 | ) 136 | 137 | model = SentenceTransformer( 138 | model_name, 139 | device=device, 140 | cache_folder=cache_dir, 141 | **model_kwargs 142 | ) 143 | 144 | return cls( 145 | model=model, 146 | device=device, 147 | cache_dir=cache_dir, 148 | multi_process=multi_process, 149 | show_progress=show_progress, 150 | encode_kwargs=encode_kwargs or {}, 151 | **kwargs 152 | ) -------------------------------------------------------------------------------- /green_bit_llm/patches/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/green_bit_llm/patches/__init__.py -------------------------------------------------------------------------------- /green_bit_llm/patches/deepseek_v3_moe_patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Import DeepSeek V3 components at module level 4 | try: 5 | from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE 6 | 7 | DEEPSEEK_V3_AVAILABLE = True 8 | except ImportError: 9 | DeepseekV3MoE = None 10 | DEEPSEEK_V3_AVAILABLE = False 11 | 12 | 13 | class QuantizedDeepSeekV3MoE: 14 | """ 15 | DeepSeekV3MoE forward method optimized for quantized models 16 | """ 17 | 18 | @staticmethod 19 | def moe_token_by_token(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): 20 | """ 21 | Quantization-friendly MoE implementation using token-by-token processing 22 | """ 23 | batch_size, sequence_length, hidden_dim = hidden_states.shape 24 | total_tokens = batch_size * sequence_length 25 | 26 | # Flatten processing 27 | hidden_states_flat = hidden_states.view(-1, hidden_dim) 28 | 29 | # Initialize output with correct dtype 30 | final_hidden_states = torch.zeros_like(hidden_states_flat, dtype=topk_weights.dtype) 31 | 32 | # Pre-convert to CPU to reduce GPU-CPU transfers 33 | topk_indices_cpu = topk_indices.cpu().numpy() 34 | topk_weights_cpu = topk_weights.cpu().numpy() 35 | top_k = topk_indices.shape[-1] 36 | 37 | # Process token by token for maximum quantization compatibility 38 | for token_idx in range(total_tokens): 39 | # Get single token input - fixed batch size (1) 40 | token_input = hidden_states_flat[token_idx:token_idx + 1] # [1, hidden_dim] 41 | token_output = torch.zeros_like(token_input, dtype=topk_weights.dtype) 42 | 43 | # Get expert indices and weights for this token 44 | token_experts = topk_indices_cpu[token_idx] # [top_k] 45 | token_weights = topk_weights_cpu[token_idx] # [top_k] 46 | 47 | # Process selected experts 48 | for expert_pos in range(top_k): 49 | expert_idx = int(token_experts[expert_pos]) 50 | expert_weight = float(token_weights[expert_pos]) 51 | 52 | # Skip small weights for performance 53 | if expert_weight < 1e-6: 54 | continue 55 | 56 | # Call expert network - fixed batch size, quantization friendly 57 | expert_layer = self.experts[expert_idx] 58 | expert_output = expert_layer(token_input) 59 | 60 | # Weighted accumulation - simple scalar operations 61 | token_output = token_output + expert_output * expert_weight 62 | 63 | # Store result 64 | final_hidden_states[token_idx] = token_output[0] 65 | 66 | # Convert back to original dtype 67 | return final_hidden_states.type(hidden_states.dtype) 68 | 69 | @staticmethod 70 | def moe_vectorized(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor): 71 | """ 72 | Vectorized MoE implementation adapted from Qwen3 method 73 | """ 74 | batch_size, sequence_length, hidden_dim = hidden_states.shape 75 | total_tokens = batch_size * sequence_length 76 | 77 | # Flatten processing 78 | hidden_states_flat = hidden_states.view(-1, hidden_dim) 79 | top_k = topk_indices.shape[-1] 80 | 81 | # Pre-allocate expert output storage - key vectorization optimization 82 | expert_outputs = torch.zeros(total_tokens, top_k, hidden_dim, 83 | dtype=topk_weights.dtype, device=hidden_states.device) 84 | 85 | # Process experts in batches - adapted for 256 experts 86 | for expert_idx in range(len(self.experts)): 87 | # Create expert mask [total_tokens, top_k] 88 | expert_mask = (topk_indices == expert_idx) 89 | 90 | if not expert_mask.any(): 91 | continue 92 | 93 | # Find positions using current expert 94 | token_idx, topk_idx = torch.where(expert_mask) 95 | 96 | if len(token_idx) == 0: 97 | continue 98 | 99 | # Batch processing - key performance improvement 100 | expert_inputs = hidden_states_flat[token_idx] 101 | expert_result = self.experts[expert_idx](expert_inputs) 102 | 103 | # Store results 104 | expert_outputs[token_idx, topk_idx] = expert_result 105 | 106 | # Apply weights and sum - vectorized operations 107 | weights_expanded = topk_weights.unsqueeze(-1) 108 | final_hidden_states = (expert_outputs * weights_expanded).sum(dim=1) 109 | 110 | # Reshape back to original shape 111 | final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim) 112 | 113 | # Convert back to original dtype 114 | return final_hidden_states.type(hidden_states.dtype) 115 | 116 | @staticmethod 117 | def should_use_vectorized(self, hidden_states): 118 | """ 119 | Determine whether to use vectorized method based on batch size efficiency 120 | """ 121 | batch_size, seq_len, hidden_dim = hidden_states.shape 122 | total_tokens = batch_size * seq_len 123 | top_k = getattr(self, 'top_k', 6) # DeepSeek V3 default 124 | 125 | # Estimate memory requirement for expert_outputs tensor 126 | estimated_memory_mb = total_tokens * top_k * hidden_dim * 2 / (1024 * 1024) # fp16 127 | 128 | # Primary criterion: batch size efficiency 129 | if total_tokens < 64: 130 | # Too small batch, vectorization has no advantage 131 | return False, estimated_memory_mb 132 | 133 | # Optional safety check for extreme cases 134 | try: 135 | total_memory = torch.cuda.get_device_properties(0).total_memory 136 | allocated_memory = torch.cuda.memory_allocated() 137 | available_memory_mb = (total_memory - allocated_memory) / (1024 * 1024) 138 | 139 | # Only fallback if memory is really insufficient (< 20% available) 140 | memory_threshold = available_memory_mb * 0.2 141 | 142 | if estimated_memory_mb > memory_threshold: 143 | return False, estimated_memory_mb 144 | 145 | except Exception: 146 | # If cannot get memory info, don't restrict 147 | pass 148 | 149 | return True, estimated_memory_mb 150 | 151 | @staticmethod 152 | def forward_hybrid(self, hidden_states): 153 | """ 154 | Hybrid strategy forward method for DeepSeek V3 MoE 155 | """ 156 | # Save for residual connection and shared experts 157 | residuals = hidden_states 158 | 159 | # Route calculation - maintain DeepSeek V3's complex routing logic 160 | topk_indices, topk_weights = self.gate(hidden_states) 161 | 162 | # Choose strategy based on memory and batch size 163 | use_vectorized, estimated_mb = QuantizedDeepSeekV3MoE.should_use_vectorized(self, hidden_states) 164 | 165 | if use_vectorized: 166 | moe_output = QuantizedDeepSeekV3MoE.moe_vectorized( 167 | self, hidden_states, topk_indices, topk_weights 168 | ) 169 | else: 170 | moe_output = QuantizedDeepSeekV3MoE.moe_token_by_token( 171 | self, hidden_states, topk_indices, topk_weights 172 | ) 173 | 174 | # Add shared expert output - DeepSeek V3 specific feature 175 | shared_expert_output = self.shared_experts(residuals) 176 | 177 | # Final output = MoE output + shared expert output 178 | final_output = moe_output + shared_expert_output 179 | 180 | return final_output 181 | 182 | @staticmethod 183 | def forward_conservative(self, hidden_states): 184 | """ 185 | Conservative forward method using only token-by-token processing 186 | """ 187 | residuals = hidden_states 188 | 189 | # Route calculation 190 | topk_indices, topk_weights = self.gate(hidden_states) 191 | 192 | # Use token-by-token method for maximum compatibility 193 | moe_output = QuantizedDeepSeekV3MoE.moe_token_by_token( 194 | self, hidden_states, topk_indices, topk_weights 195 | ) 196 | 197 | # Add shared expert output 198 | shared_expert_output = self.shared_experts(residuals) 199 | 200 | return moe_output + shared_expert_output 201 | 202 | @staticmethod 203 | def forward_vectorized(self, hidden_states): 204 | """ 205 | Vectorized forward method for maximum performance 206 | """ 207 | residuals = hidden_states 208 | 209 | # Route calculation 210 | topk_indices, topk_weights = self.gate(hidden_states) 211 | 212 | # Use vectorized method 213 | moe_output = QuantizedDeepSeekV3MoE.moe_vectorized( 214 | self, hidden_states, topk_indices, topk_weights 215 | ) 216 | 217 | # Add shared expert output 218 | shared_expert_output = self.shared_experts(residuals) 219 | 220 | return moe_output + shared_expert_output 221 | 222 | 223 | def apply_deepseek_v3_moe_patch(strategy='hybrid'): 224 | """ 225 | Apply DeepSeek V3 MoE patch 226 | 227 | Args: 228 | strategy: 'conservative', 'vectorized', 'hybrid' 229 | """ 230 | if not DEEPSEEK_V3_AVAILABLE: 231 | print("Error: DeepSeek V3 models not available in current transformers installation") 232 | return False 233 | 234 | # Save original method 235 | if not hasattr(DeepseekV3MoE, '_original_forward'): 236 | DeepseekV3MoE._original_forward = DeepseekV3MoE.forward 237 | 238 | # Apply strategy 239 | if strategy == 'conservative': 240 | DeepseekV3MoE.forward = QuantizedDeepSeekV3MoE.forward_conservative 241 | print("Info: Applied DeepSeek V3 conservative MoE patch") 242 | elif strategy == 'vectorized': 243 | DeepseekV3MoE.forward = QuantizedDeepSeekV3MoE.forward_vectorized 244 | print("Info: Applied DeepSeek V3 vectorized MoE patch") 245 | elif strategy == 'hybrid': 246 | DeepseekV3MoE.forward = QuantizedDeepSeekV3MoE.forward_hybrid 247 | print("Info: Applied DeepSeek V3 hybrid MoE patch") 248 | else: 249 | raise ValueError(f"Unknown strategy: {strategy}") 250 | 251 | return True 252 | 253 | 254 | def restore_deepseek_v3_moe_patch(): 255 | """ 256 | Restore original DeepSeek V3 MoE forward method 257 | """ 258 | if not DEEPSEEK_V3_AVAILABLE: 259 | return 260 | 261 | if hasattr(DeepseekV3MoE, '_original_forward'): 262 | DeepseekV3MoE.forward = DeepseekV3MoE._original_forward 263 | delattr(DeepseekV3MoE, '_original_forward') 264 | print("Info: Restored original DeepSeek V3 MoE forward method") 265 | 266 | 267 | def detect_deepseek_v3_moe_model(config): 268 | """ 269 | Detect if model is DeepSeek V3 MoE 270 | """ 271 | return ( 272 | hasattr(config, 'model_type') and 273 | 'deepseek_v3' in config.model_type.lower() and 274 | getattr(config, 'n_routed_experts', 0) > 0 275 | ) -------------------------------------------------------------------------------- /green_bit_llm/patches/qwen3_moe_patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Tuple 4 | 5 | 6 | class QuantizedQwen3MoeSparseMoeBlock: 7 | """ 8 | Qwen3MoeSparseMoeBlock forward method optimized for quantized models 9 | """ 10 | 11 | @staticmethod 12 | def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 13 | """ 14 | Quantization-friendly MoE forward implementation using vectorized strategy 15 | """ 16 | batch_size, sequence_length, hidden_dim = hidden_states.shape 17 | total_tokens = batch_size * sequence_length 18 | 19 | # flat 20 | hidden_states_flat = hidden_states.reshape(total_tokens, hidden_dim) 21 | 22 | # 1. Route calculation 23 | router_logits = self.gate(hidden_states_flat) 24 | if len(router_logits.shape) > 2: 25 | router_logits = router_logits.reshape(total_tokens, -1) 26 | 27 | # 2. Calculating routing weight 28 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) 29 | 30 | # 3. Select top experts and weights 31 | routing_weights_topk, indices_topk = torch.topk(routing_weights, self.top_k, dim=1) 32 | 33 | # 4. Normalized top weights 34 | if self.norm_topk_prob: 35 | routing_weights_topk /= routing_weights_topk.sum(dim=1, keepdim=True) 36 | routing_weights_topk = routing_weights_topk.to(hidden_states.dtype) 37 | 38 | # 5. Pre-allocated expert output storage 39 | # [total_tokens, top_k, hidden_dim] 40 | expert_outputs = torch.zeros(total_tokens, self.top_k, hidden_dim, 41 | dtype=hidden_states.dtype, device=hidden_states.device) 42 | 43 | # 6. Batch processing by experts 44 | for expert_idx in range(self.num_experts): 45 | # Create expert mask [total_tokens, top_k] 46 | expert_mask = (indices_topk == expert_idx) 47 | 48 | if not expert_mask.any(): 49 | continue 50 | 51 | # Find a location using current experts 52 | token_idx, topk_idx = torch.where(expert_mask) 53 | 54 | if len(token_idx) == 0: 55 | continue 56 | 57 | # Batch Processing 58 | expert_inputs = hidden_states_flat[token_idx] 59 | expert_result = self.experts[expert_idx](expert_inputs) 60 | 61 | # Storing Results 62 | expert_outputs[token_idx, topk_idx] = expert_result 63 | 64 | # 7. Apply weights and sum 65 | # Expand weight dimension: [total_tokens, top_k, 1] 66 | weights_expanded = routing_weights_topk.unsqueeze(-1) 67 | 68 | # Weighted sum: [total_tokens, hidden_dim] 69 | final_hidden_states = (expert_outputs * weights_expanded).sum(dim=1) 70 | 71 | # 8. Reshape back to original shape 72 | final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim) 73 | 74 | return final_hidden_states, router_logits 75 | 76 | @staticmethod 77 | def forward_micro_batched(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 78 | """ 79 | Quantization-friendly MoE forward implementation using micro_batched strategy 80 | """ 81 | batch_size, sequence_length, hidden_dim = hidden_states.shape 82 | total_tokens = batch_size * sequence_length 83 | 84 | hidden_states_flat = hidden_states.reshape(total_tokens, hidden_dim) 85 | 86 | # Route calculation 87 | router_logits = self.gate(hidden_states_flat) 88 | if len(router_logits.shape) > 2: 89 | router_logits = router_logits.reshape(total_tokens, -1) 90 | 91 | routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) 92 | routing_weights_topk, indices_topk = torch.topk(routing_weights, self.top_k, dim=1) 93 | 94 | if self.norm_topk_prob: 95 | routing_weights_topk /= routing_weights_topk.sum(dim=1, keepdim=True) 96 | routing_weights_topk = routing_weights_topk.to(hidden_states.dtype) 97 | 98 | final_hidden_states = torch.zeros_like(hidden_states_flat) 99 | 100 | # Fixed micro-batch size - quantization friendly 101 | micro_batch_size = min(8, total_tokens) # Small fixed batch size 102 | 103 | for start_idx in range(0, total_tokens, micro_batch_size): 104 | end_idx = min(start_idx + micro_batch_size, total_tokens) 105 | 106 | # Still process token by token in micro batch - maintain quantization compatibility 107 | for token_idx in range(start_idx, end_idx): 108 | token_input = hidden_states_flat[token_idx:token_idx + 1] 109 | token_output = torch.zeros_like(token_input) 110 | 111 | for expert_pos in range(self.top_k): 112 | expert_idx = indices_topk[token_idx, expert_pos].item() 113 | expert_weight = routing_weights_topk[token_idx, expert_pos].item() 114 | 115 | if expert_weight < 1e-4: 116 | continue 117 | 118 | expert_output = self.experts[expert_idx](token_input) 119 | token_output = token_output + expert_output * expert_weight 120 | 121 | final_hidden_states[token_idx] = token_output[0] 122 | 123 | final_hidden_states = final_hidden_states.view(batch_size, sequence_length, hidden_dim) 124 | return final_hidden_states, router_logits 125 | 126 | def apply_qwen3_moe_patch(): 127 | """ 128 | Apply the monkey patch of Qwen3MoeSparseMoeBlock 129 | """ 130 | try: 131 | from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock 132 | 133 | # Save the original method (in case we need to restore it) 134 | if not hasattr(Qwen3MoeSparseMoeBlock, '_original_forward'): 135 | Qwen3MoeSparseMoeBlock._original_forward = Qwen3MoeSparseMoeBlock.forward 136 | 137 | # Replace the forward method 138 | Qwen3MoeSparseMoeBlock.forward = QuantizedQwen3MoeSparseMoeBlock.forward 139 | 140 | print("Info: Successfully applied Qwen3MoeSparseMoeBlock patch for quantized models") 141 | 142 | except ImportError as e: 143 | print(f"Error: Could not apply Qwen3MoeSparseMoeBlock patch: {e}") 144 | print(" This might be expected if Qwen3 models are not being used") 145 | 146 | def restore_qwen3_moe_patch(): 147 | """ 148 | Restore the original Qwen3MoeSparseMoeBlock forward method 149 | """ 150 | try: 151 | from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock 152 | 153 | if hasattr(Qwen3MoeSparseMoeBlock, '_original_forward'): 154 | Qwen3MoeSparseMoeBlock.forward = Qwen3MoeSparseMoeBlock._original_forward 155 | delattr(Qwen3MoeSparseMoeBlock, '_original_forward') 156 | print("Info: Successfully restored original Qwen3MoeSparseMoeBlock forward method") 157 | 158 | except ImportError: 159 | pass -------------------------------------------------------------------------------- /green_bit_llm/routing/__init__.py: -------------------------------------------------------------------------------- 1 | from .confidence_scorer import ConfidenceScorer -------------------------------------------------------------------------------- /green_bit_llm/routing/confidence_scorer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from green_bit_llm.routing.libra_router.ue_router import MahalanobisDistanceSeq 3 | 4 | 5 | class ConfidenceScorer: 6 | """ 7 | A class to compute confidence scores based on Mahalanobis distance. 8 | 9 | Attributes: 10 | parameters_path (str): Path to the model parameters 11 | json_file_path (str): Path to the uncertainty bounds JSON file 12 | device (str): Device to run computations on ('cpu', 'cuda', 'mps') 13 | 14 | Example: 15 | confidence_scorer = ConfidenceScore( 16 | parameters_path="path/to/params", 17 | json_file_path="path/to/bounds.json", 18 | device="cuda" 19 | ) 20 | 21 | confidence = confidence_scorer.calculate_confidence(hidden_states) 22 | """ 23 | 24 | def __init__( 25 | self, 26 | parameters_path: str, 27 | model_id: str, 28 | device: str = "cuda" 29 | ): 30 | """ 31 | Initialize the ConfidenceScore calculator. 32 | 33 | Args: 34 | parameters_path: Path to model parameters 35 | json_file_path: Path to uncertainty bounds JSON 36 | device: Computation device 37 | threshold: Confidence threshold for routing 38 | """ 39 | self.parameters_path = parameters_path 40 | self.device = device 41 | 42 | # Initialize Mahalanobis distance calculator 43 | try: 44 | self.mahalanobis = MahalanobisDistanceSeq( 45 | parameters_path=parameters_path, 46 | normalize=False, 47 | model_id=model_id, 48 | device=self.device 49 | ) 50 | except Exception as e: 51 | raise RuntimeError(f"Failed to initialize Mahalanobis distance calculator: {str(e)}") 52 | 53 | def calculate_confidence( 54 | self, 55 | hidden_states: torch.Tensor 56 | ) -> float: 57 | """ 58 | Calculate confidence score from hidden states. 59 | Support both single input and batch input. 60 | 61 | Args: 62 | hidden_states: Model hidden states tensor 63 | return_uncertainty: Whether to return the raw uncertainty score 64 | 65 | Returns: 66 | Union[float, List[float]]: Single confidence score or list of confidence scores 67 | 68 | Raises: 69 | ValueError: If hidden states have invalid shape 70 | RuntimeError: If confidence calculation fails 71 | """ 72 | 73 | try: 74 | # Calculate uncertainty using Mahalanobis distance 75 | uncertainty = self.mahalanobis(hidden_states) 76 | if uncertainty is None: 77 | raise RuntimeError("Failed to calculate uncertainty") 78 | 79 | # Normalize uncertainty if bounds are available 80 | if self.mahalanobis.ue_bounds_tensor is not None: 81 | uncertainty = self.mahalanobis.normalize_ue( 82 | uncertainty[0], 83 | self.device 84 | ) 85 | else: 86 | uncertainty = uncertainty[0] 87 | 88 | # Handle both single input and batch input 89 | if uncertainty.dim() == 0: # single value 90 | confidence_score = 1.0 - uncertainty.cpu().item() 91 | return confidence_score 92 | else: # batch of values 93 | confidence_scores = 1.0 - uncertainty.cpu().tolist() 94 | return confidence_scores 95 | 96 | except Exception as e: 97 | raise RuntimeError(f"Failed to calculate confidence score: {str(e)}") -------------------------------------------------------------------------------- /green_bit_llm/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/green_bit_llm/serve/__init__.py -------------------------------------------------------------------------------- /green_bit_llm/serve/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/green_bit_llm/serve/api/__init__.py -------------------------------------------------------------------------------- /green_bit_llm/serve/api/v1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/green_bit_llm/serve/api/v1/__init__.py -------------------------------------------------------------------------------- /green_bit_llm/serve/auth/__init__.py: -------------------------------------------------------------------------------- 1 | from .api_key_auth import APIKeyAuth, get_api_key_auth 2 | from .rate_limiter import RateLimiter 3 | 4 | __all__ = ['APIKeyAuth', 'get_api_key_auth', 'RateLimiter'] -------------------------------------------------------------------------------- /green_bit_llm/serve/auth/api_key_auth.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import hashlib 3 | from typing import Optional 4 | from fastapi import HTTPException, Security 5 | from fastapi.security.api_key import APIKeyHeader 6 | from starlette.status import HTTP_403_FORBIDDEN 7 | import logging 8 | from pathlib import Path 9 | from dotenv import load_dotenv 10 | import os 11 | 12 | from .rate_limiter import RateLimiter 13 | 14 | # API key header 15 | API_KEY_HEADER = APIKeyHeader(name="X-Api-Key", auto_error=True) 16 | 17 | 18 | class APIKeyAuth: 19 | def __init__(self, db_path: str): 20 | self.db_path = db_path 21 | self.rate_limiter = RateLimiter() 22 | self.logger = logging.getLogger("greenbit_server") 23 | 24 | def get_db_connection(self): 25 | return sqlite3.connect(self.db_path) 26 | 27 | def _hash_key(self, api_key: str) -> str: 28 | """Hash the API key for database lookup.""" 29 | return hashlib.blake2b( 30 | api_key.encode(), 31 | digest_size=32, 32 | salt=b"greenbit_storage", 33 | person=b"api_key_storage" 34 | ).hexdigest() 35 | 36 | def validate_api_key(self, api_key: str) -> dict: 37 | """Validate API key and return user info if valid.""" 38 | try: 39 | hashed_key = self._hash_key(api_key) 40 | 41 | with self.get_db_connection() as conn: 42 | cursor = conn.cursor() 43 | cursor.execute(""" 44 | SELECT 45 | ak.user_id, 46 | u.name, 47 | u.email, 48 | u.organization, 49 | ak.tier, 50 | ak.rpm_limit, 51 | ak.tpm_limit, 52 | ak.concurrent_requests, 53 | ak.max_tokens, 54 | ak.permissions, 55 | ak.is_active 56 | FROM api_keys ak 57 | JOIN users u ON ak.user_id = u.id 58 | WHERE ak.api_key_hash = ? 59 | """, (hashed_key,)) 60 | result = cursor.fetchone() 61 | 62 | if not result: 63 | raise HTTPException( 64 | status_code=HTTP_403_FORBIDDEN, 65 | detail="Invalid API key" 66 | ) 67 | 68 | user_info = { 69 | "user_id": result[0], 70 | "name": result[1], 71 | "email": result[2], 72 | "organization": result[3], 73 | "tier": result[4], 74 | "rpm_limit": result[5], 75 | "tpm_limit": result[6], 76 | "concurrent_requests": result[7], 77 | "max_tokens": result[8], 78 | "permissions": result[9].split(','), 79 | "is_active": bool(result[10]) 80 | } 81 | 82 | if not user_info["is_active"]: 83 | raise HTTPException( 84 | status_code=HTTP_403_FORBIDDEN, 85 | detail="API key is inactive" 86 | ) 87 | 88 | # Update last_used_at 89 | cursor.execute(""" 90 | UPDATE api_keys 91 | SET last_used_at = CURRENT_TIMESTAMP 92 | WHERE api_key_hash = ? 93 | """, (hashed_key,)) 94 | conn.commit() 95 | 96 | return user_info 97 | 98 | except sqlite3.Error as e: 99 | self.logger.error(f"Database error during API key validation: {str(e)}") 100 | raise HTTPException(status_code=500, detail="Internal server error") 101 | 102 | async def check_rate_limits(self, api_key: str, user_info: dict, estimated_tokens: Optional[int] = None): 103 | """Check all rate limits.""" 104 | try: 105 | # Check RPM 106 | self.rate_limiter.check_rate_limit(api_key, user_info["rpm_limit"]) 107 | 108 | # Check TPM if tokens are provided 109 | if estimated_tokens is not None: 110 | self.rate_limiter.check_token_limit( 111 | api_key, 112 | estimated_tokens, 113 | user_info["tpm_limit"] 114 | ) 115 | 116 | # Try to acquire concurrent request slot 117 | await self.rate_limiter.acquire_concurrent_request( 118 | api_key, 119 | user_info["concurrent_requests"] 120 | ) 121 | 122 | except Exception as e: 123 | self.logger.error(f"Rate limit check failed: {str(e)}") 124 | raise 125 | 126 | def check_permissions(self, user_info: dict, endpoint_type: str): 127 | """Check if the user has permission to access the endpoint.""" 128 | if endpoint_type not in user_info["permissions"]: 129 | raise HTTPException( 130 | status_code=HTTP_403_FORBIDDEN, 131 | detail=f"No permission to access {endpoint_type} endpoint" 132 | ) 133 | 134 | def check_token_limit(self, user_info: dict, requested_tokens: int): 135 | """Check if the requested tokens are within the user's limit.""" 136 | if requested_tokens > user_info["max_tokens"]: 137 | raise HTTPException( 138 | status_code=HTTP_403_FORBIDDEN, 139 | detail=f"Requested tokens ({requested_tokens}) exceed maximum allowed ({user_info['max_tokens']})" 140 | ) 141 | 142 | 143 | def load_api_key(env_file: str = None) -> str: 144 | """Load API key from environment file or environment variables.""" 145 | if env_file and Path(env_file).exists(): 146 | load_dotenv(env_file) 147 | 148 | api_key = os.getenv("LIBRA_API_KEY") 149 | if not api_key: 150 | raise HTTPException( 151 | status_code=HTTP_403_FORBIDDEN, 152 | detail="API key not found in environment" 153 | ) 154 | 155 | return api_key 156 | 157 | 158 | async def get_api_key_auth( 159 | api_key: str = Security(API_KEY_HEADER), 160 | env_file: str = None, 161 | estimated_tokens: Optional[int] = None 162 | ) -> dict: 163 | """FastAPI dependency for API key authentication.""" 164 | # If no API key in header, try to load from environment 165 | if not api_key: 166 | api_key = load_api_key(env_file) 167 | 168 | db_path = os.getenv("LIBRA_DB_PATH", str(Path(__file__).parent.parent.parent.parent / "db" / "greenbit.db")) 169 | auth_handler = APIKeyAuth(db_path) 170 | user_info = auth_handler.validate_api_key(api_key) 171 | await auth_handler.check_rate_limits(api_key, user_info, estimated_tokens) 172 | return user_info -------------------------------------------------------------------------------- /green_bit_llm/serve/auth/rate_limiter.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timedelta 2 | from typing import Dict, List, Tuple 3 | from fastapi import HTTPException 4 | from starlette.status import HTTP_429_TOO_MANY_REQUESTS 5 | import asyncio 6 | from collections import defaultdict 7 | 8 | 9 | class RateLimiter: 10 | def __init__(self): 11 | self._request_times: Dict[str, List[datetime]] = defaultdict(list) 12 | self._token_counts: Dict[str, List[Tuple[datetime, int]]] = defaultdict(list) 13 | self._concurrent_requests: Dict[str, int] = defaultdict(int) 14 | self._locks: Dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) 15 | 16 | async def acquire_concurrent_request(self, api_key: str, limit: int): 17 | """Try to acquire a concurrent request slot.""" 18 | async with self._locks[api_key]: 19 | if self._concurrent_requests[api_key] >= limit: 20 | raise HTTPException( 21 | status_code=HTTP_429_TOO_MANY_REQUESTS, 22 | detail=f"Maximum concurrent requests ({limit}) exceeded" 23 | ) 24 | self._concurrent_requests[api_key] += 1 25 | 26 | async def release_concurrent_request(self, api_key: str): 27 | """Release a concurrent request slot.""" 28 | async with self._locks[api_key]: 29 | if self._concurrent_requests[api_key] > 0: 30 | self._concurrent_requests[api_key] -= 1 31 | 32 | def check_rate_limit(self, api_key: str, rpm_limit: int): 33 | """Check requests per minute limit.""" 34 | now = datetime.now() 35 | minute_ago = now - timedelta(minutes=1) 36 | 37 | # Clean old entries 38 | self._request_times[api_key] = [ 39 | time for time in self._request_times[api_key] 40 | if time > minute_ago 41 | ] 42 | 43 | # Check RPM limit 44 | if len(self._request_times[api_key]) >= rpm_limit: 45 | raise HTTPException( 46 | status_code=HTTP_429_TOO_MANY_REQUESTS, 47 | detail=f"Rate limit exceeded. Maximum {rpm_limit} requests per minute." 48 | ) 49 | 50 | self._request_times[api_key].append(now) 51 | 52 | def check_token_limit(self, api_key: str, new_tokens: int, tpm_limit: int): 53 | """Check tokens per minute limit.""" 54 | now = datetime.now() 55 | minute_ago = now - timedelta(minutes=1) 56 | 57 | # Clean old entries 58 | self._token_counts[api_key] = [ 59 | (time, count) for time, count in self._token_counts[api_key] 60 | if time > minute_ago 61 | ] 62 | 63 | # Calculate current token usage 64 | current_tpm = sum(count for _, count in self._token_counts[api_key]) 65 | 66 | # Check TPM limit 67 | if current_tpm + new_tokens > tpm_limit: 68 | raise HTTPException( 69 | status_code=HTTP_429_TOO_MANY_REQUESTS, 70 | detail=f"Token rate limit exceeded. Maximum {tpm_limit} tokens per minute." 71 | ) 72 | 73 | self._token_counts[api_key].append((now, new_tokens)) -------------------------------------------------------------------------------- /green_bit_llm/sft/README.md: -------------------------------------------------------------------------------- 1 | # Finetuning GreenBitAI's Low-bit LLMs 2 | 3 | ## Overview 4 | 5 | This package demonstrates the capabilities of [GreenBitAI's low-bit large language models (LLMs)](https://huggingface.co/GreenBitAI) through two main features: 6 | 1. Full-parameter fine-tuning using quantized LLMs. 7 | 2. Parameter efficient fine-tuning 8 | 9 | 10 | ## Installation 11 | 12 | Please follow the [main installation instructions](../../README.md#installation) for how to install the packages required to run this inference package. 13 | Afterward, install the following additional libraries: 14 | 15 | ```bash 16 | pip install trl 17 | pip install -U git+https://github.com/huggingface/peft.git 18 | ``` 19 | 20 | If you want to use a **8-bit customized optimizer** with the gradient low-rank projection for maximizing memory savings, you will also need to install the following package: 21 | 22 | ```bash 23 | pip install bitsandbytes galore-torch 24 | ``` 25 | 26 | ## Usage 27 | 28 | ### LLMs 29 | 30 | We have released over 200 highly precise 2.2/2.5/3/4-bit models across the modern LLM family, featuring LLaMA 2/3, 01-Yi, Qwen, Mistral, Phi-3, and more. Currently, only layer-mix quantized models are supported for sft. In addition to our low-bit models, green-bit-llm is fully compatible with the AutoGPTQ series of 4-bit quantization and compression models. 31 | 32 | Happy scaling low-bit LLMs with more data! 33 | 34 | | Family | Bpw | Size | HF collection id | 35 | |:----------------:|:------------------:|:------------------------------:|:-----------------------------------------------------------------------------------------------------------------:| 36 | | Llama-3 | `4.0/3.0/2.5/2.2` | `8B/70B` | [`GreenBitAI Llama-3`](https://huggingface.co/collections/GreenBitAI/greenbitai-llama-3-6627bc1ec6538e3922c5d81c) | 37 | | Llama-2 | `3.0/2.5/2.2` | `7B/13B/70B` | [`GreenBitAI Llama-2`](https://huggingface.co/collections/GreenBitAI/greenbitai-llama-2-661f87e3b073ff8e48a12834) | 38 | | Qwen-1.5 | `4.0/3.0/2.5/2.2` | `0.5B/1.8B/4B/7B/14B/32B/110B` | [`GreenBitAI Qwen 1.5`](https://huggingface.co/collections/GreenBitAI/greenbitai-qwen15-661f86ea69433f3d3062c920) | 39 | | Phi-3 | `3.0/2.5/2.2` | `mini` | [`GreenBitAI Phi-3`](https://huggingface.co/collections/GreenBitAI/greenbitai-phi-3-6628d008cdf168398a296c92) | 40 | | Mistral | `3.0/2.5/2.2` | `7B` | [`GreenBitAI Mistral`](https://huggingface.co/collections/GreenBitAI/greenbitai-mistral-661f896c45da9d8b28a193a8) | 41 | | 01-Yi | `3.0/2.5/2.2` | `6B/34B` | [`GreenBitAI 01-Yi`](https://huggingface.co/collections/GreenBitAI/greenbitai-01-yi-661f88af0648daa766d5102f) | 42 | | Llama-3-instruct | `4.0/3.0/2.5/2.2` | `8B/70B` | [`GreenBitAI Llama-3`](https://huggingface.co/collections/GreenBitAI/greenbitai-llama-3-6627bc1ec6538e3922c5d81c) | 43 | | Mistral-instruct | `3.0/2.5/2.2` | `7B` | [`GreenBitAI Mistral`](https://huggingface.co/collections/GreenBitAI/greenbitai-mistral-661f896c45da9d8b28a193a8) | 44 | | Phi-3-instruct | `3.0/2.5/2.2` | `mini` | [`GreenBitAI Phi-3`](https://huggingface.co/collections/GreenBitAI/greenbitai-phi-3-6628d008cdf168398a296c92) | 45 | | Qwen-1.5-Chat | `4.0/3.0/2.5/2.2` | `0.5B/1.8B/4B/7B/14B/32B/110B` | [`GreenBitAI Qwen 1.5`](https://huggingface.co/collections/GreenBitAI/greenbitai-qwen15-661f86ea69433f3d3062c920) | 46 | | 01-Yi-Chat | `3.0/2.5/2.2` | `6B/34B` | [`GreenBitAI 01-Yi`](https://huggingface.co/collections/GreenBitAI/greenbitai-01-yi-661f88af0648daa766d5102f) | 47 | 48 | 49 | ### Full-parameter fine-tuning 50 | 51 | Run the script as follows to fine-tune the quantized weights of the model on the target dataset. 52 | The **--tune-qweight-only** parameter determines whether to fine-tune only the quantized weights or all weights, including non-quantized ones. 53 | 54 | ```bash 55 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.sft.finetune --model GreenBitAI/Qwen-1.5-1.8B-layer-mix-bpw-3.0 --dataset tatsu-lab/alpaca --tune-qweight-only 56 | 57 | # AutoGPTQ model Q-SFT 58 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.sft.finetune --model astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit --dataset tatsu-lab/alpaca --tune-qweight-only --batch-size 1 59 | ``` 60 | If you want to further save memory, we also support [Galore](https://github.com/jiaweizzhao/GaLore): a memory-efficient low-rank training strategy. 61 | You just need to add the **--galore** parameter in your command line. However, it's important to note that Galore requires the computation of projection matrices for the gradients, which will incur additional time costs. 62 | You can think of this as a trade-off strategy where time is exchanged for space. 63 | To select the "AdamW8bit" optimizer, simply add "--optimizer AdamW8bit" to your command line. 64 | 65 | ### Parameter efficient fine-tuning 66 | 67 | ```bash 68 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.sft.peft_lora --model GreenBitAI/Qwen-1.5-1.8B-layer-mix-bpw-3.0 --dataset tatsu-lab/alpaca --lr-fp 1e-6 69 | 70 | # AutoGPTQ model with Lora 71 | CUDA_VISIBLE_DEVICES=0 python -m green_bit_llm.sft.peft_lora --model astronomer/Llama-3-8B-Instruct-GPTQ-4-Bit --dataset tatsu-lab/alpaca --lr-fp 1e-6 72 | ``` 73 | 74 | ### 0-Shot Evaluation of Q-SFT 75 | 76 | The 0-shot evaluations of quantized Llama 3 8B model under different fine-tuning settings are listed as an example. **Q-SFT** indicates quantized surpervised-finetuning. 77 | 78 | | Task | Bpw | Llama 3 8B Base | Llama 3 8B + LoRA | Llama 3 8B Q-SFT + Galore | Llama 3 8B + Q-SFT | 79 | |:-------------:|:--------:|:-----------------:|:-------------------:|:--------------------------:|:------------------:| 80 | | PIQA | 2.2 | 0.72 | 0.75 | 0.75 | 0.75 | 81 | | | 2.5 | 0.74 | 0.77 | 0.76 | 0.76 | 82 | | | 3.0 | 0.76 | 0.78 | 0.78 | 0.79 | 83 | | BoolQ | 2.2 | 0.74 | 0.77 | 0.77 | 0.78 | 84 | | | 2.5 | 0.75 | 0.76 | 0.76 | 0.78 | 85 | | | 3.0 | 0.78 | 0.80 | 0.79 | 0.80 | 86 | | Winogr. | 2.2 | 0.67 | 0.68 | 0.68 | 0.67 | 87 | | | 2.5 | 0.68 | 0.69 | 0.69 | 0.69 | 88 | | | 3.0 | 0.70 | 0.71 | 0.71 | 0.71 | 89 | | ARC-E | 2.2 | 0.73 | 0.77 | 0.76 | 0.75 | 90 | | | 2.5 | 0.76 | 0.77 | 0.77 | 0.76 | 91 | | | 3.0 | 0.77 | 0.79 | 0.79 | 0.79 | 92 | | ARC-C | 2.2 | 0.39 | 0.46 | 0.45 | 0.45 | 93 | | | 2.5 | 0.41 | 0.44 | 0.43 | 0.43 | 94 | | | 3.0 | 0.44 | 0.49 | 0.47 | 0.49 | 95 | | WiC | 2.2 | 0.50 | 0.50 | 0.50 | 0.50 | 96 | | | 2.5 | 0.51 | 0.50 | 0.52 | 0.51 | 97 | | | 3.0 | 0.52 | 0.52 | 0.57 | 0.60 | 98 | | Avg | 2.2 | 0.62 | 0.65 | 0.65 | 0.65 | 99 | | | 2.5 | 0.64 | 0.65 | 0.65 | 0.65 | 100 | | | 3.0 | 0.66 | 0.68 | 0.68 | 0.69 | 101 | 102 | Compared to traditional LoRA based fine-tuning, our approach streamlines engineering supply chain from fine-tuning to hardware deployment, while also enhancing performance. 103 | 104 | ### Current Limitations 105 | 106 | 1. Gradient clipping is currently unavailable for Full-parameter fine-tuning due to PyTorch's restrictions on the dtype of gradient tensors. The integer tensor type we use for qweight is not supported. We plan to address this issue gradually in the future. 107 | 2. Due to the need for deep modifications to the Python code related to the Int gradient tensor type, to ensure stability and safety, we currently do not support distributed or data parallel training. We plan to support this in the future. Stay tuned. 108 | 109 | 110 | ## License 111 | - The script 'optim/adamw8bit.py' has been modified from [GaLore repository](https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/adamw8bit.py), which is released under the Apache 2.0 License. 112 | - The script 'optim/bnb_optimizer.py' has been modified from [bitsandbytes repository](https://github.com/TimDettmers/bitsandbytes/blob/main/bitsandbytes/optim/optimizer.py), which is released under the MIT License. 113 | - We release our changes and additions to these files under the [Apache 2.0 License](../../LICENSE). 114 | -------------------------------------------------------------------------------- /green_bit_llm/sft/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/green_bit_llm/sft/__init__.py -------------------------------------------------------------------------------- /green_bit_llm/sft/finetune.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import torch 5 | 6 | from transformers import PreTrainedTokenizer, TrainingArguments 7 | from datasets import load_dataset 8 | from green_bit_llm.sft.trainer import GbaSFTTrainer 9 | 10 | from green_bit_llm.common import load 11 | from green_bit_llm.args_parser import setup_shared_arg_parser 12 | 13 | import warnings 14 | 15 | warnings.filterwarnings('ignore') 16 | 17 | try: 18 | from bitorch_engine.optim import DiodeMix 19 | from bitorch_engine.layers.qlinear.nbit import MPQLinearBase 20 | except ModuleNotFoundError as e: 21 | raise Exception(f"Error occurred while importing Bitorch Engine module '{str(e)}'.") 22 | 23 | from green_bit_llm.sft.optim import AdamW8bit 24 | from green_bit_llm.sft.utils import str_to_torch_dtype, create_param_groups 25 | 26 | 27 | # default value for arguments 28 | DEFAULT_MODEL_PATH = "GreenBitAI/Qwen-1.5-1.8B-layer-mix-bpw-3.0" 29 | DEFAULT_SEQLEN = 512 30 | DEFAULT_RANDOM_SEED = 0 31 | DEFAULT_LR = 5e-6 32 | DEFAULT_LR_GALORE = 5e-5 33 | DEFAULT_LR_ADAMW8BIT = 5e-3 34 | DEFAULT_BETAS = (0.9, 0.99) 35 | 36 | 37 | def setup_arg_parser(): 38 | """Set up and return the argument parser.""" 39 | parser = setup_shared_arg_parser("green-bit-llm finetune script") 40 | parser.add_argument( 41 | "--seed", 42 | type=int, 43 | default=DEFAULT_RANDOM_SEED, 44 | help="The random seed for data loader.", 45 | ) 46 | # GaLore parameters 47 | parser.add_argument( 48 | "--galore", 49 | action="store_true", 50 | help="Enable using galore", 51 | ) 52 | parser.add_argument("--galore-rank", type=int, default=256) 53 | parser.add_argument("--galore-update-proj-gap", type=int, default=200) 54 | parser.add_argument("--galore-scale", type=float, default=0.25) 55 | parser.add_argument("--galore-proj-type", type=str, default="std") 56 | 57 | # qweight related 58 | parser.add_argument( 59 | "--tune-qweight-only", 60 | action="store_true", 61 | help="Set whether to adjust only the low-bit qweight and keep the regular parameters unchanged during the training process.", 62 | ) 63 | parser.add_argument( 64 | "--lr-2bit", 65 | type=float, 66 | default=-1.0, 67 | help="Learning rate for 2-bit qweight." 68 | ) 69 | parser.add_argument( 70 | "--lr-4bit", 71 | type=float, 72 | default=-1.0, 73 | help="Learning rate for 4-bit qweight." 74 | ) 75 | parser.add_argument( 76 | "--lr-fp", 77 | type=float, 78 | default=DEFAULT_LR, 79 | help="Learning rate for full-precision weight." 80 | ) 81 | return parser 82 | 83 | 84 | def main(args): 85 | 86 | # Building configs 87 | tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} 88 | pretrain_model_config = { 89 | "trust_remote_code": True if args.trust_remote_code else None, 90 | "attn_implementation": "flash_attention_2" if args.use_flash_attention_2 else None 91 | } 92 | 93 | model, tokenizer, config = load( 94 | args.model, 95 | tokenizer_config=tokenizer_config, 96 | device_map='auto', 97 | seqlen=args.seqlen, 98 | model_config=pretrain_model_config, 99 | requires_grad=True, 100 | ) 101 | 102 | # NOTE: 103 | # Typically, Hugging Face's Trainer does not support fine-tuning quantized models. 104 | # However, our tool supports this scenario. 105 | # Therefore, we need to delete this attribute after loading the model. 106 | if hasattr(model, 'is_quantized'): 107 | delattr(model, 'is_quantized') 108 | 109 | param_groups = create_param_groups(model, args, DEFAULT_BETAS, DEFAULT_LR_GALORE, DEFAULT_LR_ADAMW8BIT, DEFAULT_LR) 110 | 111 | model.train() 112 | 113 | dataset = load_dataset(args.dataset, split="train") 114 | 115 | if not args.galore: 116 | args.save_dir = os.path.join(args.save_dir, "finetune/common/", args.model) 117 | else: 118 | args.save_dir = os.path.join(args.save_dir, "finetune/galore/", args.model) 119 | 120 | 121 | train_args = TrainingArguments( 122 | output_dir=args.save_dir, 123 | gradient_checkpointing=True, 124 | #auto_find_batch_size=True, 125 | per_device_train_batch_size=args.batch_size, 126 | logging_steps=1, 127 | num_train_epochs=1, 128 | save_steps=args.save_step, 129 | save_total_limit=3, 130 | gradient_accumulation_steps=1, 131 | lr_scheduler_type='cosine', 132 | max_grad_norm=0, # NOTE: max_grad_norm MUST be <= 0 or None, otherwise raise dtype error due to the Int dtype of qweight. 133 | ) 134 | 135 | # Optimizer 136 | if 'adamw8bit' in args.optimizer.lower(): 137 | optimizer = AdamW8bit(param_groups, weight_decay=args.weight_decay, dtype=str_to_torch_dtype(args.dtype)) 138 | elif 'diodemix' in args.optimizer.lower(): 139 | optimizer = DiodeMix(param_groups, dtype=str_to_torch_dtype(args.dtype)) 140 | 141 | optimizers = (optimizer, None) 142 | 143 | # Trainer 144 | trainer = GbaSFTTrainer( 145 | model=model, 146 | args=train_args, 147 | train_dataset=dataset, 148 | dataset_text_field="text", 149 | optimizers=optimizers, 150 | max_seq_length=args.seqlen, 151 | ) 152 | 153 | trainer.train() 154 | 155 | model.save_pretrained(args.save_dir) 156 | 157 | 158 | if __name__ == "__main__": 159 | if not torch.cuda.is_available(): 160 | print("Warning: CUDA is required to run the model.") 161 | sys.exit(0) 162 | 163 | parser = setup_arg_parser() 164 | args = parser.parse_args() 165 | 166 | main(args) 167 | -------------------------------------------------------------------------------- /green_bit_llm/sft/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .adamw8bit import AdamW8bit -------------------------------------------------------------------------------- /green_bit_llm/sft/optim/adamw8bit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .bnb_optimizer import Optimizer2State 4 | 5 | try: 6 | from galore_torch.galore_projector import GaLoreProjector 7 | except ModuleNotFoundError as e: 8 | raise Exception("Error: GaLoreProjector is not available. Make sure 'galore-torch' has been installed on you system.") 9 | 10 | try: 11 | from bitorch_engine.layers.qlinear.nbit import MPQWeightParameter 12 | from bitorch_engine.utils.quant_operators import gptq_style_unpacking 13 | from bitorch_engine.layers.qlinear.nbit.cuda.utils import pack_fp_weight 14 | except ModuleNotFoundError as e: 15 | raise Exception(f"Error occurred while importing Bitorch Engine module '{str(e)}'.") 16 | 17 | 18 | class AdamW8bit(Optimizer2State): 19 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=.0, amsgrad=False, optim_bits=8, 20 | args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False, 21 | dtype: torch.dtype = torch.float16): 22 | self.dtype = dtype 23 | super().__init__( "adam", params, lr, betas, eps, weight_decay, optim_bits, args, min_8bit_size, percentile_clipping, 24 | block_wise, is_paged=is_paged ) 25 | 26 | @torch.no_grad() 27 | def step(self, closure=None): 28 | """Performs a single optimization step. 29 | 30 | Arguments: 31 | closure (callable, optional): A closure that reevaluates the model 32 | and returns the loss. 33 | """ 34 | loss = None 35 | if closure is not None: 36 | with torch.enable_grad(): 37 | loss = closure() 38 | 39 | if not self.initialized: 40 | self.check_overrides() 41 | self.to_gpu() # needed for fairseq pure fp16 training 42 | self.initialized = True 43 | 44 | #if self.is_paged: self.page_mng.prefetch_all() 45 | for gindex, group in enumerate(self.param_groups): 46 | for pindex, p in enumerate(group["params"]): 47 | if p.grad is None: 48 | continue 49 | 50 | state = self.state[p] 51 | 52 | if "step" not in state: 53 | state["step"] = 0 54 | 55 | if isinstance(p, MPQWeightParameter): 56 | grad = p.privileged_grad.to(self.dtype).to(p.grad.device) 57 | else: 58 | grad = p.grad.to(self.dtype) 59 | 60 | # GaLore Projection 61 | if "rank" in group: 62 | if "projector" not in state: 63 | state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"]) 64 | 65 | projector = state["projector"] 66 | grad = projector.project(grad, state["step"]) 67 | 68 | saved_data = None 69 | if "rank" in group or isinstance(p, MPQWeightParameter): 70 | # suboptimal implementation 71 | # In the implementation mentioned, the author sets the variable p (representing model parameters) to zero, 72 | # meaning p does not change during the update step. Instead, only the gradient states are updated, 73 | # and actual weight modifications are calculated manually later in the code. 74 | saved_data = p.data.clone() 75 | p.data = torch.zeros_like(grad) 76 | 77 | if 'weight_decay' in group and group['weight_decay'] > 0: 78 | # ensure that the weight decay is not applied to the norm grad 79 | group['weight_decay_saved'] = group['weight_decay'] 80 | group['weight_decay'] = 0 81 | 82 | if 'state1' not in state: 83 | self.init_state(group, p, gindex, pindex, grad) 84 | 85 | self.prefetch_state(p) 86 | 87 | self.update_step(group, p, gindex, pindex, grad) 88 | 89 | torch.cuda.synchronize() 90 | 91 | if 'weight_decay_saved' in group: 92 | group['weight_decay'] = group['weight_decay_saved'] 93 | del group['weight_decay_saved'] 94 | 95 | w_unpacked = None 96 | # GaLore Projection Back 97 | if "rank" in group: 98 | # now the p.data is actually: -norm_grad*lr 99 | norm_grad = projector.project_back(p.data) 100 | 101 | if isinstance(p, MPQWeightParameter): 102 | # unpack qweight 103 | p.data = saved_data 104 | w_unpacked = gptq_style_unpacking(p).to(self.dtype).to(saved_data.device) 105 | w_unpacked.add_(norm_grad) 106 | if group["weight_decay"] > 0.0: 107 | w_unpacked.add_(w_unpacked, alpha=-group['lr'] * group['weight_decay']) 108 | # pack fp weight back to Q-weight and update qweight data 109 | p.data = pack_fp_weight(w_unpacked, p) 110 | else: 111 | p.data = saved_data.add_(norm_grad) 112 | if group["weight_decay"] > 0.0: 113 | p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay']) 114 | elif isinstance(p, MPQWeightParameter): 115 | # now the p.data is actually: -norm_grad*lr 116 | norm_grad = p.data.clone() 117 | # unpack qweight 118 | p.data = saved_data 119 | w_unpacked = gptq_style_unpacking(p).to(self.dtype).to(saved_data.device) 120 | w_unpacked.add_(norm_grad) 121 | if group["weight_decay"] > 0.0: 122 | w_unpacked.add_(w_unpacked, alpha=-group['lr'] * group['weight_decay']) 123 | # pack fp weight back to Q-weight and update qweight data 124 | p.data = pack_fp_weight(w_unpacked, p) 125 | 126 | # pack fp weight back to qweight 127 | if w_unpacked is not None: 128 | del w_unpacked 129 | if saved_data is not None: 130 | del saved_data 131 | if torch.cuda.is_available(): 132 | torch.cuda.empty_cache() 133 | 134 | if self.is_paged: 135 | # all paged operation are asynchronous, we need 136 | # to sync to make sure all tensors are in the right state 137 | torch.cuda.synchronize() 138 | 139 | return loss 140 | -------------------------------------------------------------------------------- /green_bit_llm/sft/peft_lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import torch 5 | 6 | from transformers import PreTrainedTokenizer, TrainingArguments 7 | from datasets import load_dataset 8 | from peft import PeftModel, LoraConfig, get_peft_model 9 | 10 | from green_bit_llm.sft.trainer import GbaSFTTrainer 11 | from green_bit_llm.common import load 12 | from green_bit_llm.args_parser import setup_shared_arg_parser 13 | from green_bit_llm.sft.peft_utils.model import * 14 | from green_bit_llm.sft.optim import AdamW8bit 15 | 16 | import warnings 17 | warnings.filterwarnings('ignore') 18 | 19 | try: 20 | from bitorch_engine.optim import DiodeMix 21 | except ModuleNotFoundError as e: 22 | raise Exception(f"Error occurred while importing Bitorch Engine module '{str(e)}'.") 23 | 24 | from green_bit_llm.sft.utils import str_to_torch_dtype, create_param_groups 25 | 26 | # default value for arguments 27 | DEFAULT_MODEL_PATH = "GreenBitAI/Qwen-1.5-1.8B-layer-mix-bpw-3.0" 28 | DEFAULT_SEQLEN = 512 29 | DEFAULT_RANDOM_SEED = 0 30 | DEFAULT_LR = 1e-5 31 | DEFAULT_LR_GALORE = 1e-4 32 | DEFAULT_LR_FP = 1e-6 33 | DEFAULT_BETAS = (0.9, 0.999) 34 | 35 | 36 | def setup_arg_parser(): 37 | """Set up and return the argument parser.""" 38 | parser = setup_shared_arg_parser("green-bit-llm lora script") 39 | parser.add_argument( 40 | "--seed", 41 | type=int, 42 | default=DEFAULT_RANDOM_SEED, 43 | help="The random seed for data loader.", 44 | ) 45 | # qweight related 46 | parser.add_argument( 47 | "--lr-fp", 48 | type=float, 49 | default=DEFAULT_LR_FP, 50 | help="Learning rate for full-precision weight." 51 | ) 52 | parser.add_argument("--lora-rank", type=int, default=64) 53 | parser.add_argument("--lora-alpha", type=int, default=32) 54 | parser.add_argument("--lora-dropout", type=float, default=0.01) 55 | return parser 56 | 57 | 58 | def main(args): 59 | 60 | # Building configs 61 | tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} 62 | pretrain_model_config = { 63 | "trust_remote_code": True if args.trust_remote_code else None, 64 | "attn_implementation": "flash_attention_2" if args.use_flash_attention_2 else None 65 | } 66 | 67 | model, tokenizer, config = load( 68 | args.model, 69 | tokenizer_config=tokenizer_config, 70 | device_map='auto', 71 | seqlen=args.seqlen, 72 | model_config=pretrain_model_config, 73 | requires_grad=False, 74 | ) 75 | 76 | config = LoraConfig( 77 | r=args.lora_rank, 78 | lora_alpha=args.lora_alpha, 79 | target_modules=["q_proj", "v_proj", "out_proj", "down_proj", "up_proj"], 80 | lora_dropout=args.lora_dropout, 81 | bias="none", 82 | task_type="CAUSAL_LM", 83 | ) 84 | 85 | replace_peft_lora_model_with_gba_lora_model() 86 | 87 | model = get_peft_model(model, config) 88 | 89 | param_groups = create_param_groups(model, args, DEFAULT_BETAS) 90 | 91 | model.train() 92 | 93 | dataset = load_dataset(args.dataset, split="train") 94 | 95 | args.save_dir = os.path.join(args.save_dir, "lora/", args.model) 96 | 97 | train_args = TrainingArguments( 98 | output_dir=args.save_dir, 99 | gradient_checkpointing=True, 100 | #auto_find_batch_size=True, 101 | per_device_train_batch_size=args.batch_size, 102 | logging_steps=1, 103 | num_train_epochs=1, 104 | gradient_accumulation_steps=1, 105 | save_steps=args.save_step, 106 | #warmup_ratio=0.05, 107 | max_grad_norm=0, # NOTE: max_grad_norm MUST be <= 0 or None, otherwise raise dtype error due to the Int dtype of qweight. 108 | ) 109 | 110 | # Optimizer 111 | if 'adamw8bit' in args.optimizer.lower(): 112 | optimizer = AdamW8bit(param_groups, weight_decay=args.weight_decay, lr=5e-3, dtype=str_to_torch_dtype(args.dtype)) 113 | elif 'diodemix' in args.optimizer.lower(): 114 | optimizer = DiodeMix(param_groups, dtype=str_to_torch_dtype(args.dtype)) 115 | optimizers = (optimizer, None) 116 | 117 | for name, param in model.named_parameters(): 118 | if "qweight" not in name: 119 | param.requires_grad = True 120 | 121 | trainer = GbaSFTTrainer( 122 | model=model, 123 | args=train_args, 124 | train_dataset=dataset, 125 | dataset_text_field="text", 126 | optimizers=optimizers, 127 | max_seq_length=args.seqlen, 128 | ) 129 | 130 | trainer.train() 131 | 132 | model.save_pretrained(args.save_dir) 133 | 134 | 135 | if __name__ == "__main__": 136 | if not torch.cuda.is_available(): 137 | print("Warning: CUDA is needed to run the model.") 138 | sys.exit(0) 139 | 140 | parser = setup_arg_parser() 141 | args = parser.parse_args() 142 | 143 | main(args) 144 | -------------------------------------------------------------------------------- /green_bit_llm/sft/peft_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/green_bit_llm/sft/peft_utils/__init__.py -------------------------------------------------------------------------------- /green_bit_llm/sft/peft_utils/gba_lora.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023-present the HuggingFace Inc. team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import Any, Optional 17 | 18 | ENGINE_AVAILABLE=True 19 | try: 20 | from bitorch_engine.layers.qlinear.nbit import MPQLinearBase 21 | from bitorch_engine.layers.qlinear.nbit.cuda import MPQLinearCuda, MBWQLinearCuda 22 | except ModuleNotFoundError as e: 23 | ENGINE_AVAILABLE = False 24 | raise Exception(f"Error occurred while importing Bitorch Engine module '{str(e)}'.") 25 | 26 | import torch 27 | import torch.nn as nn 28 | from transformers.pytorch_utils import Conv1D 29 | 30 | from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge 31 | from peft.tuners.lora import LoraLayer 32 | 33 | 34 | class GBALoraLayer(LoraLayer): 35 | """ 36 | GBALoraLayer class extends LoraLayer to support Gradient-Based Adapter tuning for various model layers. 37 | It maintains lists of both LoRA-specific parameters and other adapter-related parameters. 38 | """ 39 | 40 | # All names of layers that may contain (trainable) adapter weights 41 | adapter_layer_names = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B") 42 | # All names of other parameters that may contain adapter-related parameters 43 | other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout") 44 | 45 | def __init__(self, base_layer: nn.Module, **kwargs) -> None: 46 | """ 47 | Initializes a GBALoraLayer instance. 48 | Args: 49 | base_layer: The underlying neural network layer that LoRA is being applied to. 50 | **kwargs: Additional keyword arguments for customization. 51 | 52 | This method initializes adapter components, configures the underlying base layer, and sets the 53 | feature sizes based on the base layer type. 54 | """ 55 | self.base_layer = base_layer 56 | self.r = {} 57 | self.lora_alpha = {} 58 | self.scaling = {} 59 | self.lora_dropout = nn.ModuleDict({}) 60 | self.lora_A = nn.ModuleDict({}) 61 | self.lora_B = nn.ModuleDict({}) 62 | # For Embedding layer 63 | self.lora_embedding_A = nn.ParameterDict({}) 64 | self.lora_embedding_B = nn.ParameterDict({}) 65 | # Mark the weight as unmerged 66 | self._disable_adapters = False 67 | self.merged_adapters = [] 68 | self.use_dora: dict[str, bool] = {} 69 | self.lora_magnitude_vector: Optional[torch.nn.ParameterDict] = None # for DoRA 70 | self._caches: dict[str, Any] = {} 71 | self.kwargs = kwargs 72 | 73 | base_layer = self.get_base_layer() 74 | if isinstance(base_layer, nn.Linear): 75 | in_features, out_features = base_layer.in_features, base_layer.out_features 76 | elif isinstance(base_layer, nn.Conv2d): 77 | in_features, out_features = base_layer.in_channels, base_layer.out_channels 78 | elif isinstance(base_layer, nn.Embedding): 79 | in_features, out_features = base_layer.num_embeddings, base_layer.embedding_dim 80 | elif isinstance(base_layer, Conv1D): 81 | in_features, out_features = ( 82 | base_layer.weight.ds_shape if hasattr(base_layer.weight, "ds_shape") else base_layer.weight.shape 83 | ) 84 | elif hasattr(base_layer, "infeatures") and hasattr(base_layer, "outfeatures"): 85 | # QuantLinear 86 | in_features, out_features = base_layer.infeatures, base_layer.outfeatures 87 | elif hasattr(base_layer, "input_size") and hasattr(base_layer, "output_size"): 88 | # Megatron ColumnParallelLinear,RowParallelLinear 89 | in_features, out_features = base_layer.input_size, base_layer.output_size 90 | elif hasattr(base_layer, "codebooks") and base_layer.__class__.__name__ == "QuantizedLinear": 91 | # AQLM QuantLinear 92 | in_features, out_features = base_layer.in_features, base_layer.out_features 93 | elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": 94 | # Awq layers 95 | in_features, out_features = base_layer.in_features, base_layer.out_features 96 | elif base_layer.__class__.__name__ == "MBWQLinearCuda": 97 | in_features, out_features = base_layer.in_channels, base_layer.out_channels 98 | elif base_layer.__class__.__name__ == "MPQLinearCuda": 99 | in_features, out_features = base_layer.in_channels, base_layer.out_channels 100 | else: 101 | raise ValueError(f"Unsupported layer type {type(base_layer)}") 102 | 103 | self.in_features = in_features 104 | self.out_features = out_features 105 | 106 | 107 | class GBALoraLinear(torch.nn.Module, GBALoraLayer): 108 | """ 109 | Implements a LoRA (Low-Rank Adaptation) module integrated into a dense linear layer. 110 | This class extends functionality by allowing modifications to the layer through 111 | low-rank matrices to efficiently adapt large pre-trained models without extensive retraining. 112 | """ 113 | # Lora implemented in a dense layer 114 | def __init__( 115 | self, 116 | base_layer: torch.nn.Module, 117 | adapter_name: str, 118 | r: int = 0, 119 | lora_alpha: int = 1, 120 | lora_dropout: float = 0.0, 121 | init_lora_weights: bool = True, 122 | use_rslora: bool = False, 123 | use_dora: bool = False, 124 | **kwargs, 125 | ) -> None: 126 | """ 127 | Initializes the LoRA adapted layer with specific parameters and configurations. 128 | 129 | Parameters: 130 | base_layer (torch.nn.Module): The original base layer to which LoRA adjustments are applied. 131 | adapter_name (str): The name of the adapter for identification. 132 | r (int): The rank of the low-rank approximation matrices. 133 | lora_alpha (int): Scaling factor for the LoRA parameters. 134 | lora_dropout (float): Dropout rate for regularization during training. 135 | init_lora_weights (bool): Whether to initialize LoRA weights upon creation. 136 | use_rslora (bool): Indicates whether to use rank-stabilized LoRA. 137 | use_dora (bool): Indicates whether to use dynamic orthogonal regularization adapter. 138 | """ 139 | super().__init__() 140 | GBALoraLayer.__init__(self, base_layer) 141 | self.fan_in_fan_out = False 142 | 143 | self._active_adapter = adapter_name 144 | self.update_layer( 145 | adapter_name, 146 | r, 147 | lora_alpha=lora_alpha, 148 | lora_dropout=lora_dropout, 149 | init_lora_weights=init_lora_weights, 150 | use_rslora=use_rslora, 151 | use_dora=use_dora, 152 | ) 153 | 154 | def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: 155 | """ 156 | Defines the computation performed at every call. Applies the base layer computation 157 | and modifies the output using the configured LoRA parameters. 158 | 159 | Parameters: 160 | x (torch.Tensor): The input tensor to process. 161 | 162 | Returns: 163 | torch.Tensor: The output tensor after applying the LoRA adaptation. 164 | """ 165 | self._check_forward_args(x, *args, **kwargs) 166 | adapter_names = kwargs.pop("adapter_names", None) 167 | 168 | result = self.base_layer(x, *args, **kwargs) 169 | # As per Tim Dettmers, for 4bit, we need to defensively clone here. 170 | # The reason is that in some cases, an error can occur that backprop 171 | # does not work on a manipulated view. This issue may be solved with 172 | # newer PyTorch versions but this would need extensive testing to be 173 | # sure. 174 | 175 | for active_adapter in self.active_adapters: 176 | if active_adapter not in self.lora_A.keys(): 177 | continue 178 | lora_A = self.lora_A[active_adapter] 179 | lora_B = self.lora_B[active_adapter] 180 | dropout = self.lora_dropout[active_adapter] 181 | scaling = self.scaling[active_adapter] 182 | 183 | requires_conversion = not torch.is_autocast_enabled() 184 | if requires_conversion: 185 | expected_dtype = result.dtype 186 | x = x.to(lora_A.weight.dtype) 187 | 188 | if not self.use_dora[active_adapter]: 189 | output = lora_B(lora_A(dropout(x))) * scaling 190 | else: 191 | output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) 192 | if requires_conversion: 193 | output = output.to(expected_dtype) 194 | 195 | result = result + output 196 | 197 | return result 198 | 199 | def __repr__(self) -> str: 200 | """ 201 | Provides a string representation of the module, enhancing the default 202 | representation with a prefix to identify it as a LoRA-adapted layer. 203 | """ 204 | rep = super().__repr__() 205 | return "lora." + rep 206 | 207 | 208 | def dispatch_gba(target: torch.nn.Module, adapter_name: str, **kwargs): 209 | new_module = None 210 | 211 | if isinstance(target, BaseTunerLayer): 212 | target_base_layer = target.get_base_layer() 213 | else: 214 | target_base_layer = target 215 | 216 | if ENGINE_AVAILABLE and issubclass(type(target_base_layer), MPQLinearBase): 217 | new_module = GBALoraLinear(target_base_layer, adapter_name, **kwargs) 218 | 219 | return new_module 220 | -------------------------------------------------------------------------------- /green_bit_llm/sft/peft_utils/model.py: -------------------------------------------------------------------------------- 1 | import peft.peft_model 2 | from peft.tuners import lora 3 | from peft.utils import _get_submodules, PeftType 4 | 5 | from green_bit_llm.sft.peft_utils.gba_lora import dispatch_gba 6 | 7 | 8 | class GBALoraModel(lora.LoraModel): 9 | """ 10 | A specialized version of LoraModel for low-rank adaptation. This class overrides the method to create new modules specifically tailored 11 | to GBA needs, by selecting appropriate backend functions to handle LoRA layers. 12 | """ 13 | @staticmethod 14 | def _create_new_module(lora_config, adapter_name, target, **kwargs): 15 | """ 16 | Creates a new module based on the provided configuration for LoRA and the type of target module. 17 | This method selects the correct dispatch function for integrating LoRA into the specified model layer. 18 | If no suitable module can be found, it raises an error. 19 | 20 | Args: 21 | lora_config: Configuration parameters for the LoRA adaptation. 22 | adapter_name: Identifier for the LoRA adapter. 23 | target: The target neural network layer to which LoRA should be applied. 24 | **kwargs: Additional keyword arguments. 25 | 26 | Returns: 27 | new_module: A new module with LoRA adaptation applied, or raises an error if the target is unsupported. 28 | 29 | Raises: 30 | ValueError: If the target module type is not supported by the currently available dispatchers. 31 | """ 32 | # Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters, 33 | # because the first match is always used. Therefore, the default layers should be checked last. 34 | dispatchers = [] 35 | 36 | dispatchers.append(dispatch_gba) 37 | 38 | dispatchers.extend( 39 | [dispatch_gba] 40 | ) 41 | 42 | new_module = None 43 | for dispatcher in dispatchers: 44 | new_module = dispatcher(target, adapter_name, lora_config=lora_config, **kwargs) 45 | if new_module is not None: # first match wins 46 | break 47 | 48 | if new_module is None: 49 | # no module could be matched 50 | raise ValueError( 51 | f"Target module {target} is not supported. Currently, only the following modules are supported: " 52 | "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`." 53 | ) 54 | 55 | return new_module 56 | 57 | def replace_peft_lora_model_with_gba_lora_model(): 58 | """ 59 | Replaces the existing LoRA model in the PEFT framework with the GBA-enhanced LoRA model. 60 | This function patches the model mapping in PEFT to use `GBALoraModel` for LoRA configurations. 61 | """ 62 | peft.peft_model.PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GBALoraModel 63 | -------------------------------------------------------------------------------- /green_bit_llm/sft/trainer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | import os 4 | from typing import Optional 5 | 6 | from trl import SFTTrainer 7 | 8 | from green_bit_llm.common.utils import STRATEGY_FILE_NAME 9 | from green_bit_llm.common.utils import get_model_path 10 | 11 | 12 | class GbaSFTTrainer(SFTTrainer): 13 | def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): 14 | """ 15 | Saves the model to the specified directory and also ensures that the 16 | 'quant_strategy.json' file is copied over to the same directory. 17 | 18 | Args: 19 | output_dir (Optional[str]): The directory to which the model and the 20 | 'quant_strategy.json' file should be saved. 21 | If None, the model will be saved to the default location. 22 | _internal_call (bool): A flag used to indicate whether this method was 23 | called internally by the library, which can affect 24 | certain behaviors (not used in this override). 25 | 26 | Raises: 27 | ValueError: If the expected GBA prefix is not found in the output directory path. 28 | """ 29 | 30 | # Perform the original save model behavior of the superclass 31 | # out_dir should be os.path.join(args.save_dir, args.model) 32 | super().save_model(output_dir) 33 | 34 | # Define the prefix to look for in the output directory path 35 | gba_prefix = "GreenBitAI" + os.path.sep 36 | # Find the prefix in the output directory path 37 | start_index = output_dir.find(gba_prefix) 38 | 39 | if start_index == -1: 40 | config_path = os.path.join(output_dir, "config.json") 41 | if os.path.isfile(config_path): 42 | with open(config_path, 'r') as file: 43 | data = json.load(file) 44 | if "quantization_config" in data.keys(): 45 | quantization_config = data["quantization_config"] 46 | if "exllama_config" in quantization_config.keys(): 47 | del quantization_config["exllama_config"] 48 | if "use_exllama" in quantization_config.keys(): 49 | del quantization_config["use_exllama"] 50 | 51 | with open(config_path, 'w') as file: 52 | json.dump(data, file, indent=4) 53 | return 54 | 55 | # Ensure this is executed only on the main process 56 | if not self.is_world_process_zero(): 57 | return 58 | 59 | # save "quant_strategy.json" file 60 | start_pos = start_index + len(gba_prefix) - 1 61 | end_pos = output_dir.find(os.path.sep, start_pos + 1) 62 | 63 | if end_pos == -1: 64 | model_name = output_dir[start_index:] 65 | else: 66 | model_name = output_dir[start_index:end_pos] 67 | 68 | model_from_path = get_model_path(model_name) 69 | quant_strategy_file = os.path.join(model_from_path, STRATEGY_FILE_NAME) 70 | destination_path = os.path.join(output_dir, STRATEGY_FILE_NAME) 71 | shutil.copy(quant_strategy_file, destination_path) 72 | -------------------------------------------------------------------------------- /green_bit_llm/sft/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from colorama import init, Fore, Style 4 | init(autoreset=True) 5 | 6 | try: 7 | from bitorch_engine.layers.qlinear.nbit import MPQLinearBase 8 | except ModuleNotFoundError as e: 9 | raise Exception(f"Error occurred while importing Bitorch Engine module '{str(e)}'.") 10 | 11 | 12 | def str_to_torch_dtype(dtype: str): 13 | """Get torch dtype from the input data type string.""" 14 | if dtype is None: 15 | return None 16 | elif dtype == "float": 17 | return torch.float 18 | elif dtype == "half": 19 | return torch.float16 20 | else: 21 | raise ValueError(f"Unsupported dtype: {dtype}") 22 | 23 | 24 | def create_device_map(cuda_device_id): 25 | ids = cuda_device_id.split(',') 26 | # Create strings in the format "cuda:x" for each ID and put them into the collection 27 | device_map = {f"cuda:{id}" for id in ids} 28 | return device_map 29 | 30 | 31 | def get_learning_rate(lr_bit, galore, default_lr_galore, default_lr): 32 | """Adaptivly get the learning rate value from the input setting parameters.""" 33 | if lr_bit > 0: 34 | return lr_bit 35 | return default_lr_galore if galore else default_lr 36 | 37 | 38 | def create_param_groups(model, args, betas=(0.9, 0.999), lr_galore=1e-4, lr_adamw8b=5e-3, lr_default=1e-5): 39 | """ 40 | Create parameter groups based on the bit-width of quantized weights in the model. 41 | This function categorizes parameters into groups with different learning rates and beta values 42 | for optimizers. 43 | 44 | This function also prints out the number of trainable params and all params. 45 | 46 | Args: 47 | model (nn.Module): The neural network model. 48 | args (argparse.ArgumentParser): Command line arguments for additional parameters. 49 | 50 | Returns: 51 | List[dict]: A list of dictionaries where each dictionary contains a parameter group. 52 | """ 53 | params_2_bit = [] 54 | params_4_bit = [] 55 | 56 | regular_trainable_numel = [] 57 | qweight_trainable_numel = [] 58 | total_numel = [] 59 | trainable_numel = [] 60 | 61 | for module_name, module in model.named_modules(): 62 | if issubclass(type(module), MPQLinearBase): 63 | if module.w_bit == 2: 64 | params_2_bit.append(module.qweight) 65 | qweight_trainable_numel.append(int(module.qweight.numel() * 32 / 2)) 66 | elif module.w_bit == 4: 67 | params_4_bit.append(module.qweight) 68 | qweight_trainable_numel.append(int(module.qweight.numel() * 32 / 4)) 69 | else: 70 | raise Exception(f"Error: Invalid qweight bit width: '{module.w_bit}'.") 71 | 72 | total_parameters = list(model.parameters()) 73 | for param in total_parameters: 74 | if not hasattr(param, "qweight"): 75 | total_numel.append(param.numel()) 76 | total_numel += qweight_trainable_numel 77 | 78 | if hasattr(args, 'lora_rank'): # peft 79 | param_groups = [] 80 | 81 | # Create list of peft parameters 82 | params_lora = [p for n, p in model.named_parameters() if "lora" in n] 83 | 84 | for param in params_lora: 85 | if param.requires_grad: 86 | trainable_numel.append(param.numel()) 87 | 88 | params_group_lora = {'params': params_lora, 'lr': args.lr_fp, 'betas': betas} 89 | 90 | param_groups.append(params_group_lora) 91 | 92 | elif hasattr(args, 'tune_qweight_only'): # full parameter finetune 93 | 94 | id_2bit_params = [id(p) for p in params_2_bit] 95 | id_4bit_params = [id(p) for p in params_4_bit] 96 | # Concatenate IDs to form a single list 97 | excluded_ids = id_2bit_params + id_4bit_params 98 | 99 | # Create list of regular parameters excluding 2-bit and 4-bit params 100 | params_regular = [p for p in model.parameters() if id(p) not in excluded_ids] 101 | for param in params_regular: 102 | if param.requires_grad: 103 | regular_trainable_numel.append(param.numel()) 104 | 105 | lr_2 = get_learning_rate( 106 | args.lr_2bit, 107 | args.galore, 108 | lr_adamw8b if 'adamw8bit' in args.optimizer.lower() else lr_galore, 109 | 1e-3 if 'adamw8bit' in args.optimizer.lower() else lr_default 110 | ) 111 | 112 | lr_4 = get_learning_rate( 113 | args.lr_4bit, 114 | args.galore, 115 | lr_adamw8b if 'adamw8bit' in args.optimizer.lower() else lr_galore, 116 | 1e-3 if 'adamw8bit' in args.optimizer.lower() else lr_default 117 | ) 118 | 119 | params_group_2bit = {'params': params_2_bit, 'lr': lr_2, 'betas': betas} 120 | params_group_4bit = {'params': params_4_bit, 'lr': lr_4, 'betas': betas} 121 | params_group_regular = {'params': params_regular, 'lr': args.lr_fp, 'betas': betas} 122 | 123 | # Optionally add extra settings from command line arguments 124 | if args.galore: 125 | galore_settings = { 126 | 'rank': args.galore_rank, 127 | 'update_proj_gap': args.galore_update_proj_gap, 128 | 'scale': args.galore_scale, 129 | 'proj_type': args.galore_proj_type 130 | } 131 | params_group_2bit.update(galore_settings) 132 | params_group_4bit.update(galore_settings) 133 | 134 | param_groups = [ 135 | params_group_2bit, 136 | params_group_4bit 137 | ] 138 | 139 | trainable_numel = qweight_trainable_numel 140 | if not args.tune_qweight_only: 141 | param_groups.append(params_group_regular) 142 | trainable_numel += regular_trainable_numel 143 | else: 144 | raise Exception("Error: invalid use case in creating param_group.") 145 | 146 | # print out trainable params info 147 | print(Style.BRIGHT + Fore.CYAN + 148 | f"Info: trainable params: {sum(trainable_numel):,d} || " 149 | f"all params: {sum(total_numel):,d} || " 150 | f"trainable%: {100 * sum(trainable_numel) / sum(total_numel):.4f}" 151 | ) 152 | 153 | return param_groups 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /green_bit_llm/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.6" 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate>=0.27.2 2 | colorama 3 | datasets 4 | torch>=2.0.0 5 | sentencepiece 6 | transformers>=4.52.4 7 | huggingface-hub 8 | lm-eval==0.3.0 9 | termcolor 10 | pillow 11 | requests 12 | prompt-toolkit 13 | rich 14 | optimum 15 | auto-gptq 16 | langchain-core 17 | fastapi 18 | uvicorn 19 | peewee 20 | python-dotenv -------------------------------------------------------------------------------- /scripts/curl_script: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set variables for testing 4 | API_KEY="" # Replace with your actual API key 5 | HOST="172.20.8.79:8000" 6 | 7 | # Array of models to test 8 | MODELS=( 9 | "GreenBitAI-Llama-3-8B-instruct-layer-mix-bpw-40" 10 | "GreenBitAI-Qwen-25-7B-Instruct-layer-mix-bpw-40" 11 | ) 12 | 13 | # Function to run tests for a specific model 14 | run_tests() { 15 | local MODEL_NAME=$1 16 | echo "Running tests for model: ${MODEL_NAME}" 17 | echo "==================================" 18 | 19 | # 1. Health Check 20 | echo "Testing Health Check Endpoint..." 21 | curl -X GET "http://${HOST}/health" 22 | 23 | # 2. Root Endpoint 24 | echo -e "\n\nTesting Root Endpoint..." 25 | curl -X GET "http://${HOST}/" 26 | 27 | # 3. Text Completion Endpoints 28 | echo -e "\n\nTesting Basic Completion..." 29 | curl -X POST "http://${HOST}/v1/${MODEL_NAME}/completions" \ 30 | -H "Content-Type: application/json" \ 31 | -H "X-Api-Key: ${API_KEY}" \ 32 | -d '{ 33 | "model": "'${MODEL_NAME}'", 34 | "prompt": "Write a story about a robot", 35 | "max_tokens": 100, 36 | "temperature": 0.7 37 | }' 38 | 39 | echo -e "\n\nTesting Streaming Completion..." 40 | curl -X POST "http://${HOST}/v1/${MODEL_NAME}/completions" \ 41 | -H "Content-Type: application/json" \ 42 | -H "X-Api-Key: ${API_KEY}" \ 43 | -d '{ 44 | "model": "'${MODEL_NAME}'", 45 | "prompt": "Write a story about a robot", 46 | "max_tokens": 200, 47 | "temperature": 0.7, 48 | "stream": true 49 | }' 50 | 51 | echo -e "\n\nTesting Batch Completion..." 52 | curl -X POST "http://${HOST}/v1/${MODEL_NAME}/completions" \ 53 | -H "Content-Type: application/json" \ 54 | -H "X-Api-Key: ${API_KEY}" \ 55 | -d '{ 56 | "model": "'${MODEL_NAME}'", 57 | "prompt": ["Tell me a joke", "Write a poem"], 58 | "max_tokens": 100, 59 | "temperature": 0.7 60 | }' 61 | 62 | # 4. Chat Completion Endpoints 63 | echo -e "\n\nTesting Basic Chat Completion..." 64 | curl -X POST "http://${HOST}/v1/${MODEL_NAME}/chat/completions" \ 65 | -H "Content-Type: application/json" \ 66 | -H "X-Api-Key: ${API_KEY}" \ 67 | -d '{ 68 | "model": "'${MODEL_NAME}'", 69 | "messages": [ 70 | {"role": "system", "content": "You are a helpful assistant."}, 71 | {"role": "user", "content": "What is the capital of France?"} 72 | ], 73 | "max_tokens": 100, 74 | "temperature": 0.7 75 | }' 76 | 77 | echo -e "\n\nTesting Streaming Chat Completion..." 78 | curl -X POST "http://${HOST}/v1/${MODEL_NAME}/chat/completions" \ 79 | -H "Content-Type: application/json" \ 80 | -H "X-Api-Key: ${API_KEY}" \ 81 | -d '{ 82 | "model": "'${MODEL_NAME}'", 83 | "messages": [ 84 | {"role": "user", "content": "Write a story about a cat"} 85 | ], 86 | "max_tokens": 100, 87 | "temperature": 0.7, 88 | "stream": true 89 | }' 90 | 91 | echo -e "\n\nTesting Remote Confidence Scores..." 92 | curl -X POST "http://${HOST}/v1/${MODEL_NAME}/chat/completions" \ 93 | -H "Content-Type: application/json" \ 94 | -H "X-Api-Key: ${API_KEY}" \ 95 | -d '{ 96 | "model": "'${MODEL_NAME}'", 97 | "messages": [ 98 | {"role": "system", "content": "You are a helpful assistant."}, 99 | {"role": "user", "content": "hi"} 100 | ], 101 | "max_tokens": 1, 102 | "temperature": 0.7, 103 | "top_p": 1.0, 104 | "with_hidden_states": true, 105 | "remote_score": true 106 | }' 107 | 108 | echo -e "\n\nTesting Chat Completion with History..." 109 | curl -X POST "http://${HOST}/v1/${MODEL_NAME}/chat/completions" \ 110 | -H "Content-Type: application/json" \ 111 | -H "X-Api-Key: ${API_KEY}" \ 112 | -d '{ 113 | "model": "'${MODEL_NAME}'", 114 | "messages": [ 115 | {"role": "system", "content": "You are a friendly and knowledgeable AI assistant."}, 116 | {"role": "user", "content": "Tell me about Paris"}, 117 | {"role": "assistant", "content": "Paris is the capital of France."}, 118 | {"role": "user", "content": "What are some famous landmarks there?"} 119 | ], 120 | "max_tokens": 150, 121 | "temperature": 0.7, 122 | "top_p": 0.9 123 | }' 124 | 125 | echo -e "\n\nCompleted tests for ${MODEL_NAME}" 126 | echo "==================================\n\n" 127 | } 128 | 129 | # Run tests for each model 130 | for MODEL in "${MODELS[@]}"; do 131 | run_tests "$MODEL" 132 | done -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | from setuptools import setup, find_packages 5 | 6 | package_dir = Path(__file__).parent / "green_bit_llm" 7 | with open(Path(__file__).parent / "requirements.txt") as fid: 8 | requirements = [l.strip() for l in fid.readlines()] 9 | 10 | sys.path.append(str(package_dir)) 11 | from version import __version__ 12 | 13 | setup( 14 | name="green-bit-llm", 15 | version=__version__, 16 | description="A toolkit for fine-tuning, inferencing, and evaluating GreenBitAI's LLMs.", 17 | long_description=open("README.md", encoding="utf-8").read(), 18 | long_description_content_type="text/markdown", 19 | author_email="team@greenbit.ai", 20 | author="GreenBitAI Contributors", 21 | url="https://github.com/GreenBitAI/green-bit-llm", 22 | license="Apache-2.0", 23 | install_requires=requirements, 24 | packages=find_packages(where=".", exclude=["tests", "tests.*", "examples", "examples.*"]), 25 | python_requires=">=3.9", 26 | data_files=[('.', ['requirements.txt'])], 27 | ) 28 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GreenBitAI/green-bit-llm/71a575da60524c9f7b1b04f63ad397a7c7650309/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_langchain_chatmodel.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import MagicMock, patch 3 | from langchain_core.messages import HumanMessage, AIMessage, SystemMessage 4 | from langchain_core.outputs import LLMResult, Generation, ChatResult, ChatGeneration 5 | from green_bit_llm.langchain import GreenBitPipeline, ChatGreenBit 6 | 7 | 8 | class TestChatGreenBit(unittest.TestCase): 9 | 10 | def setUp(self): 11 | self.mock_pipeline = MagicMock(spec=GreenBitPipeline) 12 | self.mock_pipeline.pipeline = MagicMock() 13 | self.mock_pipeline.pipeline.tokenizer = MagicMock() 14 | self.mock_pipeline.pipeline.tokenizer.apply_chat_template = MagicMock(return_value="Mocked chat template") 15 | self.chat_model = ChatGreenBit(llm=self.mock_pipeline) 16 | 17 | def test_llm_type(self): 18 | self.assertEqual(self.chat_model._llm_type, "greenbit-chat") 19 | 20 | def test_messages_to_dict(self): 21 | """Test message conversion to dictionary format""" 22 | system_message = SystemMessage(content="You are an AI assistant.") 23 | human_message = HumanMessage(content="Hello, AI!") 24 | ai_message = AIMessage(content="Hello, human!") 25 | 26 | messages = [system_message, human_message, ai_message] 27 | result = self.chat_model._messages_to_dict(messages) 28 | 29 | expected = [ 30 | {"role": "system", "content": "You are an AI assistant."}, 31 | {"role": "user", "content": "Hello, AI!"}, 32 | {"role": "assistant", "content": "Hello, human!"} 33 | ] 34 | self.assertEqual(result, expected) 35 | 36 | def test_prepare_prompt(self): 37 | """Test prompt preparation using apply_chat_template""" 38 | messages = [ 39 | SystemMessage(content="You are an AI assistant."), 40 | HumanMessage(content="Hello, AI!"), 41 | AIMessage(content="Hello, human!"), 42 | HumanMessage(content="How are you?") 43 | ] 44 | 45 | result = self.chat_model._prepare_prompt(messages) 46 | 47 | self.assertEqual(result, "Mocked chat template") 48 | self.mock_pipeline.pipeline.tokenizer.apply_chat_template.assert_called_once() 49 | 50 | # Check that the call was made with correct message format 51 | call_args = self.mock_pipeline.pipeline.tokenizer.apply_chat_template.call_args 52 | messages_arg = call_args[0][0] # First positional argument 53 | expected_messages = [ 54 | {"role": "system", "content": "You are an AI assistant."}, 55 | {"role": "user", "content": "Hello, AI!"}, 56 | {"role": "assistant", "content": "Hello, human!"}, 57 | {"role": "user", "content": "How are you?"} 58 | ] 59 | self.assertEqual(messages_arg, expected_messages) 60 | 61 | def test_prepare_prompt_with_enable_thinking(self): 62 | """Test prompt preparation with enable_thinking parameter""" 63 | messages = [HumanMessage(content="Hello, AI!")] 64 | 65 | result = self.chat_model._prepare_prompt(messages, enable_thinking=True) 66 | 67 | self.assertEqual(result, "Mocked chat template") 68 | call_args = self.mock_pipeline.pipeline.tokenizer.apply_chat_template.call_args 69 | kwargs = call_args[1] # Keyword arguments 70 | self.assertTrue(kwargs.get("enable_thinking")) 71 | self.assertTrue(kwargs.get("add_generation_prompt")) 72 | 73 | def test_prepare_prompt_no_tokenizer(self): 74 | """Test error handling when tokenizer is not available""" 75 | self.mock_pipeline.pipeline.tokenizer = None 76 | messages = [HumanMessage(content="Hello, AI!")] 77 | 78 | with self.assertRaises(ValueError) as context: 79 | self.chat_model._prepare_prompt(messages) 80 | self.assertIn("Tokenizer not available", str(context.exception)) 81 | 82 | def test_prepare_prompt_no_chat_template(self): 83 | """Test error handling when apply_chat_template is not available""" 84 | del self.mock_pipeline.pipeline.tokenizer.apply_chat_template 85 | messages = [HumanMessage(content="Hello, AI!")] 86 | 87 | with self.assertRaises(ValueError) as context: 88 | self.chat_model._prepare_prompt(messages) 89 | self.assertIn("does not support apply_chat_template", str(context.exception)) 90 | 91 | def test_create_chat_result(self): 92 | """Test conversion from LLM result to chat result""" 93 | llm_result = LLMResult(generations=[[Generation(text="Hello, human!")]]) 94 | chat_result = self.chat_model._create_chat_result(llm_result) 95 | 96 | self.assertEqual(len(chat_result.generations), 1) 97 | self.assertIsInstance(chat_result.generations[0], ChatGeneration) 98 | self.assertEqual(chat_result.generations[0].message.content, "Hello, human!") 99 | 100 | @patch.object(ChatGreenBit, '_prepare_prompt') 101 | def test_generate(self, mock_prepare_prompt): 102 | """Test generation with mocked prompt preparation""" 103 | mock_prepare_prompt.return_value = "Mocked chat prompt" 104 | self.mock_pipeline.generate.return_value = LLMResult(generations=[[Generation(text="Generated response")]]) 105 | 106 | messages = [HumanMessage(content="Hello, AI!")] 107 | result = self.chat_model.generate(messages, temperature=0.8, max_tokens=100) 108 | 109 | # Check that prompt was prepared correctly 110 | mock_prepare_prompt.assert_called_once_with(messages, temperature=0.8, max_tokens=100) 111 | 112 | # Check that pipeline.generate was called with correct arguments 113 | self.mock_pipeline.generate.assert_called_once() 114 | call_args = self.mock_pipeline.generate.call_args 115 | 116 | # Check prompts argument 117 | self.assertEqual(call_args[1]["prompts"], ["Mocked chat prompt"]) 118 | 119 | # Check that result is ChatResult 120 | self.assertIsInstance(result, ChatResult) 121 | self.assertEqual(result.generations[0].message.content, "Generated response") 122 | 123 | @patch.object(ChatGreenBit, '_prepare_prompt') 124 | def test_stream(self, mock_prepare_prompt): 125 | """Test streaming with mocked prompt preparation""" 126 | mock_prepare_prompt.return_value = "Mocked chat prompt" 127 | mock_chunk1 = MagicMock() 128 | mock_chunk1.text = "Hello" 129 | mock_chunk2 = MagicMock() 130 | mock_chunk2.text = " human!" 131 | self.mock_pipeline.stream.return_value = [mock_chunk1, mock_chunk2] 132 | 133 | messages = [HumanMessage(content="Hello, AI!")] 134 | stream_result = list(self.chat_model.stream(messages, temperature=0.7)) 135 | 136 | # Check that prompt was prepared correctly 137 | mock_prepare_prompt.assert_called_once_with(messages, temperature=0.7) 138 | 139 | # Check that pipeline.stream was called 140 | self.mock_pipeline.stream.assert_called_once() 141 | 142 | # Check stream results 143 | self.assertEqual(len(stream_result), 2) 144 | self.assertIsInstance(stream_result[0], ChatGeneration) 145 | self.assertEqual(stream_result[0].message.content, "Hello") 146 | self.assertEqual(stream_result[1].message.content, " human!") 147 | 148 | def test_generate_with_enable_thinking(self): 149 | """Test generation with enable_thinking parameter""" 150 | with patch.object(self.chat_model, '_prepare_prompt') as mock_prepare_prompt: 151 | mock_prepare_prompt.return_value = "Mocked chat prompt" 152 | self.mock_pipeline.generate.return_value = LLMResult(generations=[[Generation(text="Generated response")]]) 153 | 154 | messages = [HumanMessage(content="Hello, AI!")] 155 | result = self.chat_model.generate(messages, enable_thinking=True) 156 | 157 | # Check that enable_thinking was passed to _prepare_prompt 158 | mock_prepare_prompt.assert_called_once_with(messages, enable_thinking=True) 159 | 160 | 161 | if __name__ == '__main__': 162 | unittest.main() -------------------------------------------------------------------------------- /tests/test_langchain_embedding.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from typing import List 3 | from green_bit_llm.langchain import GreenBitEmbeddings 4 | 5 | class TestGreenBitEmbeddings(unittest.TestCase): 6 | def setUp(self): 7 | model_kwargs = {'trust_remote_code': True} 8 | encode_kwargs = {'normalize_embeddings': False} 9 | self.embeddings = GreenBitEmbeddings.from_model_id( 10 | model_name="sentence-transformers/all-MiniLM-L6-v2", 11 | cache_dir="cache", 12 | device="cpu", 13 | multi_process=False, 14 | model_kwargs=model_kwargs, 15 | encode_kwargs=encode_kwargs 16 | ) 17 | 18 | def test_embed_documents_returns_list(self): 19 | texts = ["Hello, world!", "This is a test."] 20 | result = self.embeddings.embed_documents(texts) 21 | self.assertIsInstance(result, list) 22 | 23 | def test_embed_documents_returns_correct_number_of_embeddings(self): 24 | texts = ["Hello, world!", "This is a test."] 25 | result = self.embeddings.embed_documents(texts) 26 | self.assertEqual(len(result), len(texts)) 27 | 28 | def test_embed_query_returns_list(self): 29 | query = "What is the meaning of life?" 30 | result = self.embeddings.embed_query(query) 31 | self.assertIsInstance(result, list) 32 | 33 | def test_embed_query_returns_non_empty_list(self): 34 | query = "What is the meaning of life?" 35 | result = self.embeddings.embed_query(query) 36 | self.assertTrue(len(result) > 0) 37 | 38 | if __name__ == '__main__': 39 | unittest.main() -------------------------------------------------------------------------------- /tests/test_langchain_pipeline.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | import torch 4 | from green_bit_llm.langchain import GreenBitPipeline 5 | from langchain_core.outputs import LLMResult, Generation, GenerationChunk 6 | 7 | 8 | class TestGreenBitPipeline(unittest.TestCase): 9 | 10 | @patch('green_bit_llm.langchain.pipeline.check_engine_available') 11 | @patch('green_bit_llm.langchain.pipeline.load') 12 | @patch('green_bit_llm.langchain.pipeline.pipeline') 13 | def test_from_model_id(self, mock_pipeline, mock_load, mock_check_engine): 14 | """Test pipeline creation from model ID""" 15 | # Setup 16 | mock_check_engine.return_value = True 17 | mock_model = MagicMock() 18 | mock_tokenizer = MagicMock() 19 | mock_tokenizer.pad_token = None 20 | mock_tokenizer.eos_token = "<|endoftext|>" 21 | mock_tokenizer.pad_token_id = None 22 | mock_tokenizer.eos_token_id = 50256 23 | mock_load.return_value = (mock_model, mock_tokenizer, None) 24 | mock_pipe = MagicMock() 25 | mock_pipeline.return_value = mock_pipe 26 | 27 | # Test 28 | gb_pipeline = GreenBitPipeline.from_model_id( 29 | model_id="test_model", 30 | task="text-generation", 31 | model_kwargs={"dtype": torch.float16}, 32 | pipeline_kwargs={"max_length": 100} 33 | ) 34 | 35 | # Assert 36 | self.assertIsInstance(gb_pipeline, GreenBitPipeline) 37 | self.assertEqual(gb_pipeline.model_id, "test_model") 38 | self.assertEqual(gb_pipeline.task, "text-generation") 39 | self.assertEqual(gb_pipeline.model_kwargs, {"dtype": torch.float16}) 40 | self.assertEqual(gb_pipeline.pipeline_kwargs, {"max_length": 100}) 41 | mock_check_engine.assert_called_once() 42 | mock_pipeline.assert_called_once() 43 | 44 | # Check that tokenizer pad_token was set 45 | self.assertEqual(mock_tokenizer.pad_token, "<|endoftext|>") 46 | self.assertEqual(mock_tokenizer.pad_token_id, 50256) 47 | 48 | def test_identifying_params(self): 49 | """Test identifying parameters property""" 50 | gb_pipeline = GreenBitPipeline( 51 | pipeline=MagicMock(), 52 | model_id="test_model", 53 | task="text-generation", 54 | model_kwargs={"dtype": torch.float16}, 55 | pipeline_kwargs={"max_length": 100} 56 | ) 57 | params = gb_pipeline._identifying_params 58 | self.assertEqual(params["model_id"], "test_model") 59 | self.assertEqual(params["task"], "text-generation") 60 | self.assertEqual(params["model_kwargs"], {"dtype": torch.float16}) 61 | self.assertEqual(params["pipeline_kwargs"], {"max_length": 100}) 62 | 63 | def test_llm_type(self): 64 | """Test LLM type property""" 65 | gb_pipeline = GreenBitPipeline(pipeline=MagicMock(), model_kwargs={}, pipeline_kwargs={}) 66 | self.assertEqual(gb_pipeline._llm_type, "greenbit_pipeline") 67 | 68 | def test_prepare_generation_config(self): 69 | """Test generation config preparation""" 70 | mock_pipeline = MagicMock() 71 | mock_pipeline.tokenizer.pad_token_id = 0 72 | mock_pipeline.tokenizer.eos_token_id = 1 73 | 74 | gb_pipeline = GreenBitPipeline( 75 | pipeline=mock_pipeline, 76 | model_kwargs={}, 77 | pipeline_kwargs={"temperature": 0.5, "max_new_tokens": 50} 78 | ) 79 | 80 | config = gb_pipeline._prepare_generation_config({"temperature": 0.8}) 81 | 82 | self.assertEqual(config["temperature"], 0.8) # Should override pipeline_kwargs 83 | self.assertEqual(config["max_new_tokens"], 50) # Should use pipeline default 84 | self.assertEqual(config["pad_token_id"], 0) 85 | self.assertEqual(config["eos_token_id"], 1) 86 | self.assertTrue(config["do_sample"]) # Changed default to True 87 | 88 | def test_prepare_prompt_from_text_with_chat_template(self): 89 | """Test prompt preparation from plain text using chat template""" 90 | mock_tokenizer = MagicMock() 91 | mock_tokenizer.apply_chat_template = MagicMock( 92 | return_value="<|im_start|>user\nHello<|im_end|><|im_start|>assistant\n") 93 | 94 | mock_pipeline = MagicMock() 95 | mock_pipeline.tokenizer = mock_tokenizer 96 | 97 | gb_pipeline = GreenBitPipeline(pipeline=mock_pipeline, model_kwargs={}, pipeline_kwargs={}) 98 | 99 | result = gb_pipeline._prepare_prompt_from_text("Hello", enable_thinking=True) 100 | 101 | self.assertEqual(result, "<|im_start|>user\nHello<|im_end|><|im_start|>assistant\n") 102 | mock_tokenizer.apply_chat_template.assert_called_once() 103 | 104 | # Check the call arguments 105 | call_args = mock_tokenizer.apply_chat_template.call_args 106 | messages = call_args[0][0] 107 | kwargs = call_args[1] 108 | 109 | self.assertEqual(messages, [{"role": "user", "content": "Hello"}]) 110 | self.assertTrue(kwargs["add_generation_prompt"]) 111 | self.assertTrue(kwargs["enable_thinking"]) 112 | 113 | def test_prepare_prompt_from_text_no_chat_template(self): 114 | """Test prompt preparation fallback when no chat template available""" 115 | mock_tokenizer = MagicMock() 116 | # Remove apply_chat_template method 117 | del mock_tokenizer.apply_chat_template 118 | 119 | mock_pipeline = MagicMock() 120 | mock_pipeline.tokenizer = mock_tokenizer 121 | 122 | gb_pipeline = GreenBitPipeline(pipeline=mock_pipeline, model_kwargs={}, pipeline_kwargs={}) 123 | 124 | result = gb_pipeline._prepare_prompt_from_text("Hello") 125 | 126 | # Should return original text when no chat template 127 | self.assertEqual(result, "Hello") 128 | 129 | def test_prepare_prompt_from_text_template_error(self): 130 | """Test prompt preparation fallback when template application fails""" 131 | mock_tokenizer = MagicMock() 132 | mock_tokenizer.apply_chat_template = MagicMock(side_effect=Exception("Template error")) 133 | 134 | mock_pipeline = MagicMock() 135 | mock_pipeline.tokenizer = mock_tokenizer 136 | 137 | gb_pipeline = GreenBitPipeline(pipeline=mock_pipeline, model_kwargs={}, pipeline_kwargs={}) 138 | 139 | result = gb_pipeline._prepare_prompt_from_text("Hello") 140 | 141 | # Should return original text when template application fails 142 | self.assertEqual(result, "Hello") 143 | 144 | @patch.object(GreenBitPipeline, '_prepare_prompt_from_text') 145 | def test_generate_with_plain_text(self, mock_prepare_prompt): 146 | """Test generation with plain text prompts""" 147 | # Setup 148 | mock_prepare_prompt.return_value = "" 149 | mock_pipeline = MagicMock() 150 | mock_pipeline.tokenizer.encode.return_value = [1, 2, 3] 151 | mock_pipeline.device = "cpu" 152 | mock_pipeline.model.generate.return_value = MagicMock( 153 | sequences=torch.tensor([[1, 2, 3, 4, 5]]), 154 | hidden_states=None 155 | ) 156 | mock_pipeline.tokenizer.return_value = { 157 | 'input_ids': torch.tensor([[1, 2, 3]]), 158 | 'attention_mask': torch.tensor([[1, 1, 1]]) 159 | } 160 | mock_pipeline.tokenizer.decode.return_value = "Generated text" 161 | 162 | gb_pipeline = GreenBitPipeline( 163 | pipeline=mock_pipeline, 164 | model_kwargs={}, 165 | pipeline_kwargs={"max_new_tokens": 100} 166 | ) 167 | 168 | # Test with plain text (should trigger chat template) 169 | result = gb_pipeline.generate(["Hello"], enable_thinking=True) 170 | 171 | # Check that prompt was processed 172 | mock_prepare_prompt.assert_called_once_with("Hello", enable_thinking=True) 173 | 174 | # Check result 175 | self.assertIsInstance(result, LLMResult) 176 | self.assertEqual(len(result.generations), 1) 177 | 178 | def test_generate_with_formatted_prompt(self): 179 | """Test generation with already formatted prompts""" 180 | # Setup 181 | mock_pipeline = MagicMock() 182 | mock_pipeline.tokenizer.encode.return_value = [1, 2, 3] 183 | mock_pipeline.device = "cpu" 184 | mock_pipeline.model.generate.return_value = MagicMock( 185 | sequences=torch.tensor([[1, 2, 3, 4, 5]]), 186 | hidden_states=None 187 | ) 188 | mock_pipeline.tokenizer.return_value = { 189 | 'input_ids': torch.tensor([[1, 2, 3]]), 190 | 'attention_mask': torch.tensor([[1, 1, 1]]) 191 | } 192 | mock_pipeline.tokenizer.decode.return_value = "Generated text" 193 | 194 | gb_pipeline = GreenBitPipeline( 195 | pipeline=mock_pipeline, 196 | model_kwargs={}, 197 | pipeline_kwargs={"max_new_tokens": 100} 198 | ) 199 | 200 | # Test with already formatted prompt (should not trigger chat template) 201 | formatted_prompt = "<|im_start|>user\nHello<|im_end|><|im_start|>assistant\n" 202 | 203 | with patch.object(gb_pipeline, '_prepare_prompt_from_text') as mock_prepare: 204 | result = gb_pipeline.generate([formatted_prompt]) 205 | 206 | # Should not call _prepare_prompt_from_text for already formatted prompts 207 | mock_prepare.assert_not_called() 208 | 209 | @patch('green_bit_llm.langchain.pipeline.TextIteratorStreamer') 210 | @patch('green_bit_llm.langchain.pipeline.Thread') 211 | @patch.object(GreenBitPipeline, '_prepare_prompt_from_text') 212 | def test_stream(self, mock_prepare_prompt, mock_thread, mock_streamer): 213 | """Test streaming functionality""" 214 | # Setup 215 | mock_prepare_prompt.return_value = "" 216 | mock_pipeline = MagicMock() 217 | mock_pipeline.tokenizer.return_value = {'input_ids': torch.tensor([[1, 2, 3]])} 218 | mock_pipeline.device = "cpu" 219 | 220 | mock_streamer_instance = MagicMock() 221 | mock_streamer_instance.__iter__.return_value = iter(["Hello", " ", "world"]) 222 | mock_streamer.return_value = mock_streamer_instance 223 | 224 | gb_pipeline = GreenBitPipeline(pipeline=mock_pipeline, model_kwargs={}, pipeline_kwargs={}) 225 | 226 | # Test 227 | chunks = list(gb_pipeline.stream("Hi", enable_thinking=True)) 228 | 229 | # Check that prompt was processed 230 | mock_prepare_prompt.assert_called_once_with("Hi", enable_thinking=True) 231 | 232 | # Assert 233 | self.assertEqual(len(chunks), 3) 234 | self.assertIsInstance(chunks[0], GenerationChunk) 235 | self.assertEqual(chunks[0].text, "Hello") 236 | self.assertEqual(chunks[1].text, " ") 237 | self.assertEqual(chunks[2].text, "world") 238 | mock_thread.assert_called_once() 239 | mock_streamer.assert_called_once() 240 | 241 | @patch.object(GreenBitPipeline, '_prepare_prompt_from_text') 242 | def test_stream_with_formatted_prompt(self, mock_prepare_prompt): 243 | """Test streaming with already formatted prompt""" 244 | # Setup 245 | mock_pipeline = MagicMock() 246 | mock_pipeline.tokenizer.return_value = {'input_ids': torch.tensor([[1, 2, 3]])} 247 | mock_pipeline.device = "cpu" 248 | 249 | gb_pipeline = GreenBitPipeline(pipeline=mock_pipeline, model_kwargs={}, pipeline_kwargs={}) 250 | 251 | formatted_prompt = "<|start_header_id|>user<|end_header_id|>Hello<|eot_id|>" 252 | 253 | with patch('green_bit_llm.langchain.pipeline.TextIteratorStreamer') as mock_streamer: 254 | with patch('green_bit_llm.langchain.pipeline.Thread'): 255 | mock_streamer_instance = MagicMock() 256 | mock_streamer_instance.__iter__.return_value = iter(["Hello"]) 257 | mock_streamer.return_value = mock_streamer_instance 258 | 259 | list(gb_pipeline.stream(formatted_prompt)) 260 | 261 | # Should not call _prepare_prompt_from_text for already formatted prompts 262 | mock_prepare_prompt.assert_not_called() 263 | 264 | 265 | if __name__ == '__main__': 266 | unittest.main() --------------------------------------------------------------------------------