├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── lint_and_format.yml │ ├── publish_pypi.yml │ ├── publish_to_ghcr.yml │ └── pytest.yml ├── .gitignore ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── assets ├── bert.gif ├── glass-gif.gif ├── inference.gif ├── llama_images │ ├── llama-13b-class.png │ ├── llama-13b-summ.png │ ├── llama-7b-class.png │ └── llama-7b-summ.png ├── money.gif ├── overview_diagram.png ├── progress.gif ├── readme_images │ └── redpajama_results │ │ ├── RedPajama-3B Classification.png │ │ ├── RedPajama-3B Summarization.png │ │ ├── RedPajama-7B Classification.png │ │ └── RedPajama-7B Summarization.png ├── repo-main.png ├── rocket.gif ├── time.gif └── toolkit-animation.gif ├── examples ├── ablation_config.yml ├── config.yml ├── llama2_config.yml ├── mistral_config.yml └── test_suite │ ├── dot_product_tests.csv │ └── json_validity_tests.csv ├── llama2 ├── README.md ├── baseline_inference.sh ├── llama2_baseline_inference.py ├── llama2_classification.py ├── llama2_classification_inference.py ├── llama2_summarization.py ├── llama2_summarization_inference.py ├── llama_patch.py ├── prompts.py ├── run_lora.sh └── sample_ablate.sh ├── llmtune ├── __init__.py ├── cli │ ├── __init__.py │ └── toolkit.py ├── config.yml ├── constants │ └── files.py ├── data │ ├── __init__.py │ ├── dataset_generator.py │ └── ingestor.py ├── finetune │ ├── __init__.py │ ├── generics.py │ └── lora.py ├── inference │ ├── __init__.py │ ├── generics.py │ └── lora.py ├── pydantic_models │ ├── __init__.py │ └── config_model.py ├── qa │ ├── __init__.py │ ├── metric_suite.py │ ├── qa_metrics.py │ ├── qa_tests.py │ └── test_suite.py ├── ui │ ├── __init__.py │ ├── generics.py │ └── rich_ui.py └── utils │ ├── __init__.py │ ├── ablation_utils.py │ ├── rich_print_utils.py │ └── save_utils.py ├── mistral ├── README.md ├── baseline_inference.sh ├── mistral_baseline_inference.py ├── mistral_classification.py ├── mistral_classification_inference.py ├── mistral_summarization.py ├── mistral_summarization_inference.py ├── prompts.py ├── run_lora.sh └── sample_ablate.sh ├── poetry.lock ├── pyproject.toml ├── requirements.txt ├── test_utils ├── __init__.py └── test_config.py └── tests ├── data ├── test_dataset_generator.py └── test_ingestor.py ├── finetune ├── test_finetune_generics.py └── test_finetune_lora.py ├── inference ├── test_inference_generics.py └── test_inference_lora.py ├── qa ├── test_metric_suite.py ├── test_qa_metrics.py ├── test_qa_tests.py └── test_test_suite.py ├── test_ablation_utils.py ├── test_cli.py ├── test_directory_helper.py └── test_version.py /.dockerignore: -------------------------------------------------------------------------------- 1 | # .git 2 | 3 | # exlcude temp folder 4 | temp 5 | 6 | # exclude data files 7 | data/ucr/* 8 | 9 | # Jupyter 10 | *.bundle.* 11 | lib/ 12 | node_modules/ 13 | *.egg-info/ 14 | .ipynb_checkpoints 15 | .ipynb_checkpoints/ 16 | *.tsbuildinfo 17 | 18 | # IDE 19 | .vscode 20 | .vscode/* 21 | 22 | #Cython 23 | *.pyc 24 | 25 | # Packages 26 | *.egg 27 | !/tests/**/*.egg 28 | /*.egg-info 29 | /dist/* 30 | build 31 | _build 32 | .cache 33 | *.so 34 | 35 | # Installer logs 36 | pip-log.txt 37 | 38 | # Unit test / coverage reports 39 | .coverage 40 | .tox 41 | .pytest_cache 42 | 43 | .DS_Store 44 | .idea/* 45 | .python-version 46 | .vscode/* 47 | 48 | /test.py 49 | /test_*.* 50 | 51 | /setup.cfg 52 | MANIFEST.in 53 | /docs/site/* 54 | /tests/fixtures/simple_project/setup.py 55 | /tests/fixtures/project_with_extras/setup.py 56 | .mypy_cache 57 | 58 | .venv 59 | /releases/* 60 | pip-wheel-metadata 61 | /poetry.toml 62 | 63 | poetry/core/* 64 | 65 | # Logs 66 | log.txt 67 | logs 68 | /logs/* 69 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Environment:** 27 | - OS: [e.g. Ubuntu 22.04] 28 | - GPU model and number [e.g. 1x NVIDIA RTX 3090] 29 | - GPU driver version [e.g. 535.171.04] 30 | - CUDA Driver version [e.g. 12.2] 31 | - Packages Installed 32 | 33 | **Additional context** 34 | Add any other context about the problem here. 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/lint_and_format.yml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | on: pull_request 3 | jobs: 4 | lint: 5 | name: Lint & Format 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v4 9 | - uses: chartboost/ruff-action@v1 10 | name: Lint 11 | with: 12 | version: 0.3.5 13 | args: "check --output-format=full --statistics" 14 | - uses: chartboost/ruff-action@v1 15 | name: Format 16 | with: 17 | version: 0.3.5 18 | args: "format --check" 19 | -------------------------------------------------------------------------------- /.github/workflows/publish_pypi.yml: -------------------------------------------------------------------------------- 1 | name: PyPI CD 2 | on: 3 | release: 4 | types: [published] 5 | jobs: 6 | pypi: 7 | name: Build and Upload Release 8 | runs-on: ubuntu-latest 9 | environment: 10 | name: pypi 11 | url: https://pypi.org/p/llm-toolkit 12 | permissions: 13 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 14 | steps: 15 | # ---------------- 16 | # Set Up 17 | # ---------------- 18 | - name: Checkout 19 | uses: actions/checkout@v4 20 | with: 21 | fetch-depth: 0 22 | - name: Setup Python 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: "3.11" 26 | - name: Install poetry 27 | uses: snok/install-poetry@v1 28 | with: 29 | version: 1.5.1 30 | virtualenvs-create: true 31 | virtualenvs-in-project: true 32 | installer-parallel: true 33 | # ---------------- 34 | # Install Deps 35 | # ---------------- 36 | - name: Install Dependencies 37 | run: | 38 | poetry install --no-interaction --no-root 39 | poetry self add "poetry-dynamic-versioning[plugin]" 40 | # ---------------- 41 | # Build & Publish 42 | # ---------------- 43 | - name: Build 44 | run: poetry build 45 | - name: Publish package distributions to PyPI 46 | uses: pypa/gh-action-pypi-publish@release/v1 47 | -------------------------------------------------------------------------------- /.github/workflows/publish_to_ghcr.yml: -------------------------------------------------------------------------------- 1 | name: Github Packages CD 2 | on: 3 | release: 4 | types: [published] 5 | env: 6 | REGISTRY: ghcr.io 7 | IMAGE_NAME: ${{ github.repository }} 8 | jobs: 9 | build-and-push-image: 10 | name: Build and Push Image 11 | runs-on: ubuntu-latest 12 | permissions: 13 | contents: read 14 | packages: write 15 | steps: 16 | - name: Checkout repository 17 | uses: actions/checkout@v4 18 | - name: Log in to the Container registry 19 | uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1 20 | with: 21 | registry: ${{ env.REGISTRY }} 22 | username: ${{ github.actor }} 23 | password: ${{ secrets.GITHUB_TOKEN }} 24 | - name: Extract metadata (tags, labels) for Docker 25 | id: meta 26 | uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7 27 | with: 28 | images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 29 | - name: Build and push Docker image 30 | uses: docker/build-push-action@f2a1d5e99d037542a71f64918e516c093c6f3fc4 31 | with: 32 | context: . 33 | push: true 34 | tags: ${{ steps.meta.outputs.tags }} 35 | labels: ${{ steps.meta.outputs.labels }} 36 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: pytest CI 2 | on: pull_request 3 | jobs: 4 | pytest: 5 | name: Run pytest and check min coverage threshold (80%) 6 | runs-on: ubuntu-latest 7 | steps: 8 | # ---------------- 9 | # Set Up 10 | # ---------------- 11 | - name: Checkout 12 | uses: actions/checkout@v4 13 | with: 14 | fetch-depth: 0 15 | - name: Setup Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: "3.11" 19 | - name: Install poetry 20 | uses: snok/install-poetry@v1 21 | with: 22 | version: 1.5.1 23 | virtualenvs-create: true 24 | virtualenvs-in-project: true 25 | installer-parallel: true 26 | # ---------------- 27 | # Install Deps 28 | # ---------------- 29 | - name: Install Dependencies 30 | run: | 31 | poetry install --no-interaction --no-root 32 | # ---------------- 33 | # Run Test 34 | # ---------------- 35 | - name: Run pytest 36 | run: poetry run pytest --cov=./ --cov-report=term 37 | - name: Check Coverage 38 | run: poetry run coverage report --fail-under=80 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # experiment files 4 | */experiments 5 | */experiment 6 | experiment/* 7 | */archive 8 | */backup 9 | */baseline_results 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # Jupyter Notebook 40 | .ipynb_checkpoints 41 | 42 | # Environments 43 | .env 44 | .venv 45 | env/ 46 | venv/ 47 | ENV/ 48 | env.bak/ 49 | venv.bak/ 50 | 51 | # Coverage Report 52 | .coverage 53 | /htmlcov 54 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing 2 | 3 | If you would like to contribute to this project, we recommend following the ["fork-and-pull" Git workflow](https://www.atlassian.com/git/tutorials/comparing-workflows/forking-workflow). 4 | 5 | 1. **Fork** the repo on GitHub 6 | 2. **Clone** the project to your own machine 7 | 3. **Commit** changes to your own branch 8 | 4. **Push** your work back up to your fork 9 | 5. Submit a **Pull request** so that we can review your changes 10 | 11 | NOTE: Be sure to merge the latest from "upstream" before making a pull request! 12 | 13 | ### Set Up Dev Environment 14 | 15 |
16 | 1. Clone Repo 17 | 18 | ```shell 19 | git clone https://github.com/georgian-io/LLM-Finetuning-Toolkit.git 20 | cd LLM-Finetuning-Toolkit/ 21 | ``` 22 | 23 |
24 | 25 |
26 | 2. Install Dependencies 27 |
28 | Install with Docker [Recommended] 29 | 30 | ```shell 31 | docker build -t llm-toolkit . 32 | ``` 33 | 34 | ```shell 35 | # CPU 36 | docker run -it llm-toolkit 37 | # GPU 38 | docker run -it --gpus all llm-toolkit 39 | ``` 40 | 41 |
42 | 43 |
44 | Poetry (recommended) 45 | 46 | See poetry documentation page for poetry [installation instructions](https://python-poetry.org/docs/#installation) 47 | 48 | ```shell 49 | poetry install 50 | ``` 51 | 52 |
53 |
54 | pip 55 | We recommend using a virtual environment like `venv` or `conda` for installation 56 | 57 | ```shell 58 | pip install -e . 59 | ``` 60 | 61 |
62 |
63 | 64 | ### Checklist Before Pull Request (Optional) 65 | 66 | 1. Use `ruff check --fix` to check and fix lint errors 67 | 2. Use `ruff format` to apply formatting 68 | 3. Run `pytest` at the top level directory to run unit tests 69 | 70 | NOTE: Ruff linting and formatting checks are done when PR is raised via Git Action. Before raising a PR, it is a good practice to check and fix lint errors, as well as apply formatting. 71 | 72 | ### Releasing 73 | 74 | To manually release a PyPI package, please run: 75 | 76 | ```shell 77 | make build-release 78 | ``` 79 | 80 | Note: Make sure you have a pypi token for this [PyPI repo](https://pypi.org/project/llm-toolkit/). 81 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04 2 | RUN apt-get update && \ 3 | DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends git curl software-properties-common && \ 4 | add-apt-repository ppa:deadsnakes/ppa && apt-get update && apt install -y python3.10 python3.10-distutils && \ 5 | curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10 6 | 7 | 8 | RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1 && \ 9 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 2 && \ 10 | update-alternatives --auto python3 11 | 12 | RUN export CUDA_HOME=/usr/local/cuda/ 13 | 14 | COPY . /home/llm-finetuning-hub 15 | WORKDIR /home/llm-finetuning-hub 16 | RUN pip3 install --no-cache-dir -r ./requirements.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | 179 | Copyright 2024 Georgian Partners 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | http://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. 192 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | test-coverage: 2 | pytest --cov=llmtune tests/ 3 | 4 | fix-format: 5 | ruff check --fix 6 | ruff format 7 | 8 | build-release: 9 | rm -rf dist 10 | rm -rf build 11 | poetry build 12 | poetry publish 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM Finetuning Toolkit 2 | 3 |

4 | 5 |

6 | 7 | ## Overview 8 | 9 | LLM Finetuning toolkit is a config-based CLI tool for launching a series of LLM fine-tuning experiments on your data and gathering their results. From one single `yaml` config file, control all elements of a typical experimentation pipeline - **prompts**, **open-source LLMs**, **optimization strategy** and **LLM testing**. 10 | 11 |

12 | 13 |

14 | 15 | ## Installation 16 | 17 | ### [pipx](https://pipx.pypa.io/stable/) (recommended) 18 | 19 | [pipx](https://pipx.pypa.io/stable/) installs the package and dependencies in a separate virtual environment 20 | 21 | ```shell 22 | pipx install llm-toolkit 23 | ``` 24 | 25 | ### pip 26 | 27 | ```shell 28 | pip install llm-toolkit 29 | ``` 30 | 31 | ## Quick Start 32 | 33 | This guide contains 3 stages that will enable you to get the most out of this toolkit! 34 | 35 | - **Basic**: Run your first LLM fine-tuning experiment 36 | - **Intermediate**: Run a custom experiment by changing the components of the YAML configuration file 37 | - **Advanced**: Launch series of fine-tuning experiments across different prompt templates, LLMs, optimization techniques -- all through **one** YAML configuration file 38 | 39 | ### Basic 40 | 41 | ```shell 42 | llmtune generate config 43 | llmtune run ./config.yml 44 | ``` 45 | 46 | The first command generates a helpful starter `config.yml` file and saves in the current working directory. This is provided to users to quickly get started and as a base for further modification. 47 | 48 | Then the second command initiates the fine-tuning process using the settings specified in the default YAML configuration file `config.yaml`. 49 | 50 | ### Intermediate 51 | 52 | The configuration file is the central piece that defines the behavior of the toolkit. It is written in YAML format and consists of several sections that control different aspects of the process, such as data ingestion, model definition, training, inference, and quality assurance. We highlight some of the critical sections. 53 | 54 | #### Flash Attention 2 55 | 56 | To enable Flash-attention for [supported models](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2). First install `flash-attn`: 57 | 58 | **pipx** 59 | 60 | ```shell 61 | pipx inject llm-toolkit flash-attn --pip-args=--no-build-isolation 62 | ``` 63 | 64 | **pip** 65 | 66 | ``` 67 | pip install flash-attn --no-build-isolation 68 | ``` 69 | 70 | Then, add to config file. 71 | 72 | ```yaml 73 | model: 74 | torch_dtype: "bfloat16" # or "float16" if using older GPU 75 | attn_implementation: "flash_attention_2" 76 | ``` 77 | 78 | #### Data Ingestion 79 | 80 | An example of what the data ingestion may look like: 81 | 82 | ```yaml 83 | data: 84 | file_type: "huggingface" 85 | path: "yahma/alpaca-cleaned" 86 | prompt: 87 | ### Instruction: {instruction} 88 | ### Input: {input} 89 | ### Output: 90 | prompt_stub: { output } 91 | test_size: 0.1 # Proportion of test as % of total; if integer then # of samples 92 | train_size: 0.9 # Proportion of train as % of total; if integer then # of samples 93 | train_test_split_seed: 42 94 | ``` 95 | 96 | - While the above example illustrates using a public dataset from Hugging Face, the config file can also ingest your own data. 97 | 98 | ```yaml 99 | file_type: "json" 100 | path: " 101 | ``` 102 | 103 | ```yaml 104 | file_type: "csv" 105 | path: " 106 | ``` 107 | 108 | - The prompt fields help create instructions to fine-tune the LLM on. It reads data from specific columns, mentioned in {} brackets, that are present in your dataset. In the example provided, it is expected for the data file to have column names: `instruction`, `input` and `output`. 109 | 110 | - The prompt fields use both `prompt` and `prompt_stub` during fine-tuning. However, during testing, **only** the `prompt` section is used as input to the fine-tuned LLM. 111 | 112 | #### LLM Definition 113 | 114 | ```yaml 115 | model: 116 | hf_model_ckpt: "NousResearch/Llama-2-7b-hf" 117 | quantize: true 118 | bitsandbytes: 119 | load_in_4bit: true 120 | bnb_4bit_compute_dtype: "bf16" 121 | bnb_4bit_quant_type: "nf4" 122 | 123 | # LoRA Params ------------------- 124 | lora: 125 | task_type: "CAUSAL_LM" 126 | r: 32 127 | lora_dropout: 0.1 128 | target_modules: 129 | - q_proj 130 | - v_proj 131 | - k_proj 132 | - o_proj 133 | - up_proj 134 | - down_proj 135 | - gate_proj 136 | ``` 137 | 138 | - While the above example showcases using Llama2 7B, in theory, any open-source LLM supported by Hugging Face can be used in this toolkit. 139 | 140 | ```yaml 141 | hf_model_ckpt: "mistralai/Mistral-7B-v0.1" 142 | ``` 143 | 144 | ```yaml 145 | hf_model_ckpt: "tiiuae/falcon-7b" 146 | ``` 147 | 148 | - The parameters for LoRA, such as the rank `r` and dropout, can be altered. 149 | 150 | ```yaml 151 | lora: 152 | r: 64 153 | lora_dropout: 0.25 154 | ``` 155 | 156 | #### Quality Assurance 157 | 158 | ```yaml 159 | qa: 160 | llm_metrics: 161 | - length_test 162 | - word_overlap_test 163 | ``` 164 | 165 | - To ensure that the fine-tuned LLM behaves as expected, you can add tests that check if the desired behaviour is being attained. Example: for an LLM fine-tuned for a summarization task, we may want to check if the generated summary is indeed smaller in length than the input text. We would also like to learn the overlap between words in the original text and generated summary. 166 | 167 | #### Artifact Outputs 168 | 169 | This config will run fine-tuning and save the results under directory `./experiment/[unique_hash]`. Each unique configuration will generate a unique hash, so that our tool can automatically pick up where it left off. For example, if you need to exit in the middle of the training, by relaunching the script, the program will automatically load the existing dataset that has been generated under the directory, instead of doing it all over again. 170 | 171 | After the script finishes running you will see these distinct artifacts: 172 | 173 | ```shell 174 | /dataset # generated pkl file in hf datasets format 175 | /model # peft model weights in hf format 176 | /results # csv of prompt, ground truth, and predicted values 177 | /qa # csv of test results: e.g. vector similarity between ground truth and prediction 178 | ``` 179 | 180 | Once all the changes have been incorporated in the YAML file, you can simply use it to run a custom fine-tuning experiment! 181 | 182 | ```shell 183 | python toolkit.py --config-path 184 | ``` 185 | 186 | ### Advanced 187 | 188 | Fine-tuning workflows typically involve running ablation studies across various LLMs, prompt designs and optimization techniques. The configuration file can be altered to support running ablation studies. 189 | 190 | - Specify different prompt templates to experiment with while fine-tuning. 191 | 192 | ```yaml 193 | data: 194 | file_type: "huggingface" 195 | path: "yahma/alpaca-cleaned" 196 | prompt: 197 | - >- 198 | This is the first prompt template to iterate over 199 | ### Input: {input} 200 | ### Output: 201 | - >- 202 | This is the second prompt template 203 | ### Instruction: {instruction} 204 | ### Input: {input} 205 | ### Output: 206 | prompt_stub: { output } 207 | test_size: 0.1 # Proportion of test as % of total; if integer then # of samples 208 | train_size: 0.9 # Proportion of train as % of total; if integer then # of samples 209 | train_test_split_seed: 42 210 | ``` 211 | 212 | - Specify various LLMs that you would like to experiment with. 213 | 214 | ```yaml 215 | model: 216 | hf_model_ckpt: 217 | [ 218 | "NousResearch/Llama-2-7b-hf", 219 | mistralai/Mistral-7B-v0.1", 220 | "tiiuae/falcon-7b", 221 | ] 222 | quantize: true 223 | bitsandbytes: 224 | load_in_4bit: true 225 | bnb_4bit_compute_dtype: "bf16" 226 | bnb_4bit_quant_type: "nf4" 227 | ``` 228 | 229 | - Specify different configurations of LoRA that you would like to ablate over. 230 | 231 | ```yaml 232 | lora: 233 | r: [16, 32, 64] 234 | lora_dropout: [0.25, 0.50] 235 | ``` 236 | 237 | ## Extending 238 | 239 | The toolkit provides a modular and extensible architecture that allows developers to customize and enhance its functionality to suit their specific needs. Each component of the toolkit, such as data ingestion, fine-tuning, inference, and quality assurance testing, is designed to be easily extendable. 240 | 241 | ## Contributing 242 | 243 | Open-source contributions to this toolkit are welcome and encouraged. 244 | If you would like to contribute, please see [CONTRIBUTING.md](CONTRIBUTING.md). 245 | -------------------------------------------------------------------------------- /assets/bert.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/bert.gif -------------------------------------------------------------------------------- /assets/glass-gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/glass-gif.gif -------------------------------------------------------------------------------- /assets/inference.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/inference.gif -------------------------------------------------------------------------------- /assets/llama_images/llama-13b-class.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/llama_images/llama-13b-class.png -------------------------------------------------------------------------------- /assets/llama_images/llama-13b-summ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/llama_images/llama-13b-summ.png -------------------------------------------------------------------------------- /assets/llama_images/llama-7b-class.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/llama_images/llama-7b-class.png -------------------------------------------------------------------------------- /assets/llama_images/llama-7b-summ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/llama_images/llama-7b-summ.png -------------------------------------------------------------------------------- /assets/money.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/money.gif -------------------------------------------------------------------------------- /assets/overview_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/overview_diagram.png -------------------------------------------------------------------------------- /assets/progress.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/progress.gif -------------------------------------------------------------------------------- /assets/readme_images/redpajama_results/RedPajama-3B Classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/readme_images/redpajama_results/RedPajama-3B Classification.png -------------------------------------------------------------------------------- /assets/readme_images/redpajama_results/RedPajama-3B Summarization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/readme_images/redpajama_results/RedPajama-3B Summarization.png -------------------------------------------------------------------------------- /assets/readme_images/redpajama_results/RedPajama-7B Classification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/readme_images/redpajama_results/RedPajama-7B Classification.png -------------------------------------------------------------------------------- /assets/readme_images/redpajama_results/RedPajama-7B Summarization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/readme_images/redpajama_results/RedPajama-7B Summarization.png -------------------------------------------------------------------------------- /assets/repo-main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/repo-main.png -------------------------------------------------------------------------------- /assets/rocket.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/rocket.gif -------------------------------------------------------------------------------- /assets/time.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/time.gif -------------------------------------------------------------------------------- /assets/toolkit-animation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/assets/toolkit-animation.gif -------------------------------------------------------------------------------- /examples/ablation_config.yml: -------------------------------------------------------------------------------- 1 | save_dir: "./experiment/" 2 | 3 | ablation: 4 | use_ablate: false 5 | 6 | # Data Ingestion ------------------- 7 | data: 8 | file_type: "huggingface" # one of 'json', 'csv', 'huggingface' 9 | path: "yahma/alpaca-cleaned" 10 | prompt: 11 | >- # prompt, make sure column inputs are enclosed in {} brackets and that they match your data 12 | Below is an instruction that describes a task. 13 | Write a response that appropriately completes the request. 14 | ### Instruction: {instruction} 15 | ### Input: {input} 16 | ### Output: 17 | prompt_stub: 18 | >- # Stub to add for training at the end of prompt, for test set or inference, this is omitted; make sure only one variable is present 19 | {output} 20 | test_size: 0.1 # Proportion of test as % of total; if integer then # of samples 21 | train_size: 0.9 # Proportion of train as % of total; if integer then # of samples 22 | train_test_split_seed: 42 23 | 24 | # Model Definition ------------------- 25 | model: 26 | hf_model_ckpt: ["NousResearch/Llama-2-7b-hf", "mistralai/Mistral-7B-v0.1"] 27 | quantize: true 28 | bitsandbytes: 29 | load_in_4bit: true 30 | bnb_4bit_compute_dtype: "bfloat16" 31 | bnb_4bit_quant_type: "nf4" 32 | 33 | # LoRA Params ------------------- 34 | lora: 35 | task_type: "CAUSAL_LM" 36 | r: [16, 32, 64] 37 | lora_dropout: [0.1, 0.25] 38 | target_modules: 39 | - q_proj 40 | - v_proj 41 | - k_proj 42 | - o_proj 43 | - up_proj 44 | - down_proj 45 | - gate_proj 46 | 47 | # Training ------------------- 48 | training: 49 | training_args: 50 | num_train_epochs: 5 51 | per_device_train_batch_size: 4 52 | gradient_accumulation_steps: 4 53 | gradient_checkpointing: True 54 | optim: "paged_adamw_32bit" 55 | logging_steps: 100 56 | learning_rate: 2.0e-4 57 | bf16: true # Set to true for mixed precision training on Newer GPUs 58 | tf32: true 59 | # fp16: false # Set to true for mixed precision training on Older GPUs 60 | max_grad_norm: 0.3 61 | warmup_ratio: 0.03 62 | lr_scheduler_type: "constant" 63 | sft_args: 64 | max_seq_length: 5000 65 | # neftune_noise_alpha: None 66 | 67 | inference: 68 | max_new_tokens: 1024 69 | use_cache: True 70 | do_sample: True 71 | top_p: 0.9 72 | temperature: 0.8 73 | -------------------------------------------------------------------------------- /examples/config.yml: -------------------------------------------------------------------------------- 1 | save_dir: "./experiment/" 2 | 3 | ablation: 4 | use_ablate: false 5 | 6 | # Data Ingestion ------------------- 7 | data: 8 | file_type: "huggingface" # one of 'json', 'csv', 'huggingface' 9 | path: "yahma/alpaca-cleaned" 10 | prompt: 11 | >- # prompt, make sure column inputs are enclosed in {} brackets and that they match your data 12 | Below is an instruction that describes a task. 13 | Write a response that appropriately completes the request. 14 | ### Instruction: {instruction} 15 | ### Input: {input} 16 | ### Output: 17 | prompt_stub: 18 | >- # Stub to add for training at the end of prompt, for test set or inference, this is omitted; make sure only one variable is present 19 | {output} 20 | test_size: 0.1 # Proportion of test as % of total; if integer then # of samples 21 | train_size: 0.9 # Proportion of train as % of total; if integer then # of samples 22 | train_test_split_seed: 42 23 | 24 | # Model Definition ------------------- 25 | model: 26 | hf_model_ckpt: "NousResearch/Llama-2-7b-hf" 27 | quantize: true 28 | bitsandbytes: 29 | load_in_4bit: true 30 | bnb_4bit_compute_dtype: "bfloat16" 31 | bnb_4bit_quant_type: "nf4" 32 | 33 | # LoRA Params ------------------- 34 | lora: 35 | task_type: "CAUSAL_LM" 36 | r: 32 37 | lora_dropout: 0.1 38 | target_modules: 39 | - q_proj 40 | - v_proj 41 | - k_proj 42 | - o_proj 43 | - up_proj 44 | - down_proj 45 | - gate_proj 46 | 47 | # Training ------------------- 48 | training: 49 | training_args: 50 | num_train_epochs: 5 51 | per_device_train_batch_size: 4 52 | gradient_accumulation_steps: 4 53 | gradient_checkpointing: True 54 | optim: "paged_adamw_32bit" 55 | logging_steps: 100 56 | learning_rate: 2.0e-4 57 | bf16: true # Set to true for mixed precision training on Newer GPUs 58 | tf32: true 59 | # fp16: false # Set to true for mixed precision training on Older GPUs 60 | max_grad_norm: 0.3 61 | warmup_ratio: 0.03 62 | lr_scheduler_type: "constant" 63 | sft_args: 64 | max_seq_length: 5000 65 | # neftune_noise_alpha: None 66 | 67 | inference: 68 | max_new_tokens: 1024 69 | use_cache: True 70 | do_sample: True 71 | top_p: 0.9 72 | temperature: 0.8 -------------------------------------------------------------------------------- /examples/llama2_config.yml: -------------------------------------------------------------------------------- 1 | save_dir: "./experiment/" 2 | 3 | ablation: 4 | use_ablate: false 5 | 6 | # Data Ingestion ------------------- 7 | data: 8 | file_type: "huggingface" # one of 'json', 'csv', 'huggingface' 9 | path: "yahma/alpaca-cleaned" 10 | prompt: 11 | >- # prompt, make sure column inputs are enclosed in {} brackets and that they match your data 12 | Below is an instruction that describes a task. 13 | Write a response that appropriately completes the request. 14 | ### Instruction: {instruction} 15 | ### Input: {input} 16 | ### Output: 17 | prompt_stub: 18 | >- # Stub to add for training at the end of prompt, for test set or inference, this is omitted; make sure only one variable is present 19 | {output} 20 | test_size: 0.1 # Proportion of test as % of total; if integer then # of samples 21 | train_size: 0.9 # Proportion of train as % of total; if integer then # of samples 22 | train_test_split_seed: 42 23 | 24 | # Model Definition ------------------- 25 | model: 26 | hf_model_ckpt: "NousResearch/Llama-2-7b-hf" 27 | quantize: true 28 | bitsandbytes: 29 | load_in_4bit: true 30 | bnb_4bit_compute_dtype: "bfloat16" 31 | bnb_4bit_quant_type: "nf4" 32 | 33 | # LoRA Params ------------------- 34 | lora: 35 | task_type: "CAUSAL_LM" 36 | r: 32 37 | lora_dropout: 0.1 38 | target_modules: 39 | - q_proj 40 | - v_proj 41 | - k_proj 42 | - o_proj 43 | - up_proj 44 | - down_proj 45 | - gate_proj 46 | 47 | # Training ------------------- 48 | training: 49 | training_args: 50 | num_train_epochs: 5 51 | per_device_train_batch_size: 4 52 | gradient_accumulation_steps: 4 53 | gradient_checkpointing: True 54 | optim: "paged_adamw_32bit" 55 | logging_steps: 100 56 | learning_rate: 2.0e-4 57 | bf16: true # Set to true for mixed precision training on Newer GPUs 58 | tf32: true 59 | # fp16: false # Set to true for mixed precision training on Older GPUs 60 | max_grad_norm: 0.3 61 | warmup_ratio: 0.03 62 | lr_scheduler_type: "constant" 63 | sft_args: 64 | max_seq_length: 5000 65 | # neftune_noise_alpha: None 66 | 67 | inference: 68 | max_new_tokens: 1024 69 | use_cache: True 70 | do_sample: True 71 | top_p: 0.9 72 | temperature: 0.8 73 | -------------------------------------------------------------------------------- /examples/mistral_config.yml: -------------------------------------------------------------------------------- 1 | save_dir: "./experiment/" 2 | 3 | ablation: 4 | use_ablate: false 5 | 6 | # Data Ingestion ------------------- 7 | data: 8 | file_type: "huggingface" # one of 'json', 'csv', 'huggingface' 9 | path: "yahma/alpaca-cleaned" 10 | prompt: 11 | >- # prompt, make sure column inputs are enclosed in {} brackets and that they match your data 12 | Below is an instruction that describes a task. 13 | Write a response that appropriately completes the request. 14 | ### Instruction: {instruction} 15 | ### Input: {input} 16 | ### Output: 17 | prompt_stub: 18 | >- # Stub to add for training at the end of prompt, for test set or inference, this is omitted; make sure only one variable is present 19 | {output} 20 | test_size: 0.1 # Proportion of test as % of total; if integer then # of samples 21 | train_size: 0.9 # Proportion of train as % of total; if integer then # of samples 22 | train_test_split_seed: 42 23 | 24 | # Model Definition ------------------- 25 | model: 26 | hf_model_ckpt: "mistralai/Mistral-7B-v0.1" 27 | quantize: true 28 | bitsandbytes: 29 | load_in_4bit: true 30 | bnb_4bit_compute_dtype: "bfloat16" 31 | bnb_4bit_quant_type: "nf4" 32 | 33 | # LoRA Params ------------------- 34 | lora: 35 | task_type: "CAUSAL_LM" 36 | r: 32 37 | lora_dropout: 0.1 38 | target_modules: 39 | - q_proj 40 | - v_proj 41 | - k_proj 42 | - o_proj 43 | - up_proj 44 | - down_proj 45 | - gate_proj 46 | 47 | # Training ------------------- 48 | training: 49 | training_args: 50 | num_train_epochs: 5 51 | per_device_train_batch_size: 4 52 | gradient_accumulation_steps: 4 53 | gradient_checkpointing: True 54 | optim: "paged_adamw_32bit" 55 | logging_steps: 100 56 | learning_rate: 2.0e-4 57 | bf16: true # Set to true for mixed precision training on Newer GPUs 58 | tf32: true 59 | # fp16: false # Set to true for mixed precision training on Older GPUs 60 | max_grad_norm: 0.3 61 | warmup_ratio: 0.03 62 | lr_scheduler_type: "constant" 63 | sft_args: 64 | max_seq_length: 5000 65 | # neftune_noise_alpha: None 66 | 67 | inference: 68 | max_new_tokens: 1024 69 | use_cache: True 70 | do_sample: True 71 | top_p: 0.9 72 | temperature: 0.8 73 | -------------------------------------------------------------------------------- /examples/test_suite/dot_product_tests.csv: -------------------------------------------------------------------------------- 1 | "Test Type","prompt","ground_truth","threshold" 2 | "cosine_similarity","Output “the dog is brown”, word for word:","the dog is brown",0.95 3 | "cosine_similarity","You are a customer service chatbot, offering polite responses to customer questions. The company has a 30 day cancellation policy. Answer the following customer question. Question: “How long is the cancellation policy? Answer: ","30 days",0.2 4 | -------------------------------------------------------------------------------- /examples/test_suite/json_validity_tests.csv: -------------------------------------------------------------------------------- 1 | "Test Type","prompt" 2 | "json_valid","Hi my name is John. My age is 30 and I have blue eyes. Here is the same data in json format: " 3 | "json_valid","Here is an example of a simple json object: " 4 | -------------------------------------------------------------------------------- /llama2/baseline_inference.sh: -------------------------------------------------------------------------------- 1 | python llama2_baseline_inference.py --task_type classification --prompt_type zero-shot & wait 2 | python llama2_baseline_inference.py --task_type classification --prompt_type few-shot & wait 3 | python llama2_baseline_inference.py --task_type summarization --prompt_type zero-shot & wait 4 | python llama2_baseline_inference.py --task_type summarization --prompt_type few-shot 5 | -------------------------------------------------------------------------------- /llama2/llama2_baseline_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import evaluate 4 | import warnings 5 | import json 6 | import pandas as pd 7 | import pickle 8 | import torch 9 | import time 10 | 11 | from datasets import load_dataset 12 | from prompts import ( 13 | ZERO_SHOT_CLASSIFIER_PROMPT, 14 | FEW_SHOT_CLASSIFIER_PROMPT, 15 | ZERO_SHOT_SUMMARIZATION_PROMPT, 16 | FEW_SHOT_SUMMARIZATION_PROMPT, 17 | get_newsgroup_data, 18 | get_samsum_data, 19 | ) 20 | from transformers import ( 21 | AutoTokenizer, 22 | AutoModelForCausalLM, 23 | BitsAndBytesConfig, 24 | ) 25 | from sklearn.metrics import ( 26 | accuracy_score, 27 | f1_score, 28 | precision_score, 29 | recall_score, 30 | ) 31 | 32 | metric = evaluate.load("rouge") 33 | warnings.filterwarnings("ignore") 34 | 35 | 36 | def compute_metrics_decoded(decoded_labs, decoded_preds, args): 37 | if args.task_type == "summarization": 38 | rouge = metric.compute( 39 | predictions=decoded_preds, references=decoded_labs, use_stemmer=True 40 | ) 41 | metrics = {metric: round(rouge[metric] * 100.0, 3) for metric in rouge.keys()} 42 | 43 | elif args.task_type == "classification": 44 | metrics = { 45 | "micro_f1": f1_score(decoded_labs, decoded_preds, average="micro"), 46 | "macro_f1": f1_score(decoded_labs, decoded_preds, average="macro"), 47 | "precision": precision_score(decoded_labs, decoded_preds, average="micro"), 48 | "recall": recall_score(decoded_labs, decoded_preds, average="micro"), 49 | "accuracy": accuracy_score(decoded_labs, decoded_preds), 50 | } 51 | 52 | return metrics 53 | 54 | 55 | def main(args): 56 | # replace attention with flash attention 57 | if torch.cuda.get_device_capability()[0] >= 8: 58 | from llama_patch import replace_attn_with_flash_attn 59 | 60 | print("Using flash attention") 61 | replace_attn_with_flash_attn() 62 | use_flash_attention = True 63 | 64 | save_dir = os.path.join( 65 | "baseline_results", args.pretrained_ckpt, args.task_type, args.prompt_type 66 | ) 67 | if not os.path.exists(save_dir): 68 | os.makedirs(save_dir) 69 | 70 | if args.task_type == "classification": 71 | dataset = load_dataset("rungalileo/20_Newsgroups_Fixed") 72 | test_dataset = dataset["test"] 73 | test_data, test_labels = test_dataset["text"], test_dataset["label"] 74 | 75 | newsgroup_classes, few_shot_samples, _ = get_newsgroup_data() 76 | 77 | elif args.task_type == "summarization": 78 | dataset = load_dataset("samsum") 79 | test_dataset = dataset["test"] 80 | test_data, test_labels = test_dataset["dialogue"], test_dataset["summary"] 81 | 82 | few_shot_samples = get_samsum_data() 83 | 84 | if args.prompt_type == "zero-shot": 85 | if args.task_type == "classification": 86 | prompt = ZERO_SHOT_CLASSIFIER_PROMPT 87 | elif args.task_type == "summarization": 88 | prompt = ZERO_SHOT_SUMMARIZATION_PROMPT 89 | 90 | elif args.prompt_type == "few-shot": 91 | if args.task_type == "classification": 92 | prompt = FEW_SHOT_CLASSIFIER_PROMPT 93 | elif args.task_type == "summarization": 94 | prompt = FEW_SHOT_SUMMARIZATION_PROMPT 95 | 96 | # BitsAndBytesConfig int-4 config 97 | bnb_config = BitsAndBytesConfig( 98 | load_in_4bit=True, 99 | bnb_4bit_use_double_quant=True, 100 | bnb_4bit_quant_type="nf4", 101 | bnb_4bit_compute_dtype=torch.bfloat16, 102 | ) 103 | 104 | # Load model and tokenizer 105 | model = AutoModelForCausalLM.from_pretrained( 106 | args.pretrained_ckpt, 107 | quantization_config=bnb_config, 108 | use_cache=False, 109 | device_map="auto", 110 | ) 111 | model.config.pretraining_tp = 1 112 | 113 | if use_flash_attention: 114 | from llama_patch import forward 115 | 116 | assert ( 117 | model.model.layers[0].self_attn.forward.__doc__ == forward.__doc__ 118 | ), "Model is not using flash attention" 119 | 120 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_ckpt) 121 | tokenizer.pad_token = tokenizer.eos_token 122 | tokenizer.padding_side = "right" 123 | 124 | results = [] 125 | good_data, good_labels = [], [] 126 | ctr = 0 127 | # for instruct, label in zip(instructions, labels): 128 | for data, label in zip(test_data, test_labels): 129 | if not isinstance(data, str): 130 | continue 131 | if not isinstance(label, str): 132 | continue 133 | 134 | # example = instruct[:-len(label)] # remove the answer from the example 135 | if args.prompt_type == "zero-shot": 136 | if args.task_type == "classification": 137 | example = prompt.format( 138 | newsgroup_classes=newsgroup_classes, 139 | sentence=data, 140 | ) 141 | elif args.task_type == "summarization": 142 | example = prompt.format( 143 | dialogue=data, 144 | ) 145 | 146 | elif args.prompt_type == "few-shot": 147 | if args.task_type == "classification": 148 | example = prompt.format( 149 | newsgroup_classes=newsgroup_classes, 150 | few_shot_samples=few_shot_samples, 151 | sentence=data, 152 | ) 153 | elif args.task_type == "summarization": 154 | example = prompt.format( 155 | few_shot_samples=few_shot_samples, 156 | dialogue=data, 157 | ) 158 | 159 | input_ids = tokenizer( 160 | example, return_tensors="pt", truncation=True 161 | ).input_ids.cuda() 162 | 163 | with torch.inference_mode(): 164 | outputs = model.generate( 165 | input_ids=input_ids, 166 | max_new_tokens=20 if args.task_type == "classification" else 50, 167 | do_sample=True, 168 | top_p=0.95, 169 | temperature=1e-3, 170 | ) 171 | result = tokenizer.batch_decode( 172 | outputs.detach().cpu().numpy(), skip_special_tokens=True 173 | )[0] 174 | 175 | # Extract the generated text, and do basic processing 176 | result = result[len(example) :].replace("\n", "").lstrip().rstrip() 177 | results.append(result) 178 | good_labels.append(label) 179 | good_data.append(data) 180 | 181 | print(f"Example {ctr}/{len(test_data)} | GT: {label} | Pred: {result}") 182 | ctr += 1 183 | 184 | metrics = compute_metrics_decoded(good_labels, results, args) 185 | print(metrics) 186 | metrics["predictions"] = results 187 | metrics["labels"] = good_labels 188 | metrics["data"] = good_data 189 | 190 | with open(os.path.join(save_dir, "metrics.pkl"), "wb") as handle: 191 | pickle.dump(metrics, handle) 192 | 193 | print(f"Completed experiment {save_dir}") 194 | print("----------------------------------------") 195 | 196 | 197 | if __name__ == "__main__": 198 | parser = argparse.ArgumentParser() 199 | parser.add_argument("--pretrained_ckpt", default="NousResearch/Llama-2-7b-hf") 200 | parser.add_argument("--prompt_type", default="zero-shot") 201 | parser.add_argument("--task_type", default="classification") 202 | args = parser.parse_args() 203 | 204 | main(args) 205 | -------------------------------------------------------------------------------- /llama2/llama2_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import numpy as np 5 | import pandas as pd 6 | import pickle 7 | 8 | 9 | from peft import ( 10 | LoraConfig, 11 | prepare_model_for_kbit_training, 12 | get_peft_model, 13 | ) 14 | from transformers import ( 15 | AutoTokenizer, 16 | AutoModelForCausalLM, 17 | BitsAndBytesConfig, 18 | TrainingArguments, 19 | ) 20 | from trl import SFTTrainer 21 | 22 | from prompts import get_newsgroup_data_for_ft 23 | 24 | 25 | def main(args): 26 | train_dataset, test_dataset = get_newsgroup_data_for_ft( 27 | mode="train", train_sample_fraction=args.train_sample_fraction 28 | ) 29 | print(f"Sample fraction:{args.train_sample_fraction}") 30 | print(f"Training samples:{train_dataset.shape}") 31 | 32 | # BitsAndBytesConfig int-4 config 33 | bnb_config = BitsAndBytesConfig( 34 | load_in_4bit=True, 35 | bnb_4bit_use_double_quant=True, 36 | bnb_4bit_quant_type="nf4", 37 | bnb_4bit_compute_dtype=torch.bfloat16, 38 | ) 39 | 40 | # Load model and tokenizer 41 | model = AutoModelForCausalLM.from_pretrained( 42 | args.pretrained_ckpt, 43 | quantization_config=bnb_config, 44 | use_cache=False, 45 | device_map="auto", 46 | ) 47 | model.config.pretraining_tp = 1 48 | 49 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_ckpt) 50 | tokenizer.pad_token = tokenizer.eos_token 51 | tokenizer.padding_side = "right" 52 | 53 | # LoRA config based on QLoRA paper 54 | peft_config = LoraConfig( 55 | lora_alpha=16, 56 | lora_dropout=args.dropout, 57 | r=args.lora_r, 58 | bias="none", 59 | task_type="CAUSAL_LM", 60 | ) 61 | 62 | # prepare model for training 63 | model = prepare_model_for_kbit_training(model) 64 | model = get_peft_model(model, peft_config) 65 | 66 | results_dir = f"experiments/classification-sampleFraction-{args.train_sample_fraction}_epochs-{args.epochs}_rank-{args.lora_r}_dropout-{args.dropout}" 67 | 68 | training_args = TrainingArguments( 69 | output_dir=results_dir, 70 | logging_dir=f"{results_dir}/logs", 71 | num_train_epochs=args.epochs, 72 | per_device_train_batch_size=6 if use_flash_attention else 4, 73 | gradient_accumulation_steps=2, 74 | gradient_checkpointing=True, 75 | optim="paged_adamw_32bit", 76 | logging_steps=100, 77 | learning_rate=2e-4, 78 | bf16=True, 79 | tf32=True, 80 | max_grad_norm=0.3, 81 | warmup_ratio=0.03, 82 | lr_scheduler_type="constant", 83 | report_to="none", 84 | # disable_tqdm=True # disable tqdm since with packing values are in correct 85 | ) 86 | 87 | max_seq_length = 512 # max sequence length for model and packing of the dataset 88 | 89 | trainer = SFTTrainer( 90 | model=model, 91 | train_dataset=train_dataset, 92 | peft_config=peft_config, 93 | max_seq_length=max_seq_length, 94 | tokenizer=tokenizer, 95 | packing=True, 96 | args=training_args, 97 | dataset_text_field="instructions", 98 | ) 99 | 100 | trainer_stats = trainer.train() 101 | train_loss = trainer_stats.training_loss 102 | print(f"Training loss:{train_loss}") 103 | 104 | peft_model_id = f"{results_dir}/assets" 105 | trainer.model.save_pretrained(peft_model_id) 106 | tokenizer.save_pretrained(peft_model_id) 107 | 108 | with open(f"{results_dir}/results.pkl", "wb") as handle: 109 | run_result = [ 110 | args.epochs, 111 | args.lora_r, 112 | args.dropout, 113 | train_loss, 114 | ] 115 | pickle.dump(run_result, handle) 116 | print("Experiment over") 117 | 118 | 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("--pretrained_ckpt", default="NousResearch/Llama-2-7b-hf") 122 | parser.add_argument("--lora_r", default=8, type=int) 123 | parser.add_argument("--epochs", default=5, type=int) 124 | parser.add_argument("--dropout", default=0.1, type=float) 125 | parser.add_argument("--train_sample_fraction", default=0.99, type=float) 126 | 127 | args = parser.parse_args() 128 | main(args) 129 | -------------------------------------------------------------------------------- /llama2/llama2_classification_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import pandas as pd 5 | import evaluate 6 | import pickle 7 | import warnings 8 | from tqdm import tqdm 9 | 10 | from llama_patch import unplace_flash_attn_with_attn 11 | from peft import AutoPeftModelForCausalLM 12 | from transformers import AutoTokenizer 13 | from sklearn.metrics import ( 14 | accuracy_score, 15 | f1_score, 16 | precision_score, 17 | recall_score, 18 | ) 19 | 20 | from prompts import get_newsgroup_data_for_ft 21 | 22 | metric = evaluate.load("rouge") 23 | warnings.filterwarnings("ignore") 24 | 25 | 26 | def main(args): 27 | _, test_dataset = get_newsgroup_data_for_ft(mode="inference") 28 | 29 | experiment = args.experiment_dir 30 | peft_model_id = f"{experiment}/assets" 31 | 32 | # unpatch flash attention 33 | unplace_flash_attn_with_attn() 34 | 35 | # load base LLM model and tokenizer 36 | model = AutoPeftModelForCausalLM.from_pretrained( 37 | peft_model_id, 38 | low_cpu_mem_usage=True, 39 | torch_dtype=torch.float16, 40 | load_in_4bit=True, 41 | ) 42 | model.eval() 43 | 44 | tokenizer = AutoTokenizer.from_pretrained(peft_model_id) 45 | 46 | results = [] 47 | oom_examples = [] 48 | instructions, labels = test_dataset["instructions"], test_dataset["labels"] 49 | 50 | for instruct, label in tqdm(zip(instructions, labels)): 51 | input_ids = tokenizer( 52 | instruct, return_tensors="pt", truncation=True 53 | ).input_ids.cuda() 54 | 55 | with torch.inference_mode(): 56 | try: 57 | outputs = model.generate( 58 | input_ids=input_ids, 59 | max_new_tokens=20, 60 | do_sample=True, 61 | top_p=0.95, 62 | temperature=1e-3, 63 | ) 64 | result = tokenizer.batch_decode( 65 | outputs.detach().cpu().numpy(), skip_special_tokens=True 66 | )[0] 67 | result = result[len(instruct) :] 68 | except: 69 | result = "" 70 | oom_examples.append(input_ids.shape[-1]) 71 | 72 | results.append(result) 73 | 74 | metrics = { 75 | "micro_f1": f1_score(labels, results, average="micro"), 76 | "macro_f1": f1_score(labels, results, average="macro"), 77 | "precision": precision_score(labels, results, average="micro"), 78 | "recall": recall_score(labels, results, average="micro"), 79 | "accuracy": accuracy_score(labels, results), 80 | "oom_examples": oom_examples, 81 | } 82 | print(metrics) 83 | 84 | save_dir = os.path.join(experiment, "metrics") 85 | if not os.path.exists(save_dir): 86 | os.makedirs(save_dir) 87 | 88 | with open(os.path.join(save_dir, "metrics.pkl"), "wb") as handle: 89 | pickle.dump(metrics, handle) 90 | 91 | print(f"Completed experiment {peft_model_id}") 92 | print("----------------------------------------") 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument( 98 | "--experiment_dir", 99 | default="experiments/classification-sampleFraction-0.1_epochs-5_rank-8_dropout-0.1", 100 | ) 101 | 102 | args = parser.parse_args() 103 | main(args) 104 | -------------------------------------------------------------------------------- /llama2/llama2_summarization.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import numpy as np 5 | import pandas as pd 6 | import pickle 7 | import datasets 8 | from datasets import Dataset, load_dataset 9 | 10 | from peft import ( 11 | LoraConfig, 12 | prepare_model_for_kbit_training, 13 | get_peft_model, 14 | ) 15 | from transformers import ( 16 | AutoTokenizer, 17 | AutoModelForCausalLM, 18 | BitsAndBytesConfig, 19 | TrainingArguments, 20 | ) 21 | from trl import SFTTrainer 22 | 23 | from prompts import TRAINING_SUMMARIZATION_PROMPT_v2 24 | 25 | 26 | def prepare_instructions(dialogues, summaries): 27 | instructions = [] 28 | 29 | prompt = TRAINING_SUMMARIZATION_PROMPT_v2 30 | 31 | for dialogue, summary in zip(dialogues, summaries): 32 | example = prompt.format( 33 | dialogue=dialogue, 34 | summary=summary, 35 | ) 36 | instructions.append(example) 37 | 38 | return instructions 39 | 40 | 41 | def prepare_samsum_data(): 42 | dataset = load_dataset("samsum") 43 | train_dataset = dataset["train"] 44 | val_dataset = dataset["test"] 45 | 46 | dialogues = train_dataset["dialogue"] 47 | summaries = train_dataset["summary"] 48 | train_instructions = prepare_instructions(dialogues, summaries) 49 | train_dataset = datasets.Dataset.from_pandas( 50 | pd.DataFrame(data={"instructions": train_instructions}) 51 | ) 52 | 53 | return train_dataset 54 | 55 | 56 | def main(args): 57 | train_dataset = prepare_samsum_data() 58 | 59 | # BitsAndBytesConfig int-4 config 60 | bnb_config = BitsAndBytesConfig( 61 | load_in_4bit=True, 62 | bnb_4bit_use_double_quant=True, 63 | bnb_4bit_quant_type="nf4", 64 | bnb_4bit_compute_dtype=torch.bfloat16, 65 | ) 66 | 67 | # Load model and tokenizer 68 | model = AutoModelForCausalLM.from_pretrained( 69 | args.pretrained_ckpt, 70 | quantization_config=bnb_config, 71 | use_cache=False, 72 | device_map="auto", 73 | ) 74 | model.config.pretraining_tp = 1 75 | 76 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_ckpt) 77 | tokenizer.pad_token = tokenizer.eos_token 78 | tokenizer.padding_side = "right" 79 | 80 | # LoRA config based on QLoRA paper 81 | peft_config = LoraConfig( 82 | lora_alpha=16, 83 | lora_dropout=args.dropout, 84 | r=args.lora_r, 85 | bias="none", 86 | task_type="CAUSAL_LM", 87 | ) 88 | 89 | # prepare model for training 90 | model = prepare_model_for_kbit_training(model) 91 | model = get_peft_model(model, peft_config) 92 | 93 | results_dir = f"experiments/summarization_epochs-{args.epochs}_rank-{args.lora_r}_dropout-{args.dropout}" 94 | 95 | training_args = TrainingArguments( 96 | output_dir=results_dir, 97 | logging_dir=f"{results_dir}/logs", 98 | num_train_epochs=args.epochs, 99 | per_device_train_batch_size=6 if use_flash_attention else 4, 100 | gradient_accumulation_steps=2, 101 | gradient_checkpointing=True, 102 | optim="paged_adamw_32bit", 103 | logging_steps=100, 104 | learning_rate=2e-4, 105 | bf16=True, 106 | tf32=True, 107 | max_grad_norm=0.3, 108 | warmup_ratio=0.03, 109 | lr_scheduler_type="constant", 110 | report_to="none", 111 | # disable_tqdm=True # disable tqdm since with packing values are in correct 112 | ) 113 | 114 | max_seq_length = 512 # max sequence length for model and packing of the dataset 115 | 116 | trainer = SFTTrainer( 117 | model=model, 118 | train_dataset=train_dataset, 119 | peft_config=peft_config, 120 | max_seq_length=max_seq_length, 121 | tokenizer=tokenizer, 122 | packing=True, 123 | args=training_args, 124 | dataset_text_field="instructions", 125 | ) 126 | 127 | trainer_stats = trainer.train() 128 | train_loss = trainer_stats.training_loss 129 | print(f"Training loss:{train_loss}") 130 | 131 | peft_model_id = f"{results_dir}/assets" 132 | trainer.model.save_pretrained(peft_model_id) 133 | tokenizer.save_pretrained(peft_model_id) 134 | 135 | with open(f"{results_dir}/results.pkl", "wb") as handle: 136 | run_result = [ 137 | args.epochs, 138 | args.lora_r, 139 | args.dropout, 140 | train_loss, 141 | ] 142 | pickle.dump(run_result, handle) 143 | print("Experiment over") 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument("--pretrained_ckpt", default="NousResearch/Llama-2-7b-hf") 149 | parser.add_argument("--lora_r", default=64, type=int) 150 | parser.add_argument("--epochs", default=1, type=int) 151 | parser.add_argument("--dropout", default=0.1, type=float) 152 | 153 | args = parser.parse_args() 154 | main(args) 155 | -------------------------------------------------------------------------------- /llama2/llama2_summarization_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import pandas as pd 5 | import evaluate 6 | from datasets import load_dataset 7 | import pickle 8 | import warnings 9 | 10 | from llama_patch import unplace_flash_attn_with_attn 11 | from peft import AutoPeftModelForCausalLM 12 | from transformers import AutoTokenizer 13 | 14 | from prompts import INFERENCE_SUMMARIZATION_PROMPT_v2 15 | 16 | metric = evaluate.load("rouge") 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | def prepare_instructions(dialogues, summaries): 21 | instructions = [] 22 | 23 | prompt = INFERENCE_SUMMARIZATION_PROMPT_v2 24 | 25 | for dialogue, summary in zip(dialogues, summaries): 26 | example = prompt.format( 27 | dialogue=dialogue, 28 | ) 29 | instructions.append(example) 30 | 31 | return instructions 32 | 33 | 34 | def prepare_samsum_data(): 35 | dataset = load_dataset("samsum") 36 | val_dataset = dataset["test"] 37 | 38 | dialogues = val_dataset["dialogue"] 39 | summaries = val_dataset["summary"] 40 | val_instructions = prepare_instructions(dialogues, summaries) 41 | 42 | return val_instructions, summaries 43 | 44 | 45 | def main(args): 46 | val_instructions, summaries = prepare_samsum_data() 47 | 48 | experiment = args.experiment_dir 49 | peft_model_id = f"{experiment}/assets" 50 | 51 | # unpatch flash attention 52 | unplace_flash_attn_with_attn() 53 | 54 | # load base LLM model and tokenizer 55 | model = AutoPeftModelForCausalLM.from_pretrained( 56 | peft_model_id, 57 | low_cpu_mem_usage=True, 58 | torch_dtype=torch.float16, 59 | load_in_4bit=True, 60 | ) 61 | tokenizer = AutoTokenizer.from_pretrained(peft_model_id) 62 | 63 | results = [] 64 | for instruct, summary in zip(val_instructions, summaries): 65 | input_ids = tokenizer( 66 | instruct, return_tensors="pt", truncation=True 67 | ).input_ids.cuda() 68 | with torch.inference_mode(): 69 | outputs = model.generate( 70 | input_ids=input_ids, 71 | max_new_tokens=100, 72 | do_sample=True, 73 | top_p=0.9, 74 | temperature=1e-2, 75 | ) 76 | result = tokenizer.batch_decode( 77 | outputs.detach().cpu().numpy(), skip_special_tokens=True 78 | )[0] 79 | result = result[len(instruct) :] 80 | results.append(result) 81 | print(f"Instruction:{instruct}") 82 | print(f"Summary:{summary}") 83 | print(f"Generated:{result}") 84 | print("----------------------------------------") 85 | 86 | # compute metric 87 | rouge = metric.compute(predictions=results, references=summaries, use_stemmer=True) 88 | 89 | metrics = {metric: round(rouge[metric] * 100, 2) for metric in rouge.keys()} 90 | 91 | save_dir = os.path.join(experiment, "metrics") 92 | if not os.path.exists(save_dir): 93 | os.makedirs(save_dir) 94 | 95 | with open(os.path.join(save_dir, "metrics.pkl"), "wb") as handle: 96 | pickle.dump(metrics, handle) 97 | 98 | print(f"Completed experiment {peft_model_id}") 99 | print("----------------------------------------") 100 | 101 | 102 | if __name__ == "__main__": 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument( 105 | "--experiment_dir", 106 | default="experiments/summarization_epochs-1_rank-64_dropout-0.1", 107 | ) 108 | 109 | args = parser.parse_args() 110 | main(args) 111 | -------------------------------------------------------------------------------- /llama2/llama_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | import warnings 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 8 | from peft.tuners.lora import LoraLayer 9 | 10 | try: 11 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 12 | from flash_attn.bert_padding import unpad_input, pad_input 13 | except Exception: 14 | raise ModuleNotFoundError( 15 | "Please install FlashAttention first, e.g., with pip install flash-attn --no-build-isolation, Learn more at https://github.com/Dao-AILab/flash-attention#installation-and-features" 16 | ) 17 | 18 | try: 19 | from einops import rearrange 20 | except Exception: 21 | raise ModuleNotFoundError( 22 | "Please install einops first, e.g., with pip install einops" 23 | ) 24 | 25 | 26 | # ADAPTED from https://github.com/allenai/open-instruct/blob/main/open_instruct/llama_flash_attn_monkey_patch.py 27 | # AND https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py 28 | # AND https://github.com/LAION-AI/Open-Assistant/blob/04fa9a24b2a58c8885b8aa6a2eb02b18de6b4961/model/model_training/models/patching_llama.py 29 | # AND Sourabh https://github.com/huggingface/transformers/commit/ee81bf5aee0d65f005d157c013777e3d27d8d6bf 30 | def forward( 31 | self, 32 | hidden_states: torch.Tensor, 33 | attention_mask: Optional[torch.Tensor] = None, 34 | position_ids: Optional[torch.Tensor] = None, 35 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 36 | output_attentions: bool = False, 37 | use_cache: bool = False, 38 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 39 | """Input shape: Batch x Time x Channel 40 | 41 | attention_mask: [bsz, q_len] 42 | """ 43 | if output_attentions: 44 | warnings.warn( 45 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 46 | ) 47 | 48 | bsz, q_len, _ = hidden_states.size() 49 | 50 | query_states = ( 51 | self.q_proj(hidden_states) 52 | .view(bsz, q_len, self.num_heads, self.head_dim) 53 | .transpose(1, 2) 54 | ) 55 | key_states = ( 56 | self.k_proj(hidden_states) 57 | .view(bsz, q_len, self.num_heads, self.head_dim) 58 | .transpose(1, 2) 59 | ) 60 | value_states = ( 61 | self.v_proj(hidden_states) 62 | .view(bsz, q_len, self.num_heads, self.head_dim) 63 | .transpose(1, 2) 64 | ) 65 | # [bsz, q_len, nh, hd] 66 | # [bsz, nh, q_len, hd] 67 | 68 | kv_seq_len = key_states.shape[-2] 69 | if past_key_value is not None: 70 | kv_seq_len += past_key_value[0].shape[-2] 71 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 72 | query_states, key_states = apply_rotary_pos_emb( 73 | query_states, key_states, cos, sin, position_ids 74 | ) 75 | 76 | # Past Key value support 77 | if past_key_value is not None: 78 | # reuse k, v, self_attention 79 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 80 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 81 | 82 | past_key_value = (key_states, value_states) if use_cache else None 83 | 84 | # Flash attention codes from 85 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 86 | 87 | # transform the data into the format required by flash attention 88 | qkv = torch.stack( 89 | [query_states, key_states, value_states], dim=2 90 | ) # [bsz, nh, 3, q_len, hd] 91 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 92 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 93 | # the attention_mask should be the same as the key_padding_mask 94 | key_padding_mask = attention_mask 95 | 96 | if key_padding_mask is None: 97 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 98 | max_s = q_len 99 | cu_q_lens = torch.arange( 100 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 101 | ) 102 | output = flash_attn_varlen_qkvpacked_func( 103 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 104 | ) 105 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 106 | else: 107 | nheads = qkv.shape[-2] 108 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 109 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 110 | x_unpad = rearrange( 111 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 112 | ) 113 | output_unpad = flash_attn_varlen_qkvpacked_func( 114 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 115 | ) 116 | output = rearrange( 117 | pad_input( 118 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 119 | ), 120 | "b s (h d) -> b s h d", 121 | h=nheads, 122 | ) 123 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value 124 | 125 | 126 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 127 | # requires the attention mask to be the same as the key_padding_mask 128 | def _prepare_decoder_attention_mask( 129 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 130 | ): 131 | # [bsz, seq_len] 132 | return attention_mask 133 | 134 | 135 | def replace_attn_with_flash_attn(): 136 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 137 | if cuda_major < 8: 138 | print( 139 | "Flash attention is only supported on Ampere or Hopper GPU during training due to head dim > 64 backward." 140 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 141 | ) 142 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 143 | _prepare_decoder_attention_mask 144 | ) 145 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 146 | 147 | 148 | def unplace_flash_attn_with_attn(): 149 | import importlib 150 | import transformers 151 | 152 | print("Reloading llama model, unpatching flash attention") 153 | importlib.reload(transformers.models.llama.modeling_llama) 154 | 155 | 156 | # Adapted from https://github.com/tmm1/axolotl/blob/2eda9e02a9d15a7a3f92b41f257d9844d72fc220/src/axolotl/utils/models.py#L338 157 | def upcast_layer_for_flash_attention(model, torch_dtype): 158 | # LlamaRMSNorm layers are in fp32 after kbit_training, so we need to 159 | # convert them back to fp16/bf16 for flash-attn compatibility. 160 | for name, module in model.named_modules(): 161 | if isinstance(module, LoraLayer): 162 | module.to(torch_dtype) 163 | if "norm" in name: 164 | module.to(torch_dtype) 165 | if "lm_head" in name or "embed_tokens" in name: 166 | if hasattr(module, "weight"): 167 | module.to(torch_dtype) 168 | 169 | return model 170 | -------------------------------------------------------------------------------- /llama2/prompts.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import datasets 3 | from datasets import load_dataset 4 | from sklearn.model_selection import train_test_split 5 | 6 | 7 | ZERO_SHOT_CLASSIFIER_PROMPT = """Classify the sentence into one of 20 classes. The list of classes is provided below, where the classes are separated by commas: 8 | 9 | {newsgroup_classes} 10 | 11 | From the above list of classes, select only one class that the provided sentence can be classified into. The sentence will be delimited with triple backticks. Once again, only predict the class from the given list of classes. Do not predict anything else. 12 | 13 | ### Sentence: ```{sentence}``` 14 | ### Class: 15 | """ 16 | 17 | FEW_SHOT_CLASSIFIER_PROMPT = """Classify the sentence into one of 20 classes. The list of classes is provided below, where the classes are separated by commas: 18 | 19 | {newsgroup_classes} 20 | 21 | From the above list of classes, select only one class that the provided sentence can be classified into. Once again, only predict the class from the given list of classes. Do not predict anything else. The sentence will be delimited with triple backticks. To help you, examples are provided of sentence and the corresponding class they belong to. 22 | 23 | {few_shot_samples} 24 | 25 | ### Sentence: ```{sentence}``` 26 | ### Class: 27 | """ 28 | 29 | TRAINING_CLASSIFIER_PROMPT = """Classify the following sentence that is delimited with triple backticks. 30 | 31 | ### Sentence: ```{sentence}``` 32 | ### Class: {label} 33 | """ 34 | 35 | INFERENCE_CLASSIFIER_PROMPT = """Classify the following sentence that is delimited with triple backticks. 36 | 37 | ### Sentence: ```{sentence}``` 38 | ### Class: 39 | """ 40 | 41 | TRAINING_CLASSIFIER_PROMPT_v2 = """### Sentence:{sentence} ### Class:{label}""" 42 | INFERENCE_CLASSIFIER_PROMPT_v2 = """### Sentence:{sentence} ### Class:""" 43 | 44 | ZERO_SHOT_SUMMARIZATION_PROMPT = """Summarize the following dialogue that is delimited with triple backticks. 45 | 46 | ### Dialogue: ```{dialogue}``` 47 | ### Summary: 48 | """ 49 | 50 | FEW_SHOT_SUMMARIZATION_PROMPT = """Summarize the following dialogue that is delimited with triple backticks. To help you, examples of summarization are provided. 51 | 52 | {few_shot_samples} 53 | 54 | ### Dialogue: ```{dialogue}``` 55 | ### Summary: 56 | """ 57 | 58 | TRAINING_SUMMARIZATION_PROMPT = """Summarize the following dialogue that is delimited with triple backticks. 59 | 60 | ### Dialogue: ```{dialogue}``` 61 | ### Summary: {summary} 62 | """ 63 | 64 | TRAINING_SUMMARIZATION_PROMPT_v2 = """### Dialogue:{dialogue} ### Summary:{summary}""" 65 | INFERENCE_SUMMARIZATION_PROMPT_v2 = """### Dialogue:{dialogue} ### Summary:""" 66 | 67 | INFERENCE_SUMMARIZATION_PROMPT = """Summarize the following dialogue that is delimited with triple backticks. 68 | 69 | ### Dialogue: ```{dialogue}``` 70 | ### Summary: 71 | """ 72 | 73 | 74 | def get_newsgroup_instruction_data(mode, texts, labels): 75 | if mode == "train": 76 | prompt = TRAINING_CLASSIFIER_PROMPT_v2 77 | elif mode == "inference": 78 | prompt = INFERENCE_CLASSIFIER_PROMPT_v2 79 | 80 | instructions = [] 81 | 82 | for text, label in zip(texts, labels): 83 | if mode == "train": 84 | example = prompt.format( 85 | sentence=text, 86 | label=label, 87 | ) 88 | elif mode == "inference": 89 | example = prompt.format( 90 | sentence=text, 91 | ) 92 | instructions.append(example) 93 | 94 | return instructions 95 | 96 | 97 | def clean_newsgroup_data(texts, labels): 98 | label2data = {} 99 | clean_data, clean_labels = [], [] 100 | for data, label in zip(texts, labels): 101 | if isinstance(data, str) and isinstance(label, str): 102 | clean_data.append(data) 103 | clean_labels.append(label) 104 | 105 | if label not in label2data: 106 | label2data[label] = data 107 | 108 | return label2data, clean_data, clean_labels 109 | 110 | 111 | def get_newsgroup_data_for_ft(mode="train", train_sample_fraction=0.99): 112 | newsgroup_dataset = load_dataset("rungalileo/20_Newsgroups_Fixed") 113 | train_data = newsgroup_dataset["train"]["text"] 114 | train_labels = newsgroup_dataset["train"]["label"] 115 | label2data, train_data, train_labels = clean_newsgroup_data( 116 | train_data, train_labels 117 | ) 118 | 119 | test_data = newsgroup_dataset["test"]["text"] 120 | test_labels = newsgroup_dataset["test"]["label"] 121 | _, test_data, test_labels = clean_newsgroup_data(test_data, test_labels) 122 | 123 | # sample n points from training data 124 | train_df = pd.DataFrame(data={"text": train_data, "label": train_labels}) 125 | train_df, _ = train_test_split( 126 | train_df, 127 | train_size=train_sample_fraction, 128 | stratify=train_df["label"], 129 | random_state=42, 130 | ) 131 | train_data = train_df["text"] 132 | train_labels = train_df["label"] 133 | 134 | train_instructions = get_newsgroup_instruction_data(mode, train_data, train_labels) 135 | test_instructions = get_newsgroup_instruction_data(mode, test_data, test_labels) 136 | 137 | train_dataset = datasets.Dataset.from_pandas( 138 | pd.DataFrame( 139 | data={ 140 | "instructions": train_instructions, 141 | "labels": train_labels, 142 | } 143 | ) 144 | ) 145 | test_dataset = datasets.Dataset.from_pandas( 146 | pd.DataFrame( 147 | data={ 148 | "instructions": test_instructions, 149 | "labels": test_labels, 150 | } 151 | ) 152 | ) 153 | 154 | return train_dataset, test_dataset 155 | 156 | 157 | def get_newsgroup_data(): 158 | newsgroup_dataset = load_dataset("rungalileo/20_Newsgroups_Fixed") 159 | train_data = newsgroup_dataset["train"]["text"] 160 | train_labels = newsgroup_dataset["train"]["label"] 161 | 162 | label2data, clean_data, clean_labels = clean_newsgroup_data( 163 | train_data, train_labels 164 | ) 165 | df = pd.DataFrame(data={"text": clean_data, "label": clean_labels}) 166 | 167 | newsgroup_classes = df["label"].unique() 168 | newsgroup_classes = ", ".join(newsgroup_classes) 169 | 170 | few_shot_samples = "" 171 | for label, data in label2data.items(): 172 | sample = f"Sentence: {data} \n Class: {label} \n\n" 173 | few_shot_samples += sample 174 | 175 | return newsgroup_classes, few_shot_samples, df 176 | 177 | 178 | def get_samsum_data(): 179 | samsum_dataset = load_dataset("samsum") 180 | train_dataset = samsum_dataset["train"] 181 | dialogues = train_dataset["dialogue"][:2] 182 | summaries = train_dataset["summary"][:2] 183 | 184 | few_shot_samples = "" 185 | for dialogue, summary in zip(dialogues, summaries): 186 | sample = f"Sentence: {dialogue} \n Summary: {summary} \n\n" 187 | few_shot_samples += sample 188 | 189 | return few_shot_samples 190 | -------------------------------------------------------------------------------- /llama2/run_lora.sh: -------------------------------------------------------------------------------- 1 | epochs=(2 5 10 20 30 50) 2 | lora_r=(2 4 8 16) 3 | dropout=(0.1 0.2 0.5) 4 | 5 | for (( epoch=0; epoch<6; epoch=epoch+1 )) do 6 | for ((r=0; r<4; r=r+1 )) do 7 | for (( d=0; d<3; d=d+1 )) do 8 | python llama2_summarization.py --lora_r ${lora_r[$r]} --epochs ${epochs[$epoch]} --dropout ${dropout[$d]} & wait 9 | done 10 | done 11 | done 12 | -------------------------------------------------------------------------------- /llama2/sample_ablate.sh: -------------------------------------------------------------------------------- 1 | sample_fraction=(0.025 0.05 0.1) 2 | 3 | for (( sf=0; sf<3; sf=sf+1 )) do 4 | python llama2_classification.py --train_sample_fraction ${sample_fraction[$sf]} & wait 5 | done 6 | -------------------------------------------------------------------------------- /llmtune/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Georgian Partners 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __version__ = "0.0.0" 16 | -------------------------------------------------------------------------------- /llmtune/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/llmtune/cli/__init__.py -------------------------------------------------------------------------------- /llmtune/cli/toolkit.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import shutil 3 | from pathlib import Path 4 | 5 | import torch 6 | import transformers 7 | import typer 8 | import yaml 9 | from pydantic import ValidationError 10 | from typing_extensions import Annotated 11 | 12 | import llmtune 13 | from llmtune.constants.files import EXAMPLE_CONFIG_FNAME 14 | from llmtune.data.dataset_generator import DatasetGenerator 15 | from llmtune.finetune.lora import LoRAFinetune 16 | from llmtune.inference.lora import LoRAInference 17 | from llmtune.pydantic_models.config_model import Config 18 | from llmtune.qa.metric_suite import LLMMetricSuite 19 | from llmtune.qa.qa_metrics import QaMetricRegistry 20 | from llmtune.qa.test_suite import LLMTestSuite 21 | from llmtune.ui.rich_ui import RichUI 22 | from llmtune.utils.ablation_utils import generate_permutations 23 | from llmtune.utils.save_utils import DirectoryHelper 24 | 25 | 26 | transformers.logging.set_verbosity(transformers.logging.CRITICAL) 27 | torch._logging.set_logs(all=logging.CRITICAL) 28 | logging.captureWarnings(True) 29 | 30 | 31 | app = typer.Typer() 32 | generate_app = typer.Typer() 33 | 34 | app.add_typer( 35 | generate_app, 36 | name="generate", 37 | help="Generate various artefacts, such as config files", 38 | ) 39 | 40 | 41 | def run_one_experiment(config: Config, config_path: Path) -> None: 42 | dir_helper = DirectoryHelper(config_path, config) 43 | 44 | # Loading Data ------------------------------- 45 | RichUI.before_dataset_creation() 46 | 47 | with RichUI.during_dataset_creation("Injecting Values into Prompt", "monkey"): 48 | dataset_generator = DatasetGenerator(**config.data.model_dump()) 49 | 50 | _ = dataset_generator.train_columns 51 | test_column = dataset_generator.test_column 52 | 53 | dataset_path = dir_helper.save_paths.dataset 54 | if not dataset_path.exists(): 55 | train, test = dataset_generator.get_dataset() 56 | dataset_generator.save_dataset(dataset_path) 57 | else: 58 | RichUI.dataset_found(dataset_path) 59 | train, test = dataset_generator.load_dataset_from_pickle(dataset_path) 60 | 61 | RichUI.dataset_display_one_example(train[0], test[0]) 62 | RichUI.after_dataset_creation(dataset_path, train, test) 63 | 64 | # Loading Model ------------------------------- 65 | RichUI.before_finetune() 66 | 67 | weights_path = dir_helper.save_paths.weights 68 | 69 | # model_loader = ModelLoader(config, console, dir_helper) 70 | if not weights_path.exists() or not any(weights_path.iterdir()): 71 | finetuner = LoRAFinetune(config, dir_helper) 72 | with RichUI.during_finetune(): 73 | finetuner.finetune(train) 74 | finetuner.save_model() 75 | RichUI.after_finetune() 76 | else: 77 | RichUI.finetune_found(weights_path) 78 | 79 | # Inference ------------------------------- 80 | RichUI.before_inference() 81 | results_path = dir_helper.save_paths.results 82 | results_file_path = dir_helper.save_paths.results_file 83 | if not results_file_path.exists(): 84 | inference_runner = LoRAInference(test, test_column, config, dir_helper) 85 | inference_runner.infer_all() 86 | RichUI.after_inference(results_path) 87 | else: 88 | RichUI.results_found(results_path) 89 | 90 | # Quality Assurance ------------------------- 91 | RichUI.before_qa() 92 | 93 | qa_folder_path = dir_helper.save_paths.qa 94 | if not qa_folder_path.exists(): 95 | # metrics 96 | llm_metrics = config.qa.llm_metrics 97 | metrics = QaMetricRegistry.create_metrics_from_list(llm_metrics) 98 | metric_suite = LLMMetricSuite.from_csv(results_file_path, metrics) 99 | qa_metric_file = dir_helper.save_paths.metric_file 100 | metric_suite.save_metric_results(qa_metric_file) 101 | metric_suite.print_metric_results() 102 | 103 | # testing suites 104 | inference_runner = LoRAInference(test, test_column, config, dir_helper) 105 | test_suite_path = config.qa.test_suite 106 | test_suite = LLMTestSuite.from_dir(test_suite_path) 107 | test_suite.run_inference(inference_runner) 108 | test_suite.save_test_results(dir_helper.save_paths.qa) 109 | test_suite.print_test_results() 110 | 111 | 112 | @app.command("run") 113 | def run(config_path: Annotated[str, typer.Argument(help="Path of the config yaml file")] = "./config.yml") -> None: 114 | """Run the entire exmperiment pipeline""" 115 | # Load YAML config 116 | with Path(config_path).open("r") as file: 117 | config = yaml.safe_load(file) 118 | configs = ( 119 | generate_permutations(config, Config) if config.get("ablation", {}).get("use_ablate", False) else [config] 120 | ) 121 | for config in configs: 122 | # validate data with pydantic 123 | try: 124 | config = Config(**config) 125 | except ValidationError as e: 126 | print(e.json()) 127 | 128 | dir_helper = DirectoryHelper(config_path, config) 129 | 130 | # Reload config from saved config 131 | with dir_helper.save_paths.config_file.open("r") as file: 132 | config = yaml.safe_load(file) 133 | config = Config(**config) 134 | 135 | run_one_experiment(config, config_path) 136 | 137 | 138 | @generate_app.command("config") 139 | def generate_config(): 140 | """ 141 | Generate an example `config.yml` file in current directory 142 | """ 143 | module_path = Path(llmtune.__file__) 144 | example_config_path = module_path.parent / EXAMPLE_CONFIG_FNAME 145 | destination = Path.cwd() 146 | shutil.copy(example_config_path, destination) 147 | RichUI.generate_config(EXAMPLE_CONFIG_FNAME) 148 | 149 | 150 | def cli(): 151 | app() 152 | -------------------------------------------------------------------------------- /llmtune/config.yml: -------------------------------------------------------------------------------- 1 | save_dir: "./experiment/" 2 | 3 | ablation: 4 | use_ablate: false 5 | 6 | # Data Ingestion ------------------- 7 | data: 8 | file_type: "huggingface" # one of 'json', 'csv', 'huggingface' 9 | path: "yahma/alpaca-cleaned" 10 | prompt: 11 | >- # prompt, make sure column inputs are enclosed in {} brackets and that they match your data 12 | Below is an instruction that describes a task. 13 | Write a response that appropriately completes the request. 14 | ### Instruction: {instruction} 15 | ### Input: {input} 16 | ### Output: 17 | prompt_stub: 18 | >- # Stub to add for training at the end of prompt, for test set or inference, this is omitted; make sure only one variable is present 19 | {output} 20 | test_size: 25 # Proportion of test as % of total; if integer then # of samples 21 | train_size: 500 # Proportion of train as % of total; if integer then # of samples 22 | train_test_split_seed: 42 23 | 24 | # Model Definition ------------------- 25 | model: 26 | hf_model_ckpt: "facebook/opt-125m" 27 | torch_dtype: "bfloat16" 28 | #attn_implementation: "flash_attention_2" 29 | quantize: true 30 | bitsandbytes: 31 | load_in_4bit: true 32 | bnb_4bit_compute_dtype: "bfloat16" 33 | bnb_4bit_quant_type: "nf4" 34 | 35 | # LoRA Params ------------------- 36 | lora: 37 | task_type: "CAUSAL_LM" 38 | r: 32 39 | lora_alpha: 64 40 | lora_dropout: 0.1 41 | target_modules: "all-linear" 42 | # to target specific modules 43 | # target_modules: 44 | # - q_proj 45 | # - v_proj 46 | # - k_proj 47 | # - o_proj 48 | # - up_proj 49 | # - down_proj 50 | # - gate_proj 51 | 52 | # Training ------------------- 53 | training: 54 | training_args: 55 | num_train_epochs: 1 56 | per_device_train_batch_size: 4 57 | gradient_accumulation_steps: 4 58 | gradient_checkpointing: True 59 | optim: "paged_adamw_32bit" 60 | logging_steps: 1 61 | learning_rate: 2.0e-4 62 | bf16: true # [Ampere+] Set to true for mixed precision training on Newer GPUs 63 | tf32: true # [Ampere+] Set to true for mixed precision training on Newer GPUs 64 | # fp16: false # Set to true for mixed precision training on Older GPUs 65 | max_grad_norm: 0.3 66 | warmup_ratio: 0.03 67 | lr_scheduler_type: "constant" 68 | sft_args: 69 | max_seq_length: 1024 70 | # neftune_noise_alpha: None 71 | 72 | inference: 73 | max_new_tokens: 256 74 | use_cache: True 75 | do_sample: True 76 | top_p: 0.9 77 | temperature: 0.8 78 | 79 | qa: 80 | llm_metrics: 81 | - jaccard_similarity 82 | - dot_product 83 | - rouge_score 84 | - word_overlap 85 | - verb_percent 86 | - adjective_percent 87 | - noun_percent 88 | - summary_length 89 | test_suite: "examples/test_suite" 90 | -------------------------------------------------------------------------------- /llmtune/constants/files.py: -------------------------------------------------------------------------------- 1 | # Example config file 2 | EXAMPLE_CONFIG_FNAME = "config.yml" 3 | 4 | # DIRECTORY HELPER - HASH SETTING 5 | NUM_MD5_DIGITS_FOR_SQIDS = 2 6 | 7 | # DIRECTORY HELPER - DIRECTORY & FILE NAMES 8 | CONFIG_DIR_NAME = "config" 9 | CONFIG_FILE_NAME = "config.yml" 10 | 11 | DATASET_DIR_NAME = "dataset" 12 | 13 | WEIGHTS_DIR_NAME = "weights" 14 | 15 | RESULTS_DIR_NAME = "results" 16 | RESULTS_FILE_NAME = "results.csv" 17 | 18 | QA_DIR_NAME = "qa" 19 | METRIC_FILE_NAME = "qa_metrics_results.csv" 20 | -------------------------------------------------------------------------------- /llmtune/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/llmtune/data/__init__.py -------------------------------------------------------------------------------- /llmtune/data/dataset_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import re 4 | from functools import partial 5 | from os.path import exists, join 6 | from typing import Tuple, Union 7 | 8 | from datasets import Dataset 9 | 10 | from llmtune.data.ingestor import Ingestor, get_ingestor 11 | 12 | 13 | class DatasetGenerator: 14 | def __init__( 15 | self, 16 | file_type: str, 17 | path: str, 18 | prompt: str, 19 | prompt_stub: str, 20 | test_size: Union[float, int], 21 | train_size: Union[float, int], 22 | train_test_split_seed: int, 23 | ): 24 | self.ingestor: Ingestor = get_ingestor(file_type) 25 | self.ingestor: Ingestor = self.ingestor(path) 26 | 27 | self.dataset: Dataset = self.ingestor.to_dataset() 28 | self.prompt: str = prompt 29 | self.prompt_stub: str = prompt_stub 30 | self.test_size = test_size 31 | self.train_size = train_size 32 | self.train_test_split_seed: int = train_test_split_seed 33 | 34 | self.train_columns: list = self._get_train_columns() 35 | self.test_column: str = self._get_test_column() 36 | 37 | def _get_train_columns(self): 38 | pattern = r"\{([^}]*)\}" 39 | return re.findall(pattern, self.prompt) 40 | 41 | def _get_test_column(self): 42 | pattern = r"\{([^}]*)\}" 43 | return re.findall(pattern, self.prompt_stub)[0] 44 | 45 | # TODO: stratify_by_column 46 | def _train_test_split(self): 47 | self.dataset = self.dataset.train_test_split( 48 | test_size=self.test_size, 49 | train_size=self.train_size, 50 | seed=self.train_test_split_seed, 51 | ) 52 | 53 | def _format_one_prompt(self, example, is_test: bool = False): 54 | train_mapping = {var_name: example[var_name] for var_name in self.train_columns} 55 | example["formatted_prompt"] = self.prompt.format(**train_mapping) 56 | 57 | if not is_test: 58 | test_mapping = {self.test_column: example[self.test_column]} 59 | example["formatted_prompt"] += self.prompt_stub.format(**test_mapping) 60 | 61 | return example 62 | 63 | def _format_prompts(self): 64 | self.dataset["train"] = self.dataset["train"].map(partial(self._format_one_prompt, is_test=False)) 65 | self.dataset["test"] = self.dataset["test"].map(partial(self._format_one_prompt, is_test=True)) 66 | 67 | def get_dataset(self) -> Tuple[Dataset, Dataset]: 68 | self._train_test_split() 69 | self._format_prompts() 70 | 71 | return self.dataset["train"], self.dataset["test"] 72 | 73 | def save_dataset(self, save_dir: str): 74 | os.makedirs(save_dir, exist_ok=True) 75 | with open(join(save_dir, "dataset.pkl"), "wb") as f: 76 | pickle.dump(self.dataset, f) 77 | 78 | def load_dataset_from_pickle(self, save_dir: str): 79 | data_path = join(save_dir, "dataset.pkl") 80 | 81 | if not exists(data_path): 82 | raise FileNotFoundError(f"Train set pickle not found at {save_dir}") 83 | 84 | with open(data_path, "rb") as f: 85 | data = pickle.load(f) 86 | self.dataset = data 87 | 88 | return self.dataset["train"], self.dataset["test"] 89 | -------------------------------------------------------------------------------- /llmtune/data/ingestor.py: -------------------------------------------------------------------------------- 1 | import csv 2 | from abc import ABC, abstractmethod 3 | 4 | import ijson 5 | from datasets import Dataset, concatenate_datasets, load_dataset 6 | 7 | 8 | def get_ingestor(data_type: str): 9 | if data_type == "json": 10 | return JsonIngestor 11 | elif data_type == "jsonl": 12 | return JsonlIngestor 13 | elif data_type == "csv": 14 | return CsvIngestor 15 | elif data_type == "huggingface": 16 | return HuggingfaceIngestor 17 | else: 18 | raise ValueError(f"'type' must be one of 'json', 'jsonl', 'csv', or 'huggingface', you have {data_type}") 19 | 20 | 21 | class Ingestor(ABC): 22 | @abstractmethod 23 | def to_dataset(self) -> Dataset: 24 | pass 25 | 26 | 27 | class JsonIngestor(Ingestor): 28 | def __init__(self, path: str): 29 | self.path = path 30 | 31 | def _json_generator(self): 32 | with open(self.path, "rb") as f: 33 | for item in ijson.items(f, "item"): 34 | yield item 35 | 36 | def to_dataset(self) -> Dataset: 37 | return Dataset.from_generator(self._json_generator) 38 | 39 | 40 | class JsonlIngestor(Ingestor): 41 | def __init__(self, path: str): 42 | self.path = path 43 | 44 | def _jsonl_generator(self): 45 | with open(self.path, "rb") as f: 46 | for item in ijson.items(f, "", multiple_values=True): 47 | yield item 48 | 49 | def to_dataset(self) -> Dataset: 50 | return Dataset.from_generator(self._jsonl_generator) 51 | 52 | 53 | class CsvIngestor(Ingestor): 54 | def __init__(self, path: str): 55 | self.path = path 56 | 57 | def _csv_generator(self): 58 | with open(self.path) as csvfile: 59 | reader = csv.DictReader(csvfile) 60 | for row in reader: 61 | yield row 62 | 63 | def to_dataset(self) -> Dataset: 64 | return Dataset.from_generator(self._csv_generator) 65 | 66 | 67 | class HuggingfaceIngestor(Ingestor): 68 | def __init__(self, path: str): 69 | self.path = path 70 | 71 | def to_dataset(self) -> Dataset: 72 | ds = load_dataset(self.path) 73 | return concatenate_datasets(ds.values()) 74 | -------------------------------------------------------------------------------- /llmtune/finetune/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/llmtune/finetune/__init__.py -------------------------------------------------------------------------------- /llmtune/finetune/generics.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Finetune(ABC): 5 | @abstractmethod 6 | def finetune(self): 7 | pass 8 | 9 | @abstractmethod 10 | def save_model(self): 11 | pass 12 | -------------------------------------------------------------------------------- /llmtune/finetune/lora.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | from datasets import Dataset 4 | from peft import ( 5 | LoraConfig, 6 | get_peft_model, 7 | prepare_model_for_kbit_training, 8 | ) 9 | from transformers import ( 10 | AutoModelForCausalLM, 11 | AutoTokenizer, 12 | BitsAndBytesConfig, 13 | ProgressCallback, 14 | TrainingArguments, 15 | ) 16 | from trl import SFTTrainer 17 | 18 | from llmtune.finetune.generics import Finetune 19 | from llmtune.pydantic_models.config_model import Config 20 | from llmtune.ui.rich_ui import RichUI 21 | from llmtune.utils.save_utils import DirectoryHelper 22 | 23 | 24 | class LoRAFinetune(Finetune): 25 | def __init__(self, config: Config, directory_helper: DirectoryHelper): 26 | self.config = config 27 | 28 | self._model_config = config.model 29 | self._training_args = config.training.training_args 30 | self._sft_args = config.training.sft_args 31 | self._lora_config = LoraConfig(**config.lora.model_dump()) 32 | self._directory_helper = directory_helper 33 | self._weights_path = self._directory_helper.save_paths.weights 34 | self._trainer = None 35 | 36 | self.model = None 37 | self.tokenizer = None 38 | 39 | self.device_map = self._model_config.device_map 40 | 41 | self._load_model_and_tokenizer() 42 | self._inject_lora() 43 | 44 | def _load_model_and_tokenizer(self): 45 | ckpt = self._model_config.hf_model_ckpt 46 | RichUI.on_basemodel_load(ckpt) 47 | model = self._get_model() 48 | tokenizer = self._get_tokenizer() 49 | RichUI.after_basemodel_load(ckpt) 50 | 51 | self.model = model 52 | self.tokenizer = tokenizer 53 | 54 | def _get_model(self): 55 | model = AutoModelForCausalLM.from_pretrained( 56 | self._model_config.hf_model_ckpt, 57 | quantization_config=BitsAndBytesConfig(**self._model_config.bitsandbytes.model_dump()), 58 | use_cache=False, 59 | device_map=self.device_map, 60 | torch_dtype=self._model_config.casted_torch_dtype, 61 | attn_implementation=self._model_config.attn_implementation, 62 | ) 63 | 64 | model.config.pretraining_tp = 1 65 | 66 | return model 67 | 68 | def _get_tokenizer(self): 69 | tokenizer = AutoTokenizer.from_pretrained(self._model_config.hf_model_ckpt) 70 | tokenizer.pad_token = tokenizer.eos_token 71 | tokenizer.padding_side = "right" 72 | 73 | return tokenizer 74 | 75 | def _inject_lora(self): 76 | self.model.gradient_checkpointing_enable() 77 | self.model = prepare_model_for_kbit_training(self.model) 78 | self.model = get_peft_model(self.model, self._lora_config) 79 | 80 | def finetune(self, train_dataset: Dataset): 81 | logging_dir = join(self._weights_path, "/logs") 82 | training_args = TrainingArguments( 83 | output_dir=self._weights_path, 84 | logging_dir=logging_dir, 85 | report_to="none", 86 | **self._training_args.model_dump(), 87 | ) 88 | 89 | progress_callback = ProgressCallback() 90 | 91 | self._trainer = SFTTrainer( 92 | model=self.model, 93 | train_dataset=train_dataset, 94 | peft_config=self._lora_config, 95 | tokenizer=self.tokenizer, 96 | packing=True, 97 | args=training_args, 98 | dataset_text_field="formatted_prompt", # TODO: maybe move consts to a dedicated folder 99 | callbacks=[progress_callback], 100 | **self._sft_args.model_dump(), 101 | ) 102 | 103 | self._trainer.train() 104 | 105 | def save_model(self) -> None: 106 | self._trainer.model.save_pretrained(self._weights_path) 107 | self.tokenizer.save_pretrained(self._weights_path) 108 | -------------------------------------------------------------------------------- /llmtune/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/llmtune/inference/__init__.py -------------------------------------------------------------------------------- /llmtune/inference/generics.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Inference(ABC): 5 | @abstractmethod 6 | def infer_one(self, prompt: str): 7 | pass 8 | 9 | @abstractmethod 10 | def infer_all(self): 11 | pass 12 | -------------------------------------------------------------------------------- /llmtune/inference/lora.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from os.path import join 4 | from threading import Thread 5 | 6 | import torch 7 | from datasets import Dataset 8 | from peft import AutoPeftModelForCausalLM 9 | from rich.text import Text 10 | from transformers import AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer 11 | 12 | from llmtune.inference.generics import Inference 13 | from llmtune.pydantic_models.config_model import Config 14 | from llmtune.ui.rich_ui import RichUI 15 | from llmtune.utils.save_utils import DirectoryHelper 16 | 17 | 18 | # TODO: Add type hints please! 19 | class LoRAInference(Inference): 20 | def __init__( 21 | self, 22 | test_dataset: Dataset, 23 | label_column_name: str, 24 | config: Config, 25 | dir_helper: DirectoryHelper, 26 | ): 27 | self.test_dataset = test_dataset 28 | self.label_column = label_column_name 29 | self.config = config 30 | 31 | self.save_dir = dir_helper.save_paths.results 32 | self.save_path = join(self.save_dir, "results.csv") 33 | self.device_map = self.config.model.device_map 34 | self._weights_path = dir_helper.save_paths.weights 35 | 36 | self.model, self.tokenizer = self._get_merged_model(dir_helper.save_paths.weights) 37 | 38 | def _get_merged_model(self, weights_path: str): 39 | # purge VRAM 40 | torch.cuda.empty_cache() 41 | 42 | # Load from path 43 | 44 | self.model = AutoPeftModelForCausalLM.from_pretrained( 45 | weights_path, 46 | torch_dtype=self.config.model.casted_torch_dtype, 47 | quantization_config=BitsAndBytesConfig(**self.config.model.bitsandbytes.model_dump()), 48 | device_map=self.device_map, 49 | attn_implementation=self.config.model.attn_implementation, 50 | ) 51 | 52 | model = self.model.merge_and_unload() 53 | 54 | tokenizer = AutoTokenizer.from_pretrained(self._weights_path, device_map=self.device_map) 55 | 56 | return model, tokenizer 57 | 58 | def infer_all(self): 59 | results = [] 60 | prompts = self.test_dataset["formatted_prompt"] 61 | labels = self.test_dataset[self.label_column] 62 | 63 | # inference loop 64 | for idx, (prompt, label) in enumerate(zip(prompts, labels)): 65 | RichUI.inference_ground_truth_display(f"Generating on test set: {idx+1}/{len(prompts)}", prompt, label) 66 | 67 | try: 68 | result = self.infer_one(prompt) 69 | except Exception: 70 | continue 71 | results.append((prompt, label, result)) 72 | 73 | # TODO: seperate this into another method 74 | header = ["Prompt", "Ground Truth", "Predicted"] 75 | os.makedirs(self.save_dir, exist_ok=True) 76 | with open(self.save_path, "w", newline="") as f: 77 | writer = csv.writer(f) 78 | writer.writerow(header) 79 | for row in results: 80 | writer.writerow(row) 81 | 82 | def infer_one(self, prompt: str) -> str: 83 | input_ids = self.tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda() 84 | 85 | # stream processor 86 | streamer = TextIteratorStreamer( 87 | self.tokenizer, 88 | skip_prompt=True, 89 | decode_kwargs={"skip_special_tokens": True}, 90 | timeout=60, # 60 sec timeout for generation; to handle OOM errors 91 | ) 92 | 93 | generation_kwargs = dict(input_ids=input_ids, streamer=streamer, **self.config.inference.model_dump()) 94 | 95 | thread = Thread(target=self.model.generate, kwargs=generation_kwargs) 96 | thread.start() 97 | 98 | result = Text() 99 | with RichUI.inference_stream_display(result) as live: 100 | for new_text in streamer: 101 | result.append(new_text) 102 | live.update(result) 103 | 104 | return str(result) 105 | -------------------------------------------------------------------------------- /llmtune/pydantic_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/llmtune/pydantic_models/__init__.py -------------------------------------------------------------------------------- /llmtune/qa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/llmtune/qa/__init__.py -------------------------------------------------------------------------------- /llmtune/qa/metric_suite.py: -------------------------------------------------------------------------------- 1 | import statistics 2 | from pathlib import Path 3 | from typing import Dict, List, Union 4 | 5 | import pandas as pd 6 | 7 | from llmtune.qa.qa_metrics import LLMQaMetric 8 | from llmtune.ui.rich_ui import RichUI 9 | 10 | 11 | class LLMMetricSuite: 12 | """ 13 | Represents and runs a suite of metrics on a set of prompts, 14 | golden responses, and model predictions. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | metrics: List[LLMQaMetric], 20 | prompts: List[str], 21 | ground_truths: List[str], 22 | model_preds: List[str], 23 | ) -> None: 24 | self.metrics = metrics 25 | self.prompts = prompts 26 | self.ground_truths = ground_truths 27 | self.model_preds = model_preds 28 | 29 | self._results: Dict[str, List[Union[float, int]]] = {} 30 | 31 | @staticmethod 32 | def from_csv( 33 | file_path: str, 34 | metrics: List[LLMQaMetric], 35 | prompt_col: str = "Prompt", 36 | gold_col: str = "Ground Truth", 37 | pred_col="Predicted", 38 | ) -> "LLMMetricSuite": 39 | results_df = pd.read_csv(file_path) 40 | prompts = results_df[prompt_col].tolist() 41 | ground_truths = results_df[gold_col].tolist() 42 | model_preds = results_df[pred_col].tolist() 43 | return LLMMetricSuite(metrics, prompts, ground_truths, model_preds) 44 | 45 | def compute_metrics(self) -> Dict[str, List[Union[float, int]]]: 46 | results = {} 47 | for metric in self.metrics: 48 | metric_results = [] 49 | for prompt, ground_truth, model_pred in zip(self.prompts, self.ground_truths, self.model_preds): 50 | metric_results.append(metric.get_metric(prompt, ground_truth, model_pred)) 51 | results[metric.metric_name] = metric_results 52 | 53 | self._results = results 54 | return results 55 | 56 | @property 57 | def metric_results(self) -> Dict[str, List[Union[float, int]]]: 58 | return self._results if self._results else self.compute_metrics() 59 | 60 | def print_metric_results(self): 61 | result_dictionary = self.metric_results 62 | column_data = {key: list(result_dictionary[key]) for key in result_dictionary} 63 | mean_values = {key: statistics.mean(column_data[key]) for key in column_data} 64 | median_values = {key: statistics.median(column_data[key]) for key in column_data} 65 | stdev_values = {key: statistics.stdev(column_data[key]) for key in column_data} 66 | # Use the RichUI class to display the table 67 | RichUI.qa_display_metric_table(result_dictionary, mean_values, median_values, stdev_values) 68 | 69 | def save_metric_results(self, path: str): 70 | # TODO: save these! 71 | path = Path(path) 72 | dir = path.parent 73 | 74 | if not dir.exists(): 75 | dir.mkdir(parents=True, exist_ok=True) 76 | 77 | resultant_dataframe = pd.DataFrame(self.metric_results) 78 | resultant_dataframe.to_csv(path, index=False) 79 | -------------------------------------------------------------------------------- /llmtune/qa/qa_metrics.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Union 3 | 4 | import nltk 5 | import numpy as np 6 | import torch 7 | from langchain.evaluation import JsonValidityEvaluator 8 | from nltk import pos_tag 9 | from nltk.corpus import stopwords 10 | from nltk.tokenize import word_tokenize 11 | from rouge_score import rouge_scorer 12 | from transformers import DistilBertModel, DistilBertTokenizer 13 | 14 | 15 | json_validity_evaluator = JsonValidityEvaluator() 16 | 17 | nltk.download("stopwords") 18 | nltk.download("punkt") 19 | nltk.download("averaged_perceptron_tagger") 20 | 21 | 22 | class LLMQaMetric(ABC): 23 | """ 24 | Abstract base class for a metric. A metric can be computed over a single 25 | data instance, and outputs a scalar value (integer or float). 26 | """ 27 | 28 | @property 29 | @abstractmethod 30 | def metric_name(self) -> str: 31 | pass 32 | 33 | @abstractmethod 34 | def get_metric(self, prompt: str, grount_truth: str, model_pred: str) -> Union[float, int]: 35 | pass 36 | 37 | 38 | class QaMetricRegistry: 39 | """Provides a registry that maps metric names to metric classes. 40 | A user can provide a list of metrics by name, and the registry will convert 41 | that into a list of metric objects. 42 | """ 43 | 44 | registry = {} 45 | 46 | @classmethod 47 | def register(cls, *names): 48 | def inner_wrapper(wrapped_class): 49 | for name in names: 50 | cls.registry[name] = wrapped_class 51 | return wrapped_class 52 | 53 | return inner_wrapper 54 | 55 | @classmethod 56 | def create_metrics_from_list(cls, metric_names: List[str]) -> List[LLMQaMetric]: 57 | return [cls.registry[metric]() for metric in metric_names] 58 | 59 | 60 | @QaMetricRegistry.register("summary_length") 61 | class LengthMetric(LLMQaMetric): 62 | @property 63 | def metric_name(self) -> str: 64 | return "summary_length" 65 | 66 | def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: 67 | return abs(len(ground_truth) - len(model_prediction)) 68 | 69 | 70 | @QaMetricRegistry.register("jaccard_similarity") 71 | class JaccardSimilarityMetric(LLMQaMetric): 72 | @property 73 | def metric_name(self) -> str: 74 | return "jaccard_similarity" 75 | 76 | def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: 77 | set_ground_truth = set(ground_truth.lower()) 78 | set_model_prediction = set(model_prediction.lower()) 79 | 80 | intersection_size = len(set_ground_truth.intersection(set_model_prediction)) 81 | union_size = len(set_ground_truth.union(set_model_prediction)) 82 | 83 | similarity = intersection_size / union_size if union_size != 0 else 0 84 | return float(similarity) 85 | 86 | 87 | @QaMetricRegistry.register("dot_product") 88 | class DotProductSimilarityMetric(LLMQaMetric): 89 | """Encodes both the ground truth and model prediction using DistilBERT, and 90 | computes the dot product similarity between the two embeddings.""" 91 | 92 | def __init__(self): 93 | model_name = "distilbert-base-uncased" 94 | self.tokenizer = DistilBertTokenizer.from_pretrained(model_name) 95 | self.model = DistilBertModel.from_pretrained(model_name) 96 | 97 | @property 98 | def metric_name(self) -> str: 99 | return "dot_product" 100 | 101 | def _encode_sentence(self, sentence): 102 | tokens = self.tokenizer(sentence, return_tensors="pt") 103 | with torch.no_grad(): 104 | outputs = self.model(**tokens) 105 | return outputs.last_hidden_state.mean(dim=1).squeeze().numpy() 106 | 107 | def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: 108 | embedding_ground_truth = self._encode_sentence(ground_truth) 109 | embedding_model_prediction = self._encode_sentence(model_prediction) 110 | dot_product_similarity = np.dot(embedding_ground_truth, embedding_model_prediction) 111 | return float(dot_product_similarity) 112 | 113 | 114 | @QaMetricRegistry.register("rouge_score") 115 | class RougeScoreMetric(LLMQaMetric): 116 | @property 117 | def metric_name(self) -> str: 118 | return "rouge_score" 119 | 120 | def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: 121 | scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True) 122 | scores = scorer.score(model_prediction, ground_truth) 123 | return float(scores["rouge1"].precision) 124 | 125 | 126 | @QaMetricRegistry.register("word_overlap") 127 | class WordOverlapMetric(LLMQaMetric): 128 | @property 129 | def metric_name(self) -> str: 130 | return "word_overlap" 131 | 132 | def _remove_stopwords(self, text: str) -> str: 133 | stop_words = set(stopwords.words("english")) 134 | word_tokens = word_tokenize(text) 135 | filtered_text = [word for word in word_tokens if word.lower() not in stop_words] 136 | return " ".join(filtered_text) 137 | 138 | def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> Union[float, int, bool]: 139 | cleaned_model_prediction = self._remove_stopwords(model_prediction) 140 | cleaned_ground_truth = self._remove_stopwords(ground_truth) 141 | 142 | words_model_prediction = set(cleaned_model_prediction.split()) 143 | words_ground_truth = set(cleaned_ground_truth.split()) 144 | 145 | common_words = words_model_prediction.intersection(words_ground_truth) 146 | overlap_percentage = (len(common_words) / len(words_ground_truth)) * 100 147 | return float(overlap_percentage) 148 | 149 | 150 | @QaMetricRegistry.register("json_valid") 151 | class JSONValidityMetric(LLMQaMetric): 152 | """ 153 | Checks to see if valid json can be parsed from the model output, according 154 | to langchain_core.utils.json.parse_json_markdown 155 | The JSON can be wrapped in markdown and this test will still pass 156 | """ 157 | 158 | @property 159 | def metric_name(self) -> str: 160 | return "json_valid" 161 | 162 | def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float: 163 | result = json_validity_evaluator.evaluate_strings(prediction=model_prediction) 164 | binary_res = result["score"] 165 | return float(binary_res) 166 | 167 | 168 | class PosCompositionMetric(LLMQaMetric): 169 | def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float: 170 | words = word_tokenize(text) 171 | tags = pos_tag(words) 172 | pos_words = [word for word, tag in tags if tag in pos_tags] 173 | total_words = len(text.split(" ")) 174 | return round(len(pos_words) / total_words, 2) 175 | 176 | 177 | @QaMetricRegistry.register("verb_percent") 178 | class VerbPercentMetric(PosCompositionMetric): 179 | @property 180 | def metric_name(self) -> str: 181 | return "verb_percent" 182 | 183 | def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float: 184 | return self._get_pos_percent(model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"]) 185 | 186 | 187 | @QaMetricRegistry.register("adjective_percent") 188 | class AdjectivePercentMetric(PosCompositionMetric): 189 | @property 190 | def metric_name(self) -> str: 191 | return "adjective_percent" 192 | 193 | def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float: 194 | return self._get_pos_percent(model_prediction, ["JJ", "JJR", "JJS"]) 195 | 196 | 197 | @QaMetricRegistry.register("noun_percent") 198 | class NounPercentMetric(PosCompositionMetric): 199 | @property 200 | def metric_name(self) -> str: 201 | return "noun_percent" 202 | 203 | def get_metric(self, prompt: str, ground_truth: str, model_prediction: str) -> float: 204 | return self._get_pos_percent(model_prediction, ["NN", "NNS", "NNP", "NNPS"]) 205 | 206 | 207 | # Instantiate tests 208 | # length_test = LengthMetric() 209 | # jaccard_similarity_test = JaccardSimilarityMetric() 210 | # dot_product_similarity_test = DotProductSimilarityMetric() 211 | # rouge_score_test = RougeScoreMetric() 212 | # word_overlap_test = WordOverlapMetric() 213 | # verb_percent_test = VerbPercentMetric() 214 | # adjective_percent_test = AdjectivePercentMetric() 215 | # noun_percent_test = NounPercentMetric() 216 | -------------------------------------------------------------------------------- /llmtune/qa/qa_tests.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | from langchain.evaluation import JsonValidityEvaluator 7 | from transformers import DistilBertModel, DistilBertTokenizer 8 | 9 | 10 | class LLMQaTest(ABC): 11 | """ 12 | Abstract base class for a test. A test can be computed over a single 13 | data instance/llm response, and outputs a boolean value (pass or fail). 14 | """ 15 | 16 | @property 17 | @abstractmethod 18 | def test_name(self) -> str: 19 | pass 20 | 21 | @abstractmethod 22 | def test(self, prompt: str, grount_truth: str, model_pred: str) -> bool: 23 | pass 24 | 25 | 26 | # TODO this is the same as QaMetricRegistry, could be combined? 27 | class QaTestRegistry: 28 | """Provides a registry that maps metric names to metric classes. 29 | A user can provide a list of metrics by name, and the registry will convert 30 | that into a list of metric objects. 31 | """ 32 | 33 | registry = {} 34 | 35 | @classmethod 36 | def register(cls, *names): 37 | def inner_wrapper(wrapped_class): 38 | for name in names: 39 | cls.registry[name] = wrapped_class 40 | return wrapped_class 41 | 42 | return inner_wrapper 43 | 44 | @classmethod 45 | def create_tests_from_list(cls, test_names: List[str]) -> List[LLMQaTest]: 46 | return [cls.registry[test]() for test in test_names] 47 | 48 | @classmethod 49 | def from_name(cls, name: str) -> LLMQaTest: 50 | """Return a LLMQaTest object from a given name.""" 51 | return cls.registry[name]() 52 | 53 | 54 | @QaTestRegistry.register("json_valid") 55 | class JSONValidityTest(LLMQaTest): 56 | """ 57 | Checks to see if valid json can be parsed from the model output, according 58 | to langchain_core.utils.json.parse_json_markdown 59 | The JSON can be wrapped in markdown and this test will still pass 60 | """ 61 | 62 | def __init__(self): 63 | self.json_validity_evaluator = JsonValidityEvaluator() 64 | 65 | @property 66 | def test_name(self) -> str: 67 | return "json_valid" 68 | 69 | def test(self, model_pred: str) -> bool: 70 | result = self.json_validity_evaluator.evaluate_strings(prediction=model_pred) 71 | binary_res = result["score"] 72 | return bool(binary_res) 73 | 74 | 75 | @QaTestRegistry.register("cosine_similarity") 76 | class CosineSimilarityTest(LLMQaTest): 77 | """ 78 | Checks to see if the response of the LLM is within a certain cosine 79 | similarity to the gold-standard response. Uses a DistilBERT model to encode 80 | the responses into vectors. 81 | """ 82 | 83 | def __init__(self): 84 | model_name = "distilbert-base-uncased" 85 | self.tokenizer = DistilBertTokenizer.from_pretrained(model_name) 86 | self.model = DistilBertModel.from_pretrained(model_name) 87 | 88 | @property 89 | def test_name(self) -> str: 90 | return "cosine_similarity" 91 | 92 | def _encode_sentence(self, sentence: str) -> np.ndarray: 93 | """Encode a sentence into a vector using a language model.""" 94 | tokens = self.tokenizer(sentence, return_tensors="pt") 95 | with torch.no_grad(): 96 | outputs = self.model(**tokens) 97 | return outputs.last_hidden_state.mean(dim=1).squeeze().numpy() 98 | 99 | def test(self, model_pred: str, ground_truth: str, threshold: float = 0.8) -> bool: 100 | embedding_ground_truth = self._encode_sentence(ground_truth) 101 | embedding_model_prediction = self._encode_sentence(model_pred) 102 | dot_product = np.dot(embedding_ground_truth, embedding_model_prediction) 103 | norm_ground_truth = np.linalg.norm(embedding_ground_truth) 104 | norm_model_prediction = np.linalg.norm(embedding_model_prediction) 105 | cosine_similarity = dot_product / (norm_ground_truth * norm_model_prediction) 106 | return cosine_similarity >= threshold 107 | -------------------------------------------------------------------------------- /llmtune/qa/test_suite.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict, List 3 | 4 | import pandas as pd 5 | 6 | from llmtune.inference.lora import LoRAInference 7 | from llmtune.qa.qa_tests import LLMQaTest, QaTestRegistry 8 | from llmtune.ui.rich_ui import RichUI 9 | 10 | 11 | def all_same(items: List[Any]) -> bool: 12 | """Check if all items in a list are the same.""" 13 | if len(items) == 0: 14 | return False 15 | 16 | same = True 17 | for item in items: 18 | if item != items[0]: 19 | same = False 20 | break 21 | return same 22 | 23 | 24 | class TestBank: 25 | """A test bank is a collection of test cases for a single test type. 26 | Test banks can be specified using CSV files, and also save their results to CSV files. 27 | """ 28 | 29 | def __init__(self, test: LLMQaTest, cases: List[Dict[str, str]], file_name_stem: str) -> None: 30 | self.test = test 31 | self.cases = cases 32 | self.results: List[bool] = [] 33 | self.file_name = file_name_stem + "_results.csv" 34 | 35 | def generate_results(self, model: LoRAInference) -> None: 36 | """Generates pass/fail results for each test case, based on the model's predictions.""" 37 | self.results = [] # reset results 38 | for case in self.cases: 39 | prompt = case["prompt"] 40 | model_pred = model.infer_one(prompt) 41 | # run the test with the model prediction and additional args 42 | test_args = {k: v for k, v in case.items() if k != "prompt"} 43 | result = self.test.test(model_pred, **test_args) 44 | self.results.append(result) 45 | 46 | def save_test_results(self, output_dir: Path, result_col: str = "result") -> None: 47 | """ 48 | Re-saves the test results in a CSV file, with a results column. 49 | """ 50 | df = pd.DataFrame(self.cases) 51 | df[result_col] = self.results 52 | df.to_csv(output_dir / self.file_name, index=False) 53 | 54 | 55 | class LLMTestSuite: 56 | """ 57 | Represents and runs a suite of different tests for LLMs. 58 | """ 59 | 60 | def __init__( 61 | self, 62 | test_banks: List[TestBank], 63 | ) -> None: 64 | self.test_banks = test_banks 65 | 66 | @staticmethod 67 | def from_dir( 68 | dir_path: str, 69 | test_type_col: str = "Test Type", 70 | ) -> "LLMTestSuite": 71 | """Creates an LLMTestSuite from a directory of CSV files. 72 | Each CSV file is a test bank, which encodes test cases for a certain 73 | test type. 74 | """ 75 | 76 | csv_files = Path(dir_path).rglob("*.csv") 77 | 78 | test_banks = [] 79 | for file_name in csv_files: 80 | df = pd.read_csv(file_name) 81 | test_type_column = df[test_type_col].tolist() 82 | # everything that isn't the test type column is a test parameter 83 | params = list(set(df.columns.tolist()) - set([test_type_col])) # noqa: C405 84 | assert all_same( 85 | test_type_column 86 | ), f"All test cases in a test bank {file_name} must have the same test type." 87 | test_type = test_type_column[0] 88 | test = QaTestRegistry.from_name(test_type) 89 | cases = [] 90 | # all rows are a test case, encode them all 91 | for _, row in df.iterrows(): 92 | case = {} 93 | for param in params: 94 | case[param] = row[param] 95 | cases.append(case) 96 | # get file name stub without extension or path 97 | test_banks.append(TestBank(test, cases, file_name.stem)) 98 | return LLMTestSuite(test_banks) 99 | 100 | def run_inference(self, model: LoRAInference) -> None: 101 | """Runs inference on all test cases in all the test banks.""" 102 | for test_bank in self.test_banks: 103 | test_bank.generate_results(model) 104 | 105 | def print_test_results(self) -> None: 106 | """Prints the results of the tests in the suite.""" 107 | test_names, num_passed, num_instances = [], [], [] 108 | for test_bank in self.test_banks: 109 | test_name = test_bank.test.test_name 110 | test_results = test_bank.results 111 | passed = sum(test_results) 112 | instances = len(test_results) 113 | test_names.append(test_name) 114 | num_passed.append(passed) 115 | num_instances.append(instances) 116 | 117 | RichUI.qa_display_test_table(test_names, num_passed, num_instances) 118 | 119 | def save_test_results(self, output_dir: Path) -> None: 120 | """Saves the results of the tests in a folder of CSV files.""" 121 | if not output_dir.exists(): 122 | output_dir.mkdir(parents=True, exist_ok=True) 123 | for test_bank in self.test_banks: 124 | test_bank.save_test_results(output_dir) 125 | -------------------------------------------------------------------------------- /llmtune/ui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/llmtune/ui/__init__.py -------------------------------------------------------------------------------- /llmtune/ui/generics.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractstaticmethod 2 | 3 | from datasets import Dataset 4 | from rich.text import Text 5 | 6 | 7 | class UI(ABC): 8 | """ 9 | DATASET 10 | """ 11 | 12 | # Lifecycle functions 13 | @abstractstaticmethod 14 | def before_dataset_creation(): 15 | pass 16 | 17 | @abstractstaticmethod 18 | def during_dataset_creation(message: str, spinner: str): 19 | pass 20 | 21 | @abstractstaticmethod 22 | def after_dataset_creation(save_dir: str, train: Dataset, test: Dataset): 23 | pass 24 | 25 | @abstractstaticmethod 26 | def dataset_found(save_dir: str): 27 | pass 28 | 29 | # Display functions 30 | @abstractstaticmethod 31 | def dataset_display_one_example(train_row: dict, test_row: dict): 32 | pass 33 | 34 | """ 35 | FINETUNING 36 | """ 37 | 38 | # Lifecycle functions 39 | @abstractstaticmethod 40 | def before_finetune(): 41 | pass 42 | 43 | @abstractstaticmethod 44 | def on_basemodel_load(checkpoint: str): 45 | pass 46 | 47 | @abstractstaticmethod 48 | def after_basemodel_load(checkpoint: str): 49 | pass 50 | 51 | @abstractstaticmethod 52 | def during_finetune(): 53 | pass 54 | 55 | @abstractstaticmethod 56 | def after_finetune(): 57 | pass 58 | 59 | @abstractstaticmethod 60 | def finetune_found(weights_path: str): 61 | pass 62 | 63 | """ 64 | INFERENCE 65 | """ 66 | 67 | # Lifecycle functions 68 | @abstractstaticmethod 69 | def before_inference(): 70 | pass 71 | 72 | @abstractstaticmethod 73 | def during_inference(): 74 | pass 75 | 76 | @abstractstaticmethod 77 | def after_inference(results_path: str): 78 | pass 79 | 80 | @abstractstaticmethod 81 | def results_found(results_path: str): 82 | pass 83 | 84 | # Display functions 85 | @abstractstaticmethod 86 | def inference_ground_truth_display(title: str, prompt: str, label: str): 87 | pass 88 | 89 | @abstractstaticmethod 90 | def inference_stream_display(text: Text): 91 | pass 92 | 93 | """ 94 | QA 95 | """ 96 | 97 | # Lifecycle functions 98 | @abstractstaticmethod 99 | def before_qa(cls): 100 | pass 101 | 102 | @abstractstaticmethod 103 | def during_qa(cls): 104 | pass 105 | 106 | @abstractstaticmethod 107 | def after_qa(cls): 108 | pass 109 | 110 | @abstractstaticmethod 111 | def qa_found(cls): 112 | pass 113 | 114 | @abstractstaticmethod 115 | def qa_display_metric_table(cls): 116 | pass 117 | -------------------------------------------------------------------------------- /llmtune/ui/rich_ui.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset 2 | from rich.console import Console 3 | from rich.layout import Layout 4 | from rich.live import Live 5 | from rich.panel import Panel 6 | from rich.table import Table 7 | from rich.text import Text 8 | 9 | from llmtune.ui.generics import UI 10 | from llmtune.utils.rich_print_utils import inject_example_to_rich_layout 11 | 12 | 13 | console = Console() 14 | 15 | 16 | class StatusContext: 17 | def __init__(self, console, message, spinner): 18 | self.console = console 19 | self.message = message 20 | self.spinner = spinner 21 | 22 | def __enter__(self): 23 | self.task = self.console.status(self.message, spinner=self.spinner) 24 | self.task.__enter__() # Manually enter the console status context 25 | return self # This allows you to use variables from this context if needed 26 | 27 | def __exit__(self, exc_type, exc_val, exc_tb): 28 | self.task.__exit__(exc_type, exc_val, exc_tb) # Cleanly exit the console status context 29 | 30 | 31 | class LiveContext: 32 | def __init__(self, text: Text, refresh_per_second=4, vertical_overflow="visible"): 33 | self.console = console 34 | self.text = text 35 | self.refresh_per_second = refresh_per_second 36 | self.vertical_overflow = vertical_overflow 37 | 38 | def __enter__(self): 39 | self.task = Live( 40 | self.text, 41 | refresh_per_second=self.refresh_per_second, 42 | vertical_overflow=self.vertical_overflow, 43 | ) 44 | self.task.__enter__() # Manually enter the console status context 45 | return self # This allows you to use variables from this context if needed 46 | 47 | def __exit__(self, exc_type, exc_val, exc_tb): 48 | self.task.__exit__(exc_type, exc_val, exc_tb) # Cleanly exit the console status context 49 | 50 | def update(self, new_text: Text): 51 | self.task.update(new_text) 52 | 53 | 54 | class RichUI(UI): 55 | """ 56 | DATASET 57 | """ 58 | 59 | # Lifecycle functions 60 | @staticmethod 61 | def before_dataset_creation(): 62 | console.rule("[bold green]Loading Data") 63 | 64 | @staticmethod 65 | def during_dataset_creation(message: str, spinner: str): 66 | return StatusContext(console, message, spinner) 67 | 68 | @staticmethod 69 | def after_dataset_creation(save_dir: str, train: Dataset, test: Dataset): 70 | console.print(f"Dataset Saved at {save_dir}") 71 | console.print("Post-Split data size:") 72 | console.print(f"Train: {len(train)}") 73 | console.print(f"Test: {len(test)}") 74 | 75 | @staticmethod 76 | def dataset_found(save_dir: str): 77 | console.print(f"Loading formatted dataset from directory {save_dir}") 78 | 79 | # Display functions 80 | @staticmethod 81 | def dataset_display_one_example(train_row: dict, test_row: dict): 82 | layout = Layout() 83 | layout.split_row( 84 | Layout(Panel("Train Sample"), name="train"), 85 | Layout( 86 | Panel("Inference Sample"), 87 | name="inference", 88 | ), 89 | ) 90 | 91 | inject_example_to_rich_layout(layout["train"], "Train Example", train_row) 92 | inject_example_to_rich_layout(layout["inference"], "Inference Example", test_row) 93 | 94 | console.print(layout) 95 | 96 | """ 97 | FINETUNING 98 | """ 99 | 100 | # Lifecycle functions 101 | @staticmethod 102 | def before_finetune(): 103 | console.rule("[bold yellow]:smiley: Finetuning") 104 | 105 | @staticmethod 106 | def on_basemodel_load(checkpoint: str): 107 | console.print(f"Loading {checkpoint}...") 108 | 109 | @staticmethod 110 | def after_basemodel_load(checkpoint: str): 111 | console.print(f"{checkpoint} Loaded :smile:") 112 | 113 | @staticmethod 114 | def during_finetune(): 115 | return StatusContext(console, "Finetuning Model...", "runner") 116 | 117 | @staticmethod 118 | def after_finetune(): 119 | console.print("Finetuning complete!") 120 | 121 | @staticmethod 122 | def finetune_found(weights_path: str): 123 | console.print(f"Fine-Tuned Model Found at {weights_path}... skipping training") 124 | 125 | """ 126 | INFERENCE 127 | """ 128 | 129 | # Lifecycle functions 130 | @staticmethod 131 | def before_inference(): 132 | console.rule("[bold pink]:face_with_monocle: Running Inference") 133 | 134 | @staticmethod 135 | def during_inference(): 136 | pass 137 | 138 | @staticmethod 139 | def after_inference(results_path: str): 140 | console.print(f"Inference Results Saved at {results_path}") 141 | 142 | @staticmethod 143 | def results_found(results_path: str): 144 | console.print(f"Inference Results Found at {results_path}") 145 | 146 | # Display functions 147 | @staticmethod 148 | def inference_ground_truth_display(title: str, prompt: str, label: str): 149 | prompt = prompt.replace("[INST]", "").replace("[/INST]", "") 150 | label = label.replace("[INST]", "").replace("[/INST]", "") 151 | 152 | table = Table(title=title, show_lines=True) 153 | table.add_column("prompt") 154 | table.add_column("ground truth") 155 | table.add_row(prompt, label) 156 | console.print(table) 157 | 158 | @staticmethod 159 | def inference_stream_display(text: Text): 160 | console.print("[bold red]Prediction >") 161 | return LiveContext(text) 162 | 163 | """ 164 | QA 165 | """ 166 | 167 | # Lifecycle functions 168 | @staticmethod 169 | def before_qa(): 170 | pass 171 | 172 | @staticmethod 173 | def during_qa(): 174 | pass 175 | 176 | @staticmethod 177 | def after_qa(): 178 | pass 179 | 180 | @staticmethod 181 | def qa_found(): 182 | pass 183 | 184 | @staticmethod 185 | def qa_display_metric_table(result_dictionary, mean_values, median_values, stdev_values): 186 | # Create a table 187 | table = Table(show_header=True, header_style="bold", title="Test Set Metric Results") 188 | 189 | # Add columns to the table 190 | table.add_column("Metric", style="cyan") 191 | table.add_column("Mean", style="magenta") 192 | table.add_column("Median", style="green") 193 | table.add_column("Standard Deviation", style="yellow") 194 | 195 | # Add data rows to the table 196 | for key in result_dictionary: 197 | table.add_row( 198 | key, 199 | f"{mean_values[key]:.4f}", 200 | f"{median_values[key]:.4f}", 201 | f"{stdev_values[key]:.4f}", 202 | ) 203 | 204 | # Print the table 205 | console.print(table) 206 | 207 | @staticmethod 208 | def qa_display_test_table(test_names, num_passed, num_instances): 209 | # Create a table 210 | table = Table(show_header=True, header_style="bold", title="Test Suite Results") 211 | 212 | # Add columns to the table 213 | table.add_column("Test Suite", style="cyan") 214 | table.add_column("Passing", style="magenta") 215 | 216 | # Add data rows to the table 217 | for test_name, passed, total in zip(test_names, num_passed, num_instances): 218 | table.add_row(test_name, f"{passed}/{total}") 219 | 220 | # Print the table 221 | console.print(table) 222 | 223 | """ 224 | GENERATE 225 | """ 226 | 227 | @staticmethod 228 | def generate_config(file_name: str): 229 | console.print(f"Generated config at [bold green]./{file_name}[/]") 230 | -------------------------------------------------------------------------------- /llmtune/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/llmtune/utils/__init__.py -------------------------------------------------------------------------------- /llmtune/utils/ablation_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | from typing import Dict, Tuple, Union, get_args, get_origin 4 | 5 | 6 | # TODO: organize this a little bit. It's a bit of a mess rn. 7 | 8 | """ 9 | Helper functions to create multiple valid configs based on ablation (i.e. list of values) 10 | fron a single config yaml 11 | """ 12 | 13 | 14 | def get_types_from_dict(source_dict: dict, root="", type_dict={}) -> Dict[str, Tuple[type, type]]: 15 | for key, val in source_dict.items(): 16 | if not isinstance(val, dict): 17 | attr = f"{root}.{key}" if root else key 18 | tp = (type(val), None) if not isinstance(val, list) else (type(val), type(val[0])) 19 | type_dict[attr] = tp 20 | else: 21 | join_array = [root, key] if root else [key] 22 | new_root = ".".join(join_array) 23 | get_types_from_dict(val, new_root, type_dict) 24 | 25 | return type_dict 26 | 27 | 28 | def get_annotation(key: str, base_model): 29 | keys = key.split(".") 30 | model = base_model 31 | for key in keys: 32 | model = model.__annotations__[key] 33 | 34 | return model 35 | 36 | 37 | def get_model_field_type(annotation): 38 | origin = get_origin(annotation) 39 | if not origin: 40 | return annotation 41 | if origin is Union: 42 | annotations = get_args(annotation)[0] 43 | return get_model_field_type(annotations) 44 | if origin is list: 45 | return list 46 | 47 | 48 | def get_data_with_key(key, data): 49 | keys = key.split(".") 50 | for key in keys: 51 | data = data[key] 52 | return data 53 | 54 | 55 | def validate_and_get_ablations(type_dict, data, base_model): 56 | ablations = {} 57 | for key, (tp, subtype) in type_dict.items(): 58 | annotation = get_annotation(key, base_model) 59 | model_field_type = get_model_field_type(annotation) 60 | if (model_field_type is list) and (tp is list) and (subtype is list): 61 | # Handle both list and list of lists 62 | ablations[key] = get_data_with_key(key, data) 63 | elif model_field_type is not list and tp is list: 64 | # Handle single-level lists 65 | ablations[key] = get_data_with_key(key, data) 66 | 67 | return ablations 68 | 69 | 70 | def patch_with_permutation(old_dict, permutation_dict): 71 | # Create a deep copy of the old dictionary to avoid modifying the original 72 | updated_dict = copy.deepcopy(old_dict) 73 | 74 | # Iterate over each item in the permutation dictionary 75 | for dot_key, new_value in permutation_dict.items(): 76 | # Split the dot-joined key into individual keys 77 | keys = dot_key.split(".") 78 | 79 | # Start from the root of the updated dictionary 80 | current_level = updated_dict 81 | 82 | # Traverse to the second-to-last key in the nested dictionary 83 | for key in keys[:-1]: 84 | current_level = current_level[key] 85 | 86 | # Update the value at the final key 87 | current_level[keys[-1]] = new_value 88 | 89 | return updated_dict 90 | 91 | 92 | def generate_permutations(yaml_dict, model): 93 | type_dict = get_types_from_dict(yaml_dict) 94 | 95 | ablations = validate_and_get_ablations(type_dict, yaml_dict, model) 96 | 97 | # get permutations 98 | lists = list(ablations.values()) 99 | permutations = list(itertools.product(*lists)) 100 | 101 | permutation_dicts = [] 102 | for perm in permutations: 103 | new_dict = dict(zip(ablations.keys(), perm)) 104 | permutation_dicts.append(new_dict) 105 | 106 | new_dicts = [] 107 | for perm in permutation_dicts: 108 | new_dicts.append(patch_with_permutation(yaml_dict, perm)) 109 | 110 | return new_dicts 111 | -------------------------------------------------------------------------------- /llmtune/utils/rich_print_utils.py: -------------------------------------------------------------------------------- 1 | from rich.layout import Layout 2 | from rich.panel import Panel 3 | from rich.table import Table 4 | from rich.text import Text 5 | 6 | 7 | def inject_example_to_rich_layout(layout: Layout, layout_name: str, example: dict): 8 | example = example.copy() 9 | 10 | # Crate Table 11 | table = Table(expand=True) 12 | colors = [ 13 | "navy_blue", 14 | "dark_green", 15 | "spring_green3", 16 | "turquoise2", 17 | "cyan", 18 | "blue_violet", 19 | "royal_blue1", 20 | "steel_blue1", 21 | "chartreuse1", 22 | "deep_pink4", 23 | "plum2", 24 | "red", 25 | ] 26 | 27 | # Crate Formatted Text 28 | formatted = example.pop("formatted_prompt", None) 29 | formatted_text = Text(formatted) 30 | 31 | print(example) 32 | for key, c in zip(example.keys(), colors): 33 | table.add_column(key, style=c) 34 | 35 | tgt_text = str(example[key]) 36 | start_idx = formatted.find(tgt_text) 37 | formatted_text.stylize(f"bold {c}", start_idx, start_idx + len(tgt_text)) 38 | 39 | table.add_row(*[str(v) for v in example.values()]) 40 | 41 | layout.split_column( 42 | Layout( 43 | Panel( 44 | table, 45 | title=f"{layout_name} - Raw", 46 | title_align="left", 47 | ) 48 | ), 49 | Layout( 50 | Panel( 51 | formatted_text, 52 | title=f"{layout_name} - Formatted", 53 | title_align="left", 54 | ) 55 | ), 56 | ) 57 | -------------------------------------------------------------------------------- /llmtune/utils/save_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions to help managing saving and loading of experiments: 3 | 1. Generate save directory name 4 | 2. Check if files are present at various experiment stages 5 | """ 6 | 7 | import hashlib 8 | import re 9 | from dataclasses import dataclass 10 | from functools import cached_property 11 | from pathlib import Path 12 | 13 | import yaml 14 | from sqids import Sqids 15 | 16 | from llmtune.constants.files import ( 17 | CONFIG_DIR_NAME, 18 | CONFIG_FILE_NAME, 19 | DATASET_DIR_NAME, 20 | METRIC_FILE_NAME, 21 | NUM_MD5_DIGITS_FOR_SQIDS, 22 | QA_DIR_NAME, 23 | RESULTS_DIR_NAME, 24 | RESULTS_FILE_NAME, 25 | WEIGHTS_DIR_NAME, 26 | ) 27 | from llmtune.pydantic_models.config_model import Config 28 | 29 | 30 | @dataclass 31 | class DirectoryList: 32 | save_dir: Path 33 | config_hash: str 34 | 35 | @property 36 | def experiment(self) -> Path: 37 | return self.save_dir / self.config_hash 38 | 39 | @property 40 | def config(self) -> Path: 41 | return self.experiment / CONFIG_DIR_NAME 42 | 43 | @property 44 | def config_file(self) -> Path: 45 | return self.config / CONFIG_FILE_NAME 46 | 47 | @property 48 | def dataset(self) -> Path: 49 | return self.experiment / DATASET_DIR_NAME 50 | 51 | @property 52 | def weights(self) -> Path: 53 | return self.experiment / WEIGHTS_DIR_NAME 54 | 55 | @property 56 | def results(self) -> Path: 57 | return self.experiment / RESULTS_DIR_NAME 58 | 59 | @property 60 | def results_file(self) -> Path: 61 | return self.results / RESULTS_FILE_NAME 62 | 63 | @property 64 | def qa(self) -> Path: 65 | return self.experiment / QA_DIR_NAME 66 | 67 | @property 68 | def metric_file(self) -> Path: 69 | return self.qa / METRIC_FILE_NAME 70 | 71 | 72 | class DirectoryHelper: 73 | def __init__(self, config_path: Path, config: Config): 74 | self.config_path: Path = config_path 75 | self.config: Config = config 76 | self.sqids: Sqids = Sqids() 77 | self.save_paths: DirectoryList = self._get_directory_state() 78 | 79 | self.save_paths.experiment.mkdir(parents=True, exist_ok=True) 80 | if not self.save_paths.config.exists(): 81 | self.save_config() 82 | 83 | @cached_property 84 | def config_hash(self) -> str: 85 | config_str = self.config.model_dump_json() 86 | config_str = re.sub(r"\s", "", config_str) 87 | hash = hashlib.md5(config_str.encode()).digest() 88 | return self.sqids.encode(hash[:NUM_MD5_DIGITS_FOR_SQIDS]) 89 | 90 | def _get_directory_state(self) -> DirectoryList: 91 | save_dir = ( 92 | Path(self.config.save_dir) 93 | if not self.config.ablation.use_ablate 94 | else Path(self.config.save_dir) / self.config.ablation.study_name 95 | ) 96 | return DirectoryList(save_dir, self.config_hash) 97 | 98 | def save_config(self) -> None: 99 | self.save_paths.config.mkdir(parents=True, exist_ok=True) 100 | model_dict = self.config.model_dump() 101 | 102 | with (self.save_paths.config / "config.yml").open("w") as file: 103 | yaml.dump(model_dict, file) 104 | -------------------------------------------------------------------------------- /mistral/README.md: -------------------------------------------------------------------------------- 1 | Important: Mistral has the option of using flash attention to speed up inference. In order to use flash attention, please do: 2 | 3 | ```shell 4 | pip install -U flash-attn --no-build-isolation 5 | ``` 6 | 7 | ```shell 8 | git clone https://github.com/huggingface/peft 9 | cd peft 10 | pip install . 11 | ``` 12 | 13 | # Contents: 14 | 15 | - [Contents:](#contents) 16 | - [What is Mistral?](#what-is-mistral) 17 | - [Variations of Mistral and Parameters](#variations-of-mistral-and-parameters) 18 | - [What does this folder contain?](#what-does-this-folder-contain) 19 | - [Evaluation Framework](#evaluation-framework) 20 | - [ Performance ](#-performance-) 21 | - [Classification](#classification) 22 | - [Summarization](#summarization) 23 | 24 | 25 | ## What is Mistral? 26 | 27 | Mistral-7B-v0.1 is Mistral AI’s first Large Language Model (LLM). Mistral-7B-v0.1 is a decoder-based LM with the following architectural choices: (i) Sliding Window Attention, (ii) GQA (Grouped Query Attention) and (iii) Byte-fallback BPE tokenizer. 28 | 29 | 30 | ## Variations of Mistral and Parameters 31 | 32 | Mistral models come in two sizes, and can be leveraged depending on the task at hand. 33 | 34 | | Mistral variation | Parameters | 35 | |:----------------:|:-----------:| 36 | |Base |7B | 37 | |Instruct |7B | 38 | 39 | In this repository, we have experimented with the Base 7B variation. 40 | 41 | ## What does this folder contain? 42 | 43 | This folder contains ready-to-use scripts, using which you can do the following: 44 | 45 | * Finetuning Mistral using PeFT methodology QLoRA: 46 | * ```mistral_classification.py```: Finetune on News Group classification dataset 47 | * ```mistral_summarization.py```: Finetune on Samsum summarization dataset 48 | * Prompts used: 49 | * ```prompts.py```: Zero-shot, Few-shot and instruction tuning for classification and summarization 50 | * Perform hyperparameter optimization over a well-constrained search space: 51 | * ```run_lora.sh```: Ablation study on LoRA's parameters 52 | * ```sample_ablate.sh```: Ablation study over sample complexities 53 | * Infer Mistral using trained checkpoints: 54 | * ```mistral_baseline_inference.py```: Infer in zero-shot and few-shot settings using Mistral-7B 55 | * ```mistral_classification_inference.py```: Infer on News Group classification dataset 56 | * ```mistral_summarization_inference.py```: Infer on Samsum summarization dataset 57 | * Infer across a different settings: 58 | * ```baseline_inference.sh```: Loop over all settings to perform zero-shot and few-shot prompting across classification and summarization tasks 59 | 60 | ## Evaluation Framework 61 | 62 | 63 | ### Performance 64 | 65 | We evaluated Mistral under the following conditions: 66 | 67 | * Tasks & Datasets: 68 | * Classification: News Group dataset, which is a 20-way classification task. 69 | * Summarization: Samsum dataset. 70 | * Experiments: 71 | * Sample Efficiency vs Accuracy 72 | * Zero-Shot prompting vs Few-Shot prompting vs PeFT QLoRA (for summarization) 73 | * Training config: 74 | * Epochs: 5 (for classification) 75 | * Epochs: 1 (for summarization) 76 | * Mistral-7B: 77 | * PeFT technique: QLoRA 78 | * Learning rate: 2e-4 79 | * Hardware: 80 | * Cloud provider: AWC EC2 81 | * Instance: g5.2xlarge 82 | 83 | #### Classification #### 84 | 85 | Table 1: Sample Efficiency vs Accuracy 86 | 87 | |Training samples (fraction) | Mistral Base-7B | 88 | |:--------------------------:|:---------------:| 89 | |266 (2.5%) |49.30 | 90 | |533 (5%) |48.14 | 91 | |1066 (10%) |58.41 | 92 | |2666 (25%) |64.89 | 93 | |5332 (50%) |73.10 | 94 | |10664 (100%) |74.36 | 95 | 96 | 97 | The above table shows how performance of Mistral-7B track with the number of training samples. The last row of the table demonstrates the performance when the entire dataset is used. 98 | 99 | 100 | 101 | #### Summarization #### 102 | 103 | Table 2: Zero-Shot prompting vs Few-Shot prompting vs Fine-Tuning QLoRA 104 | 105 | |Method | Mistral-7B Zero-Shot | Mistral-7B Few-Shot | Fine-Tuning + QLoRA | 106 | |:-------------:|:---------------------:|:--------------------:|:-------------------:| 107 | |ROUGE-1 (in %) |32.77 |38.87 |53.61 | 108 | |ROUGE-2 (in %) |10.64 |16.71 |29.28 | 109 | 110 | 111 | Looking at the ROUGE-1 and ROUGE-2 scores, we see that Mistral-7B’s performance increases from zero-shot to few-shot to fine-tuning settings. 112 | 113 | 114 | Table 3: Mistral vs Other LLMs 115 | 116 | |Model | Flan-T5-Base Full Fine-Tune | Flan-T5-Large | Falcon-7B | RP-3B | RP-7B | Llama2-7B | Llama2-13B | Mistral-7B | 117 | |:-------------:|:---------------------------:|:-------------:|:---------:|:-----:|:-----:|:---------:|:----------:|:----------:| 118 | |ROUGE-1 (in %) |47.23 |49.21 |52.18 |47.75 |49.96 |51.71 |52.97 |53.61 | 119 | |ROUGE-2 (in %) |21.01 |23.39 |27.84 |23.53 |25.94 |26.86 |28.32 |29.28 | 120 | 121 | 122 | Mistral-7B achieves the best results, even when compared with Falcon-7B and Llama2-7B. This makes Mistral-7B, in our opinion, the best model to leverage in the 7B parameter space. 123 | 124 | 125 | -------------------------------------------------------------------------------- /mistral/baseline_inference.sh: -------------------------------------------------------------------------------- 1 | #python mistral_baseline_inference.py --task_type classification --prompt_type zero-shot & wait 2 | #python mistral_baseline_inference.py --task_type classification --prompt_type few-shot & wait 3 | python mistral_baseline_inference.py --task_type summarization --prompt_type zero-shot --use_flash_attention & wait 4 | python mistral_baseline_inference.py --task_type summarization --prompt_type few-shot --use_flash_attention 5 | -------------------------------------------------------------------------------- /mistral/mistral_baseline_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import evaluate 4 | import warnings 5 | import json 6 | import pandas as pd 7 | import pickle 8 | import torch 9 | import time 10 | 11 | from datasets import load_dataset 12 | from prompts import ( 13 | ZERO_SHOT_CLASSIFIER_PROMPT, 14 | FEW_SHOT_CLASSIFIER_PROMPT, 15 | ZERO_SHOT_SUMMARIZATION_PROMPT, 16 | FEW_SHOT_SUMMARIZATION_PROMPT, 17 | get_newsgroup_data, 18 | get_samsum_data, 19 | ) 20 | from transformers import ( 21 | AutoTokenizer, 22 | AutoModelForCausalLM, 23 | BitsAndBytesConfig, 24 | ) 25 | from sklearn.metrics import ( 26 | accuracy_score, 27 | f1_score, 28 | precision_score, 29 | recall_score, 30 | ) 31 | 32 | metric = evaluate.load("rouge") 33 | warnings.filterwarnings("ignore") 34 | 35 | 36 | def compute_metrics_decoded(decoded_labs, decoded_preds, args): 37 | if args.task_type == "summarization": 38 | rouge = metric.compute( 39 | predictions=decoded_preds, references=decoded_labs, use_stemmer=True 40 | ) 41 | metrics = {metric: round(rouge[metric] * 100.0, 3) for metric in rouge.keys()} 42 | 43 | elif args.task_type == "classification": 44 | metrics = { 45 | "micro_f1": f1_score(decoded_labs, decoded_preds, average="micro"), 46 | "macro_f1": f1_score(decoded_labs, decoded_preds, average="macro"), 47 | "precision": precision_score(decoded_labs, decoded_preds, average="micro"), 48 | "recall": recall_score(decoded_labs, decoded_preds, average="micro"), 49 | "accuracy": accuracy_score(decoded_labs, decoded_preds), 50 | } 51 | 52 | return metrics 53 | 54 | 55 | def main(args): 56 | save_dir = os.path.join( 57 | "baseline_results", args.pretrained_ckpt, args.task_type, args.prompt_type 58 | ) 59 | if not os.path.exists(save_dir): 60 | os.makedirs(save_dir) 61 | 62 | if args.task_type == "classification": 63 | dataset = load_dataset("rungalileo/20_Newsgroups_Fixed") 64 | test_dataset = dataset["test"] 65 | test_data, test_labels = test_dataset["text"], test_dataset["label"] 66 | 67 | newsgroup_classes, few_shot_samples, _ = get_newsgroup_data() 68 | 69 | elif args.task_type == "summarization": 70 | dataset = load_dataset("samsum") 71 | test_dataset = dataset["test"] 72 | test_data, test_labels = test_dataset["dialogue"], test_dataset["summary"] 73 | 74 | few_shot_samples = get_samsum_data() 75 | 76 | if args.prompt_type == "zero-shot": 77 | if args.task_type == "classification": 78 | prompt = ZERO_SHOT_CLASSIFIER_PROMPT 79 | elif args.task_type == "summarization": 80 | prompt = ZERO_SHOT_SUMMARIZATION_PROMPT 81 | 82 | elif args.prompt_type == "few-shot": 83 | if args.task_type == "classification": 84 | prompt = FEW_SHOT_CLASSIFIER_PROMPT 85 | elif args.task_type == "summarization": 86 | prompt = FEW_SHOT_SUMMARIZATION_PROMPT 87 | 88 | # BitsAndBytesConfig int-4 config 89 | bnb_config = BitsAndBytesConfig( 90 | load_in_4bit=True, 91 | bnb_4bit_use_double_quant=True, 92 | bnb_4bit_quant_type="nf4", 93 | bnb_4bit_compute_dtype=torch.bfloat16, 94 | ) 95 | 96 | # Load model and tokenizer 97 | model = AutoModelForCausalLM.from_pretrained( 98 | args.pretrained_ckpt, 99 | quantization_config=bnb_config, 100 | use_cache=False, 101 | device_map="auto", 102 | use_flash_attention_2=args.use_flash_attention, 103 | ) 104 | model.config.pretraining_tp = 1 105 | 106 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_ckpt) 107 | tokenizer.pad_token = tokenizer.eos_token 108 | tokenizer.padding_side = "right" 109 | 110 | results = [] 111 | good_data, good_labels = [], [] 112 | ctr = 0 113 | # for instruct, label in zip(instructions, labels): 114 | for data, label in zip(test_data, test_labels): 115 | if not isinstance(data, str): 116 | continue 117 | if not isinstance(label, str): 118 | continue 119 | 120 | # example = instruct[:-len(label)] # remove the answer from the example 121 | if args.prompt_type == "zero-shot": 122 | if args.task_type == "classification": 123 | example = prompt.format( 124 | newsgroup_classes=newsgroup_classes, 125 | sentence=data, 126 | ) 127 | elif args.task_type == "summarization": 128 | example = prompt.format( 129 | dialogue=data, 130 | ) 131 | 132 | elif args.prompt_type == "few-shot": 133 | if args.task_type == "classification": 134 | example = prompt.format( 135 | newsgroup_classes=newsgroup_classes, 136 | few_shot_samples=few_shot_samples, 137 | sentence=data, 138 | ) 139 | elif args.task_type == "summarization": 140 | example = prompt.format( 141 | few_shot_samples=few_shot_samples, 142 | dialogue=data, 143 | ) 144 | 145 | input_ids = tokenizer( 146 | example, return_tensors="pt", truncation=True 147 | ).input_ids.cuda() 148 | 149 | with torch.inference_mode(): 150 | outputs = model.generate( 151 | input_ids=input_ids, 152 | max_new_tokens=20 if args.task_type == "classification" else 50, 153 | do_sample=True, 154 | top_p=0.95, 155 | temperature=1e-3, 156 | ) 157 | result = tokenizer.batch_decode( 158 | outputs.detach().cpu().numpy(), skip_special_tokens=True 159 | )[0] 160 | 161 | # Extract the generated text, and do basic processing 162 | result = result[len(example) :].replace("\n", "").lstrip().rstrip() 163 | results.append(result) 164 | good_labels.append(label) 165 | good_data.append(data) 166 | 167 | print(f"Example {ctr}/{len(test_data)} | GT: {label} | Pred: {result}") 168 | ctr += 1 169 | 170 | metrics = compute_metrics_decoded(good_labels, results, args) 171 | print(metrics) 172 | metrics["predictions"] = results 173 | metrics["labels"] = good_labels 174 | metrics["data"] = good_data 175 | 176 | with open(os.path.join(save_dir, "metrics.pkl"), "wb") as handle: 177 | pickle.dump(metrics, handle) 178 | 179 | print(f"Completed experiment {save_dir}") 180 | print("----------------------------------------") 181 | 182 | 183 | if __name__ == "__main__": 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument("--pretrained_ckpt", default="mistralai/Mistral-7B-v0.1") 186 | parser.add_argument("--prompt_type", default="zero-shot") 187 | parser.add_argument("--task_type", default="classification") 188 | parser.add_argument("--use_flash_attention", action=argparse.BooleanOptionalAction) 189 | args = parser.parse_args() 190 | 191 | main(args) 192 | -------------------------------------------------------------------------------- /mistral/mistral_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import numpy as np 5 | import pandas as pd 6 | import pickle 7 | 8 | 9 | from peft import ( 10 | LoraConfig, 11 | prepare_model_for_kbit_training, 12 | get_peft_model, 13 | ) 14 | from transformers import ( 15 | AutoTokenizer, 16 | AutoModelForCausalLM, 17 | BitsAndBytesConfig, 18 | TrainingArguments, 19 | ) 20 | from trl import SFTTrainer 21 | 22 | from prompts import get_newsgroup_data_for_ft 23 | 24 | 25 | def main(args): 26 | train_dataset, test_dataset = get_newsgroup_data_for_ft( 27 | mode="train", train_sample_fraction=args.train_sample_fraction 28 | ) 29 | print(f"Sample fraction:{args.train_sample_fraction}") 30 | print(f"Training samples:{train_dataset.shape}") 31 | 32 | # BitsAndBytesConfig int-4 config 33 | bnb_config = BitsAndBytesConfig( 34 | load_in_4bit=True, 35 | bnb_4bit_use_double_quant=True, 36 | bnb_4bit_quant_type="nf4", 37 | bnb_4bit_compute_dtype=torch.bfloat16, 38 | ) 39 | 40 | # Load model and tokenizer 41 | model = AutoModelForCausalLM.from_pretrained( 42 | args.pretrained_ckpt, 43 | quantization_config=bnb_config, 44 | use_cache=False, 45 | device_map="auto", 46 | ) 47 | model.config.pretraining_tp = 1 48 | 49 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_ckpt) 50 | tokenizer.pad_token = tokenizer.eos_token 51 | tokenizer.padding_side = "right" 52 | 53 | # LoRA config based on QLoRA paper 54 | peft_config = LoraConfig( 55 | lora_alpha=16, 56 | lora_dropout=args.dropout, 57 | r=args.lora_r, 58 | bias="none", 59 | task_type="CAUSAL_LM", 60 | ) 61 | 62 | # prepare model for training 63 | model = prepare_model_for_kbit_training(model) 64 | model = get_peft_model(model, peft_config) 65 | 66 | results_dir = f"experiments/classification-sampleFraction-{args.train_sample_fraction}_epochs-{args.epochs}_rank-{args.lora_r}_dropout-{args.dropout}" 67 | 68 | training_args = TrainingArguments( 69 | output_dir=results_dir, 70 | logging_dir=f"{results_dir}/logs", 71 | num_train_epochs=args.epochs, 72 | per_device_train_batch_size=4, 73 | gradient_accumulation_steps=2, 74 | gradient_checkpointing=True, 75 | optim="paged_adamw_32bit", 76 | logging_steps=100, 77 | learning_rate=2e-4, 78 | bf16=True, 79 | tf32=True, 80 | max_grad_norm=0.3, 81 | warmup_ratio=0.03, 82 | lr_scheduler_type="constant", 83 | report_to="none", 84 | # disable_tqdm=True # disable tqdm since with packing values are in correct 85 | ) 86 | 87 | max_seq_length = 512 # max sequence length for model and packing of the dataset 88 | 89 | trainer = SFTTrainer( 90 | model=model, 91 | train_dataset=train_dataset, 92 | peft_config=peft_config, 93 | max_seq_length=max_seq_length, 94 | tokenizer=tokenizer, 95 | packing=True, 96 | args=training_args, 97 | dataset_text_field="instructions", 98 | ) 99 | 100 | trainer_stats = trainer.train() 101 | train_loss = trainer_stats.training_loss 102 | print(f"Training loss:{train_loss}") 103 | 104 | peft_model_id = f"{results_dir}/assets" 105 | trainer.model.save_pretrained(peft_model_id) 106 | tokenizer.save_pretrained(peft_model_id) 107 | 108 | with open(f"{results_dir}/results.pkl", "wb") as handle: 109 | run_result = [ 110 | args.epochs, 111 | args.lora_r, 112 | args.dropout, 113 | train_loss, 114 | ] 115 | pickle.dump(run_result, handle) 116 | print("Experiment over") 117 | 118 | 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("--pretrained_ckpt", default="mistralai/Mistral-7B-v0.1") 122 | parser.add_argument("--lora_r", default=8, type=int) 123 | parser.add_argument("--epochs", default=5, type=int) 124 | parser.add_argument("--dropout", default=0.1, type=float) 125 | parser.add_argument("--train_sample_fraction", default=0.99, type=float) 126 | 127 | args = parser.parse_args() 128 | main(args) 129 | -------------------------------------------------------------------------------- /mistral/mistral_classification_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import pandas as pd 5 | import evaluate 6 | import pickle 7 | import warnings 8 | from tqdm import tqdm 9 | 10 | from peft import AutoPeftModelForCausalLM 11 | from transformers import AutoTokenizer 12 | from sklearn.metrics import ( 13 | accuracy_score, 14 | f1_score, 15 | precision_score, 16 | recall_score, 17 | ) 18 | 19 | from prompts import get_newsgroup_data_for_ft 20 | 21 | metric = evaluate.load("rouge") 22 | warnings.filterwarnings("ignore") 23 | 24 | 25 | def main(args): 26 | _, test_dataset = get_newsgroup_data_for_ft(mode="inference") 27 | 28 | experiment = args.experiment_dir 29 | peft_model_id = f"{experiment}/assets" 30 | 31 | # load base LLM model and tokenizer 32 | model = AutoPeftModelForCausalLM.from_pretrained( 33 | peft_model_id, 34 | low_cpu_mem_usage=True, 35 | torch_dtype=torch.float16, 36 | load_in_4bit=True, 37 | ) 38 | 39 | model.eval() 40 | 41 | tokenizer = AutoTokenizer.from_pretrained(peft_model_id) 42 | 43 | results = [] 44 | oom_examples = [] 45 | instructions, labels = test_dataset["instructions"], test_dataset["labels"] 46 | 47 | for instruct, label in tqdm(zip(instructions, labels)): 48 | input_ids = tokenizer( 49 | instruct, return_tensors="pt", truncation=True 50 | ).input_ids.cuda() 51 | 52 | with torch.inference_mode(): 53 | try: 54 | outputs = model.generate( 55 | input_ids=input_ids, 56 | max_new_tokens=20, 57 | do_sample=True, 58 | top_p=0.95, 59 | temperature=1e-3, 60 | ) 61 | result = tokenizer.batch_decode( 62 | outputs.detach().cpu().numpy(), skip_special_tokens=True 63 | )[0] 64 | result = result[len(instruct) :] 65 | print(result) 66 | except: 67 | result = "" 68 | oom_examples.append(input_ids.shape[-1]) 69 | 70 | results.append(result) 71 | 72 | metrics = { 73 | "micro_f1": f1_score(labels, results, average="micro"), 74 | "macro_f1": f1_score(labels, results, average="macro"), 75 | "precision": precision_score(labels, results, average="micro"), 76 | "recall": recall_score(labels, results, average="micro"), 77 | "accuracy": accuracy_score(labels, results), 78 | "oom_examples": oom_examples, 79 | } 80 | print(metrics) 81 | 82 | save_dir = os.path.join(experiment, "metrics") 83 | if not os.path.exists(save_dir): 84 | os.makedirs(save_dir) 85 | 86 | with open(os.path.join(save_dir, "metrics.pkl"), "wb") as handle: 87 | pickle.dump(metrics, handle) 88 | 89 | print(f"Completed experiment {peft_model_id}") 90 | print("----------------------------------------") 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument( 96 | "--experiment_dir", 97 | default="experiments/classification-sampleFraction-0.1_epochs-5_rank-8_dropout-0.1", 98 | ) 99 | 100 | args = parser.parse_args() 101 | main(args) 102 | -------------------------------------------------------------------------------- /mistral/mistral_summarization.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import numpy as np 5 | import pandas as pd 6 | import pickle 7 | import datasets 8 | from datasets import Dataset, load_dataset 9 | 10 | from peft import ( 11 | LoraConfig, 12 | prepare_model_for_kbit_training, 13 | get_peft_model, 14 | ) 15 | from transformers import ( 16 | AutoTokenizer, 17 | AutoModelForCausalLM, 18 | BitsAndBytesConfig, 19 | TrainingArguments, 20 | ) 21 | from trl import SFTTrainer 22 | 23 | from prompts import TRAINING_SUMMARIZATION_PROMPT_v2 24 | 25 | 26 | def prepare_instructions(dialogues, summaries): 27 | instructions = [] 28 | 29 | prompt = TRAINING_SUMMARIZATION_PROMPT_v2 30 | 31 | for dialogue, summary in zip(dialogues, summaries): 32 | example = prompt.format( 33 | dialogue=dialogue, 34 | summary=summary, 35 | ) 36 | instructions.append(example) 37 | 38 | return instructions 39 | 40 | 41 | def prepare_samsum_data(): 42 | dataset = load_dataset("samsum") 43 | train_dataset = dataset["train"] 44 | val_dataset = dataset["test"] 45 | 46 | dialogues = train_dataset["dialogue"] 47 | summaries = train_dataset["summary"] 48 | train_instructions = prepare_instructions(dialogues, summaries) 49 | train_dataset = datasets.Dataset.from_pandas( 50 | pd.DataFrame(data={"instructions": train_instructions}) 51 | ) 52 | 53 | return train_dataset 54 | 55 | 56 | def main(args): 57 | train_dataset = prepare_samsum_data() 58 | 59 | # BitsAndBytesConfig int-4 config 60 | bnb_config = BitsAndBytesConfig( 61 | load_in_4bit=True, 62 | bnb_4bit_use_double_quant=True, 63 | bnb_4bit_quant_type="nf4", 64 | bnb_4bit_compute_dtype=torch.bfloat16, 65 | ) 66 | 67 | # Load model and tokenizer 68 | model = AutoModelForCausalLM.from_pretrained( 69 | args.pretrained_ckpt, 70 | quantization_config=bnb_config, 71 | use_cache=False, 72 | device_map="auto", 73 | ) 74 | model.config.pretraining_tp = 1 75 | 76 | tokenizer = AutoTokenizer.from_pretrained(args.pretrained_ckpt) 77 | tokenizer.pad_token = tokenizer.eos_token 78 | tokenizer.padding_side = "right" 79 | 80 | # LoRA config based on QLoRA paper 81 | peft_config = LoraConfig( 82 | lora_alpha=16, 83 | lora_dropout=args.dropout, 84 | r=args.lora_r, 85 | bias="none", 86 | task_type="CAUSAL_LM", 87 | ) 88 | 89 | # prepare model for training 90 | model = prepare_model_for_kbit_training(model) 91 | model = get_peft_model(model, peft_config) 92 | 93 | results_dir = f"experiments/summarization_epochs-{args.epochs}_rank-{args.lora_r}_dropout-{args.dropout}" 94 | 95 | training_args = TrainingArguments( 96 | output_dir=results_dir, 97 | logging_dir=f"{results_dir}/logs", 98 | num_train_epochs=args.epochs, 99 | per_device_train_batch_size=4, 100 | gradient_accumulation_steps=2, 101 | gradient_checkpointing=True, 102 | optim="paged_adamw_32bit", 103 | logging_steps=100, 104 | learning_rate=2e-4, 105 | bf16=True, 106 | tf32=True, 107 | max_grad_norm=0.3, 108 | warmup_ratio=0.03, 109 | lr_scheduler_type="constant", 110 | report_to="none", 111 | # disable_tqdm=True # disable tqdm since with packing values are in correct 112 | ) 113 | 114 | max_seq_length = 512 # max sequence length for model and packing of the dataset 115 | 116 | trainer = SFTTrainer( 117 | model=model, 118 | train_dataset=train_dataset, 119 | peft_config=peft_config, 120 | max_seq_length=max_seq_length, 121 | tokenizer=tokenizer, 122 | packing=True, 123 | args=training_args, 124 | dataset_text_field="instructions", 125 | ) 126 | 127 | trainer_stats = trainer.train() 128 | train_loss = trainer_stats.training_loss 129 | print(f"Training loss:{train_loss}") 130 | 131 | peft_model_id = f"{results_dir}/assets" 132 | trainer.model.save_pretrained(peft_model_id) 133 | tokenizer.save_pretrained(peft_model_id) 134 | 135 | with open(f"{results_dir}/results.pkl", "wb") as handle: 136 | run_result = [ 137 | args.epochs, 138 | args.lora_r, 139 | args.dropout, 140 | train_loss, 141 | ] 142 | pickle.dump(run_result, handle) 143 | print("Experiment over") 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument("--pretrained_ckpt", default="mistralai/Mistral-7B-v0.1") 149 | parser.add_argument("--lora_r", default=64, type=int) 150 | parser.add_argument("--epochs", default=1, type=int) 151 | parser.add_argument("--dropout", default=0.1, type=float) 152 | 153 | args = parser.parse_args() 154 | main(args) 155 | -------------------------------------------------------------------------------- /mistral/mistral_summarization_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import pandas as pd 5 | import evaluate 6 | from datasets import load_dataset 7 | import pickle 8 | import warnings 9 | 10 | from peft import AutoPeftModelForCausalLM 11 | from transformers import AutoTokenizer 12 | 13 | from prompts import INFERENCE_SUMMARIZATION_PROMPT_v2 14 | 15 | metric = evaluate.load("rouge") 16 | warnings.filterwarnings("ignore") 17 | 18 | 19 | def prepare_instructions(dialogues, summaries): 20 | instructions = [] 21 | 22 | prompt = INFERENCE_SUMMARIZATION_PROMPT_v2 23 | 24 | for dialogue, summary in zip(dialogues, summaries): 25 | example = prompt.format( 26 | dialogue=dialogue, 27 | ) 28 | instructions.append(example) 29 | 30 | return instructions 31 | 32 | 33 | def prepare_samsum_data(): 34 | dataset = load_dataset("samsum") 35 | val_dataset = dataset["test"] 36 | 37 | dialogues = val_dataset["dialogue"] 38 | summaries = val_dataset["summary"] 39 | val_instructions = prepare_instructions(dialogues, summaries) 40 | 41 | return val_instructions, summaries 42 | 43 | 44 | def main(args): 45 | val_instructions, summaries = prepare_samsum_data() 46 | 47 | experiment = args.experiment_dir 48 | peft_model_id = f"{experiment}/assets" 49 | 50 | # load base LLM model and tokenizer 51 | model = AutoPeftModelForCausalLM.from_pretrained( 52 | peft_model_id, 53 | low_cpu_mem_usage=True, 54 | torch_dtype=torch.float16, 55 | load_in_4bit=True, 56 | ) 57 | tokenizer = AutoTokenizer.from_pretrained(peft_model_id) 58 | 59 | results = [] 60 | for instruct, summary in zip(val_instructions, summaries): 61 | input_ids = tokenizer( 62 | instruct, return_tensors="pt", truncation=True 63 | ).input_ids.cuda() 64 | with torch.inference_mode(): 65 | outputs = model.generate( 66 | input_ids=input_ids, 67 | max_new_tokens=100, 68 | do_sample=True, 69 | top_p=0.9, 70 | temperature=1e-2, 71 | ) 72 | result = tokenizer.batch_decode( 73 | outputs.detach().cpu().numpy(), skip_special_tokens=True 74 | )[0] 75 | result = result[len(instruct) :] 76 | results.append(result) 77 | print(f"Instruction:{instruct}") 78 | print(f"Summary:{summary}") 79 | print(f"Generated:{result}") 80 | print("----------------------------------------") 81 | 82 | # compute metric 83 | rouge = metric.compute(predictions=results, references=summaries, use_stemmer=True) 84 | 85 | metrics = {metric: round(rouge[metric] * 100, 2) for metric in rouge.keys()} 86 | 87 | save_dir = os.path.join(experiment, "metrics") 88 | if not os.path.exists(save_dir): 89 | os.makedirs(save_dir) 90 | 91 | with open(os.path.join(save_dir, "metrics.pkl"), "wb") as handle: 92 | pickle.dump(metrics, handle) 93 | 94 | print(f"Completed experiment {peft_model_id}") 95 | print("----------------------------------------") 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser() 100 | parser.add_argument( 101 | "--experiment_dir", 102 | default="experiments/summarization_epochs-1_rank-64_dropout-0.1", 103 | ) 104 | 105 | args = parser.parse_args() 106 | main(args) 107 | -------------------------------------------------------------------------------- /mistral/prompts.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import datasets 3 | from datasets import load_dataset 4 | from sklearn.model_selection import train_test_split 5 | 6 | 7 | ZERO_SHOT_CLASSIFIER_PROMPT = """Classify the sentence into one of 20 classes. The list of classes is provided below, where the classes are separated by commas: 8 | 9 | {newsgroup_classes} 10 | 11 | From the above list of classes, select only one class that the provided sentence can be classified into. The sentence will be delimited with triple backticks. Once again, only predict the class from the given list of classes. Do not predict anything else. 12 | 13 | ### Sentence: ```{sentence}``` 14 | ### Class: 15 | """ 16 | 17 | FEW_SHOT_CLASSIFIER_PROMPT = """Classify the sentence into one of 20 classes. The list of classes is provided below, where the classes are separated by commas: 18 | 19 | {newsgroup_classes} 20 | 21 | From the above list of classes, select only one class that the provided sentence can be classified into. Once again, only predict the class from the given list of classes. Do not predict anything else. The sentence will be delimited with triple backticks. To help you, examples are provided of sentence and the corresponding class they belong to. 22 | 23 | {few_shot_samples} 24 | 25 | ### Sentence: ```{sentence}``` 26 | ### Class: 27 | """ 28 | 29 | TRAINING_CLASSIFIER_PROMPT = """Classify the following sentence that is delimited with triple backticks. 30 | 31 | ### Sentence: ```{sentence}``` 32 | ### Class: {label} 33 | """ 34 | 35 | INFERENCE_CLASSIFIER_PROMPT = """Classify the following sentence that is delimited with triple backticks. 36 | 37 | ### Sentence: ```{sentence}``` 38 | ### Class: 39 | """ 40 | 41 | TRAINING_CLASSIFIER_PROMPT_v2 = """### Sentence:{sentence} ### Class:{label}""" 42 | INFERENCE_CLASSIFIER_PROMPT_v2 = """### Sentence:{sentence} ### Class:""" 43 | 44 | ZERO_SHOT_SUMMARIZATION_PROMPT = """Summarize the following dialogue that is delimited with triple backticks. 45 | 46 | ### Dialogue: ```{dialogue}``` 47 | ### Summary: 48 | """ 49 | 50 | FEW_SHOT_SUMMARIZATION_PROMPT = """Summarize the following dialogue that is delimited with triple backticks. To help you, examples of summarization are provided. 51 | 52 | {few_shot_samples} 53 | 54 | ### Dialogue: ```{dialogue}``` 55 | ### Summary: 56 | """ 57 | 58 | TRAINING_SUMMARIZATION_PROMPT = """Summarize the following dialogue that is delimited with triple backticks. 59 | 60 | ### Dialogue: ```{dialogue}``` 61 | ### Summary: {summary} 62 | """ 63 | 64 | TRAINING_SUMMARIZATION_PROMPT_v2 = """### Dialogue:{dialogue} ### Summary:{summary}""" 65 | INFERENCE_SUMMARIZATION_PROMPT_v2 = """### Dialogue:{dialogue} ### Summary:""" 66 | 67 | INFERENCE_SUMMARIZATION_PROMPT = """Summarize the following dialogue that is delimited with triple backticks. 68 | 69 | ### Dialogue: ```{dialogue}``` 70 | ### Summary: 71 | """ 72 | 73 | 74 | def get_newsgroup_instruction_data(mode, texts, labels): 75 | if mode == "train": 76 | prompt = TRAINING_CLASSIFIER_PROMPT_v2 77 | elif mode == "inference": 78 | prompt = INFERENCE_CLASSIFIER_PROMPT_v2 79 | 80 | instructions = [] 81 | 82 | for text, label in zip(texts, labels): 83 | if mode == "train": 84 | example = prompt.format( 85 | sentence=text, 86 | label=label, 87 | ) 88 | elif mode == "inference": 89 | example = prompt.format( 90 | sentence=text, 91 | ) 92 | instructions.append(example) 93 | 94 | return instructions 95 | 96 | 97 | def clean_newsgroup_data(texts, labels): 98 | label2data = {} 99 | clean_data, clean_labels = [], [] 100 | for data, label in zip(texts, labels): 101 | if isinstance(data, str) and isinstance(label, str): 102 | clean_data.append(data) 103 | clean_labels.append(label) 104 | 105 | if label not in label2data: 106 | label2data[label] = data 107 | 108 | return label2data, clean_data, clean_labels 109 | 110 | 111 | def get_newsgroup_data_for_ft(mode="train", train_sample_fraction=0.99): 112 | newsgroup_dataset = load_dataset("rungalileo/20_Newsgroups_Fixed") 113 | train_data = newsgroup_dataset["train"]["text"] 114 | train_labels = newsgroup_dataset["train"]["label"] 115 | label2data, train_data, train_labels = clean_newsgroup_data( 116 | train_data, train_labels 117 | ) 118 | 119 | test_data = newsgroup_dataset["test"]["text"] 120 | test_labels = newsgroup_dataset["test"]["label"] 121 | _, test_data, test_labels = clean_newsgroup_data(test_data, test_labels) 122 | 123 | # sample n points from training data 124 | train_df = pd.DataFrame(data={"text": train_data, "label": train_labels}) 125 | train_df, _ = train_test_split( 126 | train_df, 127 | train_size=train_sample_fraction, 128 | stratify=train_df["label"], 129 | random_state=42, 130 | ) 131 | train_data = train_df["text"] 132 | train_labels = train_df["label"] 133 | 134 | train_instructions = get_newsgroup_instruction_data(mode, train_data, train_labels) 135 | test_instructions = get_newsgroup_instruction_data(mode, test_data, test_labels) 136 | 137 | train_dataset = datasets.Dataset.from_pandas( 138 | pd.DataFrame( 139 | data={ 140 | "instructions": train_instructions, 141 | "labels": train_labels, 142 | } 143 | ) 144 | ) 145 | test_dataset = datasets.Dataset.from_pandas( 146 | pd.DataFrame( 147 | data={ 148 | "instructions": test_instructions, 149 | "labels": test_labels, 150 | } 151 | ) 152 | ) 153 | 154 | return train_dataset, test_dataset 155 | 156 | 157 | def get_newsgroup_data(): 158 | newsgroup_dataset = load_dataset("rungalileo/20_Newsgroups_Fixed") 159 | train_data = newsgroup_dataset["train"]["text"] 160 | train_labels = newsgroup_dataset["train"]["label"] 161 | 162 | label2data, clean_data, clean_labels = clean_newsgroup_data( 163 | train_data, train_labels 164 | ) 165 | df = pd.DataFrame(data={"text": clean_data, "label": clean_labels}) 166 | 167 | newsgroup_classes = df["label"].unique() 168 | newsgroup_classes = ", ".join(newsgroup_classes) 169 | 170 | few_shot_samples = "" 171 | for label, data in label2data.items(): 172 | sample = f"Sentence: {data} \n Class: {label} \n\n" 173 | few_shot_samples += sample 174 | 175 | return newsgroup_classes, few_shot_samples, df 176 | 177 | 178 | def get_samsum_data(): 179 | samsum_dataset = load_dataset("samsum") 180 | train_dataset = samsum_dataset["train"] 181 | dialogues = train_dataset["dialogue"][:2] 182 | summaries = train_dataset["summary"][:2] 183 | 184 | few_shot_samples = "" 185 | for dialogue, summary in zip(dialogues, summaries): 186 | sample = f"Sentence: {dialogue} \n Summary: {summary} \n\n" 187 | few_shot_samples += sample 188 | 189 | return few_shot_samples 190 | -------------------------------------------------------------------------------- /mistral/run_lora.sh: -------------------------------------------------------------------------------- 1 | epochs=(2 5 10 20 30 50) 2 | lora_r=(2 4 8 16) 3 | dropout=(0.1 0.2 0.5) 4 | 5 | for (( epoch=0; epoch<6; epoch=epoch+1 )) do 6 | for ((r=0; r<4; r=r+1 )) do 7 | for (( d=0; d<3; d=d+1 )) do 8 | python mistral_summarization.py --lora_r ${lora_r[$r]} --epochs ${epochs[$epoch]} --dropout ${dropout[$d]} & wait 9 | done 10 | done 11 | done 12 | -------------------------------------------------------------------------------- /mistral/sample_ablate.sh: -------------------------------------------------------------------------------- 1 | sample_fraction=(0.025 0.05 0.1) 2 | 3 | for (( sf=0; sf<3; sf=sf+1 )) do 4 | python mistral_classification.py --train_sample_fraction ${sample_fraction[$sf]} & wait 5 | done 6 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "llm-toolkit" 3 | version = "0.0.0" 4 | description = "LLM Finetuning resource hub + toolkit" 5 | authors = ["Benjamin Ye "] 6 | license = "Apache 2.0" 7 | readme = "README.md" 8 | packages = [{include = "llmtune"}] 9 | repository = "https://github.com/georgian-io/LLM-Finetuning-Toolkit" 10 | # homepage = "" 11 | # documentation = "" 12 | keywords = ["llm", "finetuning", "language models", "machine learning", "deep learning"] 13 | classifiers = [ 14 | "Intended Audience :: Developers", 15 | "Intended Audience :: Science/Research", 16 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 17 | ] 18 | 19 | 20 | [tool.poetry.scripts] 21 | llmtune = "llmtune.cli.toolkit:cli" 22 | 23 | [tool.poetry-dynamic-versioning] 24 | enable = true 25 | vcs = "git" 26 | style = "semver" 27 | 28 | [tool.poetry-dynamic-versioning.substitution] 29 | folders = [ 30 | { path = "llmtune" } 31 | ] 32 | 33 | [tool.poetry.dependencies] 34 | python = ">=3.9, <=3.12" 35 | transformers = "~4.40.2" 36 | datasets = "^2.17.0" 37 | peft = "^0.8.2" 38 | pandas = "^2.2.0" 39 | numpy = "^1.26.4" 40 | ipdb = "^0.13.13" 41 | evaluate = "^0.4.1" 42 | wandb = "^0.16.3" 43 | einops = "^0.7.0" 44 | bitsandbytes = ">=0.43.2" 45 | nltk = "^3.8.1" 46 | accelerate = "^0.27.0" 47 | trl = "~0.8.6" 48 | rouge-score = "^0.1.2" 49 | absl-py = "^2.1.0" 50 | py7zr = "^0.20.8" 51 | tiktoken = "^0.6.0" 52 | ninja = "^1.11.1.1" 53 | packaging = "^23.2" 54 | sentencepiece = "^0.1.99" 55 | protobuf = "^4.25.2" 56 | ai21 = "^2.0.3" 57 | openai = "^1.12.0" 58 | ujson = "^5.9.0" 59 | pyyaml = "^6.0.1" 60 | ijson = "^3.2.3" 61 | rich = "^13.7.0" 62 | sqids = "^0.4.1" 63 | pydantic = "^2.6.1" 64 | typer = "^0.10.0" 65 | shellingham = "^1.5.4" 66 | langchain = "^0.2.5" 67 | 68 | 69 | [tool.poetry.group.dev.dependencies] 70 | ruff = "~0.3.5" 71 | pytest = "^8.1.1" 72 | pytest-cov = "^5.0.0" 73 | pytest-mock = "^3.14.0" 74 | 75 | [build-system] 76 | requires = ["poetry-core", "poetry-dynamic-versioning>=1.0.0,<2.0.0"] 77 | build-backend = "poetry_dynamic_versioning.backend" 78 | 79 | [tool.ruff] 80 | lint.ignore = ["C901", "E501", "E741", "F402", "F823" ] 81 | lint.select = ["C", "E", "F", "I", "W"] 82 | line-length = 119 83 | exclude = [ 84 | "llama2", 85 | "mistral", 86 | ] 87 | 88 | 89 | [tool.ruff.lint.isort] 90 | lines-after-imports = 2 91 | known-first-party = ["llmtune"] 92 | 93 | [tool.ruff.format] 94 | quote-style = "double" 95 | indent-style = "space" 96 | skip-magic-trailing-comma = false 97 | line-ending = "auto" 98 | 99 | [tool.coverage.run] 100 | omit = [ 101 | # Ignore UI for now as this might change quite often 102 | "llmtune/ui/*", 103 | "llmtune/utils/rich_print_utils.py" 104 | ] 105 | 106 | [tool.coverage.report] 107 | skip_empty = true 108 | exclude_also = [ 109 | "pass", 110 | ] 111 | 112 | [tool.pytest.ini_options] 113 | addopts = "--cov=llmtune --cov-report term-missing" 114 | -------------------------------------------------------------------------------- /test_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/test_utils/__init__.py -------------------------------------------------------------------------------- /test_utils/test_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines a configuration that can be used for unit testing. 3 | """ 4 | 5 | from llmtune.pydantic_models.config_model import ( 6 | AblationConfig, 7 | BitsAndBytesConfig, 8 | Config, 9 | DataConfig, 10 | InferenceConfig, 11 | LoraConfig, 12 | ModelConfig, 13 | QaConfig, 14 | SftArgs, 15 | TrainingArgs, 16 | TrainingConfig, 17 | ) 18 | 19 | 20 | def get_sample_config(): 21 | """Function to return a comprehensive Config object for testing.""" 22 | return Config( 23 | save_dir="./test", 24 | ablation=AblationConfig( 25 | use_ablate=False, 26 | ), 27 | model=ModelConfig( 28 | hf_model_ckpt="NousResearch/Llama-2-7b-hf", 29 | device_map="auto", 30 | torch_dtype="auto", 31 | quantize=False, 32 | bitsandbytes=BitsAndBytesConfig( 33 | load_in_8bit=False, 34 | load_in_4bit=False, 35 | bnb_4bit_compute_dtype="float32", 36 | bnb_4bit_quant_type="nf4", 37 | bnb_4bit_use_double_quant=True, 38 | ), 39 | ), 40 | lora=LoraConfig( 41 | r=8, 42 | task_type="CAUSAL_LM", 43 | lora_alpha=16, 44 | bias="none", 45 | lora_dropout=0.1, 46 | target_modules=None, 47 | fan_in_fan_out=False, 48 | ), 49 | training=TrainingConfig( 50 | training_args=TrainingArgs( 51 | num_train_epochs=1, 52 | per_device_train_batch_size=1, 53 | gradient_accumulation_steps=1, 54 | optim="adamw_8bit", 55 | learning_rate=2.0e-4, 56 | logging_steps=100, 57 | ), 58 | sft_args=SftArgs(max_seq_length=512, neftune_noise_alpha=None), 59 | ), 60 | inference=InferenceConfig( 61 | max_length=128, 62 | do_sample=False, 63 | num_beams=5, 64 | temperature=1.0, 65 | top_k=50, 66 | top_p=1.0, 67 | use_cache=True, 68 | ), 69 | data=DataConfig( 70 | file_type="json", 71 | path="path/to/dataset.json", 72 | prompt="Your prompt here {column_name}", 73 | prompt_stub="Stub for prompt {column_name}", 74 | train_size=0.9, 75 | test_size=0.1, 76 | train_test_split_seed=42, 77 | ), 78 | qa=QaConfig( 79 | llm_metrics=[ 80 | "jaccard_similarity", 81 | "dot_product", 82 | "rouge_score", 83 | "word_overlap", 84 | "verb_percent", 85 | "adjective_percent", 86 | "noun_percent", 87 | "summary_length", 88 | ] 89 | ), 90 | ) 91 | -------------------------------------------------------------------------------- /tests/data/test_dataset_generator.py: -------------------------------------------------------------------------------- 1 | # TODO 2 | -------------------------------------------------------------------------------- /tests/data/test_ingestor.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, mock_open 2 | 3 | import pytest 4 | from datasets import Dataset 5 | 6 | from llmtune.data.ingestor import ( 7 | CsvIngestor, 8 | HuggingfaceIngestor, 9 | JsonIngestor, 10 | JsonlIngestor, 11 | get_ingestor, 12 | ) 13 | 14 | 15 | def test_get_ingestor(): 16 | assert isinstance(get_ingestor("json")(""), JsonIngestor) 17 | assert isinstance(get_ingestor("jsonl")(""), JsonlIngestor) 18 | assert isinstance(get_ingestor("csv")(""), CsvIngestor) 19 | assert isinstance(get_ingestor("huggingface")(""), HuggingfaceIngestor) 20 | 21 | with pytest.raises(ValueError): 22 | get_ingestor("unsupported_type") 23 | 24 | 25 | def test_json_ingestor_to_dataset(mocker): 26 | mock_generator = mocker.patch("llmtune.data.ingestor.JsonIngestor._json_generator") 27 | mock_dataset = mocker.patch("llmtune.data.ingestor.Dataset") 28 | JsonIngestor("").to_dataset() 29 | 30 | mock_dataset.from_generator.assert_called_once_with(mock_generator) 31 | 32 | 33 | def test_jsonl_ingestor_to_dataset(mocker): 34 | mock_generator = mocker.patch("llmtune.data.ingestor.JsonlIngestor._jsonl_generator") 35 | mock_dataset = mocker.patch("llmtune.data.ingestor.Dataset") 36 | JsonlIngestor("").to_dataset() 37 | 38 | mock_dataset.from_generator.assert_called_once_with(mock_generator) 39 | 40 | 41 | def test_csv_ingestor_to_dataset(mocker): 42 | mock_generator = mocker.patch("llmtune.data.ingestor.CsvIngestor._csv_generator") 43 | mock_dataset = mocker.patch("llmtune.data.ingestor.Dataset") 44 | CsvIngestor("").to_dataset() 45 | 46 | mock_dataset.from_generator.assert_called_once_with(mock_generator) 47 | 48 | 49 | def test_huggingface_to_dataset(mocker): 50 | # Setup 51 | path = "some_path" 52 | ingestor = HuggingfaceIngestor(path) 53 | mock_concatenate_datasets = mocker.patch("llmtune.data.ingestor.concatenate_datasets") 54 | mock_load_dataset = mocker.patch("llmtune.data.ingestor.load_dataset") 55 | mock_dataset = mocker.patch("llmtune.data.ingestor.Dataset") 56 | 57 | # Configure the mock objects 58 | mock_dataset = MagicMock(spec=Dataset) 59 | mock_load_dataset.return_value = {"train": mock_dataset, "test": mock_dataset} 60 | mock_concatenate_datasets.return_value = mock_dataset 61 | 62 | # Execute 63 | result = ingestor.to_dataset() 64 | 65 | # Assert 66 | assert isinstance(result, Dataset) 67 | mock_load_dataset.assert_called_once_with(path) 68 | mock_concatenate_datasets.assert_called_once() 69 | 70 | 71 | @pytest.mark.parametrize( 72 | "file_content,expected_output", 73 | [ 74 | ( 75 | '[{"column1": "value1", "column2": "value2"}, {"column1": "value3", "column2": "value4"}]', 76 | [ 77 | {"column1": "value1", "column2": "value2"}, 78 | {"column1": "value3", "column2": "value4"}, 79 | ], 80 | ) 81 | ], 82 | ) 83 | def test_json_ingestor_generator(file_content, expected_output, mocker): 84 | mocker.patch("builtins.open", mock_open(read_data=file_content)) 85 | mocker.patch("ijson.items", side_effect=lambda f, prefix: iter(expected_output)) 86 | ingestor = JsonIngestor("dummy_path.json") 87 | 88 | assert list(ingestor._json_generator()) == expected_output 89 | 90 | 91 | @pytest.mark.parametrize( 92 | "file_content,expected_output", 93 | [ 94 | ( 95 | '{"column1": "value1", "column2": "value2"}\n{"column1": "value3", "column2": "value4"}', 96 | [ 97 | {"column1": "value1", "column2": "value2"}, 98 | {"column1": "value3", "column2": "value4"}, 99 | ], 100 | ) 101 | ], 102 | ) 103 | def test_jsonl_ingestor_generator(file_content, expected_output, mocker): 104 | mocker.patch("builtins.open", mock_open(read_data=file_content)) 105 | mocker.patch( 106 | "ijson.items", 107 | side_effect=lambda f, prefix, multiple_values: (iter(expected_output) if multiple_values else iter([])), 108 | ) 109 | ingestor = JsonlIngestor("dummy_path.jsonl") 110 | 111 | assert list(ingestor._jsonl_generator()) == expected_output 112 | 113 | 114 | @pytest.mark.parametrize( 115 | "file_content,expected_output", 116 | [ 117 | ( 118 | "column1,column2\nvalue1,value2\nvalue3,value4", 119 | [ 120 | {"column1": "value1", "column2": "value2"}, 121 | {"column1": "value3", "column2": "value4"}, 122 | ], 123 | ) 124 | ], 125 | ) 126 | def test_csv_ingestor_generator(file_content, expected_output, mocker): 127 | mocker.patch("builtins.open", mock_open(read_data=file_content)) 128 | ingestor = CsvIngestor("dummy_path.csv") 129 | 130 | assert list(ingestor._csv_generator()) == expected_output 131 | -------------------------------------------------------------------------------- /tests/finetune/test_finetune_generics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from llmtune.finetune.generics import Finetune 4 | 5 | 6 | class MockFinetune(Finetune): 7 | def finetune(self): 8 | return "finetuning complete" 9 | 10 | def save_model(self): 11 | return "model saved" 12 | 13 | 14 | def test_finetune_method(): 15 | mock_finetuner = MockFinetune() 16 | result = mock_finetuner.finetune() 17 | assert result == "finetuning complete" 18 | 19 | 20 | def test_save_model_method(): 21 | mock_finetuner = MockFinetune() 22 | result = mock_finetuner.save_model() 23 | assert result == "model saved" 24 | 25 | 26 | def test_finetune_abstract_class_instantiation(): 27 | with pytest.raises(TypeError): 28 | _ = Finetune() 29 | -------------------------------------------------------------------------------- /tests/finetune/test_finetune_lora.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | from transformers import BitsAndBytesConfig 4 | 5 | from llmtune.finetune.lora import LoRAFinetune 6 | from test_utils.test_config import get_sample_config 7 | 8 | 9 | def test_lora_finetune_initialization(mocker): 10 | """Test the initialization of LoRAFinetune with a sample configuration.""" 11 | # Mock dependencies that LoRAFinetune might call during initialization 12 | mocker.patch("llmtune.finetune.lora.AutoModelForCausalLM.from_pretrained") 13 | mocker.patch("llmtune.finetune.lora.AutoTokenizer.from_pretrained") 14 | mock_lora_config = mocker.patch("llmtune.finetune.lora.LoraConfig") 15 | mocker.patch( 16 | "llmtune.finetune.lora.LoRAFinetune._inject_lora", 17 | return_value=None, # _inject_lora doesn't return a value 18 | ) 19 | 20 | # Initialize LoRAFinetune with the sample configuration 21 | lora_finetune = LoRAFinetune(config=get_sample_config(), directory_helper=MagicMock()) 22 | # Assertions to ensure that LoRAFinetune is initialized as expected 23 | mock_lora_config.assert_called_once_with(**get_sample_config().lora.model_dump()) 24 | 25 | assert lora_finetune.config == get_sample_config(), "Configuration should match the input configuration" 26 | 27 | 28 | def test_model_and_tokenizer_loading(mocker): 29 | # Prepare the configuration 30 | sample_config = get_sample_config() 31 | 32 | mock_model = mocker.patch( 33 | "llmtune.finetune.lora.AutoModelForCausalLM.from_pretrained", 34 | return_value=MagicMock(), 35 | ) 36 | mock_tokenizer = mocker.patch("llmtune.finetune.lora.AutoTokenizer.from_pretrained", return_value=MagicMock()) 37 | mock_inject_lora = mocker.patch( 38 | "llmtune.finetune.lora.LoRAFinetune._inject_lora", 39 | return_value=None, # _inject_lora doesn't return a value 40 | ) 41 | directory_helper = MagicMock() 42 | LoRAFinetune(config=sample_config, directory_helper=directory_helper) 43 | 44 | mock_model.assert_called_once_with( 45 | sample_config.model.hf_model_ckpt, 46 | quantization_config=BitsAndBytesConfig(), 47 | use_cache=False, 48 | device_map=sample_config.model.device_map, 49 | torch_dtype=sample_config.model.casted_torch_dtype, 50 | attn_implementation=sample_config.model.attn_implementation, 51 | ) 52 | 53 | mock_tokenizer.assert_called_once_with(sample_config.model.hf_model_ckpt) 54 | mock_inject_lora.assert_called_once() 55 | 56 | 57 | def test_lora_injection(mocker): 58 | """Test the initialization of LoRAFinetune with a sample configuration.""" 59 | # Mock dependencies that LoRAFinetune might call during initialization 60 | mocker.patch( 61 | "llmtune.finetune.lora.AutoModelForCausalLM.from_pretrained", 62 | return_value=MagicMock(), 63 | ) 64 | mocker.patch( 65 | "llmtune.finetune.lora.AutoTokenizer.from_pretrained", 66 | return_value=MagicMock(), 67 | ) 68 | 69 | mock_kbit = mocker.patch("llmtune.finetune.lora.prepare_model_for_kbit_training") 70 | mock_get_peft = mocker.patch("llmtune.finetune.lora.get_peft_model") 71 | 72 | # Initialize LoRAFinetune with the sample configuration 73 | LoRAFinetune(config=get_sample_config(), directory_helper=MagicMock()) 74 | 75 | mock_kbit.assert_called_once() 76 | mock_get_peft.assert_called_once() 77 | 78 | 79 | def test_model_finetune(mocker): 80 | sample_config = get_sample_config() 81 | 82 | mocker.patch( 83 | "llmtune.finetune.lora.AutoModelForCausalLM.from_pretrained", 84 | return_value=MagicMock(), 85 | ) 86 | mocker.patch("llmtune.finetune.lora.AutoTokenizer.from_pretrained", return_value=MagicMock()) 87 | mocker.patch( 88 | "llmtune.finetune.lora.LoRAFinetune._inject_lora", 89 | return_value=None, # _inject_lora doesn't return a value 90 | ) 91 | 92 | mock_trainer = mocker.MagicMock() 93 | mock_sft_trainer = mocker.patch("llmtune.finetune.lora.SFTTrainer", return_value=mock_trainer) 94 | 95 | directory_helper = MagicMock() 96 | 97 | mock_training_args = mocker.patch( 98 | "llmtune.finetune.lora.TrainingArguments", 99 | return_value=MagicMock(), 100 | ) 101 | 102 | ft = LoRAFinetune(config=sample_config, directory_helper=directory_helper) 103 | 104 | mock_dataset = MagicMock() 105 | ft.finetune(mock_dataset) 106 | 107 | mock_training_args.assert_called_once_with( 108 | logging_dir="/logs", 109 | output_dir=ft._weights_path, 110 | report_to="none", 111 | **sample_config.training.training_args.model_dump(), 112 | ) 113 | 114 | mock_sft_trainer.assert_called_once_with( 115 | model=ft.model, 116 | train_dataset=mock_dataset, 117 | peft_config=ft._lora_config, 118 | tokenizer=ft.tokenizer, 119 | packing=True, 120 | args=mocker.ANY, # You can replace this with the expected TrainingArguments if needed 121 | dataset_text_field="formatted_prompt", 122 | callbacks=mocker.ANY, # You can replace this with the expected callbacks if needed 123 | **sample_config.training.sft_args.model_dump(), 124 | ) 125 | 126 | mock_trainer.train.assert_called_once() 127 | 128 | 129 | def test_save_model(mocker): 130 | # Prepare the configuration and directory helper 131 | sample_config = get_sample_config() 132 | 133 | # Mock dependencies that LoRAFinetune might call during initialization 134 | mocker.patch("llmtune.finetune.lora.AutoModelForCausalLM.from_pretrained") 135 | 136 | mock_tok = mocker.MagicMock() 137 | mocker.patch("llmtune.finetune.lora.AutoTokenizer.from_pretrained", return_value=mock_tok) 138 | mocker.patch( 139 | "llmtune.finetune.lora.LoRAFinetune._inject_lora", 140 | return_value=None, 141 | ) 142 | 143 | directory_helper = MagicMock() 144 | directory_helper.save_paths.weights = "/path/to/weights" 145 | 146 | mock_trainer = mocker.MagicMock() 147 | mocker.patch("llmtune.finetune.lora.SFTTrainer", return_value=mock_trainer) 148 | 149 | ft = LoRAFinetune(config=sample_config, directory_helper=directory_helper) 150 | 151 | mock_dataset = MagicMock() 152 | ft.finetune(mock_dataset) 153 | ft.save_model() 154 | 155 | mock_tok.save_pretrained.assert_called_once_with("/path/to/weights") 156 | mock_trainer.model.save_pretrained.assert_called_once_with("/path/to/weights") 157 | -------------------------------------------------------------------------------- /tests/inference/test_inference_generics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from llmtune.inference.generics import Inference 4 | 5 | 6 | class MockInference(Inference): 7 | def infer_one(self, prompt: str): 8 | return "inferred one" 9 | 10 | def infer_all(self): 11 | return "inferred all" 12 | 13 | 14 | def test_infer_one(): 15 | mock_inference = MockInference() 16 | result = mock_inference.infer_one("") 17 | assert result == "inferred one" 18 | 19 | 20 | def test_infer_all(): 21 | mock_inference = MockInference() 22 | result = mock_inference.infer_all() 23 | assert result == "inferred all" 24 | 25 | 26 | def test_inference_abstract_class_instantiation(): 27 | with pytest.raises(TypeError): 28 | _ = Inference() 29 | -------------------------------------------------------------------------------- /tests/inference/test_inference_lora.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | from datasets import Dataset 4 | from transformers import BitsAndBytesConfig 5 | 6 | from llmtune.inference.lora import LoRAInference 7 | from test_utils.test_config import get_sample_config # Adjust import path as needed 8 | 9 | 10 | def test_lora_inference_initialization(mocker): 11 | # Mock dependencies 12 | mock_model = mocker.patch( 13 | "llmtune.inference.lora.AutoPeftModelForCausalLM.from_pretrained", 14 | return_value=MagicMock(), 15 | ) 16 | mock_tokenizer = mocker.patch("llmtune.inference.lora.AutoTokenizer.from_pretrained", return_value=MagicMock()) 17 | 18 | # Mock configuration and directory helper 19 | config = get_sample_config() 20 | dir_helper = MagicMock(save_paths=MagicMock(results="results_dir", weights="weights_dir")) 21 | test_dataset = Dataset.from_dict( 22 | { 23 | "formatted_prompt": ["prompt1", "prompt2"], 24 | "label_column_name": ["label1", "label2"], 25 | } 26 | ) 27 | 28 | _ = LoRAInference( 29 | test_dataset=test_dataset, 30 | label_column_name="label_column_name", 31 | config=config, 32 | dir_helper=dir_helper, 33 | ) 34 | 35 | mock_model.assert_called_once_with( 36 | "weights_dir", 37 | torch_dtype=config.model.casted_torch_dtype, 38 | quantization_config=BitsAndBytesConfig(), 39 | device_map=config.model.device_map, 40 | attn_implementation=config.model.attn_implementation, 41 | ) 42 | mock_tokenizer.assert_called_once_with("weights_dir", device_map=config.model.device_map) 43 | 44 | 45 | def test_infer_all(mocker): 46 | mocker.patch( 47 | "llmtune.inference.lora.AutoPeftModelForCausalLM.from_pretrained", 48 | return_value=MagicMock(), 49 | ) 50 | mocker.patch("llmtune.inference.lora.AutoTokenizer.from_pretrained", return_value=MagicMock()) 51 | mocker.patch("os.makedirs") 52 | mock_open = mocker.patch("builtins.open", mocker.mock_open()) 53 | mock_csv_writer = mocker.patch("csv.writer") 54 | 55 | mock_infer_one = mocker.patch.object(LoRAInference, "infer_one", return_value="predicted") 56 | 57 | config = get_sample_config() 58 | dir_helper = MagicMock(save_paths=MagicMock(results="results_dir", weights="weights_dir")) 59 | test_dataset = Dataset.from_dict({"formatted_prompt": ["prompt1"], "label_column_name": ["label1"]}) 60 | 61 | inference = LoRAInference( 62 | test_dataset=test_dataset, 63 | label_column_name="label_column_name", 64 | config=config, 65 | dir_helper=dir_helper, 66 | ) 67 | inference.infer_all() 68 | 69 | mock_infer_one.assert_called_once_with("prompt1") 70 | mock_open.assert_called_once_with("results_dir/results.csv", "w", newline="") 71 | mock_csv_writer.assert_called() # You might want to add more specific assertions based on your CSV structure 72 | -------------------------------------------------------------------------------- /tests/qa/test_metric_suite.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pandas import DataFrame 3 | 4 | from llmtune.qa.metric_suite import LLMMetricSuite 5 | from llmtune.qa.qa_metrics import LLMQaMetric 6 | 7 | 8 | @pytest.fixture 9 | def mock_rich_ui(mocker): 10 | return mocker.patch("llmtune.ui.rich_ui.RichUI") 11 | 12 | 13 | @pytest.fixture 14 | def example_data(): 15 | data = { 16 | "Prompt": ["What is 2+2?", "What is the capital of France?"], 17 | "Ground Truth": ["4", "Paris"], 18 | "Predicted": ["3", "Paris"], 19 | } 20 | return DataFrame(data) 21 | 22 | 23 | @pytest.fixture 24 | def mock_csv(mocker, example_data): 25 | mocker.patch("pandas.read_csv", return_value=example_data) 26 | 27 | 28 | class MockQaMetric(LLMQaMetric): 29 | @property 30 | def metric_name(self): 31 | return "Mock Accuracy" 32 | 33 | def get_metric(self, prompt, ground_truth, model_pred) -> int: 34 | return int(ground_truth == model_pred) 35 | 36 | 37 | @pytest.fixture 38 | def mock_metrics(): 39 | return [MockQaMetric()] 40 | 41 | 42 | def test_from_csv(mock_metrics, mock_csv): 43 | suite = LLMMetricSuite.from_csv("dummy_path.csv", mock_metrics) 44 | assert len(suite.metrics) == 1 45 | assert suite.prompts[0] == "What is 2+2?" 46 | 47 | 48 | def test_compute_metrics(mock_metrics, mock_csv): 49 | suite = LLMMetricSuite.from_csv("dummy_path.csv", mock_metrics) 50 | results = suite.compute_metrics() 51 | assert results["Mock Accuracy"] == [0, 1] # Expected results from the mock test 52 | 53 | 54 | def test_save_metric_results(mock_metrics, mocker, mock_csv): 55 | mocker.patch("pandas.DataFrame.to_csv") 56 | test_suite = LLMMetricSuite.from_csv("dummy_path.csv", mock_metrics) 57 | test_suite.save_metric_results("dummy_save_path.csv") 58 | assert DataFrame.to_csv.called # Check if pandas DataFrame to_csv was called 59 | 60 | 61 | def test_print_metric_results(capfd, example_data): 62 | metrics = [MockQaMetric()] 63 | suite = LLMMetricSuite(metrics, example_data["Prompt"], example_data["Ground Truth"], example_data["Predicted"]) 64 | suite.print_metric_results() 65 | out, err = capfd.readouterr() 66 | 67 | assert "0.5000" in out 68 | assert "0.5000" in out 69 | assert "0.7071" in out 70 | -------------------------------------------------------------------------------- /tests/qa/test_qa_metrics.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from llmtune.qa.qa_metrics import ( 4 | AdjectivePercentMetric, 5 | DotProductSimilarityMetric, 6 | JaccardSimilarityMetric, 7 | JSONValidityMetric, 8 | LengthMetric, 9 | NounPercentMetric, 10 | RougeScoreMetric, 11 | VerbPercentMetric, 12 | WordOverlapMetric, 13 | ) 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "test_class,expected_type", 18 | [ 19 | (LengthMetric, int), 20 | (JaccardSimilarityMetric, float), 21 | (DotProductSimilarityMetric, float), 22 | (RougeScoreMetric, float), 23 | (WordOverlapMetric, float), 24 | (VerbPercentMetric, float), 25 | (AdjectivePercentMetric, float), 26 | (NounPercentMetric, float), 27 | (JSONValidityMetric, float), 28 | ], 29 | ) 30 | def test_metric_return_type(test_class, expected_type): 31 | test_instance = test_class() 32 | prompt = "This is a test prompt." 33 | ground_truth = "This is a ground truth sentence." 34 | model_prediction = "This is a model predicted sentence." 35 | 36 | # Depending on the test class, the output could be different. 37 | metric_result = test_instance.get_metric(prompt, ground_truth, model_prediction) 38 | assert isinstance( 39 | metric_result, expected_type 40 | ), f"Expected return type {expected_type}, but got {type(metric_result)}." 41 | 42 | 43 | def test_length_metric(): 44 | test = LengthMetric() 45 | result = test.get_metric("prompt", "short text", "longer text") 46 | assert result == 1, "Length difference should be 1." 47 | 48 | 49 | def test_jaccard_similarity_metric(): 50 | test = JaccardSimilarityMetric() 51 | result = test.get_metric("prompt", "hello world", "world hello") 52 | assert result == 1.0, "Jaccard similarity should be 1.0 for the same words in different orders." 53 | 54 | 55 | def test_dot_product_similarity_metric(): 56 | test = DotProductSimilarityMetric() 57 | result = test.get_metric("prompt", "data", "data") 58 | assert result >= 0, "Dot product similarity should be non-negative." 59 | 60 | 61 | def test_rouge_score_metric(): 62 | test = RougeScoreMetric() 63 | result = test.get_metric("prompt", "the quick brown fox", "the quick brown fox jumps over the lazy dog") 64 | assert result >= 0, "ROUGE precision should be non-negative." 65 | 66 | 67 | def test_word_overlap_metric(): 68 | test = WordOverlapMetric() 69 | result = test.get_metric("prompt", "jump over the moon", "jump around the sun") 70 | assert result >= 0, "Word overlap percentage should be non-negative." 71 | 72 | 73 | def test_verb_percent_metric(): 74 | test = VerbPercentMetric() 75 | result = test.get_metric("prompt", "He eats", "He is eating") 76 | assert result >= 0, "Verb percentage should be non-negative." 77 | 78 | 79 | def test_adjective_percent_metric(): 80 | test = AdjectivePercentMetric() 81 | result = test.get_metric("prompt", "It is beautiful", "It is extremely beautiful") 82 | assert result >= 0, "Adjective percentage should be non-negative." 83 | 84 | 85 | def test_noun_percent_metric(): 86 | test = NounPercentMetric() 87 | result = test.get_metric("prompt", "The cat", "The cat and the dog") 88 | assert result >= 0, "Noun percentage should be non-negative." 89 | 90 | 91 | @pytest.mark.parametrize( 92 | "input_string,expected_value", 93 | [ 94 | ('{"Answer": "The cat"}', 1), 95 | ("{'Answer': 'The cat'}", 0), # Double quotes are required in json 96 | ('{"Answer": "The cat",}', 0), 97 | ('{"Answer": "The cat", "test": "case"}', 1), 98 | ('```json\n{"Answer": "The cat"}\n```', 1), # this json block can still be processed 99 | ('Here is an example of a JSON block: {"Answer": "The cat"}', 0), 100 | ], 101 | ) 102 | def test_json_valid_metric(input_string: str, expected_value: float): 103 | test = JSONValidityMetric() 104 | result = test.get_metric("prompt", "The cat", input_string) 105 | assert result == expected_value, f"JSON validity should be {expected_value} but got {result}." 106 | -------------------------------------------------------------------------------- /tests/qa/test_qa_tests.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from llmtune.qa.qa_tests import ( 4 | JSONValidityTest, 5 | ) 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "test_class", 10 | [ 11 | JSONValidityTest, 12 | ], 13 | ) 14 | def test_test_return_bool(test_class): 15 | """Test to ensure that all tests return pass/fail boolean value.""" 16 | test_instance = test_class() 17 | model_prediction = "This is a model predicted sentence." 18 | 19 | metric_result = test_instance.test(model_prediction) 20 | assert isinstance(metric_result, bool), f"Expected return type bool, but got {type(metric_result)}." 21 | 22 | 23 | @pytest.mark.parametrize( 24 | "input_string,expected_value", 25 | [ 26 | ('{"Answer": "The cat"}', True), 27 | ("{'Answer': 'The cat'}", False), # Double quotes are required in json 28 | ('{"Answer": "The cat",}', False), # Trailing comma is not allowed 29 | ('{"Answer": "The cat", "test": "case"}', True), 30 | ('```json\n{"Answer": "The cat"}\n```', True), # this json block can still be processed 31 | ('Here is an example of a JSON block: {"Answer": "The cat"}', False), 32 | ], 33 | ) 34 | def test_json_valid_metric(input_string: str, expected_value: bool): 35 | test = JSONValidityTest() 36 | result = test.test(input_string) 37 | assert result == expected_value, f"JSON validity should be {expected_value} but got {result}." 38 | -------------------------------------------------------------------------------- /tests/qa/test_test_suite.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | from pandas import DataFrame 5 | 6 | from llmtune.qa.qa_tests import LLMQaTest 7 | from llmtune.qa.test_suite import LLMTestSuite, TestBank, all_same 8 | 9 | 10 | @pytest.fixture 11 | def mock_rich_ui(mocker): 12 | return mocker.patch("llmtune.ui.rich_ui.RichUI") 13 | 14 | 15 | @pytest.fixture 16 | def example_data(): 17 | data = { 18 | "Prompt": ["What is 2+2?", "What is the capital of France?"], 19 | "Ground Truth": ["4", "Paris"], 20 | "Predicted": ["3", "Paris"], 21 | } 22 | return DataFrame(data) 23 | 24 | 25 | @pytest.fixture 26 | def mock_csv(mocker, example_data): 27 | mocker.patch("pandas.read_csv", return_value=example_data) 28 | 29 | 30 | # mock a LoRAInference object that returns a value when .infer_one() is called 31 | class MockLoRAInference: 32 | def infer_one(self, prompt: str) -> str: 33 | return "Paris" 34 | 35 | 36 | @pytest.mark.parametrize( 37 | "data, expected", 38 | [ 39 | (["a", "a", "a"], True), 40 | (["a", "b", "a"], False), 41 | ([], False), 42 | ], 43 | ) 44 | def test_all_same(data, expected): 45 | assert all_same(data) == expected 46 | 47 | 48 | @pytest.fixture 49 | def mock_cases(): 50 | return [ 51 | {"prompt": "What is the capital of France?"}, 52 | {"prompt": "What is the capital of Germany?"}, 53 | ] 54 | 55 | 56 | class MockQaTest(LLMQaTest): 57 | @property 58 | def test_name(self): 59 | return "Mock Accuracy" 60 | 61 | def test(self, model_pred) -> bool: 62 | return model_pred == "Paris" 63 | 64 | 65 | @pytest.fixture 66 | def mock_test_banks(mock_cases): 67 | return [ 68 | TestBank(MockQaTest(), mock_cases, "mock_file_name_stem"), 69 | TestBank(MockQaTest(), mock_cases, "mock_file_name_stem"), 70 | ] 71 | 72 | 73 | def test_test_bank_save_test_results(mocker, mock_cases): 74 | mocker.patch("pandas.DataFrame.to_csv") 75 | test_bank = TestBank(MockQaTest(), mock_cases, "mock_file_name_stem") 76 | test_bank.generate_results(MockLoRAInference()) 77 | test_bank.save_test_results(Path("mock/dir/path")) 78 | assert DataFrame.to_csv.called # Check if pandas DataFrame to_csv was called 79 | 80 | 81 | def test_test_suite_save_test_results(mocker, mock_test_banks): 82 | mocker.patch("pandas.DataFrame.to_csv") 83 | ts = LLMTestSuite(mock_test_banks) 84 | ts.run_inference(MockLoRAInference()) 85 | ts.save_test_results(Path("mock/dir/path/doesnt/exist")) 86 | assert DataFrame.to_csv.called # Check if pandas DataFrame to_csv was called 87 | 88 | 89 | def test_test_suite_from_dir(): 90 | ts = LLMTestSuite.from_dir("examples/test_suite") 91 | ts.run_inference(MockLoRAInference()) 92 | 93 | 94 | def test_test_suite_print_results(capfd, mock_test_banks): 95 | ts = LLMTestSuite(mock_test_banks) 96 | ts.run_inference(MockLoRAInference()) 97 | ts.print_test_results() 98 | out, _ = capfd.readouterr() 99 | assert "Mock Accuracy" in out 100 | -------------------------------------------------------------------------------- /tests/test_ablation_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import BaseModel 3 | 4 | from llmtune.utils.ablation_utils import ( 5 | generate_permutations, 6 | get_annotation, 7 | get_model_field_type, 8 | get_types_from_dict, 9 | patch_with_permutation, 10 | ) 11 | 12 | 13 | # Mocks or fixtures for models and data if necessary 14 | class BarModel(BaseModel): 15 | baz: list 16 | qux: str 17 | 18 | 19 | class ConfigModel(BaseModel): 20 | foo: int 21 | bar: BarModel 22 | 23 | 24 | @pytest.fixture 25 | def example_yaml(): 26 | return {"foo": 10, "bar": {"baz": [[1, 2, 3], [4, 5]], "qux": ["hello", "world"]}} 27 | 28 | 29 | def test_get_types_from_dict(example_yaml): 30 | expected = {"foo": (int, None), "bar.baz": (list, list), "bar.qux": (list, str)} 31 | assert get_types_from_dict(example_yaml) == expected 32 | 33 | 34 | def test_get_annotation(): 35 | key = "foo" 36 | assert get_annotation(key, ConfigModel) == int 37 | 38 | key_nested = "bar.qux" 39 | # Assuming you adjust your FakeModel or real implementation for nested annotations correctly 40 | assert get_annotation(key_nested, ConfigModel) == str 41 | 42 | 43 | def test_get_model_field_type_from_typing_list(): 44 | from typing import List 45 | 46 | annotation = List[int] 47 | assert get_model_field_type(annotation) == list 48 | 49 | 50 | def test_get_model_field_type_from_union(): 51 | from typing import Union 52 | 53 | annotation = Union[int, str] 54 | # Assuming the first type is picked from Union for simplicity 55 | assert get_model_field_type(annotation) == int 56 | 57 | 58 | def test_patch_with_permutation(): 59 | old_dict = {"foo": {"bar": 10, "baz": 20}} 60 | permutation_dict = {"foo.bar": 100} 61 | expected = {"foo": {"bar": 100, "baz": 20}} 62 | assert patch_with_permutation(old_dict, permutation_dict) == expected 63 | 64 | 65 | def test_generate_permutations(example_yaml): 66 | results = generate_permutations(example_yaml, ConfigModel) 67 | assert isinstance(results, list) 68 | 69 | # Calculate expected permutations 70 | expected_permutation_count = len(example_yaml["bar"]["baz"]) * len(example_yaml["bar"]["qux"]) 71 | assert len(results) == expected_permutation_count 72 | 73 | for _, result_dict in enumerate(results): 74 | assert result_dict["foo"] == example_yaml["foo"] # 'foo' should remain unchanged 75 | assert result_dict["bar"]["baz"] in example_yaml["bar"]["baz"] 76 | assert result_dict["bar"]["qux"] in example_yaml["bar"]["qux"] 77 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import patch 2 | 3 | from typer.testing import CliRunner 4 | 5 | from llmtune.cli.toolkit import app, cli 6 | 7 | 8 | runner = CliRunner() 9 | 10 | 11 | def test_run_command(): 12 | # Test the `run` command 13 | with patch("llmtune.cli.toolkit.run_one_experiment") as mock_run_one_experiment: 14 | result = runner.invoke(app, ["run", "./llmtune/config.yml"]) 15 | assert result.exit_code == 0 16 | mock_run_one_experiment.assert_called_once() 17 | 18 | 19 | def test_generate_config_command(): 20 | # Test the `generate config` command 21 | with patch("llmtune.cli.toolkit.shutil.copy") as mock_copy: 22 | result = runner.invoke(app, ["generate", "config"]) 23 | assert result.exit_code == 0 24 | mock_copy.assert_called_once() 25 | 26 | 27 | def test_cli(): 28 | # Test the `cli` function 29 | with patch("llmtune.cli.toolkit.app") as mock_app: 30 | cli() 31 | mock_app.assert_called_once() 32 | -------------------------------------------------------------------------------- /tests/test_directory_helper.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/georgian-io/LLM-Finetuning-Toolkit/1593c3ca14a99ba98518c051eb22d80e51b625d7/tests/test_directory_helper.py -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | from llmtune import __version__ 2 | 3 | 4 | def test_version(): 5 | assert __version__ == "0.0.0" 6 | --------------------------------------------------------------------------------