├── .github └── workflows │ ├── docker_build_and_push.yaml │ └── pypi_publish.yaml ├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── assets ├── benchmarks-openchat-3.6-20240522.svg ├── embeddings.svg ├── logo_new.png ├── openchat-3.6-20240522.png ├── openchat-bench-0106.png ├── openchat.png ├── openchat_grok.png ├── vicuna_gpt35.svg └── vicuna_gpt4.svg ├── docker └── serving │ ├── Dockerfile │ └── start.sh ├── ochat ├── __init__.py ├── config │ ├── __init__.py │ ├── conversation_template.py │ └── model_config.py ├── data │ └── generate_dataset.py ├── evaluation │ ├── README.md │ ├── conv_eval.py │ ├── convert_to_evalplus.py │ ├── eval_data │ │ ├── coding │ │ │ └── humaneval │ │ │ │ └── humaneval.jsonl │ │ ├── fs_cothub │ │ │ ├── bbh │ │ │ │ ├── boolean_expressions.jsonl │ │ │ │ ├── causal_judgement.jsonl │ │ │ │ ├── date_understanding.jsonl │ │ │ │ ├── disambiguation_qa.jsonl │ │ │ │ ├── dyck_languages.jsonl │ │ │ │ ├── formal_fallacies.jsonl │ │ │ │ ├── geometric_shapes.jsonl │ │ │ │ ├── hyperbaton.jsonl │ │ │ │ ├── logical_deduction_five_objects.jsonl │ │ │ │ ├── logical_deduction_seven_objects.jsonl │ │ │ │ ├── logical_deduction_three_objects.jsonl │ │ │ │ ├── movie_recommendation.jsonl │ │ │ │ ├── multistep_arithmetic_two.jsonl │ │ │ │ ├── navigate.jsonl │ │ │ │ ├── object_counting.jsonl │ │ │ │ ├── penguins_in_a_table.jsonl │ │ │ │ ├── reasoning_about_colored_objects.jsonl │ │ │ │ ├── ruin_names.jsonl │ │ │ │ ├── salient_translation_error_detection.jsonl │ │ │ │ ├── snarks.jsonl │ │ │ │ ├── sports_understanding.jsonl │ │ │ │ ├── temporal_sequences.jsonl │ │ │ │ ├── tracking_shuffled_objects_five_objects.jsonl │ │ │ │ ├── tracking_shuffled_objects_seven_objects.jsonl │ │ │ │ ├── tracking_shuffled_objects_three_objects.jsonl │ │ │ │ ├── web_of_lies.jsonl │ │ │ │ └── word_sorting.jsonl │ │ │ ├── gsm8k │ │ │ │ └── gsm8k.jsonl │ │ │ ├── math │ │ │ │ └── MATH.jsonl │ │ │ └── mmlu │ │ │ │ ├── abstract_algebra.jsonl │ │ │ │ ├── anatomy.jsonl │ │ │ │ ├── astronomy.jsonl │ │ │ │ ├── business_ethics.jsonl │ │ │ │ ├── clinical_knowledge.jsonl │ │ │ │ ├── college_biology.jsonl │ │ │ │ ├── college_chemistry.jsonl │ │ │ │ ├── college_computer_science.jsonl │ │ │ │ ├── college_mathematics.jsonl │ │ │ │ ├── college_medicine.jsonl │ │ │ │ ├── college_physics.jsonl │ │ │ │ ├── computer_security.jsonl │ │ │ │ ├── conceptual_physics.jsonl │ │ │ │ ├── econometrics.jsonl │ │ │ │ ├── electrical_engineering.jsonl │ │ │ │ ├── elementary_mathematics.jsonl │ │ │ │ ├── formal_logic.jsonl │ │ │ │ ├── global_facts.jsonl │ │ │ │ ├── high_school_biology.jsonl │ │ │ │ ├── high_school_chemistry.jsonl │ │ │ │ ├── high_school_computer_science.jsonl │ │ │ │ ├── high_school_european_history.jsonl │ │ │ │ ├── high_school_geography.jsonl │ │ │ │ ├── high_school_government_and_politics.jsonl │ │ │ │ ├── high_school_macroeconomics.jsonl │ │ │ │ ├── high_school_mathematics.jsonl │ │ │ │ ├── high_school_microeconomics.jsonl │ │ │ │ ├── high_school_physics.jsonl │ │ │ │ ├── high_school_psychology.jsonl │ │ │ │ ├── high_school_statistics.jsonl │ │ │ │ ├── high_school_us_history.jsonl │ │ │ │ ├── high_school_world_history.jsonl │ │ │ │ ├── human_aging.jsonl │ │ │ │ ├── human_sexuality.jsonl │ │ │ │ ├── international_law.jsonl │ │ │ │ ├── jurisprudence.jsonl │ │ │ │ ├── logical_fallacies.jsonl │ │ │ │ ├── machine_learning.jsonl │ │ │ │ ├── management.jsonl │ │ │ │ ├── marketing.jsonl │ │ │ │ ├── medical_genetics.jsonl │ │ │ │ ├── miscellaneous.jsonl │ │ │ │ ├── moral_disputes.jsonl │ │ │ │ ├── moral_scenarios.jsonl │ │ │ │ ├── nutrition.jsonl │ │ │ │ ├── philosophy.jsonl │ │ │ │ ├── prehistory.jsonl │ │ │ │ ├── professional_accounting.jsonl │ │ │ │ ├── professional_law.jsonl │ │ │ │ ├── professional_medicine.jsonl │ │ │ │ ├── professional_psychology.jsonl │ │ │ │ ├── public_relations.jsonl │ │ │ │ ├── security_studies.jsonl │ │ │ │ ├── sociology.jsonl │ │ │ │ ├── us_foreign_policy.jsonl │ │ │ │ ├── virology.jsonl │ │ │ │ └── world_religions.jsonl │ │ └── zs │ │ │ ├── agieval │ │ │ ├── aqua-rat.zero-shot.jsonl │ │ │ ├── logiqa-en.zero-shot.jsonl │ │ │ ├── lsat-ar.zero-shot.jsonl │ │ │ ├── lsat-lr.zero-shot.jsonl │ │ │ ├── lsat-rc.zero-shot.jsonl │ │ │ ├── sat-en-without-passage.zero-shot.jsonl │ │ │ ├── sat-en.zero-shot.jsonl │ │ │ └── sat-math.zero-shot.jsonl │ │ │ ├── bbh_mc_orca │ │ │ ├── boolean_expressions.jsonl │ │ │ ├── causal_judgment.jsonl │ │ │ ├── date_understanding.jsonl │ │ │ ├── disambiguation_qa.jsonl │ │ │ ├── formal_fallacies_syllogisms_negation.jsonl │ │ │ ├── geometric_shapes.jsonl │ │ │ ├── hyperbaton.jsonl │ │ │ ├── logical_deduction_five_objects.jsonl │ │ │ ├── logical_deduction_seven_objects.jsonl │ │ │ ├── logical_deduction_three_objects.jsonl │ │ │ ├── movie_recommendation.jsonl │ │ │ ├── navigate.jsonl │ │ │ ├── penguins_in_a_table.jsonl │ │ │ ├── reasoning_about_colored_objects.jsonl │ │ │ ├── ruin_names.jsonl │ │ │ ├── salient_translation_error_detection.jsonl │ │ │ ├── snarks.jsonl │ │ │ ├── sports_understanding.jsonl │ │ │ ├── temporal_sequences.jsonl │ │ │ ├── tracking_shuffled_objects_five_objects.jsonl │ │ │ ├── tracking_shuffled_objects_seven_objects.jsonl │ │ │ ├── tracking_shuffled_objects_three_objects.jsonl │ │ │ └── web_of_lies.jsonl │ │ │ ├── gpqa │ │ │ └── diamond.jsonl │ │ │ └── truthfulqa_orca │ │ │ └── truthfulqa_mc.jsonl │ ├── grading │ │ ├── math_grader.py │ │ └── math_normalize.py │ ├── match_answer.py │ ├── run_eval.py │ └── view_results.py ├── experimental │ ├── generate_dataset_old.py │ ├── sharegpt.ipynb │ ├── test_multipack_dataloader.ipynb │ ├── text_length.ipynb │ ├── train_alpaca.py │ ├── verify_dataset.ipynb │ └── verify_dataset_orca.ipynb ├── models │ ├── __init__.py │ ├── unpadded_gemma.py │ ├── unpadded_llama.py │ └── unpadded_mistral.py ├── scripts │ ├── hf_add_tokens.py │ ├── init_special_embedding_llama3.py │ └── modify_eos_embedding.py ├── serving │ ├── async_tokenizer.py │ ├── openai_api_protocol.py │ └── openai_api_server.py ├── tests │ └── test_model_config.py └── training_deepspeed │ ├── deepspeed_config.json │ ├── hf_hub.py │ ├── multipack_sampler.py │ ├── openchat_dataset.py │ └── train.py ├── pyproject.toml └── pytest.ini /.github/workflows/docker_build_and_push.yaml: -------------------------------------------------------------------------------- 1 | name: Docker Build and Push 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Publish to PyPI.org"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | docker: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v4 15 | 16 | - name: Set up QEMU 17 | uses: docker/setup-qemu-action@v3 18 | 19 | - name: Set up Docker Buildx 20 | uses: docker/setup-buildx-action@v3 21 | 22 | - name: Login to Docker Hub 23 | uses: docker/login-action@v3 24 | with: 25 | username: ${{ secrets.DOCKERHUB_USERNAME }} 26 | password: ${{ secrets.DOCKERHUB_TOKEN }} 27 | 28 | - name: Build and push 29 | uses: docker/build-push-action@v5 30 | with: 31 | context: ./docker/serving 32 | platforms: linux/amd64 33 | push: true 34 | tags: ${{ secrets.DOCKERHUB_TAG }} 35 | -------------------------------------------------------------------------------- /.github/workflows/pypi_publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI.org 2 | on: 3 | release: 4 | types: [published] 5 | jobs: 6 | pypi: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - name: Checkout 10 | uses: actions/checkout@v3 11 | with: 12 | fetch-depth: 0 13 | - run: python3 -m pip install --upgrade build && python3 -m build 14 | - name: Publish package 15 | uses: pypa/gh-action-pypi-publish@release/v1 16 | with: 17 | password: ${{ secrets.PYPI_API_TOKEN }} 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # VSCode 2 | .vscode/ 3 | 4 | # WandB 5 | wandb/ 6 | 7 | # Old 8 | old/ 9 | temp/ 10 | profiler/ 11 | 12 | # Logs 13 | logs/ 14 | 15 | # eval 16 | eval_results/ 17 | evalplus_codegen/ 18 | 19 | # All datasets 20 | dataset/ 21 | dataset_processed/ 22 | dataset_processed_*/ 23 | tokenizer/ 24 | 25 | # All evaluation results 26 | eval_baselines/ 27 | eval_results/ 28 | eval_results_temp/ 29 | 30 | # Byte-compiled / optimized / DLL files 31 | __pycache__/ 32 | *.py[cod] 33 | *$py.class 34 | 35 | # C extensions 36 | *.so 37 | 38 | # Distribution / packaging 39 | .Python 40 | build/ 41 | develop-eggs/ 42 | dist/ 43 | downloads/ 44 | eggs/ 45 | .eggs/ 46 | lib/ 47 | lib64/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | wheels/ 52 | share/python-wheels/ 53 | *.egg-info/ 54 | .installed.cfg 55 | *.egg 56 | MANIFEST 57 | 58 | # PyInstaller 59 | # Usually these files are written by a python script from a template 60 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 61 | *.manifest 62 | *.spec 63 | 64 | # Installer logs 65 | pip-log.txt 66 | pip-delete-this-directory.txt 67 | 68 | # Unit test / coverage reports 69 | htmlcov/ 70 | .tox/ 71 | .nox/ 72 | .coverage 73 | .coverage.* 74 | .cache 75 | nosetests.xml 76 | coverage.xml 77 | *.cover 78 | *.py,cover 79 | .hypothesis/ 80 | .pytest_cache/ 81 | cover/ 82 | 83 | # Translations 84 | *.mo 85 | *.pot 86 | 87 | # Django stuff: 88 | *.log 89 | local_settings.py 90 | db.sqlite3 91 | db.sqlite3-journal 92 | 93 | # Flask stuff: 94 | instance/ 95 | .webassets-cache 96 | 97 | # Scrapy stuff: 98 | .scrapy 99 | 100 | # Sphinx documentation 101 | docs/_build/ 102 | 103 | # PyBuilder 104 | .pybuilder/ 105 | target/ 106 | 107 | # Jupyter Notebook 108 | .ipynb_checkpoints 109 | 110 | # IPython 111 | profile_default/ 112 | ipython_config.py 113 | 114 | # pyenv 115 | # For a library or package, you might want to ignore these files since the code is 116 | # intended to run in multiple environments; otherwise, check them in: 117 | # .python-version 118 | 119 | # pipenv 120 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 121 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 122 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 123 | # install all needed dependencies. 124 | #Pipfile.lock 125 | 126 | # poetry 127 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 128 | # This is especially recommended for binary packages to ensure reproducibility, and is more 129 | # commonly ignored for libraries. 130 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 131 | #poetry.lock 132 | 133 | # pdm 134 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 135 | #pdm.lock 136 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 137 | # in version control. 138 | # https://pdm.fming.dev/#use-with-ide 139 | .pdm.toml 140 | 141 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 142 | __pypackages__/ 143 | 144 | # Celery stuff 145 | celerybeat-schedule 146 | celerybeat.pid 147 | 148 | # SageMath parsed files 149 | *.sage.py 150 | 151 | # Environments 152 | .env 153 | .venv 154 | env/ 155 | venv/ 156 | ENV/ 157 | env.bak/ 158 | venv.bak/ 159 | 160 | # Spyder project settings 161 | .spyderproject 162 | .spyproject 163 | 164 | # Rope project settings 165 | .ropeproject 166 | 167 | # mkdocs documentation 168 | /site 169 | 170 | # mypy 171 | .mypy_cache/ 172 | .dmypy.json 173 | dmypy.json 174 | 175 | # Pyre type checker 176 | .pyre/ 177 | 178 | # pytype static type analyzer 179 | .pytype/ 180 | 181 | # Cython debug symbols 182 | cython_debug/ 183 | 184 | # PyCharm 185 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 186 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 187 | # and can be added to the global gitignore or merged into this file. For a more nuclear 188 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 189 | #.idea/ 190 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | | Version | Supported | 6 | |---------|--------------------| 7 | | 3.x | :white_check_mark: | 8 | | < 3.0 | :x: | 9 | 10 | ## Reporting a Vulnerability 11 | 12 | We take security vulnerabilities in our open-source project seriously and appreciate responsible disclosure from users. 13 | 14 | If you believe you have found a security vulnerability in our project, please report it to us by creating a Github issue with the label "security" or "vulnerability". Please do not publicly disclose the vulnerability until it has been addressed by our project team. 15 | 16 | We will acknowledge receipt of your vulnerability report and will keep you informed of our progress in addressing the vulnerability. If you would like to communicate with us about the vulnerability, please email [imonenext at gmail dot com]. 17 | 18 | We will not take legal action against users who report vulnerabilities in good faith and in accordance with this disclosure policy. 19 | 20 | Thank you for helping us keep our open-source project secure! 21 | -------------------------------------------------------------------------------- /assets/logo_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/assets/logo_new.png -------------------------------------------------------------------------------- /assets/openchat-3.6-20240522.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/assets/openchat-3.6-20240522.png -------------------------------------------------------------------------------- /assets/openchat-bench-0106.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/assets/openchat-bench-0106.png -------------------------------------------------------------------------------- /assets/openchat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/assets/openchat.png -------------------------------------------------------------------------------- /assets/openchat_grok.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/assets/openchat_grok.png -------------------------------------------------------------------------------- /assets/vicuna_gpt35.svg: -------------------------------------------------------------------------------- 1 | 111.0108.1105.298.995.578.670.9020406080100OpenChatOpenChat8192OpenCoderPlusVicunaBardAlpacaLLaMA-13BAverage Score (%) -------------------------------------------------------------------------------- /assets/vicuna_gpt4.svg: -------------------------------------------------------------------------------- 1 | 106.6105.7102.592.489.565.751.0020406080100OpenChat8192OpenChatOpenCoderPlusVicunaBardAlpacaLLaMA-13BAverage Score (%) -------------------------------------------------------------------------------- /docker/serving/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.11-slim-bookworm 2 | 3 | ######### Setup system 4 | 5 | RUN mkdir /workspace && mkdir /workspace/transformers_cache 6 | WORKDIR /workspace 7 | 8 | ENV HF_HOME /workspace/transformers_cache 9 | 10 | ######### Install system dependencies 11 | RUN apt update && apt install -y git bash curl wget libxml2 12 | 13 | # Install ssh server, remove all pre-generated ssh host keys, and disable password auth 14 | RUN apt install -y openssh-server && \ 15 | rm -f /etc/ssh/ssh_host_* && \ 16 | sed -i 's/#PasswordAuthentication yes/PasswordAuthentication no/g' /etc/ssh/sshd_config 17 | 18 | # Install CUDA (for FlashAttention 2) 19 | RUN wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run && \ 20 | chmod +x cuda_installer && \ 21 | ./cuda_installer --silent --toolkit --override && \ 22 | rm -f cuda_installer 23 | 24 | ######### Install OpenChat 25 | # Install OpenChat 26 | RUN pip3 install ninja packaging torch 27 | RUN pip3 install ochat 28 | 29 | ######### Install Cloudflared 30 | RUN wget -q --show-progress --progress=bar:force:noscroll -O /cloudflared https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 && chmod +x /cloudflared 31 | 32 | ######### Startup script 33 | 34 | COPY start.sh /start.sh 35 | ENTRYPOINT ["/start.sh"] 36 | -------------------------------------------------------------------------------- /docker/serving/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # start ssh server 4 | if [ -n "$PUBLIC_KEY" ]; then 5 | mkdir -p ~/.ssh 6 | echo "$PUBLIC_KEY" >> ~/.ssh/authorized_keys 7 | chmod 700 -R ~/.ssh 8 | 9 | dpkg-reconfigure openssh-server # generate ssh keys 10 | service ssh start 11 | fi 12 | 13 | # start cloudflare tunnel 14 | if [ -n "$CLOUDFLARED_TUNNEL_ARGS" ]; then 15 | /cloudflared $CLOUDFLARED_TUNNEL_ARGS & 16 | fi 17 | 18 | # start openchat server 19 | python3 -m ochat.serving.openai_api_server --model $MODEL --host 127.0.0.1 --port 18888 --engine-use-ray --worker-use-ray --disable-log-requests --disable-log-stats $ARGS & 20 | 21 | wait 22 | -------------------------------------------------------------------------------- /ochat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/ochat/__init__.py -------------------------------------------------------------------------------- /ochat/config/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import transformers 5 | 6 | from ochat.config.model_config import ModelConfig 7 | from ochat.config.conversation_template import Message, Conversation, ConversationTemplate 8 | import ochat.models 9 | 10 | 11 | _GEMMA_IT_PREFIXES = { 12 | "user": "user", 13 | "assistant": "model" 14 | } 15 | 16 | 17 | def _v3_2_role_prefix(from_role, condition): 18 | return f"{condition} {from_role.title()}:".strip() 19 | 20 | 21 | def _v3_6_role_prefix(from_role, condition, role_start_token, role_end_token): 22 | return role_start_token + f"{condition} {from_role.title()}".strip() + role_end_token 23 | 24 | 25 | MODEL_CONFIG_MAP = { 26 | # OpenChat V3.6 (llama 3) 27 | "openchat_3.6": ModelConfig( 28 | # Model 29 | model_max_context=8192, 30 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=True), # Llama 3 only has fast tokenizer 31 | model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained, 32 | low_cpu_mem_usage=True, 33 | torch_dtype=torch.bfloat16), 34 | # Conversation Template 35 | conversation_template=partial(ConversationTemplate, 36 | role_prefix=partial(_v3_6_role_prefix, 37 | role_start_token="<|start_header_id|>", 38 | role_end_token="<|end_header_id|>\n\n"), 39 | eot="<|eot_id|>", 40 | system_as_role=True, 41 | inference_condition="GPT4 Correct"), 42 | hf_chat_template="{{ bos_token }}{% for message in messages %}{% if message['role'] in ['user', 'assistant'] %}{% set content = '<|start_header_id|>GPT4 Correct ' + message['role'].title() + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{% elif message['role'] == 'system' %}{% set content = '<|start_header_id|>System<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{% else %}{{ raise_exception('Only user, assistant and system roles are supported!') }}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n' }}{% endif %}", 43 | ), 44 | 45 | # OpenChat V3.2 46 | "openchat_v3.2": ModelConfig( 47 | # Model 48 | model_max_context=4096, 49 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False), 50 | model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained, 51 | low_cpu_mem_usage=True, 52 | torch_dtype=torch.bfloat16), 53 | 54 | # Conversation Template 55 | conversation_template=partial(ConversationTemplate, 56 | role_prefix=_v3_2_role_prefix, 57 | eot="<|end_of_turn|>", 58 | inference_condition="GPT4") 59 | ), 60 | 61 | "openchat_v3.2_mistral": ModelConfig( 62 | serving_aliases=("openchat_3.5", ), 63 | 64 | # Model 65 | model_max_context=8192, 66 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=True), 67 | model_create_for_training=partial(ochat.models.MistralForCausalLM.from_pretrained, 68 | low_cpu_mem_usage=True, 69 | torch_dtype=torch.bfloat16), 70 | 71 | # Conversation Template 72 | conversation_template=partial(ConversationTemplate, 73 | role_prefix=_v3_2_role_prefix, 74 | eot="<|end_of_turn|>", 75 | inference_condition="GPT4 Correct"), 76 | hf_chat_template="{{ bos_token }}{% for message in messages %}{% if message['role'] in ['user', 'assistant'] %}{% set content = 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>' %}{% elif message['role'] == 'system' %}{% set content = message['content'] + '<|end_of_turn|>' %}{% else %}{{ raise_exception('Only user, assistant and system roles are supported!') }}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", ), 77 | 78 | "openchat_v3.2_gemma_new": ModelConfig( 79 | serving_aliases=("openchat_3.5_gemma_new", ), 80 | 81 | # Model 82 | model_max_context=8192, 83 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=True), 84 | model_create_for_training=partial(ochat.models.GemmaForCausalLM.from_pretrained, 85 | low_cpu_mem_usage=True, 86 | torch_dtype=torch.bfloat16), 87 | 88 | # Conversation Template 89 | conversation_template=partial(ConversationTemplate, 90 | role_prefix=_v3_2_role_prefix, 91 | eot="", 92 | inference_condition="GPT4 Correct"), 93 | hf_chat_template="{{ bos_token }}{% for message in messages %}{% if message['role'] in ['user', 'assistant'] %}{% set content = 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '' %}{% elif message['role'] == 'system' %}{% set content = message['content'] + '' %}{% else %}{{ raise_exception('Only user, assistant and system roles are supported!') }}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", 94 | ), 95 | 96 | ### Other models 97 | "chatml_8192": ModelConfig( 98 | # Model 99 | model_max_context=8192, 100 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=True), 101 | model_create_for_training=lambda x: None, 102 | 103 | # Conversation Template 104 | conversation_template=partial(ConversationTemplate, 105 | role_prefix=lambda from_role, condition: f"<|im_start|>{from_role}\n", 106 | eot="<|im_end|>", 107 | inference_condition="") 108 | ), 109 | "zephyr_mistral": ModelConfig( 110 | # Model 111 | model_max_context=8192, 112 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False), 113 | model_create_for_training=partial(ochat.models.MistralForCausalLM.from_pretrained, 114 | low_cpu_mem_usage=True, 115 | torch_dtype=torch.bfloat16), 116 | 117 | # Conversation Template 118 | conversation_template=partial(ConversationTemplate, 119 | role_prefix=lambda from_role, condition: f"<|{from_role}|>\n", 120 | eot="", 121 | inference_condition="") 122 | ), 123 | "gemma_it": ModelConfig( 124 | # Model 125 | model_max_context=8192, 126 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False), 127 | model_create_for_training=partial(ochat.models.GemmaForCausalLM.from_pretrained, 128 | low_cpu_mem_usage=True, 129 | torch_dtype=torch.bfloat16), 130 | 131 | # Conversation Template 132 | conversation_template=partial(ConversationTemplate, 133 | role_prefix=lambda from_role, condition: f"{_GEMMA_IT_PREFIXES[from_role]}\n", 134 | eot="", 135 | inference_condition="") 136 | ), 137 | "llama3_instruct": ModelConfig( 138 | # Model 139 | model_max_context=8192, 140 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=True), # Llama 3 only has fast tokenizer 141 | model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained, 142 | low_cpu_mem_usage=True, 143 | torch_dtype=torch.bfloat16), 144 | 145 | # Conversation Template 146 | conversation_template=partial(ConversationTemplate, 147 | role_prefix=lambda from_role, condition: f"<|start_header_id|>{from_role}<|end_header_id|>\n\n", 148 | eot="<|eot_id|>", 149 | inference_condition="") 150 | ), 151 | } 152 | -------------------------------------------------------------------------------- /ochat/config/conversation_template.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, Iterable, List 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class Message(BaseModel): 7 | role: str 8 | content: str 9 | 10 | weight: Optional[float] = None 11 | 12 | 13 | class Conversation(BaseModel): 14 | items: List[Message] 15 | 16 | condition: str = "" 17 | system: str = "" 18 | 19 | 20 | class ConversationTemplate(BaseModel): 21 | tokenizer: Callable 22 | 23 | # Prompt 24 | role_prefix: Callable 25 | system_as_role: Optional[bool] = False 26 | eot: str 27 | 28 | inference_condition: Optional[str] = None 29 | 30 | # Private 31 | bos_tokens_: List[int] 32 | eot_tokens_: List[int] 33 | 34 | def __init__(self, **data): 35 | tokenizer = data["tokenizer"] 36 | eot = data["eot"] 37 | bos_tokens_ = tokenizer("").input_ids 38 | eot_tokens_ = tokenizer(eot, add_special_tokens=False).input_ids 39 | super().__init__(**data, bos_tokens_=bos_tokens_, eot_tokens_=eot_tokens_) 40 | 41 | def _tokenize(self, strings: Iterable[str], ignore_special: bool = True) -> List[List[int]]: 42 | if self.tokenizer.is_fast: 43 | # Support for fast tokenizer 44 | # https://github.com/huggingface/tokenizers/pull/1419 45 | self.tokenizer._tokenizer.encode_special_tokens = ignore_special 46 | result = self.tokenizer(strings, return_attention_mask=False, add_special_tokens=False).input_ids 47 | self.tokenizer._tokenizer.encode_special_tokens = False 48 | else: 49 | result = self.tokenizer(strings, split_special_tokens=ignore_special, return_attention_mask=False, add_special_tokens=False).input_ids 50 | 51 | return result 52 | 53 | def tokenize_conversations(self, conversations: Iterable[Conversation], inference: bool = False, seq_level_weight: bool = False): 54 | # Pre-tokenize all conversations 55 | default_condition = self.inference_condition if inference else "" 56 | 57 | sys_mappings = set() 58 | role_mappings = set() 59 | all_text = [] 60 | for conv in conversations: 61 | sys_mappings.add(conv.system) 62 | for msg in conv.items: 63 | role_mappings.add((msg.role, conv.condition or default_condition)) 64 | all_text.append(msg.content) 65 | 66 | system_role_tokens = [] 67 | if self.system_as_role: 68 | system_role_tokens = self._tokenize(self.role_prefix("system", ""), ignore_special=False) 69 | 70 | sys_mappings = list(sys_mappings) 71 | role_mappings = list(role_mappings) 72 | 73 | sys_mappings = dict(zip(sys_mappings, self._tokenize(sys_mappings))) 74 | role_mappings = dict(zip(role_mappings, self._tokenize([self.role_prefix(*args) for args in role_mappings], ignore_special=False))) 75 | all_text = self._tokenize(all_text) 76 | 77 | # Convert 78 | result_tokens = [] 79 | result_weights = [] 80 | all_text_idx = 0 81 | for conv in conversations: 82 | tokens = [] 83 | weights = [] 84 | 85 | # bos tokens 86 | tokens.extend(self.bos_tokens_) 87 | weights.extend([0.] * len(self.bos_tokens_)) 88 | 89 | # System 90 | if conv.system: 91 | tokens.extend(system_role_tokens) 92 | weights.extend([0.] * len(system_role_tokens)) 93 | 94 | system = sys_mappings[conv.system] 95 | tokens.extend(system) 96 | weights.extend([0.] * len(system)) 97 | 98 | tokens.extend(self.eot_tokens_) 99 | weights.extend([0.] * len(self.eot_tokens_)) 100 | 101 | # Messages 102 | last_idx = len(conv.items) - 1 103 | for idx, msg in enumerate(conv.items): 104 | # Role Prefix 105 | role = role_mappings[(msg.role, conv.condition or default_condition)] 106 | tokens.extend(role) 107 | weights.extend([0.] * len(role)) 108 | 109 | # Message 110 | text = all_text[all_text_idx] 111 | all_text_idx += 1 112 | 113 | # weight 114 | w = None 115 | if not inference: 116 | assert msg.weight is not None 117 | 118 | w = msg.weight 119 | if seq_level_weight: 120 | w /= len(text) + len(self.eot_tokens_) 121 | 122 | # Message tokens 123 | tokens.extend(text) 124 | weights.extend([w] * len(text)) 125 | 126 | if not (inference and idx == last_idx): # Do not add EOT on last turn during inference 127 | tokens.extend(self.eot_tokens_) 128 | weights.extend([w] * len(self.eot_tokens_)) 129 | 130 | # Append result 131 | result_tokens.append(tokens) 132 | result_weights.append(weights) 133 | 134 | # Sanity check 135 | assert all_text_idx == len(all_text) 136 | 137 | return result_tokens, result_weights 138 | -------------------------------------------------------------------------------- /ochat/config/model_config.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable, Iterable 2 | 3 | from pydantic import BaseModel, ConfigDict 4 | 5 | 6 | class ModelConfig(BaseModel): 7 | # Alias 8 | serving_aliases: Iterable[str] = () 9 | 10 | # Model 11 | model_max_context: int 12 | model_tokenizer_create: Callable 13 | model_create_for_training: Callable 14 | 15 | # conversation template 16 | conversation_template: Callable 17 | hf_chat_template: Optional[str] = None 18 | 19 | model_config = ConfigDict(protected_namespaces=()) # Disables warnings for the model_ namespace used above 20 | -------------------------------------------------------------------------------- /ochat/data/generate_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate training data based on conversations 3 | 4 | Usage: python -m ochat.data.generate_data --in-file sharegpt_gpt4.jsonl --tokenizer-name HF_REPO_NAME --out-dir . 5 | """ 6 | 7 | import argparse 8 | import os 9 | import gc 10 | import random 11 | 12 | import ray 13 | import orjson 14 | import pyarrow 15 | from pyarrow import parquet 16 | 17 | 18 | PAD_TOKEN_ID = 0 19 | 20 | 21 | def _split(a, n): 22 | # Split list a to n chunks 23 | # https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length 24 | k, m = divmod(len(a), n) 25 | return [a[i*k+min(i, m): (i+1)*k+min(i+1, m)] for i in range(n)] 26 | 27 | 28 | def truncate_trailing_zero_weighted(tokens, weights): 29 | non_zero_index = len(weights) - 1 30 | while non_zero_index >= 0 and weights[non_zero_index] == 0: 31 | non_zero_index -= 1 32 | 33 | return tokens[:non_zero_index + 1], weights[:non_zero_index + 1] 34 | 35 | 36 | def add_single_conv(output, tokens, weights): 37 | # truncate trailing zero weighted tokens 38 | tokens, weights = truncate_trailing_zero_weighted(tokens, weights) 39 | if not tokens: 40 | return 41 | 42 | # labels 43 | length = len(tokens) 44 | labels = [(t if w != 0 else PAD_TOKEN_ID) for t, w in zip(tokens, weights)] 45 | 46 | # populate results 47 | results = { 48 | "total_length": length, 49 | 50 | "seqlens": [length], 51 | "nz_input_ids": tokens, 52 | "nz_position_ids": list(range(length)), 53 | 54 | "nz_shifted_label_ids": labels[1:] + [PAD_TOKEN_ID], 55 | "nz_shifted_loss_weights": weights[1:] + [0.0] 56 | } 57 | results["num_seqs"] = sum(results["nz_shifted_loss_weights"]) 58 | 59 | for k, v in results.items(): 60 | output[k].append(v) 61 | 62 | 63 | @ray.remote 64 | def convert_conversation_batch(model_type: str, model_path: str, batch: list, schema: pyarrow.Schema, per_sequence_loss: bool): 65 | from ochat.config import MODEL_CONFIG_MAP, Conversation 66 | 67 | # Tokenization 68 | model_config = MODEL_CONFIG_MAP[model_type] 69 | tokenizer = model_config.model_tokenizer_create(model_path) 70 | conv_template = model_config.conversation_template(tokenizer=tokenizer) 71 | 72 | # Decode data 73 | print ("Decoding JSON ...") 74 | batch = [Conversation(**orjson.loads(json_line)) for json_line in batch] 75 | 76 | # Tokenize 77 | print ("Tokenizing ...") 78 | tokens_list, weights_list = conv_template.tokenize_conversations(batch, inference=False, seq_level_weight=per_sequence_loss) 79 | 80 | del batch 81 | gc.collect() 82 | 83 | # Generate data 84 | print ("Generating ...") 85 | max_context = model_config.model_max_context 86 | 87 | outputs = {k: [] for k in schema.names} 88 | for tokens, weights in zip(tokens_list, weights_list): 89 | assert len(tokens) == len(weights) 90 | 91 | # Truncate to specified tokens 92 | tokens = tokens[:max_context] 93 | weights = weights[:max_context] 94 | 95 | # Add to results 96 | add_single_conv(outputs, tokens, weights) 97 | 98 | del tokens_list, weights_list 99 | gc.collect() 100 | 101 | print ("To table ...") 102 | table = pyarrow.Table.from_pydict(outputs, schema=schema) 103 | 104 | del outputs 105 | gc.collect() 106 | 107 | print ("Chunk finish") 108 | return table 109 | 110 | 111 | def generate_epoch(seed: int, model_type: str, model_path: str, in_filename: str, out_filename: str, per_sequence_loss: bool): 112 | # schema 113 | metadata = { 114 | "model_type": model_type 115 | } 116 | schema = [ 117 | pyarrow.field("total_length", pyarrow.int32()), 118 | pyarrow.field("num_seqs", pyarrow.float32()), 119 | 120 | pyarrow.field(f"seqlens", pyarrow.list_(pyarrow.int32())), 121 | pyarrow.field(f"nz_input_ids", pyarrow.list_(pyarrow.int32())), 122 | pyarrow.field(f"nz_position_ids", pyarrow.list_(pyarrow.int32())), 123 | pyarrow.field(f"nz_shifted_label_ids", pyarrow.list_(pyarrow.int32())), 124 | pyarrow.field(f"nz_shifted_loss_weights", pyarrow.list_(pyarrow.float32())) 125 | ] 126 | 127 | schema = pyarrow.schema(schema, metadata={"metadata_json": orjson.dumps(metadata)}) 128 | 129 | # Load data 130 | with open(in_filename, "rb") as f: 131 | batches = f.readlines() 132 | 133 | random.seed(seed) # Randomized load balancing 134 | random.shuffle(batches) 135 | 136 | batches = _split(batches, int(ray.available_resources()["CPU"])) 137 | 138 | # launch remote workers 139 | handles = [convert_conversation_batch.remote( 140 | model_type=model_type, # type: ignore 141 | model_path=model_path, 142 | batch=batch, 143 | schema=schema, 144 | per_sequence_loss=per_sequence_loss 145 | ) for batch in batches] 146 | 147 | # write 148 | parquet.write_table(pyarrow.concat_tables([ray.get(handle) for handle in handles]), out_filename) 149 | 150 | 151 | def generate_dataset(model_type, model_path, in_prefix, out_prefix, per_sequence_loss, seed): 152 | # Initialize Ray 153 | if not ray.is_initialized(): 154 | ray.init(ignore_reinit_error=True, num_cpus=os.cpu_count()) 155 | 156 | # Load epochs and tokenize 157 | epoch = 0 158 | while True: 159 | in_filename = f"{in_prefix}.{epoch}.jsonl" 160 | if not os.path.exists(in_filename): 161 | break 162 | 163 | out_filename = f"{out_prefix}.{epoch}.parquet" 164 | generate_epoch( 165 | seed=seed + epoch, 166 | model_type=model_type, 167 | model_path=model_path, 168 | in_filename=in_filename, 169 | out_filename=out_filename, 170 | per_sequence_loss=per_sequence_loss 171 | ) 172 | gc.collect() 173 | 174 | epoch += 1 175 | 176 | 177 | if __name__ == "__main__": 178 | parser = argparse.ArgumentParser() 179 | parser.add_argument("--model-type", type=str, required=True) 180 | parser.add_argument("--model-path", type=str, required=True) 181 | 182 | parser.add_argument("--in-prefix", type=str, required=True) 183 | parser.add_argument("--out-prefix", type=str, required=True) 184 | 185 | parser.add_argument("--per-sequence-loss", action="store_true") 186 | parser.add_argument("--seed", type=int, default=42) 187 | args = parser.parse_args() 188 | 189 | generate_dataset(**vars(args)) 190 | -------------------------------------------------------------------------------- /ochat/evaluation/README.md: -------------------------------------------------------------------------------- 1 | # vLLM Eval 2 | 3 | **Working in progress... Stay tuned!** 4 | 5 | 🚀 Efficiently evaluate large language models (LLMs) in just 5 minutes (full eval suite) ! 6 | 7 | ## Suite Included 8 | 9 | ### Zero-shot Multiple-Choice 10 | 11 | - AGIEval 12 | - BBH 13 | - TruthfulQA 14 | 15 | ### Few-shot CoT 16 | 17 | - BBH 18 | - GSM8k 19 | -------------------------------------------------------------------------------- /ochat/evaluation/conv_eval.py: -------------------------------------------------------------------------------- 1 | from typing import OrderedDict 2 | import signal 3 | import os 4 | import json 5 | import subprocess 6 | import argparse 7 | import time 8 | import requests 9 | import re 10 | import coolname 11 | 12 | 13 | MAX_CONTEXT = 4096 14 | 15 | 16 | def find_models(path, prefix, ep_filter): 17 | run_name = '_'.join(coolname.generate(2)) 18 | 19 | def generate_model_name(root, ep_number): 20 | return f"{prefix}{os.path.basename(root)}_ep{ep_number}_{run_name}" 21 | 22 | models = {} 23 | for root, dirs, _ in os.walk(path): 24 | for d in dirs: 25 | ep_match = re.match(r"ep_(\d+)", d) 26 | if not ep_match: 27 | continue 28 | 29 | if ep_filter and ep_match.group(1) != ep_filter: 30 | continue 31 | 32 | model_name = generate_model_name(root, ep_match.group(1)) 33 | models[model_name] = os.path.join(root, d) 34 | 35 | # Sort and return the models dictionary as an OrderedDict 36 | return OrderedDict(sorted(models.items(), reverse=True, key=lambda x: x[0].split("_ep")[::-1])) 37 | 38 | 39 | def run_mt_bench(mt_bench_path, model_name): 40 | working_dir = os.path.join(mt_bench_path, "fastchat", "llm_judge") 41 | 42 | # Skip if result exists 43 | if os.path.exists(os.path.join(working_dir, "data", "mt_bench", "model_answer", f"{model_name}.jsonl")): 44 | return 45 | 46 | # run mt bench 47 | commands = [ 48 | f"python gen_api_answer.py --model {model_name} --max-tokens {MAX_CONTEXT} --parallel 128 --openai-api-base http://localhost:18888/v1", 49 | f"python gen_judgment.py --model-list {model_name} --parallel 8 --mode single", 50 | # f"python gen_judgment.py --model-list {model_name} --parallel 8 --mode pairwise-baseline", 51 | ] 52 | 53 | env = os.environ.copy() 54 | env["PYTHONPATH"] = f'{env.get("PYTHONPATH", "")}:{mt_bench_path}' 55 | 56 | for command in commands: 57 | subprocess.run(command, shell=True, cwd=working_dir, env=env) 58 | 59 | 60 | def run_vicuna_bench(mt_bench_path, model_name): 61 | working_dir = os.path.join(mt_bench_path, "fastchat", "llm_judge") 62 | 63 | # Skip if result exists 64 | if os.path.exists(os.path.join(working_dir, "data", "vicuna_bench", "model_answer", f"{model_name}.jsonl")): 65 | return 66 | 67 | # run mt bench 68 | commands = [ 69 | f"python gen_api_answer.py --model {model_name} --max-tokens {MAX_CONTEXT} --parallel 128 --openai-api-base http://localhost:18888/v1 --bench-name vicuna_bench", 70 | f"python gen_judgment.py --model-list {model_name} --parallel 8 --mode pairwise-baseline --bench-name vicuna_bench", 71 | ] 72 | 73 | env = os.environ.copy() 74 | env["PYTHONPATH"] = f'{env.get("PYTHONPATH", "")}:{mt_bench_path}' 75 | 76 | for command in commands: 77 | subprocess.run(command, shell=True, cwd=working_dir, env=env) 78 | 79 | 80 | def create_alpaca_eval_config(alpacaeval_path, model_name): 81 | config_dir = os.path.join(alpacaeval_path, "src", "alpaca_eval", "models_configs", model_name.lower()) 82 | os.makedirs(config_dir, exist_ok=True) 83 | config_path = os.path.join(config_dir, "configs.yaml") 84 | 85 | config_content = f"""{model_name.lower()}: 86 | prompt_template: "openchat-13b/prompt.txt" 87 | fn_completions: "openai_completions" 88 | completions_kwargs: 89 | openai_api_base: http://127.0.0.1:18888/v1 90 | requires_chatml: True 91 | sleep_time: 0 92 | 93 | model_name: "{model_name.lower()}" 94 | max_tokens: {MAX_CONTEXT} 95 | 96 | top_p: 1.0 97 | temperature: 0.7 98 | 99 | num_procs: 128 100 | 101 | pretty_name: "{model_name}" 102 | link: "https://github.com/imoneoi/openchat" 103 | """ 104 | 105 | with open(config_path, "w") as f: 106 | f.write(config_content) 107 | 108 | 109 | def run_alpaca_eval(alpacaeval_path, model_name): 110 | # Skip if result exists 111 | if os.path.exists(os.path.join(alpacaeval_path, "results", model_name.lower())): 112 | return 113 | 114 | # Create config 115 | create_alpaca_eval_config(alpacaeval_path, model_name) 116 | 117 | # Run 118 | command = f"python -m alpaca_eval.main evaluate_from_model --model_configs {model_name.lower()} --annotators_config alpaca_eval_gpt4" 119 | 120 | env = os.environ.copy() 121 | env["PYTHONPATH"] = f'{env.get("PYTHONPATH", "")}:{os.path.join(alpacaeval_path, "src")}' 122 | 123 | subprocess.run(command, shell=True, cwd=alpacaeval_path, env=env) 124 | 125 | 126 | def wait_for_server(url): 127 | while True: 128 | try: 129 | response = requests.get(url) 130 | if response.status_code in [200, 404]: 131 | break 132 | except requests.exceptions.RequestException: 133 | pass 134 | 135 | time.sleep(1) 136 | 137 | 138 | def main(path, prefix, ep_filter, mt_bench_path, alpacaeval_path): 139 | models = find_models(path, prefix, ep_filter) 140 | 141 | for i, (model_name, model_path) in enumerate(models.items()): 142 | print(f"Processing model {i + 1}/{len(models)}: {model_name}") 143 | 144 | print("Starting server...") 145 | server_command = f"python -m ochat.serving.openai_api_server --model {model_path} --engine-use-ray --worker-use-ray" 146 | server_process = subprocess.Popen(server_command, shell=True, preexec_fn=os.setsid) 147 | 148 | wait_for_server("http://127.0.0.1:18888/v1") 149 | print("Server is ready.") 150 | 151 | print("Running MT-bench...") 152 | run_mt_bench(mt_bench_path, model_name) 153 | 154 | print("Running AlpacaEval...") 155 | # run_alpaca_eval(alpacaeval_path, model_name) 156 | 157 | # print("Running Vicuna-bench") 158 | # run_vicuna_bench(mt_bench_path, model_name) 159 | 160 | print("Terminating server...") 161 | os.killpg(os.getpgid(server_process.pid), signal.SIGTERM) 162 | server_process.wait() 163 | 164 | 165 | if __name__ == "__main__": 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument("--path", default="/ML-A100/home/csj/trained_models/openchat_mistral/1017", help="Path to the models directory") 168 | parser.add_argument("--prefix", default="gpt4correct_") 169 | parser.add_argument("--ep_filter", default=None, help="Filter epochs") 170 | 171 | parser.add_argument("--mt_bench_path", default="/ML-A100/home/csj/one_benchmarks/FastChat", help="Path to the MT-bench directory") 172 | parser.add_argument("--alpacaeval_path", default="/ML-A100/home/csj/one_benchmarks/alpaca_eval", help="Path to the AlpacaEval directory") 173 | args = parser.parse_args() 174 | 175 | main(**vars(args)) 176 | -------------------------------------------------------------------------------- /ochat/evaluation/convert_to_evalplus.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import orjson 4 | 5 | from glob import glob 6 | 7 | 8 | def convert_to_evalplus(results_path: str, output_path: str): 9 | os.makedirs(output_path, exist_ok=True) 10 | 11 | for filename in glob(os.path.join(results_path, "*.json")): 12 | # read eval results 13 | with open(filename, "rb") as f: 14 | data = orjson.loads(f.read()) 15 | 16 | # humaneval 17 | result = bytearray() 18 | for item in data: 19 | if item["task_type"] == "coding/humaneval": 20 | result.extend(orjson.dumps(item["answer"])) 21 | result.extend(b"\n") 22 | 23 | with open(os.path.join(output_path, os.path.splitext(os.path.basename(filename))[0] + ".jsonl"), "wb") as f: 24 | f.write(result) 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | 30 | # Input / output 31 | parser.add_argument("--results_path", type=str, default="ochat/evaluation/eval_results") 32 | parser.add_argument("--output_path", type=str, default="ochat/evaluation/evalplus_codegen") 33 | args = parser.parse_args() 34 | 35 | convert_to_evalplus(**vars(args)) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /ochat/evaluation/grading/math_grader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Answer checker API that uses sympy to simplify expressions and check for equality. 3 | 4 | Call grade_answer(given_answer: str, ground_truth: str). 5 | """ 6 | import re 7 | import sympy 8 | from pylatexenc import latex2text 9 | from sympy.parsing import sympy_parser 10 | 11 | from ochat.evaluation.grading import math_normalize 12 | 13 | 14 | # sympy might hang -- we don't care about trying to be lenient in these cases 15 | BAD_SUBSTRINGS = ["^{", "^("] 16 | BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"] 17 | TUPLE_CHARS = "()[]" 18 | 19 | 20 | def _sympy_parse(expr: str): 21 | """Parses an expression with sympy.""" 22 | py_expr = expr.replace("^", "**") 23 | return sympy_parser.parse_expr( 24 | py_expr, 25 | transformations=( 26 | sympy_parser.standard_transformations 27 | + (sympy_parser.implicit_multiplication_application,) 28 | ), 29 | ) 30 | 31 | 32 | def _parse_latex(expr: str) -> str: 33 | """Attempts to parse latex to an expression sympy can read.""" 34 | expr = expr.replace("\\tfrac", "\\frac") 35 | expr = expr.replace("\\dfrac", "\\frac") 36 | expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers. 37 | expr = latex2text.LatexNodes2Text().latex_to_text(expr) 38 | 39 | # Replace the specific characters that this parser uses. 40 | expr = expr.replace("√", "sqrt") 41 | expr = expr.replace("π", "pi") 42 | expr = expr.replace("∞", "inf") 43 | expr = expr.replace("∪", "U") 44 | expr = expr.replace("·", "*") 45 | expr = expr.replace("×", "*") 46 | 47 | return expr.strip() 48 | 49 | 50 | def _is_float(num: str) -> bool: 51 | try: 52 | float(num) 53 | return True 54 | except ValueError: 55 | return False 56 | 57 | 58 | def _is_int(x: float) -> bool: 59 | try: 60 | return abs(x - int(round(x))) <= 1e-7 61 | except: 62 | return False 63 | 64 | 65 | def _is_frac(expr: str) -> bool: 66 | return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr)) 67 | 68 | 69 | def _str_is_int(x: str) -> bool: 70 | try: 71 | x = _strip_properly_formatted_commas(x) 72 | x = float(x) 73 | return abs(x - int(round(x))) <= 1e-7 74 | except: 75 | return False 76 | 77 | 78 | def _str_to_int(x: str) -> bool: 79 | x = x.replace(",", "") 80 | x = float(x) 81 | return int(x) 82 | 83 | 84 | def _inject_implicit_mixed_number(step: str): 85 | """ 86 | Automatically make a mixed number evalable 87 | e.g. 7 3/4 => 7+3/4 88 | """ 89 | p1 = re.compile("([0-9]) +([0-9])") 90 | step = p1.sub("\\1+\\2", step) ## implicit mults 91 | return step 92 | 93 | 94 | def _strip_properly_formatted_commas(expr: str): 95 | # We want to be careful because we don't want to strip tuple commas 96 | p1 = re.compile("(\d)(,)(\d\d\d)($|\D)") 97 | while True: 98 | next_expr = p1.sub("\\1\\3\\4", expr) 99 | if next_expr == expr: 100 | break 101 | expr = next_expr 102 | return next_expr 103 | 104 | 105 | def _normalize(expr: str) -> str: 106 | """Normalize answer expressions.""" 107 | if expr is None: 108 | return None 109 | 110 | # Remove enclosing `\text{}`. 111 | m = re.search("^\\\\text\{(?P.+?)\}$", expr) 112 | if m is not None: 113 | expr = m.group("text") 114 | 115 | expr = expr.replace("\\%", "%") 116 | expr = expr.replace("\\$", "$") 117 | expr = expr.replace("$", "") 118 | expr = expr.replace("%", "") 119 | expr = expr.replace(" or ", " , ") 120 | expr = expr.replace(" and ", " , ") 121 | 122 | expr = expr.replace("million", "*10^6") 123 | expr = expr.replace("billion", "*10^9") 124 | expr = expr.replace("trillion", "*10^12") 125 | 126 | for unit in [ 127 | "degree", 128 | "cm", 129 | "centimeter", 130 | "meter", 131 | "mile", 132 | "second", 133 | "minute", 134 | "hour", 135 | "day", 136 | "week", 137 | "month", 138 | "year", 139 | "foot", 140 | "feet", 141 | "inch", 142 | "yard", 143 | ]: 144 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr) 145 | expr = re.sub(f"\^ *\\\\circ", "", expr) 146 | 147 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}": 148 | expr = expr[1:-1] 149 | 150 | expr = re.sub(",\\\\! *", "", expr) 151 | if _is_float(expr) and _is_int(float(expr)): 152 | expr = str(int(round(float(expr)))) 153 | if "\\" in expr: 154 | try: 155 | expr = _parse_latex(expr) 156 | except: 157 | pass 158 | 159 | # edge case with mixed numbers and negative signs 160 | expr = re.sub("- *", "-", expr) 161 | 162 | expr = _inject_implicit_mixed_number(expr) 163 | expr = expr.replace(" ", "") 164 | 165 | # if we somehow still have latex braces here, just drop them 166 | expr = expr.replace("{", "") 167 | expr = expr.replace("}", "") 168 | 169 | # don't be case sensitive for text answers 170 | expr = expr.lower() 171 | 172 | if _str_is_int(expr): 173 | expr = str(_str_to_int(expr)) 174 | 175 | return expr 176 | 177 | 178 | def count_unknown_letters_in_expr(expr: str): 179 | expr = expr.replace("sqrt", "") 180 | expr = expr.replace("frac", "") 181 | letters_in_expr = set([x for x in expr if x.isalpha()]) 182 | return len(letters_in_expr) 183 | 184 | 185 | def should_allow_eval(expr: str): 186 | # we don't want to try parsing unknown text or functions of more than two variables 187 | if count_unknown_letters_in_expr(expr) > 2: 188 | return False 189 | 190 | for bad_string in BAD_SUBSTRINGS: 191 | if bad_string in expr: 192 | return False 193 | 194 | for bad_regex in BAD_REGEXES: 195 | if re.search(bad_regex, expr) is not None: 196 | return False 197 | 198 | return True 199 | 200 | 201 | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str): 202 | are_equal = False 203 | try: 204 | expr = f"({ground_truth_normalized})-({given_normalized})" 205 | if should_allow_eval(expr): 206 | sympy_diff = _sympy_parse(expr) 207 | simplified = sympy.simplify(sympy_diff) 208 | if simplified == 0: 209 | are_equal = True 210 | except: 211 | pass 212 | return are_equal 213 | 214 | 215 | def split_tuple(expr: str): 216 | """ 217 | Split the elements in a tuple/interval, while handling well-formatted commas in large numbers 218 | """ 219 | expr = _strip_properly_formatted_commas(expr) 220 | if len(expr) == 0: 221 | return [] 222 | if ( 223 | len(expr) > 2 224 | and expr[0] in TUPLE_CHARS 225 | and expr[-1] in TUPLE_CHARS 226 | and all([ch not in expr[1:-1] for ch in TUPLE_CHARS]) 227 | ): 228 | elems = [elem.strip() for elem in expr[1:-1].split(",")] 229 | else: 230 | elems = [expr] 231 | return elems 232 | 233 | 234 | def grade_answer(given_answer: str, ground_truth: str) -> bool: 235 | """ 236 | The answer will be considered correct if: 237 | (a) it normalizes to the same string as the ground truth answer 238 | OR 239 | (b) sympy can simplify the difference between the expressions to 0 240 | """ 241 | if given_answer is None: 242 | return False 243 | 244 | ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth) 245 | given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer) 246 | 247 | # be at least as lenient as mathd 248 | if ground_truth_normalized_mathd == given_answer_normalized_mathd: 249 | return True 250 | 251 | ground_truth_normalized = _normalize(ground_truth) 252 | given_normalized = _normalize(given_answer) 253 | 254 | if ground_truth_normalized is None: 255 | return False 256 | 257 | if ground_truth_normalized == given_normalized: 258 | return True 259 | 260 | if len(given_normalized) == 0: 261 | return False 262 | 263 | ground_truth_elems = split_tuple(ground_truth_normalized) 264 | given_elems = split_tuple(given_normalized) 265 | 266 | if len(ground_truth_elems) > 1 and ( 267 | ground_truth_normalized[0] != given_normalized[0] 268 | or ground_truth_normalized[-1] != given_normalized[-1] 269 | ): 270 | is_correct = False 271 | elif len(ground_truth_elems) != len(given_elems): 272 | is_correct = False 273 | else: 274 | for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems): 275 | if _is_frac(ground_truth_elem) and _is_frac(given_elem): 276 | # if fractions aren't reduced, then shouldn't be marked as correct 277 | # so, we don't want to allow sympy.simplify in this case 278 | is_correct = ground_truth_elem == given_elem 279 | elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem): 280 | # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify) 281 | is_correct = False 282 | else: 283 | is_correct = are_equal_under_sympy(ground_truth_elem, given_elem) 284 | if not is_correct: 285 | break 286 | 287 | return is_correct 288 | -------------------------------------------------------------------------------- /ochat/evaluation/grading/math_normalize.py: -------------------------------------------------------------------------------- 1 | """ 2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence). 3 | """ 4 | import re 5 | from typing import Optional 6 | 7 | 8 | def normalize_answer(answer: Optional[str]) -> Optional[str]: 9 | if answer is None: 10 | return None 11 | answer = answer.strip() 12 | try: 13 | # Remove enclosing `\text{}`. 14 | m = re.search("^\\\\text\{(?P.+?)\}$", answer) 15 | if m is not None: 16 | answer = m.group("text").strip() 17 | return _strip_string(answer) 18 | except: 19 | return answer 20 | 21 | 22 | def _fix_fracs(string): 23 | substrs = string.split("\\frac") 24 | new_str = substrs[0] 25 | if len(substrs) > 1: 26 | substrs = substrs[1:] 27 | for substr in substrs: 28 | new_str += "\\frac" 29 | if substr[0] == "{": 30 | new_str += substr 31 | else: 32 | try: 33 | assert len(substr) >= 2 34 | except: 35 | return string 36 | a = substr[0] 37 | b = substr[1] 38 | if b != "{": 39 | if len(substr) > 2: 40 | post_substr = substr[2:] 41 | new_str += "{" + a + "}{" + b + "}" + post_substr 42 | else: 43 | new_str += "{" + a + "}{" + b + "}" 44 | else: 45 | if len(substr) > 2: 46 | post_substr = substr[2:] 47 | new_str += "{" + a + "}" + b + post_substr 48 | else: 49 | new_str += "{" + a + "}" + b 50 | string = new_str 51 | return string 52 | 53 | 54 | def _fix_a_slash_b(string): 55 | if len(string.split("/")) != 2: 56 | return string 57 | a = string.split("/")[0] 58 | b = string.split("/")[1] 59 | try: 60 | a = int(a) 61 | b = int(b) 62 | assert string == "{}/{}".format(a, b) 63 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" 64 | return new_string 65 | except: 66 | return string 67 | 68 | 69 | def _remove_right_units(string): 70 | # "\\text{ " only ever occurs (at least in the val set) when describing units 71 | if "\\text{ " in string: 72 | splits = string.split("\\text{ ") 73 | assert len(splits) == 2 74 | return splits[0] 75 | else: 76 | return string 77 | 78 | 79 | def _fix_sqrt(string): 80 | if "\\sqrt" not in string: 81 | return string 82 | splits = string.split("\\sqrt") 83 | new_string = splits[0] 84 | for split in splits[1:]: 85 | if split[0] != "{": 86 | a = split[0] 87 | new_substr = "\\sqrt{" + a + "}" + split[1:] 88 | else: 89 | new_substr = "\\sqrt" + split 90 | new_string += new_substr 91 | return new_string 92 | 93 | 94 | def _strip_string(string): 95 | # linebreaks 96 | string = string.replace("\n", "") 97 | # print(string) 98 | 99 | # remove inverse spaces 100 | string = string.replace("\\!", "") 101 | # print(string) 102 | 103 | # replace \\ with \ 104 | string = string.replace("\\\\", "\\") 105 | # print(string) 106 | 107 | # replace tfrac and dfrac with frac 108 | string = string.replace("tfrac", "frac") 109 | string = string.replace("dfrac", "frac") 110 | # print(string) 111 | 112 | # remove \left and \right 113 | string = string.replace("\\left", "") 114 | string = string.replace("\\right", "") 115 | # print(string) 116 | 117 | # Remove circ (degrees) 118 | string = string.replace("^{\\circ}", "") 119 | string = string.replace("^\\circ", "") 120 | 121 | # remove dollar signs 122 | string = string.replace("\\$", "") 123 | 124 | # remove units (on the right) 125 | string = _remove_right_units(string) 126 | 127 | # remove percentage 128 | string = string.replace("\\%", "") 129 | string = string.replace("\%", "") 130 | 131 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string 132 | string = string.replace(" .", " 0.") 133 | string = string.replace("{.", "{0.") 134 | # if empty, return empty string 135 | if len(string) == 0: 136 | return string 137 | if string[0] == ".": 138 | string = "0" + string 139 | 140 | # to consider: get rid of e.g. "k = " or "q = " at beginning 141 | if len(string.split("=")) == 2: 142 | if len(string.split("=")[0]) <= 2: 143 | string = string.split("=")[1] 144 | 145 | # fix sqrt3 --> sqrt{3} 146 | string = _fix_sqrt(string) 147 | 148 | # remove spaces 149 | string = string.replace(" ", "") 150 | 151 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} 152 | string = _fix_fracs(string) 153 | 154 | # manually change 0.5 --> \frac{1}{2} 155 | if string == "0.5": 156 | string = "\\frac{1}{2}" 157 | 158 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y 159 | string = _fix_a_slash_b(string) 160 | 161 | return string 162 | -------------------------------------------------------------------------------- /ochat/evaluation/match_answer.py: -------------------------------------------------------------------------------- 1 | import re 2 | import ast 3 | 4 | from ochat.evaluation.grading.math_grader import grade_answer 5 | 6 | 7 | def zs_agieval_match_answer(task_data, response): 8 | # AGIEval match first capital letter, following original paper implementation 9 | # https://github.com/microsoft/AGIEval/blob/main/src/post_process.py 10 | 11 | letter_set = {"A", "B", "C", "D", "E", "F"} 12 | for c in response: 13 | if c in letter_set: 14 | return True, c 15 | 16 | return False, "" 17 | 18 | 19 | def zs_bbh_mc_orca_truthfulqa_orca_match_answer(task_data, response): 20 | # For BBH & TruthfulQA, match first option letter 21 | 22 | for c in response: 23 | if c in task_data["options"]: 24 | return True, c 25 | 26 | return False, "" 27 | 28 | 29 | def fs_cothub_math_match_answer(task_data, response, max_length=256): 30 | def _last_boxed_only_string(string): 31 | idx = string.rfind("\\boxed") 32 | if idx < 0: 33 | idx = string.rfind("\\fbox") 34 | if idx < 0: 35 | return None 36 | 37 | i = idx 38 | left_brace_idx = None 39 | right_brace_idx = None 40 | num_left_braces_open = 0 41 | while i < len(string): 42 | if string[i] == "{": 43 | num_left_braces_open += 1 44 | if left_brace_idx is None: 45 | left_brace_idx = i 46 | elif string[i] == "}": 47 | num_left_braces_open -= 1 48 | if num_left_braces_open == 0: 49 | right_brace_idx = i 50 | break 51 | 52 | i += 1 53 | 54 | if left_brace_idx is None or right_brace_idx is None: 55 | return None 56 | 57 | return string[left_brace_idx + 1: right_brace_idx].strip() 58 | 59 | # Match true answer 60 | ground_truth_answer = _last_boxed_only_string(task_data["_metadata"]["solution"]) 61 | assert ground_truth_answer 62 | 63 | # Match model answer 64 | is_matched = False 65 | 66 | ans_line = response.split('The answer is') 67 | if len(ans_line) > 1: 68 | is_matched = True 69 | response = ans_line[-1].strip() 70 | else: 71 | ans_extracted = _last_boxed_only_string(response) 72 | if ans_extracted: 73 | is_matched = True 74 | response = ans_extracted 75 | 76 | # Grade 77 | response = response[:max_length] # To avoid sympy taking too long 78 | return is_matched, grade_answer(response, ground_truth_answer) 79 | 80 | 81 | def zs_gpqa_match_answer(task_data, response): 82 | # Expected to see answer field, otherwise return C. 83 | ans = response.split("The correct answer is") 84 | 85 | if len(ans) == 1: 86 | return False, "C" 87 | 88 | ans = ans[1] 89 | 90 | letter_set = {"A", "B", "C", "D"} 91 | for c in ans: 92 | if c in letter_set: 93 | return True, c 94 | 95 | return False, "C" 96 | 97 | 98 | def fs_cothub_bbh_match_answer(task_data, response): 99 | # CoT hub match answer for BBH 100 | # https://github.com/FranxYao/chain-of-thought-hub/blob/main/BBH/run_bbh_gpt_3.5_turbo.py 101 | 102 | ans_line = response.split('answer is ') 103 | 104 | # Expect to see 'answer is'. If not return whole string 105 | if len(ans_line) == 1: 106 | return False, response 107 | else: 108 | ans = ans_line[-1].strip() 109 | 110 | if task_data["options"]: 111 | # Multiple choice, find appearing letter 112 | options = ['(A)', '(B)', '(C)', '(D)', '(E)', '(F)', '(G)', '(H)', '(I)', '(J)', '(K)', '(L)', '(M)', '(N)', '(O)', '(P)', '(Q)', '(R)', '(S)', '(T)', '(U)', '(V)', '(W)', '(X)', '(Y)', '(Z)'] 113 | 114 | for option in options: 115 | if option in ans: 116 | return True, option 117 | 118 | return False, ans 119 | else: 120 | # Free form, direct return 121 | if len(ans) and ans[-1] == '.': 122 | ans = ans[:-1] 123 | 124 | return True, ans 125 | 126 | 127 | def fs_cothub_gsm8k_match_answer(task_data, response): 128 | # CoT hub match answer for GSM8k, match last numeric value 129 | # https://github.com/FranxYao/chain-of-thought-hub/blob/main/gsm8k/gpt3.5turbo_gsm8k_complex.ipynb 130 | 131 | pattern = '\d*\.?\d+' 132 | pred = re.findall(pattern, response) 133 | if len(pred) >= 1: 134 | return True, pred[-1] 135 | 136 | return False, response 137 | 138 | 139 | def fs_cothub_mmlu_match_answer(task_data, response): 140 | ans_line = response.split('answer is') 141 | 142 | # Expect to see 'answer is'. If not return C 143 | if len(ans_line) == 1: 144 | return False, "(C)" 145 | else: 146 | ans = ans_line[-1].strip() 147 | 148 | options = ['(A)', '(B)', '(C)', '(D)'] 149 | for option in options: 150 | if option in ans: 151 | return True, option 152 | 153 | return False, "(C)" 154 | 155 | 156 | def coding_humaneval_match_answer(task_data, response): 157 | # Matching utilities 158 | def _function_exists(code, func_name): 159 | tree = ast.parse(code) 160 | for node in ast.walk(tree): 161 | if isinstance(node, ast.FunctionDef) and node.name == func_name: 162 | return True 163 | 164 | return False 165 | 166 | def _try_match(content, prefix, entrypoint): 167 | # All markdown code blocks, as well as raw 168 | code_blocks = [m[1] for m in re.findall(r"(\`{3}.*?\n+)([\s\S]*?)(\n+\`{3})", content)] \ 169 | + [content] 170 | 171 | for block in code_blocks: 172 | # Check syntax 173 | try: 174 | code_completion = prefix + block 175 | if _function_exists(code_completion, entrypoint): 176 | return code_completion 177 | except SyntaxError: 178 | pass 179 | 180 | # Try match with include prefix 181 | humaneval_task = task_data["_metadata"] 182 | include_prefix = humaneval_task['prompt'].split('def')[0].strip() + "\n\n" 183 | 184 | result = _try_match(response, include_prefix, humaneval_task["entry_point"]) 185 | if result: 186 | return True, {"task_id": humaneval_task["task_id"], "completion": result} 187 | 188 | # If fail then match with function signature 189 | result = _try_match(response, humaneval_task["prompt"], humaneval_task["entry_point"]) 190 | if result: 191 | return True, {"task_id": humaneval_task["task_id"], "completion": result} 192 | 193 | return False, {"task_id": humaneval_task["task_id"], "completion": response} 194 | 195 | 196 | MATCH_ANSWER_FUNCTION = { 197 | "zs/agieval": zs_agieval_match_answer, 198 | "zs/bbh_mc_orca": zs_bbh_mc_orca_truthfulqa_orca_match_answer, 199 | "zs/truthfulqa_orca": zs_bbh_mc_orca_truthfulqa_orca_match_answer, 200 | "zs/gpqa": zs_gpqa_match_answer, 201 | 202 | "fs_cothub/bbh": fs_cothub_bbh_match_answer, 203 | "fs_cothub/gsm8k": fs_cothub_gsm8k_match_answer, 204 | "fs_cothub/mmlu": fs_cothub_mmlu_match_answer, 205 | "fs_cothub/math": fs_cothub_math_match_answer, 206 | 207 | "coding/humaneval": coding_humaneval_match_answer 208 | } 209 | -------------------------------------------------------------------------------- /ochat/evaluation/run_eval.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import argparse 3 | import os 4 | import asyncio 5 | from glob import glob 6 | 7 | import orjson 8 | import openai 9 | from tqdm import tqdm 10 | from openai import RateLimitError, InternalServerError, APIConnectionError 11 | from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type 12 | from vllm import LLM, SamplingParams 13 | 14 | from transformers.utils.hub import cached_file 15 | 16 | from ochat.evaluation.match_answer import MATCH_ANSWER_FUNCTION 17 | from ochat.config import MODEL_CONFIG_MAP 18 | 19 | 20 | def _strip_first_space(s: str): 21 | if len(s) and s[0] == " ": 22 | return s[1:] 23 | return s 24 | 25 | 26 | @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(20), retry=retry_if_exception_type((RateLimitError, InternalServerError, APIConnectionError, ))) 27 | async def _chat_completion_with_backoff(client, **kwargs): 28 | return await client.chat.completions.create(**kwargs) 29 | 30 | 31 | async def chat_completion_thread(model, progress_bar, queue): 32 | client = openai.AsyncOpenAI() 33 | 34 | while True: 35 | # Fetch task 36 | try: 37 | task = queue.get_nowait() 38 | except asyncio.QueueEmpty: 39 | break 40 | 41 | # Completion 42 | try: 43 | response = await _chat_completion_with_backoff( 44 | client, 45 | model=model, 46 | messages=[{"role": "user", "content": task["question"]}], 47 | 48 | temperature=0 49 | ) 50 | task["response"] = response.choices[0].message.content # type: ignore 51 | except Exception as e: 52 | if hasattr(e, "last_attempt"): 53 | e = e.last_attempt 54 | if hasattr(e, "_exception"): 55 | e = e._exception 56 | 57 | print(type(e), str(e)) 58 | 59 | # Progress 60 | progress_bar.update() 61 | 62 | 63 | async def get_openai_answers( 64 | model: str, 65 | questions: list, 66 | parallel: int 67 | ): 68 | # Complete in retry cycles 69 | last_to_complete_num = 0 70 | 71 | while True: 72 | # fill queue 73 | to_complete_num = 0 74 | queue = asyncio.Queue() 75 | for q in questions: 76 | if q["response"]: 77 | continue 78 | 79 | queue.put_nowait(q) 80 | to_complete_num += 1 81 | 82 | tqdm.write(f"New completion cycle. To complete {to_complete_num}, number of parallel calls {parallel}") 83 | 84 | # Create tasks 85 | progress_bar = tqdm(total=to_complete_num) 86 | async with asyncio.TaskGroup() as task_group: 87 | for _ in range(parallel): 88 | task_group.create_task(chat_completion_thread(model, progress_bar, queue)) 89 | 90 | # Next retry cycle 91 | # Break if cannot complete more 92 | if (to_complete_num == last_to_complete_num) or (to_complete_num == 0): 93 | break 94 | last_to_complete_num = to_complete_num 95 | 96 | # Reduce parallel calls 97 | parallel = max(1, parallel // 2) 98 | 99 | return questions 100 | 101 | 102 | def tokenize_questions(model_config: object, conv_template: object, questions: list, condition: str, system_msg: str): 103 | from ochat.config import Conversation, Message 104 | 105 | # Construct conversation 106 | prompt_indices = [] 107 | conversations = [] 108 | for idx, q in enumerate(questions): 109 | if q["response"]: 110 | continue 111 | 112 | conversations.append(Conversation( 113 | items=[ 114 | Message(role="user", content=q["question"]), 115 | Message(role="assistant", content="") 116 | ], 117 | condition=condition, 118 | system=system_msg 119 | )) 120 | prompt_indices.append(idx) 121 | 122 | # Tokenize 123 | conversations, _ = conv_template.tokenize_conversations(conversations, inference=True) 124 | conversations = [tokens[-model_config.model_max_context:] for tokens in conversations] 125 | 126 | return conversations, prompt_indices 127 | 128 | 129 | def get_model_answers( 130 | model: str, 131 | questions: list, 132 | condition: str, 133 | system_msg: str, 134 | model_type: str, 135 | tensor_parallel_size: int 136 | ): 137 | # Load model config 138 | if model_type is None: 139 | with open(cached_file(path_or_repo_id=model, filename="openchat.json"), "r") as f: 140 | model_type = orjson.loads(f.read())["model_type"] 141 | 142 | model_config = MODEL_CONFIG_MAP[model_type] 143 | tokenizer = model_config.model_tokenizer_create(model) 144 | conv_template = model_config.conversation_template(tokenizer=tokenizer) 145 | 146 | # Init vLLM engine 147 | engine = LLM(model, 148 | max_num_batched_tokens=model_config.model_max_context, 149 | max_model_len=model_config.model_max_context, 150 | tensor_parallel_size=tensor_parallel_size) 151 | sampling_params = SamplingParams(temperature=0, 152 | max_tokens=None, 153 | stop_token_ids=conv_template.eot_tokens_, # Override stop tokens 154 | ignore_eos=True) 155 | 156 | # Complete 157 | prompts, prompt_indices = tokenize_questions(model_config, conv_template, questions, 158 | condition=condition, system_msg=system_msg) 159 | 160 | # calculate & fill in responses 161 | responses = engine.generate(prompt_token_ids=prompts, sampling_params=sampling_params) 162 | for idx, resp in zip(prompt_indices, responses): 163 | questions[idx]["response"] = _strip_first_space(resp.outputs[0].text) 164 | 165 | return questions 166 | 167 | 168 | async def run_eval( 169 | model: str, 170 | condition: str, 171 | system_msg: str, 172 | model_type: str, 173 | 174 | data_path: str, 175 | eval_sets: list, 176 | 177 | continue_from: Optional[str], 178 | output_file: str, 179 | 180 | parallel: int, 181 | tensor_parallel_size: int 182 | ): 183 | print (f"Evaluating ({model_type})...\n\nCondition: {condition}\nSystem Prompt: {system_msg}\n") 184 | 185 | if continue_from is not None: 186 | # Load continue 187 | print (f"Continuing from {continue_from}...") 188 | 189 | with open(continue_from, "rb") as f: 190 | questions = orjson.loads(f.read()) 191 | else: 192 | # Load questions 193 | questions = [] 194 | 195 | for filename in glob(os.path.join(data_path, "**", "*.jsonl"), recursive=True): 196 | task_name = os.path.splitext(filename[len(data_path):])[0].strip("\\/") 197 | task_type = os.path.dirname(task_name) 198 | 199 | assert task_type in MATCH_ANSWER_FUNCTION 200 | 201 | # Filter eval sets 202 | if eval_sets and not sum([task_name.startswith(a) for a in eval_sets]): 203 | continue 204 | 205 | # Load task 206 | with open(filename, "r") as f: 207 | task_data = list(map(orjson.loads, f.readlines())) 208 | 209 | questions.extend([{**item, "task_name": task_name, "task_type": task_type, "response": ""} for item in task_data]) 210 | 211 | # run completion 212 | if model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"): 213 | questions = await get_openai_answers(model, questions, parallel) 214 | else: 215 | questions = get_model_answers(model, questions, condition, system_msg, model_type, tensor_parallel_size) 216 | 217 | # Calculate accuracy 218 | for q in questions: 219 | q["is_matched"], q["answer"] = MATCH_ANSWER_FUNCTION[q["task_type"]](q, q["response"]) 220 | try: 221 | q["is_correct"] = q["answer"] in q["label"] 222 | except: 223 | q["is_correct"] = False 224 | 225 | # Write results 226 | if output_file is None: 227 | output_file = os.path.join(os.path.dirname(data_path), "eval_results", f"{os.path.basename(model)}_{condition}.json") 228 | 229 | os.makedirs(os.path.dirname(output_file), exist_ok=True) 230 | with open(output_file, "wb") as f: 231 | f.write(orjson.dumps(questions, option=orjson.OPT_INDENT_2)) 232 | 233 | 234 | async def main(): 235 | parser = argparse.ArgumentParser() 236 | 237 | # Input / output 238 | parser.add_argument("--model", type=str, default=None) 239 | parser.add_argument("--condition", type=str, default="") 240 | parser.add_argument("--system-msg", type=str, default="") 241 | parser.add_argument("--model-type", type=str, default=None) 242 | 243 | parser.add_argument("--data-path", type=str, default="ochat/evaluation/eval_data") 244 | parser.add_argument("--eval-sets", type=str, nargs="+", default=[]) 245 | 246 | parser.add_argument("--continue-from", type=str, default=None) 247 | parser.add_argument("--output-file", type=str, default=None) 248 | parser.add_argument("--parallel", type=int, default=16) 249 | parser.add_argument("--tensor-parallel-size", type=int, default=1) 250 | 251 | args = parser.parse_args() 252 | 253 | await run_eval(**vars(args)) 254 | 255 | if __name__ == "__main__": 256 | asyncio.run(main()) 257 | -------------------------------------------------------------------------------- /ochat/evaluation/view_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import orjson 6 | import pandas as pd 7 | from glob import glob 8 | 9 | def save_results(dfs, save_path: str): 10 | os.makedirs(os.path.dirname(save_path), exist_ok=True) 11 | with pd.ExcelWriter(save_path) as writer: 12 | for task_type, df_task in dfs.items(): 13 | df_task.to_excel(writer, sheet_name=task_type) 14 | 15 | def view_results(result_path: str): 16 | # Read results 17 | eval_results = [] 18 | for filename in glob(os.path.join(result_path, "*.json")): 19 | with open(filename, "rb") as f: 20 | questions = orjson.loads(f.read()) 21 | 22 | eval_results.extend([{ 23 | "model": Path(filename).stem, 24 | "task_type": q["task_type"], 25 | "task_name": os.path.relpath(q["task_name"], q["task_type"]), 26 | "accuracy": q["is_correct"], 27 | "unmatched": not q["is_matched"], 28 | } for q in questions]) 29 | df = pd.DataFrame.from_records(eval_results) 30 | all_tables = dict() 31 | # Overall metrics table 32 | df_overall = df.pivot_table(index=["model"], columns=["task_type"], values=["accuracy", "unmatched"], aggfunc="mean") 33 | all_tables["overall"] = df_overall 34 | print(df_overall.to_string(float_format=lambda x: f"{x * 100:.1f}", na_rep="-")) 35 | # Print tables for each task 36 | for task_type in df["task_type"].unique(): 37 | df_task = df[df["task_type"] == task_type].pivot_table(index=["task_name"], columns=["model"], values=["accuracy", "unmatched"], aggfunc="mean") 38 | all_tables[task_type.replace("/", "_")] = df_task 39 | print(f"\n### {task_type}\n") 40 | print(df_task.to_string(float_format=lambda x: f"{x * 100:.1f}", na_rep="-")) 41 | return all_tables 42 | 43 | 44 | def main(): 45 | parser = argparse.ArgumentParser() 46 | 47 | # Input / output 48 | parser.add_argument("--result_path", type=str, default="ochat/evaluation/eval_results") 49 | parser.add_argument("--save_path", type=str, default="ochat/evaluation/eval_results/summary.xlsx") 50 | parser.add_argument("--save", "-s", action="store_true", help="Save the results to a file") 51 | args = parser.parse_args() 52 | 53 | all_tables = view_results(args.result_path) 54 | if args.save: 55 | save_results(all_tables, args.save_path) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /ochat/experimental/generate_dataset_old.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate training data based on conversations 3 | 4 | Usage: python -m ochat.data.generate_data --in-file sharegpt_gpt4.json --tokenizer-name HF_REPO_NAME --out-dir . 5 | """ 6 | 7 | from typing import Optional 8 | from dataclasses import dataclass 9 | import argparse 10 | import json 11 | import os 12 | import random 13 | 14 | import numpy as np 15 | import transformers 16 | from transformers.trainer_pt_utils import LabelSmoother 17 | from ray.util.multiprocessing import Pool 18 | 19 | 20 | @dataclass 21 | class ModelDataConfig: 22 | name: str 23 | 24 | # Prompt 25 | system: str 26 | 27 | role_prefix: dict 28 | ai_role: str 29 | eot_token: str 30 | bos_token: Optional[str] 31 | 32 | # Tokenize 33 | max_tokens: int 34 | pad_token: int 35 | ignore_id: int 36 | 37 | 38 | CONFIG = ModelDataConfig( 39 | name="OChat", 40 | 41 | # Prompt 42 | system="", 43 | 44 | role_prefix={ 45 | "human": "Human: ", 46 | "gpt": "Assistant: " 47 | }, 48 | ai_role="gpt", 49 | eot_token="<|end_of_turn|>", 50 | bos_token="", 51 | 52 | # Tokenize 53 | max_tokens=8192, 54 | pad_token="", 55 | ignore_id=LabelSmoother.ignore_index 56 | ) 57 | 58 | 59 | def generate_split(conversations: list, tokenizer: transformers.AutoTokenizer, split_name: str, out_dir: str): 60 | # Add prompt and tokenize conversation 61 | def _convert_single_conversation(c): 62 | tokens = [] 63 | masks = [] 64 | 65 | # begin of sentence (bos) 66 | if CONFIG.bos_token: 67 | t = tokenizer.convert_tokens_to_ids(CONFIG.bos_token) 68 | tokens.append(t) 69 | masks.append(False) 70 | 71 | # System 72 | if CONFIG.system: 73 | t = tokenizer(CONFIG.system, add_special_tokens=False) + [tokenizer.convert_tokens_to_ids(CONFIG.eot_token)] 74 | tokens.extend(t) 75 | masks.extend([False] * len(t)) 76 | 77 | # Messages 78 | for message in c["items"]: 79 | # Message 80 | message_text = CONFIG.role_prefix[message["from"]] + message["value"] 81 | 82 | t = tokenizer(message_text, add_special_tokens=False) + [tokenizer.convert_tokens_to_ids(CONFIG.eot_token)] 83 | tokens.extend(t) 84 | masks.extend([message["from"] == CONFIG.ai_role] * len(t)) 85 | 86 | return tokens, masks 87 | 88 | converted = Pool().map(_convert_single_conversation, conversations) 89 | 90 | # Pad and to numpy array 91 | pad_id = tokenizer.convert_tokens_to_ids(CONFIG.pad_token) 92 | 93 | all_input_ids = [] 94 | all_labels = [] 95 | all_attention_masks = [] 96 | all_plain_texts = [] 97 | for tokens, masks in converted: 98 | # Cut to length 99 | tokens = np.array(tokens[:CONFIG.max_tokens], np.int_) 100 | masks = np.array(masks[:CONFIG.max_tokens], np.bool_) 101 | 102 | # Pad 103 | input_ids = np.full(CONFIG.max_tokens, pad_id, np.int_) 104 | labels = np.full(CONFIG.max_tokens, CONFIG.ignore_id, np.int_) 105 | attention_masks = np.full(CONFIG.max_tokens, False, np.bool_) 106 | 107 | length = len(tokens) 108 | 109 | input_ids[:length] = tokens 110 | labels[:length] = np.where(masks, tokens, CONFIG.ignore_id) 111 | attention_masks[:length] = True 112 | 113 | all_input_ids.append(input_ids) 114 | all_labels.append(labels) 115 | all_attention_masks.append(attention_masks) 116 | all_plain_texts.append(tokens) 117 | 118 | # Output training data 119 | np.savez(os.path.join(out_dir, f"ochat.{split_name}.npz"), 120 | # Arrays 121 | input_ids=np.vstack(all_input_ids), 122 | labels=np.vstack(all_labels), 123 | attention_masks=np.vstack(all_attention_masks)) 124 | 125 | # Output plain texts 126 | all_plain_texts = tokenizer.decode(all_plain_texts) 127 | 128 | with open(os.path.join(out_dir, f"ochat.{split_name}.text.json"), "w") as f: 129 | json.dump(all_plain_texts, f) 130 | 131 | 132 | def generate_dataset(seed, in_file, tokenizer_name, out_dir, eval_ratio): 133 | # Load tokenizer 134 | tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False) 135 | 136 | # Load conversations 137 | with open(in_file, "r") as f: 138 | conversations = json.load(f) 139 | 140 | # Train-test split 141 | random.seed(seed) 142 | random.shuffle(conversations) 143 | eval_num = int(eval_ratio * len(conversations)) 144 | 145 | train_conversations = conversations[eval_num:] 146 | eval_conversations = conversations[:eval_num] 147 | 148 | generate_split(train_conversations, tokenizer, "train", out_dir) 149 | generate_split(eval_conversations, tokenizer, "eval", out_dir) 150 | 151 | 152 | if __name__ == "__main__": 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument("--seed", type=int, default=42) 155 | parser.add_argument("--in-file", type=str, required=True) 156 | parser.add_argument("--tokenizer-name", type=str, required=True) 157 | parser.add_argument("--out-dir", type=str, default=".") 158 | parser.add_argument("--eval-ratio", type=float, default=0.01) 159 | args = parser.parse_args() 160 | 161 | generate_dataset(**vars(args)) 162 | -------------------------------------------------------------------------------- /ochat/experimental/test_multipack_dataloader.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import numba\n", 11 | "\n", 12 | "import json" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 12, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from typing import Any, Optional, List, Callable\n", 22 | "\n", 23 | "import torch.distributed as dist\n", 24 | "\n", 25 | "import numpy as np\n", 26 | "import numba\n", 27 | "\n", 28 | "\n", 29 | "@numba.njit\n", 30 | "def ffd_check(a: np.ndarray, c: int, n: int):\n", 31 | " # First-fit-decreasing bin packing\n", 32 | " # Check if a[] could fit in n bins with capacity c\n", 33 | " # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing\n", 34 | "\n", 35 | " a = np.sort(a)[::-1]\n", 36 | " bins = np.full((n, ), c, dtype=a.dtype)\n", 37 | " for size in a:\n", 38 | " not_found = True\n", 39 | " for idx in range(n):\n", 40 | " if bins[idx] >= size:\n", 41 | " bins[idx] -= size\n", 42 | " not_found = False\n", 43 | " break\n", 44 | "\n", 45 | " if not_found:\n", 46 | " return False\n", 47 | "\n", 48 | " return True\n", 49 | "\n", 50 | "\n", 51 | "@numba.njit\n", 52 | "def ffd_with_result(a: np.ndarray, c: int, start_index: int):\n", 53 | " # First-fit-decreasing bin packing (with result return)\n", 54 | "\n", 55 | " indices = np.argsort(a)[::-1]\n", 56 | " a = a[indices]\n", 57 | "\n", 58 | " bins = []\n", 59 | " bins_result = []\n", 60 | " for a_id, size in enumerate(a):\n", 61 | " add_new = True\n", 62 | " for idx in range(len(bins)):\n", 63 | " if bins[idx] >= size:\n", 64 | " bins[idx] -= size\n", 65 | " bins_result[idx].append(indices[a_id] + start_index)\n", 66 | " add_new = False\n", 67 | " break\n", 68 | "\n", 69 | " if add_new:\n", 70 | " bins.append(c - size)\n", 71 | " bins_result.append([indices[a_id] + start_index])\n", 72 | "\n", 73 | " return bins_result\n", 74 | "\n", 75 | "\n", 76 | "@numba.njit\n", 77 | "def allocate(lengths: np.ndarray, numseqs: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int):\n", 78 | " # Dynamic batch allocator, similar to Multifit\n", 79 | " # https://en.wikipedia.org/wiki/Multifit_algorithm\n", 80 | " # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)\n", 81 | "\n", 82 | " s = 0\n", 83 | " start_index = 0\n", 84 | " result = []\n", 85 | " result_totseqs = []\n", 86 | "\n", 87 | " while True:\n", 88 | " # binary search [l, r)\n", 89 | " l = 1\n", 90 | " r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, \"right\")\n", 91 | "\n", 92 | " while r - l > 1:\n", 93 | " m = (l + r) // 2\n", 94 | " if ffd_check(lengths[start_index: start_index + m], c, n):\n", 95 | " l = m\n", 96 | " else:\n", 97 | " r = m\n", 98 | "\n", 99 | " # use length l\n", 100 | " batch = ffd_with_result(lengths[start_index: start_index + l], c, start_index)\n", 101 | " if len(batch) < n:\n", 102 | " break\n", 103 | "\n", 104 | " start_index += l\n", 105 | " s = lengths_cumsum[start_index - 1]\n", 106 | "\n", 107 | " # add local rank\n", 108 | " result.append(batch[rank])\n", 109 | " # add total seqs for all ranks\n", 110 | " totseq = 0\n", 111 | " for indices in batch:\n", 112 | " for idx in indices:\n", 113 | " totseq += numseqs[idx]\n", 114 | " result_totseqs.append(totseq)\n", 115 | "\n", 116 | " return result, result_totseqs, s, len(result) * c * n\n", 117 | "\n", 118 | "\n", 119 | "class MultipackDistributedDataloader:\n", 120 | " \"\"\"Unpadded data loading using Multipack.\n", 121 | " Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.\"\"\"\n", 122 | " \n", 123 | " def __init__(\n", 124 | " self,\n", 125 | " dataset: Any,\n", 126 | " lengths: np.ndarray,\n", 127 | " numseqs: np.ndarray,\n", 128 | "\n", 129 | " batch_max_length: int,\n", 130 | " collate_fn: Callable,\n", 131 | "\n", 132 | " num_replicas: Optional[int] = None,\n", 133 | " rank: Optional[int] = None,\n", 134 | "\n", 135 | " seed: int = 0,\n", 136 | " ):\n", 137 | " # Dataset\n", 138 | " self.dataset = dataset\n", 139 | " self.lengths = lengths\n", 140 | " self.numseqs = numseqs\n", 141 | " assert isinstance(self.lengths, np.ndarray)\n", 142 | "\n", 143 | " self.batch_max_length = batch_max_length\n", 144 | " self.collate_fn = collate_fn\n", 145 | "\n", 146 | " # Get rank\n", 147 | " if num_replicas is None:\n", 148 | " if not dist.is_available():\n", 149 | " raise RuntimeError(\"Requires distributed package to be available\")\n", 150 | " num_replicas = dist.get_world_size()\n", 151 | " if rank is None:\n", 152 | " if not dist.is_available():\n", 153 | " raise RuntimeError(\"Requires distributed package to be available\")\n", 154 | " rank = dist.get_rank()\n", 155 | "\n", 156 | " self.num_replicas = num_replicas\n", 157 | " self.rank = rank\n", 158 | "\n", 159 | " # Seed\n", 160 | " self.seed = seed\n", 161 | "\n", 162 | " # Epoch\n", 163 | " self.epoch = 0\n", 164 | "\n", 165 | " # statistics\n", 166 | " self.eff_total_used = 0\n", 167 | " self.eff_total_slots = 0\n", 168 | "\n", 169 | " def set_epoch(self, epoch: int):\n", 170 | " self.epoch = epoch\n", 171 | "\n", 172 | " def generate_batches(self, set_stats=False):\n", 173 | " indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(len(self.lengths))\n", 174 | "\n", 175 | " lengths = self.lengths[indices]\n", 176 | " numseqs = self.numseqs[indices]\n", 177 | " lengths_cumsum = np.cumsum(lengths)\n", 178 | "\n", 179 | " batches, totseqs, total_used, total_slots = allocate(lengths=lengths,\n", 180 | " numseqs=numseqs,\n", 181 | " lengths_cumsum=lengths_cumsum,\n", 182 | " rank=self.rank,\n", 183 | " c=self.batch_max_length,\n", 184 | " n=self.num_replicas)\n", 185 | " \n", 186 | " curseqs = [np.sum(numseqs[batch]) for batch in batches]\n", 187 | " batches = [indices[batch] for batch in batches]\n", 188 | "\n", 189 | " # statistics\n", 190 | " if set_stats:\n", 191 | " self.eff_total_used += total_used\n", 192 | " self.eff_total_slots += total_slots\n", 193 | "\n", 194 | " return batches, totseqs, curseqs\n", 195 | " \n", 196 | " def __iter__(self):\n", 197 | " all_batches, all_totseqs, all_curseqs = self.generate_batches(set_stats=True)\n", 198 | "\n", 199 | " for batch, totseq, curseq in zip(all_batches, all_totseqs, all_curseqs):\n", 200 | " yield self.collate_fn(self.dataset[batch]), totseq, curseq\n", 201 | "\n", 202 | " def num_batches(self):\n", 203 | " batches, _, _ = self.generate_batches()\n", 204 | " return len(batches)\n", 205 | "\n", 206 | " def efficiency(self):\n", 207 | " return self.eff_total_used / self.eff_total_slots\n" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": 3, 213 | "metadata": {}, 214 | "outputs": [ 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "[[3], [2, 0], [1, 4], [5]]\n" 220 | ] 221 | } 222 | ], 223 | "source": [ 224 | "lengths = np.array([1, 5, 7, 8, 3, 2])\n", 225 | "lengths_cumsum = np.cumsum(lengths)\n", 226 | "\n", 227 | "print(ffd_with_result(lengths, 8, start_index=0))" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 16, 233 | "metadata": {}, 234 | "outputs": [ 235 | { 236 | "name": "stdout", 237 | "output_type": "stream", 238 | "text": [ 239 | "[29, 29, 29, 29, 29, 29, 29, 29]\n", 240 | "Efficiency: [0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559]\n", 241 | "Overall Efficiency: 0.9955281976408559\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "DATASET = \"../../dataset_processed/openchat.train.json\"\n", 247 | "C = 14 * 2048\n", 248 | "N = 8\n", 249 | "EPOCHS = 10\n", 250 | "\n", 251 | "# Load dataset\n", 252 | "with open(DATASET, \"r\") as f:\n", 253 | " dataset = json.load(f)\n", 254 | "\n", 255 | "# Check allocator efficiency\n", 256 | "lengths = np.array([len(tokens) for tokens, masks, group in dataset])\n", 257 | "numseqs = np.random.randint(low=1, high=10, size=lengths.shape)\n", 258 | "# lengths = np.random.randint(0, 2048, (int(5e6))).astype(np.int32)\n", 259 | "\n", 260 | "# test sampler correctness & efficiency\n", 261 | "tot_len = 0\n", 262 | "tot_batches = 0\n", 263 | "\n", 264 | "dataloaders = [MultipackDistributedDataloader(dataset=np.arange(len(lengths)), lengths=lengths, numseqs=numseqs,\n", 265 | " batch_max_length=C, \n", 266 | " num_replicas=N, rank=rank,\n", 267 | " collate_fn=lambda x: x) for rank in range(N)]\n", 268 | "print([loader.num_batches() for loader in dataloaders])\n", 269 | "\n", 270 | "for epoch in range(EPOCHS):\n", 271 | " batches = []\n", 272 | " totseqs = []\n", 273 | " curseqs = []\n", 274 | "\n", 275 | " for loader in dataloaders:\n", 276 | " loader.set_epoch(epoch)\n", 277 | " totseqs.append([])\n", 278 | " curseqs.append([])\n", 279 | "\n", 280 | " for batch, totseq, curseq in loader:\n", 281 | " batches.extend(batch)\n", 282 | "\n", 283 | " gt_curseq = np.sum(numseqs[batch])\n", 284 | " # print (batch, curseq, gt_curseq)\n", 285 | " assert gt_curseq == curseq\n", 286 | "\n", 287 | " totseqs[-1].append(totseq)\n", 288 | " curseqs[-1].append(gt_curseq)\n", 289 | "\n", 290 | " # Check constraints\n", 291 | " overall_len = sum([lengths[x] for x in batch])\n", 292 | " assert overall_len <= C\n", 293 | "\n", 294 | " tot_len += overall_len\n", 295 | " tot_batches += 1\n", 296 | "\n", 297 | " # Check overall unique\n", 298 | " batches.sort()\n", 299 | " assert batches == list(set(batches)) # Unique\n", 300 | "\n", 301 | " # Check totseq accurate\n", 302 | " gt_totseqs = np.sum(curseqs, axis=0)\n", 303 | " for i in range(len(totseqs)):\n", 304 | " assert (totseqs[i] == gt_totseqs).all()\n", 305 | "\n", 306 | "# Check efficiency\n", 307 | "efficiency = [loader.efficiency() for loader in dataloaders]\n", 308 | "print(f\"Efficiency: {efficiency}\")\n", 309 | "\n", 310 | "print(f\"Overall Efficiency: {tot_len / (tot_batches * C)}\")" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "data": { 320 | "text/plain": [ 321 | "150.98552224214524" 322 | ] 323 | }, 324 | "execution_count": 23, 325 | "metadata": {}, 326 | "output_type": "execute_result" 327 | } 328 | ], 329 | "source": [ 330 | "C * N / np.mean(lengths)" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": null, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "data": { 340 | "text/plain": [ 341 | "150.31034482758622" 342 | ] 343 | }, 344 | "execution_count": 24, 345 | "metadata": {}, 346 | "output_type": "execute_result" 347 | } 348 | ], 349 | "source": [ 350 | "np.mean(gt_totseqs)" 351 | ] 352 | } 353 | ], 354 | "metadata": { 355 | "kernelspec": { 356 | "display_name": "torch", 357 | "language": "python", 358 | "name": "python3" 359 | }, 360 | "language_info": { 361 | "codemirror_mode": { 362 | "name": "ipython", 363 | "version": 3 364 | }, 365 | "file_extension": ".py", 366 | "mimetype": "text/x-python", 367 | "name": "python", 368 | "nbconvert_exporter": "python", 369 | "pygments_lexer": "ipython3", 370 | "version": "3.11.4" 371 | }, 372 | "orig_nbformat": 4 373 | }, 374 | "nbformat": 4, 375 | "nbformat_minor": 2 376 | } 377 | -------------------------------------------------------------------------------- /ochat/experimental/text_length.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 15, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%matplotlib inline\n", 10 | "\n", 11 | "import json\n", 12 | "import numpy as np\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "\n", 15 | "import sentencepiece\n", 16 | "from ray.util.multiprocessing import Pool\n", 17 | "\n", 18 | "\n", 19 | "DATASET = \"../../dataset_processed/sharegpt_gpt4.json\"\n", 20 | "TOKENIZER = \"../../tokenizer/llama_tokenizer.model\"" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [ 28 | { 29 | "name": "stderr", 30 | "output_type": "stream", 31 | "text": [ 32 | "2023-05-24 20:04:38,940\tINFO worker.py:1625 -- Started a local Ray instance.\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "# Load dataset\n", 38 | "tokenizer = sentencepiece.SentencePieceProcessor(model_file=TOKENIZER)\n", 39 | "with open(DATASET, \"r\") as f:\n", 40 | " dataset = json.load(f)\n", 41 | "\n", 42 | "# Parallel tokenization\n", 43 | "def _tokenize(sample):\n", 44 | " for c in sample[\"items\"]:\n", 45 | " c[\"value\"] = tokenizer.tokenize(c[\"value\"])\n", 46 | "\n", 47 | " return sample\n", 48 | "\n", 49 | "dataset_tokenized = Pool().map(_tokenize, dataset)" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 20, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "PROMPT_LEN = {\n", 59 | " \"human\": 4,\n", 60 | " \"gpt\": 5\n", 61 | "}\n", 62 | "\n", 63 | "dataset_len = np.array([sum(len(c[\"value\"]) + PROMPT_LEN[c[\"from\"]] for c in sample[\"items\"]) for sample in dataset_tokenized])" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 19, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "data": { 73 | "text/plain": [ 74 | "(array([4650., 896., 301., 154., 71., 33., 33., 17., 10.,\n", 75 | " 10.]),\n", 76 | " array([ 0., 5000., 10000., 15000., 20000., 25000., 30000., 35000.,\n", 77 | " 40000., 45000., 50000.]),\n", 78 | " )" 79 | ] 80 | }, 81 | "execution_count": 19, 82 | "metadata": {}, 83 | "output_type": "execute_result" 84 | }, 85 | { 86 | "data": { 87 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGdCAYAAAAMm0nCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiuklEQVR4nO3de3BU5cHH8V8u7BIuu+FiNkQSiYOCUcASNGy9vEVSVoxWK0zRUmUEtdDACFhulYLazoTBKgVBsKU1zlRF6AgqETATJFQJt2g04ZJqGxta3ASL2Q0UEiDP+4eTM6ygkhBInvj9zOwMOefZs895IJPvnOxZoowxRgAAABaJbu0JAAAANBUBAwAArEPAAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrEDAAAMA6sa09gQuloaFBBw8eVNeuXRUVFdXa0wEAAOfAGKPa2lolJSUpOvrrr7O024A5ePCgkpOTW3saAACgGQ4cOKDevXt/7f52GzBdu3aV9OUCeDyeVp4NAAA4F+FwWMnJyc7P8a/TbgOm8ddGHo+HgAEAwDLf9vYP3sQLAACsQ8AAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA4BAwAArEPAAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrxLb2BGzUZ3Zea0+hyT5dkNXaUwAAoMVwBQYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgnfMKmAULFigqKkpTp051th0/flzZ2dnq0aOHunTpolGjRqmqqirieZWVlcrKylKnTp2UkJCgGTNm6OTJkxFjtmzZosGDB8vtdqtv377Kzc09n6kCAIB2pNkBs2vXLj3//PMaOHBgxPZp06bpzTff1Jo1a1RYWKiDBw/q7rvvdvafOnVKWVlZqq+v17Zt2/Tiiy8qNzdX8+bNc8ZUVFQoKytLw4YNU0lJiaZOnaoHH3xQmzZtau50AQBAO9KsgDly5IjGjh2rP/7xj+rWrZuzPRQK6U9/+pOeeeYZ3XLLLUpPT9cLL7ygbdu2afv27ZKkt99+W3v37tVf/vIXXXvttRo5cqR+85vfaNmyZaqvr5ckrVixQqmpqXr66ad11VVXafLkyRo9erQWLVrUAqcMAABs16yAyc7OVlZWljIzMyO2FxcX68SJExHb+/fvr5SUFBUVFUmSioqKNGDAAPl8PmdMIBBQOBzWnj17nDFfPXYgEHCOcTZ1dXUKh8MRDwAA0D7FNvUJq1at0vvvv69du3adsS8YDMrlcik+Pj5iu8/nUzAYdMacHi+N+xv3fdOYcDisY8eOKS4u7ozXzsnJ0RNPPNHU0wEAABZq0hWYAwcO6JFHHtFLL72kjh07Xqg5NcucOXMUCoWcx4EDB1p7SgAA4AJpUsAUFxerurpagwcPVmxsrGJjY1VYWKglS5YoNjZWPp9P9fX1qqmpiXheVVWVEhMTJUmJiYln3JXU+PW3jfF4PGe9+iJJbrdbHo8n4gEAANqnJgXM8OHDVVpaqpKSEucxZMgQjR071vlzhw4dVFBQ4DynvLxclZWV8vv9kiS/36/S0lJVV1c7Y/Lz8+XxeJSWluaMOf0YjWMajwEAAL7bmvQemK5du+qaa66J2Na5c2f16NHD2T5hwgRNnz5d3bt3l8fj0ZQpU+T3+zV06FBJ0ogRI5SWlqb77rtPCxcuVDAY1Ny5c5WdnS232y1JmjhxopYuXaqZM2dq/Pjx2rx5s1avXq28vLyWOGcAAGC5Jr+J99ssWrRI0dHRGjVqlOrq6hQIBPTcc885+2NiYrR+/XpNmjRJfr9fnTt31rhx4/Tkk086Y1JTU5WXl6dp06Zp8eLF6t27t1auXKlAINDS0wUAABaKMsaY1p7EhRAOh+X1ehUKhVr8/TB9Ztt3JejTBVmtPQUAAL7Vuf785v9CAgAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWaVLALF++XAMHDpTH45HH45Hf79eGDRuc/cePH1d2drZ69OihLl26aNSoUaqqqoo4RmVlpbKystSpUyclJCRoxowZOnnyZMSYLVu2aPDgwXK73erbt69yc3Obf4YAAKDdaVLA9O7dWwsWLFBxcbF2796tW265RXfeeaf27NkjSZo2bZrefPNNrVmzRoWFhTp48KDuvvtu5/mnTp1SVlaW6uvrtW3bNr344ovKzc3VvHnznDEVFRXKysrSsGHDVFJSoqlTp+rBBx/Upk2bWuiUAQCA7aKMMeZ8DtC9e3c99dRTGj16tC655BK9/PLLGj16tCRp//79uuqqq1RUVKShQ4dqw4YNuv3223Xw4EH5fD5J0ooVKzRr1iwdOnRILpdLs2bNUl5ensrKypzXuOeee1RTU6ONGzee87zC4bC8Xq9CoZA8Hs/5nOIZ+szOa9HjXQyfLshq7SkAAPCtzvXnd7PfA3Pq1CmtWrVKR48eld/vV3FxsU6cOKHMzExnTP/+/ZWSkqKioiJJUlFRkQYMGODEiyQFAgGFw2HnKk5RUVHEMRrHNB7j69TV1SkcDkc8AABA+9TkgCktLVWXLl3kdrs1ceJErV27VmlpaQoGg3K5XIqPj48Y7/P5FAwGJUnBYDAiXhr3N+77pjHhcFjHjh372nnl5OTI6/U6j+Tk5KaeGgAAsESTA6Zfv34qKSnRjh07NGnSJI0bN0579+69EHNrkjlz5igUCjmPAwcOtPaUAADABRLb1Ce4XC717dtXkpSenq5du3Zp8eLFGjNmjOrr61VTUxNxFaaqqkqJiYmSpMTERO3cuTPieI13KZ0+5qt3LlVVVcnj8SguLu5r5+V2u+V2u5t6OgAAwELn/TkwDQ0NqqurU3p6ujp06KCCggJnX3l5uSorK+X3+yVJfr9fpaWlqq6udsbk5+fL4/EoLS3NGXP6MRrHNB4DAACgSVdg5syZo5EjRyolJUW1tbV6+eWXtWXLFm3atEler1cTJkzQ9OnT1b17d3k8Hk2ZMkV+v19Dhw6VJI0YMUJpaWm67777tHDhQgWDQc2dO1fZ2dnO1ZOJEydq6dKlmjlzpsaPH6/Nmzdr9erVysuz784fAABwYTQpYKqrq3X//ffrs88+k9fr1cCBA7Vp0yb98Ic/lCQtWrRI0dHRGjVqlOrq6hQIBPTcc885z4+JidH69es1adIk+f1+de7cWePGjdOTTz7pjElNTVVeXp6mTZumxYsXq3fv3lq5cqUCgUALnTIAALDdeX8OTFvF58BE4nNgAAA2uOCfAwMAANBaCBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHWaFDA5OTm67rrr1LVrVyUkJOiuu+5SeXl5xJjjx48rOztbPXr0UJcuXTRq1ChVVVVFjKmsrFRWVpY6deqkhIQEzZgxQydPnowYs2XLFg0ePFhut1t9+/ZVbm5u884QAAC0O00KmMLCQmVnZ2v79u3Kz8/XiRMnNGLECB09etQZM23aNL355ptas2aNCgsLdfDgQd19993O/lOnTikrK0v19fXatm2bXnzxReXm5mrevHnOmIqKCmVlZWnYsGEqKSnR1KlT9eCDD2rTpk0tcMoAAMB2UcYY09wnHzp0SAkJCSosLNTNN9+sUCikSy65RC+//LJGjx4tSdq/f7+uuuoqFRUVaejQodqwYYNuv/12HTx4UD6fT5K0YsUKzZo1S4cOHZLL5dKsWbOUl5ensrIy57Xuuece1dTUaOPGjec0t3A4LK/Xq1AoJI/H09xTPKs+s/Na9HgXw6cLslp7CgAAfKtz/fl9Xu+BCYVCkqTu3btLkoqLi3XixAllZmY6Y/r376+UlBQVFRVJkoqKijRgwAAnXiQpEAgoHA5rz549zpjTj9E4pvEYZ1NXV6dwOBzxAAAA7VOzA6ahoUFTp07VDTfcoGuuuUaSFAwG5XK5FB8fHzHW5/MpGAw6Y06Pl8b9jfu+aUw4HNaxY8fOOp+cnBx5vV7nkZyc3NxTAwAAbVyzAyY7O1tlZWVatWpVS86n2ebMmaNQKOQ8Dhw40NpTAgAAF0hsc540efJkrV+/Xlu3blXv3r2d7YmJiaqvr1dNTU3EVZiqqiolJiY6Y3bu3BlxvMa7lE4f89U7l6qqquTxeBQXF3fWObndbrnd7uacDgAAsEyTrsAYYzR58mStXbtWmzdvVmpqasT+9PR0dejQQQUFBc628vJyVVZWyu/3S5L8fr9KS0tVXV3tjMnPz5fH41FaWpoz5vRjNI5pPAYAAPhua9IVmOzsbL388st6/fXX1bVrV+c9K16vV3FxcfJ6vZowYYKmT5+u7t27y+PxaMqUKfL7/Ro6dKgkacSIEUpLS9N9992nhQsXKhgMau7cucrOznauoEycOFFLly7VzJkzNX78eG3evFmrV69WXp59d/8AAICW16QrMMuXL1coFNIPfvAD9erVy3m8+uqrzphFixbp9ttv16hRo3TzzTcrMTFRr732mrM/JiZG69evV0xMjPx+v372s5/p/vvv15NPPumMSU1NVV5envLz8zVo0CA9/fTTWrlypQKBQAucMgAAsN15fQ5MW8bnwETic2AAADa4KJ8DAwAA0BoIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1mlywGzdulV33HGHkpKSFBUVpXXr1kXsN8Zo3rx56tWrl+Li4pSZmamPP/44Yszhw4c1duxYeTwexcfHa8KECTpy5EjEmI8++kg33XSTOnbsqOTkZC1cuLDpZwcAANqlJgfM0aNHNWjQIC1btuys+xcuXKglS5ZoxYoV2rFjhzp37qxAIKDjx487Y8aOHas9e/YoPz9f69ev19atW/Xwww87+8PhsEaMGKHLLrtMxcXFeuqpp/T444/rD3/4QzNOEQAAtDdRxhjT7CdHRWnt2rW66667JH159SUpKUmPPvqofvnLX0qSQqGQfD6fcnNzdc8992jfvn1KS0vTrl27NGTIEEnSxo0bddttt+nf//63kpKStHz5cj322GMKBoNyuVySpNmzZ2vdunXav3//Oc0tHA7L6/UqFArJ4/E09xTPqs/svBY93sXw6YKs1p4CAADf6lx/frfoe2AqKioUDAaVmZnpbPN6vcrIyFBRUZEkqaioSPHx8U68SFJmZqaio6O1Y8cOZ8zNN9/sxIskBQIBlZeX64svvmjJKQMAAAvFtuTBgsGgJMnn80Vs9/l8zr5gMKiEhITIScTGqnv37hFjUlNTzzhG475u3bqd8dp1dXWqq6tzvg6Hw+d5NgAAoK1qN3ch5eTkyOv1Oo/k5OTWnhIAALhAWjRgEhMTJUlVVVUR26uqqpx9iYmJqq6ujth/8uRJHT58OGLM2Y5x+mt81Zw5cxQKhZzHgQMHzv+EAABAm9SiAZOamqrExEQVFBQ428LhsHbs2CG/3y9J8vv9qqmpUXFxsTNm8+bNamhoUEZGhjNm69atOnHihDMmPz9f/fr1O+uvjyTJ7XbL4/FEPAAAQPvU5IA5cuSISkpKVFJSIunLN+6WlJSosrJSUVFRmjp1qn7729/qjTfeUGlpqe6//34lJSU5dypdddVVuvXWW/XQQw9p586deu+99zR58mTdc889SkpKkiT99Kc/lcvl0oQJE7Rnzx69+uqrWrx4saZPn95iJw4AAOzV5Dfx7t69W8OGDXO+boyKcePGKTc3VzNnztTRo0f18MMPq6amRjfeeKM2btyojh07Os956aWXNHnyZA0fPlzR0dEaNWqUlixZ4uz3er16++23lZ2drfT0dPXs2VPz5s2L+KwYAADw3XVenwPTlvE5MJH4HBgAgA1a5XNgAAAALoYW/RwYtF02XjWSuHIEADg7rsAAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA4BAwAArEPAAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA4BAwAArEPAAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA4BAwAArEPAAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA4BAwAArBPb2hMAvkmf2XmtPYUm+3RBVmtPAQDaPa7AAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA7/FxLQwvj/mwDgwuMKDAAAsA4BAwAArEPAAAAA6xAwAADAOgQMAACwDnchAeDOKQDWadNXYJYtW6Y+ffqoY8eOysjI0M6dO1t7SgAAoA1oswHz6quvavr06Zo/f77ef/99DRo0SIFAQNXV1a09NQAA0MqijDGmtSdxNhkZGbruuuu0dOlSSVJDQ4OSk5M1ZcoUzZ49+1ufHw6H5fV6FQqF5PF4WnRuNl5uB4Dm4Fd1uNjO9ed3m3wPTH19vYqLizVnzhxnW3R0tDIzM1VUVHTW59TV1amurs75OhQKSfpyIVpaQ93/WvyYANAWpUxb09pTaLKyJwKtPQWch8af2992faVNBsznn3+uU6dOyefzRWz3+Xzav3//WZ+Tk5OjJ5544oztycnJF2SOAIC2yfv71p4BWkJtba28Xu/X7m+TAdMcc+bM0fTp052vGxoadPjwYfXo0UNRUVEt9jrhcFjJyck6cOBAi/9qCpFY64uDdb44WOeLg3W+OC7kOhtjVFtbq6SkpG8c1yYDpmfPnoqJiVFVVVXE9qqqKiUmJp71OW63W263O2JbfHz8hZqiPB4P3xwXCWt9cbDOFwfrfHGwzhfHhVrnb7ry0qhN3oXkcrmUnp6ugoICZ1tDQ4MKCgrk9/tbcWYAAKAtaJNXYCRp+vTpGjdunIYMGaLrr79ev//973X06FE98MADrT01AADQytpswIwZM0aHDh3SvHnzFAwGde2112rjxo1nvLH3YnO73Zo/f/4Zv65Cy2OtLw7W+eJgnS8O1vniaAvr3GY/BwYAAODrtMn3wAAAAHwTAgYAAFiHgAEAANYhYAAAgHUImCZatmyZ+vTpo44dOyojI0M7d+5s7Sm1GVu3btUdd9yhpKQkRUVFad26dRH7jTGaN2+eevXqpbi4OGVmZurjjz+OGHP48GGNHTtWHo9H8fHxmjBhgo4cORIx5qOPPtJNN92kjh07Kjk5WQsXLjxjLmvWrFH//v3VsWNHDRgwQG+99VaLn29rycnJ0XXXXaeuXbsqISFBd911l8rLyyPGHD9+XNnZ2erRo4e6dOmiUaNGnfHBkJWVlcrKylKnTp2UkJCgGTNm6OTJkxFjtmzZosGDB8vtdqtv377Kzc09Yz7t9Xti+fLlGjhwoPNBXX6/Xxs2bHD2s8YXxoIFCxQVFaWpU6c621jr8/f4448rKioq4tG/f39nv5VrbHDOVq1aZVwul/nzn/9s9uzZYx566CETHx9vqqqqWntqbcJbb71lHnvsMfPaa68ZSWbt2rUR+xcsWGC8Xq9Zt26d+fDDD82PfvQjk5qaao4dO+aMufXWW82gQYPM9u3bzd/+9jfTt29fc++99zr7Q6GQ8fl8ZuzYsaasrMy88sorJi4uzjz//PPOmPfee8/ExMSYhQsXmr1795q5c+eaDh06mNLS0gu+BhdDIBAwL7zwgikrKzMlJSXmtttuMykpKebIkSPOmIkTJ5rk5GRTUFBgdu/ebYYOHWq+//3vO/tPnjxprrnmGpOZmWk++OAD89Zbb5mePXuaOXPmOGP++c9/mk6dOpnp06ebvXv3mmeffdbExMSYjRs3OmPa8/fEG2+8YfLy8szf//53U15ebn71q1+ZDh06mLKyMmMMa3wh7Ny50/Tp08cMHDjQPPLII8521vr8zZ8/31x99dXms88+cx6HDh1y9tu4xgRME1x//fUmOzvb+frUqVMmKSnJ5OTktOKs2qavBkxDQ4NJTEw0Tz31lLOtpqbGuN1u88orrxhjjNm7d6+RZHbt2uWM2bBhg4mKijL/+c9/jDHGPPfcc6Zbt26mrq7OGTNr1izTr18/5+uf/OQnJisrK2I+GRkZ5uc//3mLnmNbUV1dbSSZwsJCY8yX69qhQwezZs0aZ8y+ffuMJFNUVGSM+TI2o6OjTTAYdMYsX77ceDweZ21nzpxprr766ojXGjNmjAkEAs7X37XviW7dupmVK1eyxhdAbW2tueKKK0x+fr75v//7PydgWOuWMX/+fDNo0KCz7rN1jfkV0jmqr69XcXGxMjMznW3R0dHKzMxUUVFRK87MDhUVFQoGgxHr5/V6lZGR4axfUVGR4uPjNWTIEGdMZmamoqOjtWPHDmfMzTffLJfL5YwJBAIqLy/XF1984Yw5/XUax7TXv6dQKCRJ6t69uySpuLhYJ06ciFiD/v37KyUlJWKtBwwYEPHBkIFAQOFwWHv27HHGfNM6fpe+J06dOqVVq1bp6NGj8vv9rPEFkJ2draysrDPWg7VuOR9//LGSkpJ0+eWXa+zYsaqsrJRk7xoTMOfo888/16lTp874JGCfz6dgMNhKs7JH4xp90/oFg0ElJCRE7I+NjVX37t0jxpztGKe/xteNaY9/Tw0NDZo6dapuuOEGXXPNNZK+PH+Xy3XGf2b61bVu7jqGw2EdO3bsO/E9UVpaqi5dusjtdmvixIlau3at0tLSWOMWtmrVKr3//vvKyck5Yx9r3TIyMjKUm5urjRs3avny5aqoqNBNN92k2tpaa9e4zf5XAgC+XXZ2tsrKyvTuu++29lTapX79+qmkpEShUEh//etfNW7cOBUWFrb2tNqVAwcO6JFHHlF+fr46duzY2tNpt0aOHOn8eeDAgcrIyNBll12m1atXKy4urhVn1nxcgTlHPXv2VExMzBnvyq6qqlJiYmIrzcoejWv0TeuXmJio6urqiP0nT57U4cOHI8ac7Rinv8bXjWlvf0+TJ0/W+vXr9c4776h3797O9sTERNXX16umpiZi/FfXurnr6PF4FBcX9534nnC5XOrbt6/S09OVk5OjQYMGafHixaxxCyouLlZ1dbUGDx6s2NhYxcbGqrCwUEuWLFFsbKx8Ph9rfQHEx8fryiuv1CeffGLtv2cC5hy5XC6lp6eroKDA2dbQ0KCCggL5/f5WnJkdUlNTlZiYGLF+4XBYO3bscNbP7/erpqZGxcXFzpjNmzeroaFBGRkZzpitW7fqxIkTzpj8/Hz169dP3bp1c8ac/jqNY9rL35MxRpMnT9batWu1efNmpaamRuxPT09Xhw4dItagvLxclZWVEWtdWloaEYz5+fnyeDxKS0tzxnzTOn4XvycaGhpUV1fHGreg4cOHq7S0VCUlJc5jyJAhGjt2rPNn1rrlHTlyRP/4xz/Uq1cve/89N/ltv99hq1atMm632+Tm5pq9e/eahx9+2MTHx0e8K/u7rLa21nzwwQfmgw8+MJLMM888Yz744APzr3/9yxjz5W3U8fHx5vXXXzcfffSRufPOO896G/X3vvc9s2PHDvPuu++aK664IuI26pqaGuPz+cx9991nysrKzKpVq0ynTp3OuI06NjbW/O53vzP79u0z8+fPb1e3UU+aNMl4vV6zZcuWiFsi//e//zljJk6caFJSUszmzZvN7t27jd/vN36/39nfeEvkiBEjTElJidm4caO55JJLznpL5IwZM8y+ffvMsmXLznpLZHv9npg9e7YpLCw0FRUV5qOPPjKzZ882UVFR5u233zbGsMYX0ul3IRnDWreERx991GzZssVUVFSY9957z2RmZpqePXua6upqY4yda0zANNGzzz5rUlJSjMvlMtdff73Zvn17a0+pzXjnnXeMpDMe48aNM8Z8eSv1r3/9a+Pz+Yzb7TbDhw835eXlEcf473//a+69917TpUsX4/F4zAMPPGBqa2sjxnz44YfmxhtvNG6321x66aVmwYIFZ8xl9erV5sorrzQul8tcffXVJi8v74Kd98V2tjWWZF544QVnzLFjx8wvfvEL061bN9OpUyfz4x//2Hz22WcRx/n000/NyJEjTVxcnOnZs6d59NFHzYkTJyLGvPPOO+baa681LpfLXH755RGv0ai9fk+MHz/eXHbZZcblcplLLrnEDB8+3IkXY1jjC+mrAcNan78xY8aYXr16GZfLZS699FIzZswY88knnzj7bVzjKGOMafp1GwAAgNbDe2AAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADW+X9sek0zUsUJ6wAAAABJRU5ErkJggg==", 88 | "text/plain": [ 89 | "
" 90 | ] 91 | }, 92 | "metadata": {}, 93 | "output_type": "display_data" 94 | } 95 | ], 96 | "source": [ 97 | "plt.hist(dataset_len, range=(0, 50000))" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 21, 103 | "metadata": {}, 104 | "outputs": [ 105 | { 106 | "data": { 107 | "text/plain": [ 108 | "5334" 109 | ] 110 | }, 111 | "execution_count": 21, 112 | "metadata": {}, 113 | "output_type": "execute_result" 114 | } 115 | ], 116 | "source": [ 117 | "np.sum(dataset_len < 8192)" 118 | ] 119 | } 120 | ], 121 | "metadata": { 122 | "kernelspec": { 123 | "display_name": "jax", 124 | "language": "python", 125 | "name": "python3" 126 | }, 127 | "language_info": { 128 | "codemirror_mode": { 129 | "name": "ipython", 130 | "version": 3 131 | }, 132 | "file_extension": ".py", 133 | "mimetype": "text/x-python", 134 | "name": "python", 135 | "nbconvert_exporter": "python", 136 | "pygments_lexer": "ipython3", 137 | "version": "3.10.11" 138 | }, 139 | "orig_nbformat": 4 140 | }, 141 | "nbformat": 4, 142 | "nbformat_minor": 2 143 | } 144 | -------------------------------------------------------------------------------- /ochat/experimental/train_alpaca.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import logging 17 | from dataclasses import dataclass, field 18 | from typing import Dict, Optional, Sequence 19 | 20 | import torch 21 | import transformers 22 | import utils 23 | from torch.utils.data import Dataset 24 | from transformers import Trainer 25 | 26 | IGNORE_INDEX = -100 27 | DEFAULT_PAD_TOKEN = "[PAD]" 28 | DEFAULT_EOS_TOKEN = "" 29 | DEFAULT_BOS_TOKEN = "" 30 | DEFAULT_UNK_TOKEN = "" 31 | PROMPT_DICT = { 32 | "prompt_input": ( 33 | "Below is an instruction that describes a task, paired with an input that provides further context. " 34 | "Write a response that appropriately completes the request.\n\n" 35 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 36 | ), 37 | "prompt_no_input": ( 38 | "Below is an instruction that describes a task. " 39 | "Write a response that appropriately completes the request.\n\n" 40 | "### Instruction:\n{instruction}\n\n### Response:" 41 | ), 42 | } 43 | 44 | 45 | @dataclass 46 | class ModelArguments: 47 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 48 | 49 | 50 | @dataclass 51 | class DataArguments: 52 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 53 | 54 | 55 | @dataclass 56 | class TrainingArguments(transformers.TrainingArguments): 57 | cache_dir: Optional[str] = field(default=None) 58 | optim: str = field(default="adamw_torch") 59 | model_max_length: int = field( 60 | default=512, 61 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 62 | ) 63 | 64 | 65 | def smart_tokenizer_and_embedding_resize( 66 | special_tokens_dict: Dict, 67 | tokenizer: transformers.PreTrainedTokenizer, 68 | model: transformers.PreTrainedModel, 69 | ): 70 | """Resize tokenizer and embedding. 71 | 72 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 73 | """ 74 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 75 | model.resize_token_embeddings(len(tokenizer)) 76 | 77 | if num_new_tokens > 0: 78 | input_embeddings = model.get_input_embeddings().weight.data 79 | output_embeddings = model.get_output_embeddings().weight.data 80 | 81 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 82 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 83 | 84 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 85 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 86 | 87 | 88 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 89 | """Tokenize a list of strings.""" 90 | tokenized_list = [ 91 | tokenizer( 92 | text, 93 | return_tensors="pt", 94 | padding="longest", 95 | max_length=tokenizer.model_max_length, 96 | truncation=True, 97 | ) 98 | for text in strings 99 | ] 100 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 101 | input_ids_lens = labels_lens = [ 102 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 103 | ] 104 | return dict( 105 | input_ids=input_ids, 106 | labels=labels, 107 | input_ids_lens=input_ids_lens, 108 | labels_lens=labels_lens, 109 | ) 110 | 111 | 112 | def preprocess( 113 | sources: Sequence[str], 114 | targets: Sequence[str], 115 | tokenizer: transformers.PreTrainedTokenizer, 116 | ) -> Dict: 117 | """Preprocess the data by tokenizing.""" 118 | examples = [s + t for s, t in zip(sources, targets)] 119 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 120 | input_ids = examples_tokenized["input_ids"] 121 | labels = copy.deepcopy(input_ids) 122 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 123 | label[:source_len] = IGNORE_INDEX 124 | return dict(input_ids=input_ids, labels=labels) 125 | 126 | 127 | class SupervisedDataset(Dataset): 128 | """Dataset for supervised fine-tuning.""" 129 | 130 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): 131 | super(SupervisedDataset, self).__init__() 132 | logging.warning("Loading data...") 133 | list_data_dict = utils.jload(data_path) 134 | 135 | logging.warning("Formatting inputs...") 136 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 137 | sources = [ 138 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) 139 | for example in list_data_dict 140 | ] 141 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] 142 | 143 | logging.warning("Tokenizing inputs... This may take some time...") 144 | data_dict = preprocess(sources, targets, tokenizer) 145 | 146 | self.input_ids = data_dict["input_ids"] 147 | self.labels = data_dict["labels"] 148 | 149 | def __len__(self): 150 | return len(self.input_ids) 151 | 152 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 153 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 154 | 155 | 156 | @dataclass 157 | class DataCollatorForSupervisedDataset(object): 158 | """Collate examples for supervised fine-tuning.""" 159 | 160 | tokenizer: transformers.PreTrainedTokenizer 161 | 162 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 163 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 164 | input_ids = torch.nn.utils.rnn.pad_sequence( 165 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 166 | ) 167 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 168 | return dict( 169 | input_ids=input_ids, 170 | labels=labels, 171 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 172 | ) 173 | 174 | 175 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: 176 | """Make dataset and collator for supervised fine-tuning.""" 177 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path) 178 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 179 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 180 | 181 | 182 | def train(): 183 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 184 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 185 | 186 | model = transformers.AutoModelForCausalLM.from_pretrained( 187 | model_args.model_name_or_path, 188 | cache_dir=training_args.cache_dir, 189 | ) 190 | 191 | tokenizer = transformers.AutoTokenizer.from_pretrained( 192 | model_args.model_name_or_path, 193 | cache_dir=training_args.cache_dir, 194 | model_max_length=training_args.model_max_length, 195 | padding_side="right", 196 | use_fast=False, 197 | ) 198 | special_tokens_dict = dict() 199 | if tokenizer.pad_token is None: 200 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 201 | if tokenizer.eos_token is None: 202 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 203 | if tokenizer.bos_token is None: 204 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 205 | if tokenizer.unk_token is None: 206 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 207 | 208 | smart_tokenizer_and_embedding_resize( 209 | special_tokens_dict=special_tokens_dict, 210 | tokenizer=tokenizer, 211 | model=model, 212 | ) 213 | 214 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 215 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 216 | trainer.train() 217 | trainer.save_state() 218 | trainer.save_model(output_dir=training_args.output_dir) 219 | 220 | 221 | if __name__ == "__main__": 222 | train() 223 | -------------------------------------------------------------------------------- /ochat/models/__init__.py: -------------------------------------------------------------------------------- 1 | from ochat.models.unpadded_llama import LlamaForCausalLM 2 | from ochat.models.unpadded_mistral import MistralForCausalLM 3 | from ochat.models.unpadded_gemma import GemmaForCausalLM 4 | -------------------------------------------------------------------------------- /ochat/models/unpadded_gemma.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ PyTorch Unpadded & Fused Gemma model. Compatible with HF. """ 21 | 22 | from typing import Optional, Tuple 23 | 24 | import torch 25 | import torch.utils.checkpoint 26 | import torch.nn.functional as F 27 | from torch import nn 28 | 29 | from transformers.activations import ACT2FN 30 | from transformers.modeling_outputs import CausalLMOutputWithPast 31 | from transformers.modeling_utils import PreTrainedModel 32 | from transformers.utils import logging 33 | from transformers.models.gemma.configuration_gemma import GemmaConfig 34 | 35 | try: 36 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 37 | from flash_attn.bert_padding import pad_input 38 | except ImportError: 39 | print ("FlashAttention not found. Install it if you need to train models.") 40 | 41 | 42 | def rotate_half(x: torch.Tensor): 43 | """Rotates half the hidden dims of the input.""" 44 | x1 = x[..., : x.shape[-1] // 2] 45 | x2 = x[..., x.shape[-1] // 2 :] 46 | return torch.cat((-x2, x1), dim=-1) 47 | 48 | 49 | def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor): 50 | # q, k: [nnz, num_heads, head_dim] 51 | # position_ids: [nnz] 52 | # cos, sin: [max_seq_len, head_dim] 53 | cos = cos[position_ids].unsqueeze(-2) # [nnz, 1, head_dim] 54 | sin = sin[position_ids].unsqueeze(-2) # [nnz, 1, head_dim] 55 | q_embed = (q * cos) + (rotate_half(q) * sin) 56 | k_embed = (k * cos) + (rotate_half(k) * sin) 57 | return q_embed, k_embed 58 | 59 | 60 | @torch.jit.script 61 | def lm_head_with_loss(embed_weights: torch.Tensor, hidden_states: torch.Tensor, nz_shifted_label_ids: torch.Tensor, nz_shifted_loss_weights: torch.Tensor): 62 | logits = F.linear(hidden_states, embed_weights) 63 | 64 | loss = (nz_shifted_loss_weights * torch.nn.functional.cross_entropy(logits, nz_shifted_label_ids, reduction="none")).sum() 65 | token_accuracy = (nz_shifted_loss_weights * (torch.argmax(logits.detach(), dim=-1) == nz_shifted_label_ids)).sum() 66 | return loss, token_accuracy 67 | 68 | 69 | # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Gemma 70 | RMS_NORM_TRACED = None 71 | 72 | 73 | def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, variance_epsilon: torch.Tensor): 74 | input_dtype = hidden_states.dtype 75 | hidden_states = hidden_states.to(torch.float32) 76 | 77 | variance = hidden_states.square().mean(-1, keepdim=True) 78 | hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) 79 | return (1 + weight) * hidden_states.to(input_dtype) 80 | 81 | 82 | class UnpaddedGemmaRMSNorm(nn.Module): 83 | def __init__(self, hidden_size, eps): 84 | """ 85 | UnpaddedGemmaRMSNorm is equivalent to T5LayerNorm 86 | """ 87 | super().__init__() 88 | 89 | self.weight = nn.Parameter(torch.zeros(hidden_size)) 90 | self.variance_epsilon = torch.tensor(eps, dtype=torch.get_default_dtype()) 91 | 92 | global RMS_NORM_TRACED 93 | if RMS_NORM_TRACED is None: 94 | RMS_NORM_TRACED = torch.jit.trace(rms_norm, (torch.ones(hidden_size), torch.ones(hidden_size), self.variance_epsilon)) 95 | 96 | def forward(self, hidden_states): 97 | global RMS_NORM_TRACED 98 | return RMS_NORM_TRACED(hidden_states, self.weight, self.variance_epsilon) 99 | 100 | 101 | # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Gemma 102 | class UnpaddedGemmaRotaryEmbedding(torch.nn.Module): 103 | def __init__(self, dim, max_position_embeddings, base, device=None): 104 | super().__init__() 105 | 106 | # RoPE 107 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) 108 | t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device) 109 | freqs = torch.outer(t, inv_freq) 110 | 111 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 112 | emb = torch.cat((freqs, freqs), dim=-1) 113 | dtype = torch.get_default_dtype() 114 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 115 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 116 | 117 | def forward(self): 118 | return self.cos_cached, self.sin_cached 119 | 120 | 121 | class UnpaddedGemmaMLP(nn.Module): 122 | def __init__(self, config: GemmaConfig): 123 | super().__init__() 124 | 125 | self.hidden_size = config.hidden_size 126 | self.intermediate_size = config.intermediate_size 127 | 128 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 129 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) 130 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) 131 | self.act_fn = ACT2FN[config.hidden_act] 132 | 133 | def forward(self, x): 134 | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) 135 | 136 | 137 | class UnpaddedGemmaAttention(nn.Module): 138 | """Multi-headed attention from 'Attention Is All You Need' paper""" 139 | 140 | def __init__(self, config: GemmaConfig): 141 | super().__init__() 142 | 143 | self.hidden_size = config.hidden_size 144 | self.num_heads = config.num_attention_heads 145 | self.head_dim = config.head_dim 146 | self.num_key_value_heads = config.num_key_value_heads 147 | 148 | if self.hidden_size % self.num_heads != 0: 149 | raise ValueError( 150 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 151 | f" and `num_heads`: {self.num_heads})." 152 | ) 153 | 154 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) 155 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) 156 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) 157 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) 158 | 159 | def forward( 160 | self, 161 | cos_sin: Tuple[torch.Tensor, torch.Tensor], 162 | # Unpadded inputs 163 | nz_hidden_states: torch.Tensor, 164 | nz_position_ids: torch.LongTensor, 165 | cu_seqlens: torch.Tensor, 166 | max_seqlen: int 167 | ) -> torch.Tensor: 168 | # nz_hidden_states: [nnz, num_heads, head_dim] 169 | # nz_position_ids: [nnz] 170 | # cu_seqlens: [bs + 1] 171 | 172 | query_states = self.q_proj(nz_hidden_states).view(-1, self.num_heads, self.head_dim) 173 | key_states = self.k_proj(nz_hidden_states).view(-1, self.num_key_value_heads, self.head_dim) 174 | value_states = self.v_proj(nz_hidden_states).view(-1, self.num_key_value_heads, self.head_dim) 175 | 176 | # RoPE 177 | cos, sin = cos_sin 178 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, nz_position_ids) 179 | 180 | # flash attn 181 | attn_output = flash_attn_varlen_func( 182 | q=query_states, k=key_states, v=value_states, 183 | cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, 184 | max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, 185 | 186 | dropout_p=0.0, causal=True) 187 | 188 | # attn_output: [total_nnz, num_heads, head_dim] 189 | attn_output = attn_output.view(-1, self.num_heads * self.head_dim) # type: ignore 190 | return self.o_proj(attn_output) 191 | 192 | 193 | class UnpaddedGemmaDecoderLayer(nn.Module): 194 | def __init__(self, config: GemmaConfig): 195 | super().__init__() 196 | 197 | self.hidden_size = config.hidden_size 198 | self.self_attn = UnpaddedGemmaAttention(config=config) 199 | self.mlp = UnpaddedGemmaMLP(config=config) 200 | self.input_layernorm = UnpaddedGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 201 | self.post_attention_layernorm = UnpaddedGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 202 | 203 | def forward( 204 | self, 205 | cos_sin: Tuple[torch.Tensor, torch.Tensor], 206 | # Unpadded inputs 207 | nz_hidden_states: torch.Tensor, 208 | nz_position_ids: torch.Tensor, 209 | cu_seqlens: torch.Tensor, 210 | max_seqlen: int 211 | ) -> torch.Tensor: 212 | # Self Attention 213 | residual = nz_hidden_states 214 | 215 | nz_hidden_states = self.input_layernorm(nz_hidden_states) 216 | nz_hidden_states = self.self_attn( 217 | cos_sin=cos_sin, 218 | 219 | nz_hidden_states=nz_hidden_states, 220 | nz_position_ids=nz_position_ids, 221 | cu_seqlens=cu_seqlens, 222 | max_seqlen=max_seqlen 223 | ) 224 | nz_hidden_states = residual + nz_hidden_states 225 | 226 | # Fully Connected 227 | residual = nz_hidden_states 228 | 229 | nz_hidden_states = self.post_attention_layernorm(nz_hidden_states) 230 | nz_hidden_states = self.mlp(nz_hidden_states) 231 | nz_hidden_states = residual + nz_hidden_states 232 | 233 | return nz_hidden_states 234 | 235 | 236 | class UnpaddedGemmaPreTrainedModel(PreTrainedModel): 237 | config_class = GemmaConfig 238 | base_model_prefix = "model" 239 | supports_gradient_checkpointing = True 240 | _no_split_modules = ["UnpaddedGemmaDecoderLayer"] 241 | 242 | def _init_weights(self, module): 243 | std = self.config.initializer_range 244 | if isinstance(module, nn.Linear): 245 | module.weight.data.normal_(mean=0.0, std=std) 246 | if module.bias is not None: 247 | module.bias.data.zero_() 248 | elif isinstance(module, nn.Embedding): 249 | module.weight.data.normal_(mean=0.0, std=std) 250 | if module.padding_idx is not None: 251 | module.weight.data[module.padding_idx].zero_() 252 | 253 | 254 | class UnpaddedGemmaModel(UnpaddedGemmaPreTrainedModel): 255 | """ 256 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`UnpaddedGemmaDecoderLayer`] 257 | 258 | Args: 259 | config: GemmaConfig 260 | """ 261 | 262 | def __init__(self, config: GemmaConfig): 263 | super().__init__(config) 264 | self.padding_idx = config.pad_token_id 265 | self.vocab_size = config.vocab_size 266 | self.normalization_factor = config.hidden_size ** 0.5 267 | 268 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 269 | self.rotary_emb = UnpaddedGemmaRotaryEmbedding(config.head_dim, 270 | max_position_embeddings=config.max_position_embeddings, 271 | base=config.rope_theta) 272 | 273 | self.layers = nn.ModuleList([UnpaddedGemmaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) 274 | self.norm = UnpaddedGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) 275 | 276 | self.gradient_checkpointing = False 277 | # Initialize weights and apply final processing 278 | self.post_init() 279 | 280 | def get_input_embeddings(self): 281 | return self.embed_tokens 282 | 283 | def set_input_embeddings(self, value): 284 | self.embed_tokens = value 285 | 286 | def forward( 287 | self, 288 | # Unpadded inputs 289 | nz_input_ids: torch.Tensor, 290 | nz_position_ids: torch.Tensor, 291 | cu_seqlens: torch.Tensor, 292 | max_seqlen: int, 293 | ) -> torch.Tensor: 294 | nz_hidden_states = self.embed_tokens(nz_input_ids) * self.normalization_factor # Normalized 295 | cos_sin = self.rotary_emb() 296 | 297 | # decoder layers 298 | for decoder_layer in self.layers: 299 | if self.gradient_checkpointing and self.training: 300 | nz_hidden_states = self._gradient_checkpointing_func( 301 | decoder_layer.__call__, 302 | 303 | cos_sin, 304 | nz_hidden_states, 305 | nz_position_ids, 306 | cu_seqlens, 307 | max_seqlen 308 | ) 309 | else: 310 | nz_hidden_states = decoder_layer( 311 | cos_sin, 312 | 313 | nz_hidden_states, 314 | nz_position_ids, 315 | cu_seqlens, 316 | max_seqlen 317 | ) 318 | 319 | nz_hidden_states = self.norm(nz_hidden_states) 320 | 321 | return nz_hidden_states 322 | 323 | 324 | class GemmaForCausalLM(UnpaddedGemmaPreTrainedModel): 325 | def __init__(self, config): 326 | super().__init__(config) 327 | self.model = UnpaddedGemmaModel(config) 328 | 329 | # Initialize weights and apply final processing 330 | self.post_init() 331 | 332 | def get_input_embeddings(self): 333 | return self.model.embed_tokens 334 | 335 | def set_input_embeddings(self, value): 336 | self.model.embed_tokens = value 337 | 338 | def get_output_embeddings(self): 339 | return self.model.embed_tokens 340 | 341 | def set_output_embeddings(self, new_embeddings): 342 | self.model.embed_tokens = new_embeddings 343 | 344 | def set_decoder(self, decoder): 345 | self.model = decoder 346 | 347 | def get_decoder(self): 348 | return self.model 349 | 350 | def forward( 351 | self, 352 | # Unpadded inputs 353 | nz_input_ids: torch.Tensor, 354 | nz_position_ids: torch.Tensor, 355 | cu_seqlens: torch.Tensor, 356 | max_seqlen: int, 357 | # Unpadded labels 358 | nz_shifted_label_ids: Optional[torch.Tensor] = None, 359 | nz_shifted_loss_weights: Optional[torch.Tensor] = None 360 | ) -> CausalLMOutputWithPast: 361 | # Model logits 362 | hidden_states = self.model( 363 | nz_input_ids=nz_input_ids, 364 | nz_position_ids=nz_position_ids, 365 | cu_seqlens=cu_seqlens, 366 | max_seqlen=max_seqlen 367 | ) 368 | 369 | # Loss 370 | loss = lm_head_with_loss( 371 | self.model.embed_tokens.weight, # Tied embeddings 372 | hidden_states, 373 | nz_shifted_label_ids, 374 | nz_shifted_loss_weights 375 | ) 376 | 377 | return CausalLMOutputWithPast( 378 | loss=loss # type: ignore 379 | ) 380 | -------------------------------------------------------------------------------- /ochat/scripts/hf_add_tokens.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import transformers 4 | import torch 5 | 6 | 7 | def add_tokens_to_embedding(added_special_tokens, embedding): 8 | # Mean embedding, shape: [1, dim] 9 | new_token_embeddings = torch.mean(embedding.to(torch.float32), dim=0, keepdim=True).to(embedding.dtype) 10 | # Expand to [N, dim] 11 | new_token_embeddings = new_token_embeddings.expand(len(added_special_tokens), -1) 12 | 13 | return torch.cat([embedding, new_token_embeddings], dim=0) 14 | 15 | 16 | def hf_add_tokens(model_path, output_dir, added_special_tokens): 17 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) 18 | model = transformers.AutoModelForCausalLM.from_pretrained(model_path, 19 | low_cpu_mem_usage=True, 20 | torch_dtype=torch.bfloat16) 21 | # Add tokens (tokenizer) 22 | tokenizer.add_special_tokens({"additional_special_tokens": added_special_tokens}) 23 | 24 | # Add tokens (embedding) 25 | assert model.model.embed_tokens.weight.requires_grad 26 | assert model.lm_head.weight.requires_grad 27 | 28 | model.model.embed_tokens.weight = torch.nn.Parameter(add_tokens_to_embedding(added_special_tokens, model.model.embed_tokens.weight), requires_grad=True) 29 | model.lm_head.weight = torch.nn.Parameter(add_tokens_to_embedding(added_special_tokens, model.lm_head.weight), requires_grad=True) 30 | 31 | model.config.vocab_size += len(added_special_tokens) 32 | 33 | # Fix model config (Mistral's actual token length is 8192) 34 | if "mistral" in model_path.lower(): 35 | assert model.config.max_position_embeddings == 32768 36 | model.config.max_position_embeddings = 8192 37 | 38 | print ({k: v.shape for k, v in model.state_dict().items()}) 39 | 40 | # Save 41 | tokenizer.save_pretrained(output_dir) 42 | model.save_pretrained(output_dir) 43 | 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument( 48 | "--model-path", 49 | help="Location of Mistral model, or HuggingFace repo ID", 50 | ) 51 | parser.add_argument( 52 | "--output-dir", 53 | help="Location to write resulting model and tokenizer", 54 | ) 55 | parser.add_argument( 56 | "--added-special-tokens", 57 | type=str, 58 | nargs="+", 59 | help="Special token list to add" 60 | ) 61 | 62 | hf_add_tokens(**vars(parser.parse_args())) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /ochat/scripts/init_special_embedding_llama3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import transformers 4 | import torch 5 | 6 | 7 | def init_eot_embedding_llama3(model_path, output_dir, special_tokens=["<|eot_id|>", "<|start_header_id|>", "<|end_header_id|>"], mean_cutoff=128000, dtype=torch.bfloat16): 8 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) 9 | model = transformers.AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=dtype) 10 | 11 | assert model.model.embed_tokens.weight.shape[0] >= mean_cutoff 12 | assert model.lm_head.weight.shape[0] >= mean_cutoff 13 | 14 | with torch.no_grad(): 15 | for token in special_tokens: 16 | token_id = tokenizer.convert_tokens_to_ids(token) 17 | 18 | print (f"Token {token} ID {token_id}") 19 | 20 | model.model.embed_tokens.weight[token_id] = torch.mean(model.model.embed_tokens.weight[:mean_cutoff].to(torch.float32), dim=0).to(dtype) 21 | model.lm_head.weight[token_id] = torch.mean(model.lm_head.weight[:mean_cutoff].to(torch.float32), dim=0).to(dtype) 22 | 23 | # Save 24 | tokenizer.save_pretrained(output_dir) 25 | model.save_pretrained(output_dir) 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument( 31 | "--model-path", 32 | help="Location of model, or HuggingFace repo ID", 33 | ) 34 | parser.add_argument( 35 | "--output-dir", 36 | help="Location to write resulting model and tokenizer", 37 | ) 38 | 39 | init_eot_embedding_llama3(**vars(parser.parse_args())) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /ochat/scripts/modify_eos_embedding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import transformers 4 | import torch 5 | 6 | 7 | def modify_eos_embeddings(model_path, output_dir): 8 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_path) 9 | model = transformers.AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16) 10 | 11 | eos_token_id = tokenizer.eos_token_id 12 | 13 | print (f"EOS Token {tokenizer.convert_ids_to_tokens(eos_token_id)} ID {eos_token_id}") 14 | with torch.no_grad(): 15 | model.model.embed_tokens.weight[eos_token_id] = torch.mean(model.model.embed_tokens.weight, dim=0) 16 | model.lm_head.weight[eos_token_id] = torch.mean(model.lm_head.weight, dim=0) 17 | 18 | # Save 19 | tokenizer.save_pretrained(output_dir) 20 | model.save_pretrained(output_dir) 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--model-path", 27 | help="Location of model, or HuggingFace repo ID", 28 | ) 29 | parser.add_argument( 30 | "--output-dir", 31 | help="Location to write resulting model and tokenizer", 32 | ) 33 | 34 | modify_eos_embeddings(**vars(parser.parse_args())) 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /ochat/serving/async_tokenizer.py: -------------------------------------------------------------------------------- 1 | import ray 2 | 3 | from ochat.config import Message, Conversation 4 | 5 | 6 | @ray.remote 7 | class AsyncTokenizer: 8 | def __init__(self, model_type: str, model_path: str) -> None: 9 | from ochat.config import MODEL_CONFIG_MAP 10 | 11 | config = MODEL_CONFIG_MAP[model_type] 12 | tokenizer = config.model_tokenizer_create(model_path) 13 | 14 | self.conv_template = config.conversation_template(tokenizer=tokenizer) 15 | 16 | def tokenize(self, messages, condition, enable_sys_prompt=False): 17 | # get system messages 18 | system_message = "" 19 | items = [] 20 | 21 | for msg_raw in messages: 22 | msg = Message(**msg_raw) 23 | if msg.role == "system": 24 | # Use system prompt only when enabled 25 | if enable_sys_prompt: 26 | system_message = msg.content.strip() 27 | 28 | continue 29 | 30 | items.append(msg) 31 | 32 | assert len(items) 33 | 34 | # append ai role 35 | if items[-1].role != "assistant": 36 | items.append(Message(role="assistant", content="")) 37 | 38 | tokens, _ = self.conv_template.tokenize_conversations([Conversation(items=items, system=system_message, condition=condition)], 39 | inference=True) 40 | return tokens[0] 41 | 42 | def get_eot_tokens(self): 43 | assert len(self.conv_template.eot_tokens_) == 1 44 | 45 | return self.conv_template.eot_tokens_ 46 | -------------------------------------------------------------------------------- /ochat/serving/openai_api_protocol.py: -------------------------------------------------------------------------------- 1 | # Adapted from 2 | # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py 3 | import time 4 | from typing import Dict, List, Literal, Optional, Union 5 | 6 | from pydantic import BaseModel, Field 7 | 8 | from vllm.utils import random_uuid 9 | 10 | 11 | class ErrorResponse(BaseModel): 12 | object: str = "error" 13 | message: str 14 | type: str 15 | param: Optional[str] = None 16 | code: Optional[str] = None 17 | 18 | 19 | class ModelPermission(BaseModel): 20 | id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") 21 | object: str = "model_permission" 22 | created: int = Field(default_factory=lambda: int(time.time())) 23 | allow_create_engine: bool = False 24 | allow_sampling: bool = True 25 | allow_logprobs: bool = False 26 | allow_search_indices: bool = False 27 | allow_view: bool = True 28 | allow_fine_tuning: bool = False 29 | organization: str = "*" 30 | group: Optional[str] = None 31 | is_blocking: str = False 32 | 33 | 34 | class ModelCard(BaseModel): 35 | id: str 36 | object: str = "model" 37 | created: int = Field(default_factory=lambda: int(time.time())) 38 | owned_by: str = "openchat" 39 | root: Optional[str] = None 40 | parent: Optional[str] = None 41 | permission: List[ModelPermission] = Field(default_factory=list) 42 | 43 | 44 | class ModelList(BaseModel): 45 | object: str = "list" 46 | data: List[ModelCard] = Field(default_factory=list) 47 | 48 | 49 | class UsageInfo(BaseModel): 50 | prompt_tokens: int = 0 51 | total_tokens: int = 0 52 | completion_tokens: Optional[int] = 0 53 | 54 | 55 | class ChatCompletionRequest(BaseModel): 56 | model: str 57 | messages: Union[str, List[Dict[str, str]]] 58 | condition: Optional[str] = "" 59 | temperature: Optional[float] = 0.7 60 | top_p: Optional[float] = 1.0 61 | n: Optional[int] = 1 62 | max_tokens: Optional[int] = None 63 | seed: Optional[int] = None 64 | stop: Optional[Union[str, List[str]]] = None 65 | stream: Optional[bool] = False 66 | presence_penalty: Optional[float] = 0.0 67 | frequency_penalty: Optional[float] = 0.0 68 | logit_bias: Optional[Dict[str, float]] = None 69 | user: Optional[str] = None 70 | 71 | 72 | class ChatMessage(BaseModel): 73 | role: str 74 | content: str 75 | 76 | 77 | class ChatCompletionResponseChoice(BaseModel): 78 | index: int 79 | message: ChatMessage 80 | finish_reason: Optional[Literal["stop", "length"]] = None 81 | 82 | 83 | class ChatCompletionResponse(BaseModel): 84 | id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") 85 | object: str = "chat.completion" 86 | created: int = Field(default_factory=lambda: int(time.time())) 87 | model: str 88 | choices: List[ChatCompletionResponseChoice] 89 | usage: UsageInfo 90 | 91 | 92 | class DeltaMessage(BaseModel): 93 | role: Optional[str] = None 94 | content: Optional[str] = None 95 | 96 | 97 | class ChatCompletionResponseStreamChoice(BaseModel): 98 | index: int 99 | delta: DeltaMessage 100 | finish_reason: Optional[Literal["stop", "length"]] = None 101 | 102 | 103 | class ChatCompletionStreamResponse(BaseModel): 104 | id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") 105 | object: str = "chat.completion.chunk" 106 | created: int = Field(default_factory=lambda: int(time.time())) 107 | model: str 108 | choices: List[ChatCompletionResponseStreamChoice] 109 | 110 | 111 | class LoggingRecord(BaseModel): 112 | time: int 113 | request: ChatCompletionRequest 114 | outputs: List[str] 115 | -------------------------------------------------------------------------------- /ochat/serving/openai_api_server.py: -------------------------------------------------------------------------------- 1 | # Adapted from 2 | # https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py 3 | 4 | import argparse 5 | import asyncio 6 | from http import HTTPStatus 7 | import json 8 | import time 9 | import logging 10 | from logging.handlers import RotatingFileHandler 11 | from typing import AsyncGenerator, Optional 12 | from dataclasses import dataclass 13 | 14 | import fastapi 15 | from fastapi import BackgroundTasks, Request 16 | from fastapi.exceptions import RequestValidationError 17 | from fastapi.middleware.cors import CORSMiddleware 18 | from fastapi.responses import JSONResponse, StreamingResponse 19 | from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer 20 | 21 | import uvicorn 22 | import ray 23 | 24 | from vllm.engine.arg_utils import AsyncEngineArgs 25 | from vllm.engine.async_llm_engine import AsyncLLMEngine 26 | from vllm.outputs import RequestOutput 27 | from vllm.sampling_params import SamplingParams 28 | from vllm.utils import random_uuid 29 | 30 | from ochat.config import MODEL_CONFIG_MAP 31 | from ochat.serving import openai_api_protocol, async_tokenizer 32 | 33 | from transformers.utils.hub import cached_file 34 | 35 | 36 | TIMEOUT_KEEP_ALIVE = 5 # seconds 37 | 38 | 39 | @dataclass 40 | class ModelConfig: 41 | names: set = None 42 | 43 | max_length: int = None 44 | stream_period: int = None 45 | eot_tokens: list = None 46 | 47 | enable_sys_prompt: bool = None 48 | api_keys: list = None 49 | 50 | 51 | logger = None 52 | app = fastapi.FastAPI() 53 | 54 | model = ModelConfig() 55 | tokenizer = None 56 | 57 | 58 | def _strip_first_space(s: str): 59 | if s[0] == " ": 60 | return s[1:] 61 | return s 62 | 63 | 64 | def log_request(created_time: int, request: openai_api_protocol.ChatCompletionRequest, output: RequestOutput): 65 | if logger is not None: 66 | logger.info(openai_api_protocol.LoggingRecord( 67 | time=created_time, 68 | request=request, 69 | outputs=[o.text for o in output.outputs] 70 | ).model_dump_json(exclude_unset=True)) 71 | 72 | 73 | def create_error_response(status_code: HTTPStatus, 74 | message: str) -> JSONResponse: 75 | return JSONResponse(openai_api_protocol.ErrorResponse(message=message, 76 | type="invalid_request_error").dict(), 77 | status_code=status_code.value) 78 | 79 | 80 | def check_model(request) -> Optional[JSONResponse]: 81 | if request.model in model.names: 82 | return 83 | 84 | return create_error_response( 85 | HTTPStatus.NOT_FOUND, 86 | f"The model `{request.model}` does not exist.", 87 | ) 88 | 89 | 90 | @app.exception_handler(RequestValidationError) 91 | async def validation_exception_handler(request, exc): # pylint: disable=unused-argument 92 | return create_error_response(HTTPStatus.BAD_REQUEST, str(exc)) 93 | 94 | 95 | async def check_api_key( 96 | auth: Optional[HTTPAuthorizationCredentials] = fastapi.Depends(HTTPBearer(auto_error=False)), 97 | ): 98 | if not model.api_keys: 99 | return 100 | 101 | if auth is None or auth.credentials not in model.api_keys: 102 | raise fastapi.HTTPException( 103 | status_code=401, 104 | detail={ 105 | "error": { 106 | "message": "", 107 | "type": "invalid_request_error", 108 | "param": None, 109 | "code": "invalid_api_key", 110 | } 111 | }, 112 | ) 113 | 114 | 115 | @app.get("/v1/models", dependencies=[fastapi.Depends(check_api_key)]) 116 | async def show_available_models(): 117 | """Show available models. Right now we only have one model.""" 118 | return openai_api_protocol.ModelList(data=[ 119 | openai_api_protocol.ModelCard(id=name, 120 | root=name, 121 | permission=[openai_api_protocol.ModelPermission()]) 122 | for name in model.names]) 123 | 124 | 125 | @app.post("/v1/chat/completions", dependencies=[fastapi.Depends(check_api_key)]) 126 | async def create_chat_completion(raw_request: Request, background_tasks: BackgroundTasks): 127 | """Completion API similar to OpenAI's API. 128 | 129 | See https://platform.openai.com/docs/api-reference/chat/create 130 | for the API specification. This API mimics the OpenAI ChatCompletion API. 131 | 132 | NOTE: Currently we do not support the following features: 133 | - function_call (Users should implement this by themselves) 134 | - logit_bias (to be supported by vLLM engine) 135 | """ 136 | 137 | request = openai_api_protocol.ChatCompletionRequest(**await raw_request.json()) 138 | 139 | error_check_ret = check_model(request) 140 | if error_check_ret is not None: 141 | return error_check_ret 142 | 143 | if request.logit_bias is not None and len(request.logit_bias) > 0: 144 | # TODO: support logit_bias in vLLM engine. 145 | return create_error_response(HTTPStatus.BAD_REQUEST, 146 | "logit_bias is not currently supported") 147 | 148 | # input ids 149 | input_ids = await tokenizer.tokenize.remote(request.messages, condition=request.condition, 150 | enable_sys_prompt=model.enable_sys_prompt) 151 | input_num_tokens = len(input_ids) 152 | 153 | # check length 154 | if request.max_tokens is None: 155 | request.max_tokens = model.max_length - input_num_tokens 156 | 157 | if input_num_tokens + request.max_tokens > model.max_length: 158 | return input_ids, create_error_response( 159 | HTTPStatus.BAD_REQUEST, 160 | f"This model's maximum context length is {model.max_length} tokens. " 161 | f"However, you requested {input_num_tokens + request.max_tokens} tokens " 162 | f"({input_num_tokens} in the messages, " 163 | f"{request.max_tokens} in the completion). " 164 | f"Please reduce the length of the messages or completion.", 165 | ) 166 | 167 | # completion 168 | model_name = request.model 169 | request_id = f"cmpl-{random_uuid()}" 170 | created_time = int(time.time()) 171 | 172 | try: 173 | sampling_params = SamplingParams( 174 | n=request.n, 175 | presence_penalty=request.presence_penalty, 176 | frequency_penalty=request.frequency_penalty, 177 | temperature=request.temperature, 178 | top_p=request.top_p, 179 | max_tokens=request.max_tokens, 180 | seed=request.seed, 181 | # Override stop tokens 182 | stop_token_ids=model.eot_tokens, 183 | ignore_eos=True 184 | ) 185 | except ValueError as e: 186 | return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) 187 | 188 | result_generator = engine.generate(prompt=None, 189 | prompt_token_ids=input_ids, 190 | sampling_params=sampling_params, 191 | request_id=request_id) 192 | 193 | def create_stream_response_json( 194 | index: int, 195 | text: str, 196 | finish_reason: Optional[str] = None, 197 | ) -> str: 198 | choice_data = openai_api_protocol.ChatCompletionResponseStreamChoice( 199 | index=index, 200 | delta=openai_api_protocol.DeltaMessage(content=text), 201 | finish_reason=finish_reason, 202 | ) 203 | response = openai_api_protocol.ChatCompletionStreamResponse( 204 | id=request_id, 205 | choices=[choice_data], 206 | model=model_name, 207 | ) 208 | 209 | return response.model_dump_json(exclude_unset=True) 210 | 211 | async def completion_stream_generator() -> AsyncGenerator[str, None]: 212 | # First chunk with role 213 | for i in range(request.n): 214 | choice_data = openai_api_protocol.ChatCompletionResponseStreamChoice( 215 | index=i, 216 | delta=openai_api_protocol.DeltaMessage(role="assistant"), 217 | finish_reason=None, 218 | ) 219 | chunk = openai_api_protocol.ChatCompletionStreamResponse(id=request_id, 220 | choices=[choice_data], 221 | model=model_name) 222 | 223 | yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" 224 | 225 | previous_texts = [""] * request.n 226 | previous_num_tokens = [0] * request.n 227 | 228 | stream_index = 0 229 | final_res = None 230 | is_first = True 231 | async for res in result_generator: 232 | stream_index += 1 233 | final_res = res 234 | 235 | for output in res.outputs: 236 | # stream on end or every stream_period 237 | if (stream_index % model.stream_period == 0) or (output.finish_reason is not None): 238 | i = output.index 239 | delta_text = output.text[len(previous_texts[i]):] 240 | if "\ufffd" not in delta_text: 241 | previous_texts[i] = output.text 242 | previous_num_tokens[i] = len(output.token_ids) 243 | 244 | if is_first: 245 | # Strip first space 246 | is_first = False 247 | delta_text = _strip_first_space(delta_text) 248 | 249 | yield f"data: {create_stream_response_json(index=i, text=delta_text)}\n\n" 250 | if output.finish_reason is not None: 251 | yield f"data: {create_stream_response_json(index=i, text='', finish_reason=output.finish_reason)}\n\n" 252 | 253 | yield "data: [DONE]\n\n" 254 | 255 | # Log request 256 | background_tasks.add_task(log_request, created_time, request, final_res) 257 | 258 | # Streaming response 259 | if request.stream: 260 | return StreamingResponse(completion_stream_generator(), 261 | media_type="text/event-stream") 262 | 263 | # Non-streaming response 264 | final_res: RequestOutput = None 265 | async for res in result_generator: 266 | if await raw_request.is_disconnected(): 267 | # Abort the request if the client disconnects. 268 | await engine.abort(request_id) 269 | return create_error_response(HTTPStatus.BAD_REQUEST, 270 | "Client disconnected") 271 | final_res = res 272 | assert final_res is not None 273 | choices = [] 274 | for output in final_res.outputs: 275 | choice_data = openai_api_protocol.ChatCompletionResponseChoice( 276 | index=output.index, 277 | message=openai_api_protocol.ChatMessage(role="assistant", content=_strip_first_space(output.text)), 278 | finish_reason=output.finish_reason, 279 | ) 280 | choices.append(choice_data) 281 | 282 | num_prompt_tokens = len(final_res.prompt_token_ids) 283 | num_generated_tokens = sum( 284 | len(output.token_ids) for output in final_res.outputs) 285 | usage = openai_api_protocol.UsageInfo( 286 | prompt_tokens=num_prompt_tokens, 287 | completion_tokens=num_generated_tokens, 288 | total_tokens=num_prompt_tokens + num_generated_tokens, 289 | ) 290 | response = openai_api_protocol.ChatCompletionResponse( 291 | id=request_id, 292 | created=created_time, 293 | model=model_name, 294 | choices=choices, 295 | usage=usage, 296 | ) 297 | 298 | # Log request 299 | background_tasks.add_task(log_request, created_time, request, final_res) 300 | 301 | return response 302 | 303 | 304 | if __name__ == "__main__": 305 | parser = argparse.ArgumentParser(description="OpenChat OpenAI-Compatible RESTful API server.") 306 | 307 | # Model 308 | parser.add_argument("--model-type", type=str, default=None, help="Model type. Leave empty to auto-detect.") 309 | 310 | parser.add_argument("--stream-period", type=int, default=6, help="Number of tokens per stream event") 311 | parser.add_argument("--api-keys", type=str, nargs="*", default=[], help="Allowed API Keys. Leave blank to not verify") 312 | parser.add_argument("--enable-sys-prompt", default=False, action="store_true") 313 | 314 | # Server 315 | parser.add_argument("--host", type=str, default="localhost", help="Host name") 316 | parser.add_argument("--port", type=int, default=18888, help="Port number") 317 | parser.add_argument("--allow-credentials", action="store_true", help="Allow credentials") 318 | parser.add_argument("--allowed-origins", type=json.loads, default=["*"], help="Allowed origins") 319 | parser.add_argument("--allowed-methods", type=json.loads, default=["*"], help="Allowed methods") 320 | parser.add_argument("--allowed-headers", type=json.loads, default=["*"], help="Allowed headers") 321 | 322 | # Logging 323 | parser.add_argument("--log-file", type=str, default=None, help="Log file. Leave blank to disable logging") 324 | parser.add_argument("--log-max-mb", type=int, default=128, help="Max log size in MB") 325 | parser.add_argument("--log-max-count", type=int, default=10, help="Max log file versions to keep") 326 | 327 | parser = AsyncEngineArgs.add_cli_args(parser) 328 | args = parser.parse_args() 329 | 330 | # App and logging 331 | app.add_middleware( 332 | CORSMiddleware, 333 | allow_origins=args.allowed_origins, 334 | allow_credentials=args.allow_credentials, 335 | allow_methods=args.allowed_methods, 336 | allow_headers=args.allowed_headers, 337 | ) 338 | 339 | if args.log_file: 340 | logger = logging.getLogger(__name__) 341 | 342 | logger.setLevel(logging.INFO) 343 | logger.addHandler(RotatingFileHandler( 344 | args.log_file, 345 | maxBytes=args.log_max_mb * 1048576, 346 | backupCount=args.log_max_count) 347 | ) 348 | logger.propagate = False 349 | 350 | # Load model type 351 | if args.model_type is None: 352 | with open(cached_file(path_or_repo_id=args.model, filename="openchat.json"), "r") as f: 353 | args.model_type = json.load(f)["model_type"] 354 | 355 | # Load tokenizer 356 | tokenizer = async_tokenizer.AsyncTokenizer.remote(args.model_type, args.model) 357 | 358 | # Model config 359 | model.names = set(list(MODEL_CONFIG_MAP[args.model_type].serving_aliases) + [args.model_type]) 360 | model.max_length = MODEL_CONFIG_MAP[args.model_type].model_max_context 361 | model.eot_tokens = ray.get(tokenizer.get_eot_tokens.remote()) 362 | 363 | model.enable_sys_prompt = args.enable_sys_prompt 364 | model.stream_period = args.stream_period 365 | model.api_keys = args.api_keys 366 | 367 | # Set max num batched tokens 368 | args.max_num_batched_tokens = max(args.max_num_batched_tokens or model.max_length, model.max_length) 369 | args.max_model_len = model.max_length 370 | 371 | # Load model engine 372 | engine_args = AsyncEngineArgs.from_cli_args(args) 373 | engine = AsyncLLMEngine.from_engine_args(engine_args) 374 | engine_model_config = asyncio.run(engine.get_model_config()) 375 | 376 | # Run 377 | uvicorn.run(app, 378 | host=args.host, 379 | port=args.port, 380 | log_level="info", 381 | access_log=False, 382 | timeout_keep_alive=TIMEOUT_KEEP_ALIVE) 383 | -------------------------------------------------------------------------------- /ochat/training_deepspeed/deepspeed_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": true 4 | }, 5 | 6 | "zero_optimization": { 7 | "stage": 2 8 | }, 9 | 10 | "gradient_clipping": 1.0, 11 | "gradient_accumulation_steps": 1, 12 | "train_micro_batch_size_per_gpu": 1, 13 | 14 | "steps_per_print": 100, 15 | "wall_clock_breakdown": false 16 | } -------------------------------------------------------------------------------- /ochat/training_deepspeed/hf_hub.py: -------------------------------------------------------------------------------- 1 | import shlex 2 | import subprocess 3 | 4 | from huggingface_hub import HfApi 5 | 6 | 7 | def hub_upload_check(push_to_hub: str): 8 | if push_to_hub is not None: 9 | # Try creating a test repo 10 | test_repo_name = f"{push_to_hub}-dummy-test" 11 | try: 12 | HfApi().create_repo( 13 | repo_id=test_repo_name, 14 | repo_type="model", 15 | private=True 16 | ) 17 | HfApi().delete_repo( 18 | repo_id=test_repo_name, 19 | repo_type="model" 20 | ) 21 | except Exception as e: 22 | raise RuntimeError(f"Failed to push test repo {test_repo_name} to HuggingFace Hub. Please check your permissions and network connection. Use `huggingface-cli login` to log in your account.\n\n{str(e)}") 23 | 24 | 25 | def hub_upload_model_async(push_to_hub: str, push_to_hub_delete_local: bool, save_path: str, epoch: int): 26 | if push_to_hub is not None: 27 | safe_repo_name = shlex.quote(f"{push_to_hub}-ep-{epoch}") 28 | safe_save_path = shlex.quote(save_path) 29 | 30 | command = f"huggingface-cli upload --quiet --repo-type model --private {safe_repo_name} {safe_save_path}" 31 | if push_to_hub_delete_local: 32 | command += f" && rm -rf {safe_save_path}" 33 | 34 | subprocess.Popen(command, shell=True) 35 | -------------------------------------------------------------------------------- /ochat/training_deepspeed/multipack_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numba 3 | 4 | 5 | @numba.njit 6 | def ffd_check(a: np.ndarray, c: int, n: int): 7 | # First-fit-decreasing bin packing 8 | # Check if a[] could fit in n bins with capacity c 9 | # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing 10 | 11 | a = np.sort(a)[::-1] 12 | bins = np.full((n, ), c, dtype=a.dtype) 13 | for size in a: 14 | not_found = True 15 | for idx in range(n): 16 | if bins[idx] >= size: 17 | bins[idx] -= size 18 | not_found = False 19 | break 20 | 21 | if not_found: 22 | return False 23 | 24 | return True 25 | 26 | 27 | @numba.njit 28 | def ffd_with_result(a: np.ndarray, c: int, start_index: int): 29 | # First-fit-decreasing bin packing (with result return) 30 | 31 | indices = np.argsort(a)[::-1] 32 | a = a[indices] 33 | 34 | bins = [] 35 | bins_result = [] 36 | for a_id, size in enumerate(a): 37 | add_new = True 38 | for idx in range(len(bins)): 39 | if bins[idx] >= size: 40 | bins[idx] -= size 41 | bins_result[idx].append(indices[a_id] + start_index) 42 | add_new = False 43 | break 44 | 45 | if add_new: 46 | bins.append(c - size) 47 | bins_result.append([indices[a_id] + start_index]) 48 | 49 | return bins_result 50 | 51 | 52 | @numba.njit 53 | def allocate(lengths: np.ndarray, numseqs: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int): 54 | # Dynamic batch allocator, similar to Multifit 55 | # https://en.wikipedia.org/wiki/Multifit_algorithm 56 | # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) 57 | 58 | s = 0 59 | start_index = 0 60 | result = [] 61 | result_totseqs = [] 62 | 63 | while True: 64 | # binary search [l, r) 65 | l = 1 66 | r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") 67 | 68 | while r - l > 1: 69 | m = (l + r) // 2 70 | if ffd_check(lengths[start_index: start_index + m], c, n): 71 | l = m 72 | else: 73 | r = m 74 | 75 | # use length l 76 | batch = ffd_with_result(lengths[start_index: start_index + l], c, start_index) # type: ignore 77 | if len(batch) < n: 78 | break 79 | 80 | start_index += l 81 | s = lengths_cumsum[start_index - 1] 82 | 83 | # add local rank 84 | result.append(batch[rank]) 85 | # add total seqs for all ranks 86 | totseq = 0 87 | for indices in batch: 88 | for idx in indices: 89 | totseq += numseqs[idx] 90 | result_totseqs.append(totseq) 91 | 92 | return result, result_totseqs, s, len(result) * c * n 93 | 94 | 95 | class MultipackDistributedSampler: 96 | """Unpadded data loading using Multipack. 97 | Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.""" 98 | 99 | def __init__( 100 | self, 101 | lengths: np.ndarray, 102 | numseqs: np.ndarray, 103 | 104 | batch_max_length: int, 105 | 106 | num_replicas: int, 107 | rank: int, 108 | 109 | seed: int, 110 | ): 111 | # Dataset 112 | self.lengths = lengths 113 | self.numseqs = numseqs 114 | assert isinstance(self.lengths, np.ndarray) 115 | 116 | self.batch_max_length = batch_max_length 117 | 118 | # Get rank 119 | self.num_replicas = num_replicas 120 | self.rank = rank 121 | 122 | # Seed 123 | self.seed = seed 124 | 125 | # statistics 126 | self.eff_total_used = 0 127 | self.eff_total_slots = 0 128 | 129 | def generate_batches(self, epoch, set_stats=False): 130 | indices = np.random.default_rng(seed=self.seed + epoch).permutation(len(self.lengths)) 131 | 132 | lengths = self.lengths[indices] 133 | numseqs = self.numseqs[indices] 134 | lengths_cumsum = np.cumsum(lengths) 135 | 136 | batches, totseqs, total_used, total_slots = allocate(lengths=lengths, 137 | numseqs=numseqs, 138 | lengths_cumsum=lengths_cumsum, 139 | rank=self.rank, 140 | c=self.batch_max_length, 141 | n=self.num_replicas) 142 | 143 | curseqs = [np.sum(numseqs[batch]) for batch in batches] 144 | batches = [indices[batch] for batch in batches] 145 | 146 | # statistics 147 | if set_stats: 148 | self.eff_total_used += total_used 149 | self.eff_total_slots += total_slots 150 | 151 | return batches, totseqs, curseqs 152 | 153 | def iter(self, epoch): 154 | all_batches, all_totseqs, all_curseqs = self.generate_batches(epoch, set_stats=True) 155 | 156 | for batch, totseq, curseq in zip(all_batches, all_totseqs, all_curseqs): 157 | yield batch, totseq, curseq 158 | 159 | def estimate_num_batches(self): 160 | batches, _, _ = self.generate_batches(epoch=0) 161 | return len(batches) 162 | 163 | def efficiency(self): 164 | return self.eff_total_used / self.eff_total_slots 165 | -------------------------------------------------------------------------------- /ochat/training_deepspeed/openchat_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.utils.data import IterableDataset, get_worker_info 4 | 5 | import pyarrow.parquet as pq 6 | import orjson 7 | 8 | from ochat.training_deepspeed.multipack_sampler import MultipackDistributedSampler 9 | 10 | 11 | def _find_multiple(a, b): 12 | return (-(a // -b)) * b 13 | 14 | 15 | class OpenchatDataset(IterableDataset): 16 | def __init__(self, dataset_filename, batch_max_length, rank, num_replicas): 17 | super().__init__() 18 | # Init constants 19 | self.PAD_ID = 0 20 | self.PAD_MULTIPLE = 64 21 | self.BATCH_KEYS = { 22 | "seqlens": torch.int32, 23 | "nz_input_ids": torch.long, 24 | "nz_position_ids": torch.long, 25 | "nz_shifted_label_ids": torch.long, 26 | 27 | "nz_shifted_loss_weights": torch.bfloat16 28 | } 29 | 30 | assert batch_max_length % self.PAD_MULTIPLE == 0, f"Batch size {batch_max_length} need to be multiples of {self.PAD_MULTIPLE}" 31 | 32 | # Load data 33 | # Convert parquet to numpy for fast random access 34 | table = pq.read_table(dataset_filename, memory_map=True) 35 | self.dataset = {k: v.to_numpy() for k, v in zip(table.column_names, table.columns)} 36 | 37 | # read metadata 38 | self.metadata = table.schema.metadata.get(b"metadata_json", None) 39 | if self.metadata is not None: 40 | self.metadata = orjson.loads(self.metadata) 41 | 42 | # Free table space 43 | del table 44 | 45 | # Create sampler 46 | self.sampler = MultipackDistributedSampler( 47 | lengths=self.dataset["total_length"], 48 | numseqs=self.dataset["num_seqs"], 49 | 50 | batch_max_length=batch_max_length, 51 | 52 | rank=rank, 53 | num_replicas=num_replicas, 54 | seed=0 55 | ) 56 | 57 | # Init state 58 | self._epoch = 0 59 | 60 | def _load_batch(self, indices): 61 | batch = {k: v[indices] for k, v in self.dataset.items()} 62 | 63 | # Concat batches 64 | batch = {k: np.concatenate(batch[k], axis=0) for k in self.BATCH_KEYS.keys()} 65 | 66 | # Pad an unused item to reach multiple of PAD_MULTIPLE, for faster GEMM 67 | total_seqlen = batch["nz_input_ids"].size 68 | pad_len = _find_multiple(total_seqlen, self.PAD_MULTIPLE) - total_seqlen 69 | 70 | if pad_len > 0: 71 | assert pad_len < self.PAD_MULTIPLE 72 | 73 | # total length 74 | padding_specs = { 75 | "seqlens": (1, pad_len), 76 | 77 | "nz_input_ids": (pad_len, self.PAD_ID), 78 | "nz_position_ids": (pad_len, 0), 79 | "nz_shifted_label_ids": (pad_len, self.PAD_ID), 80 | "nz_shifted_loss_weights": (pad_len, 0), 81 | } 82 | for k, pad_spec in padding_specs.items(): 83 | batch[k] = np.concatenate((batch[k], np.full(*pad_spec, dtype=batch[k].dtype)), axis=0) 84 | 85 | # to tensor 86 | batch_tensor = {} 87 | for k, dtype in self.BATCH_KEYS.items(): 88 | batch_tensor[k] = torch.from_numpy(batch[k]).to(dtype) 89 | 90 | # cu seqlens 91 | batch_tensor["cu_seqlens"] = torch.nn.functional.pad(batch_tensor["seqlens"].cumsum(-1, dtype=torch.int32), (1, 0)) 92 | # batch info 93 | batch_info = {"max_seqlen": torch.max(batch_tensor["seqlens"]).item()} 94 | 95 | # inputs 96 | del batch_tensor["seqlens"] 97 | return batch_tensor, batch_info 98 | 99 | def __iter__(self): 100 | worker_info = get_worker_info() 101 | assert worker_info is None or worker_info.num_workers == 1 102 | 103 | for indices, all_numseq, cur_numseq in self.sampler.iter(self._epoch): 104 | yield self._load_batch(indices), all_numseq, cur_numseq 105 | 106 | # Increase epoch count 107 | self._epoch += 1 108 | 109 | def estimate_num_batches(self): 110 | return self.sampler.estimate_num_batches() 111 | -------------------------------------------------------------------------------- /ochat/training_deepspeed/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import math 4 | import json 5 | from functools import partial 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torch.utils.data import DataLoader 10 | 11 | import tqdm 12 | import wandb 13 | import numpy as np 14 | 15 | from ochat.config import MODEL_CONFIG_MAP 16 | from ochat.training_deepspeed.openchat_dataset import OpenchatDataset 17 | from ochat.training_deepspeed.hf_hub import hub_upload_check, hub_upload_model_async 18 | 19 | try: 20 | import deepspeed 21 | except ImportError: 22 | raise ImportError("Please install deepspeed to train models.") 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser() 27 | # Distributed 28 | parser.add_argument("--local_rank", type=int, required=True) 29 | 30 | # Model type and data 31 | parser.add_argument("--model_path", type=str, required=True) 32 | parser.add_argument("--data_prefix", type=str, required=True) 33 | parser.add_argument("--save_path", type=str, required=True) 34 | parser.add_argument("--save_every", type=int, default=None) 35 | parser.add_argument("--push_to_hub", type=str, default=None, 36 | help="Specify repository prefix for pushing to HuggingFace Hub. " 37 | "For example, 'openchat/openchat-3.6' will create repositories " 38 | "like 'openchat/openchat-3.6-ep0', 'openchat/openchat-3.6-ep1', ..." 39 | "If not specified, will not push to Hub.") 40 | parser.add_argument("--push_to_hub_delete_local", action="store_true") 41 | 42 | # Hyperparameters 43 | parser.add_argument("--batch_max_len", type=int, default=81920) 44 | parser.add_argument("--epochs", type=int, default=5) 45 | 46 | # Set lr to None to automatically estimate from LLaMA pretraining parameters (e.g. lr ~ sqrt(batch_size)) 47 | parser.add_argument("--lr", type=float, default=None) 48 | parser.add_argument("--lr_min_ratio", type=float, default=0.1) 49 | parser.add_argument("--lr_warmup_ratio", type=int, default=0.05) 50 | 51 | parser.add_argument("--weight_decay", type=float, default=0.1) 52 | 53 | parser.add_argument("--beta1", type=float, default=0.9) 54 | parser.add_argument("--beta2", type=float, default=0.95) 55 | parser.add_argument("--eps", type=float, default=1e-5) 56 | 57 | parser.add_argument("--wandb_entity", type=str, default=None) 58 | parser.add_argument("--wandb_project", type=str, default=None) 59 | 60 | # DeepSpeed parameters 61 | parser = deepspeed.add_config_arguments(parser) 62 | 63 | # Parse known args 64 | args, unknown = parser.parse_known_args() 65 | return args 66 | 67 | 68 | def create_dataset_and_dataloader(args, epoch: int): 69 | # Find data 70 | filename = f"{args.data_prefix}.{epoch}.parquet" 71 | 72 | # Create dataset and dataloader 73 | print(f"Loading epoch {epoch} data from {filename}...") 74 | 75 | dataset = OpenchatDataset( 76 | dataset_filename=filename, 77 | 78 | batch_max_length=args.batch_max_len, 79 | rank=dist.get_rank(), 80 | num_replicas=dist.get_world_size() 81 | ) 82 | dataloader = DataLoader( 83 | dataset, 84 | batch_size=None, 85 | 86 | num_workers=1, 87 | prefetch_factor=8, 88 | 89 | pin_memory=True 90 | ) 91 | return dataset, dataloader 92 | 93 | 94 | def create_model(args): 95 | print(f"Loading model {args.model_type} from {args.model_path}...") 96 | 97 | # Create model + optimizer + lr scheduler 98 | model = MODEL_CONFIG_MAP[args.model_type].model_create_for_training(args.model_path) 99 | # Model to assigned cuda device 100 | model = model.to(args.local_rank) 101 | # Enable gradient checkpointing 102 | model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict( 103 | use_reentrant=False 104 | )) 105 | 106 | # Optimizer 107 | optimizer = torch.optim.AdamW(model.parameters(), 108 | lr=args.lr, 109 | weight_decay=args.weight_decay, 110 | betas=(args.beta1, args.beta2), 111 | eps=args.eps, 112 | fused=True) 113 | 114 | # DeepSpeed model 115 | model_engine, optimizer, _, _ = deepspeed.initialize(args=args, 116 | model=model, 117 | model_parameters=model.parameters(), 118 | optimizer=optimizer) 119 | 120 | # Put deepspeed arguments 121 | args.device = model_engine.device 122 | 123 | return model_engine, optimizer 124 | 125 | 126 | def cosine_schedule_with_warmup_lr_lambda( 127 | current_step: int, *, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5 128 | ): 129 | if current_step < num_warmup_steps: 130 | return float(current_step) / float(max(1, num_warmup_steps)) 131 | 132 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 133 | return min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 134 | 135 | 136 | def create_lr_scheduler(args, train_total_steps): 137 | lr_scheduler = partial( 138 | cosine_schedule_with_warmup_lr_lambda, 139 | 140 | num_warmup_steps=round(args.lr_warmup_ratio * train_total_steps), 141 | num_training_steps=train_total_steps, 142 | min_ratio=args.lr_min_ratio 143 | ) 144 | 145 | return lr_scheduler 146 | 147 | 148 | def save_tokenizer(args, save_path): 149 | model_config = MODEL_CONFIG_MAP[args.model_type] 150 | tokenizer = model_config.model_tokenizer_create(args.model_path) 151 | tokenizer.chat_template = model_config.hf_chat_template 152 | tokenizer.save_pretrained(save_path) 153 | 154 | 155 | def save_openchat_metadata(args, epoch, save_path): 156 | metadata = vars(args) 157 | metadata["epoch"] = epoch 158 | 159 | with open(os.path.join(save_path, "openchat.json"), "w") as f: 160 | json.dump(metadata, f, default=lambda o: "") 161 | 162 | 163 | def calculate_auto_lr(lr, batch_max_len, model_type, train_dataset): 164 | if lr is not None: 165 | return lr 166 | 167 | # Llama hyperparameters 168 | # FIXME: Only 7B/13B is supported 169 | base_lr = 3e-4 170 | base_bs = 4_000_000 171 | if "mistral" in model_type.lower(): 172 | base_lr /= 6.0 173 | elif "gemma" in model_type.lower(): 174 | base_lr /= 5.5 # NOTE(one): Maybe MLP and Attn layers are using different lr? 175 | elif "openchat_3.6" in model_type.lower(): # Llama 3 estimated hyperparams 176 | # NOTE(one): Estimated divisor: 1.5 * sqrt(25000 H100s / 2000 H100s) 177 | base_lr /= 5.3 178 | 179 | loss_weights = np.concatenate(train_dataset.dataset["nz_shifted_loss_weights"]) 180 | supervised_ratio = np.sum(loss_weights != 0) / len(loss_weights) 181 | 182 | supervised_tokens = batch_max_len * dist.get_world_size() * supervised_ratio 183 | lr = base_lr * math.sqrt(supervised_tokens / base_bs) 184 | 185 | print(f"Use automatic learning rate {lr} (estimated from supervised ratio {supervised_ratio} effective batch size {supervised_tokens})") 186 | return lr 187 | 188 | 189 | def state_dict_to_cpu(item, device=torch.device('cpu')): 190 | # Move all tensors to CPU 191 | if torch.is_tensor(item): 192 | return item.detach().to(device) 193 | elif isinstance(item, list): 194 | return [state_dict_to_cpu(v, device) for v in item] 195 | elif isinstance(item, tuple): 196 | return tuple([state_dict_to_cpu(v, device) for v in item]) 197 | elif isinstance(item, dict): 198 | return type(item)({k: state_dict_to_cpu(v, device) for k, v in item.items()}) 199 | else: 200 | return item 201 | 202 | 203 | def train(): 204 | deepspeed.init_distributed(dist_backend="nccl") 205 | RANK = dist.get_rank() 206 | WORLD_SIZE = dist.get_world_size() 207 | 208 | # Args 209 | args = parse_args() 210 | 211 | hub_upload_check(args.push_to_hub) 212 | 213 | # Dataset 214 | train_dataset, train_loader = create_dataset_and_dataloader(args, 0) 215 | 216 | if train_dataset is None: 217 | raise RuntimeError("Training data not found.") 218 | 219 | # Load model type 220 | args.model_type = train_dataset.metadata["model_type"] 221 | 222 | train_total_steps = args.epochs * train_dataset.estimate_num_batches() 223 | 224 | # Hyperparams 225 | args.lr = calculate_auto_lr(args.lr, args.batch_max_len, args.model_type, train_dataset) 226 | 227 | # Model 228 | model_engine, optimizer = create_model(args) 229 | 230 | # LR Scheduler 231 | lr_scheduler = create_lr_scheduler(args, train_total_steps) 232 | 233 | # Progress bar and logger 234 | progress_bar = None 235 | if RANK == 0: 236 | progress_bar = tqdm.tqdm(total=train_total_steps) 237 | 238 | wandb.init(project=args.wandb_project or os.path.basename(args.model_path), entity=args.wandb_entity, config=args) 239 | 240 | # Training Loop 241 | step = 0 242 | lr_this_step = None 243 | for epoch in range(args.epochs): 244 | print (f"[rank {RANK} of {WORLD_SIZE}]: Epoch {epoch}") 245 | 246 | ############ Load Dataset 247 | if epoch != 0: 248 | del train_dataset, train_loader 249 | 250 | train_dataset, train_loader = create_dataset_and_dataloader(args, epoch) 251 | 252 | ############ Train Epoch 253 | model_engine.train() 254 | for (batch_tensor, batch_info), all_numseq, cur_numseq in train_loader: 255 | step += 1 256 | if step > train_total_steps: # At most train_total_steps 257 | break 258 | 259 | # To device 260 | batch_tensor = {k: (v.to(args.device) if v is not None else None) for k, v in batch_tensor.items()} 261 | 262 | # Update 263 | loss, acc = model_engine(**batch_tensor, **batch_info).loss 264 | loss = (WORLD_SIZE / all_numseq) * loss 265 | acc = (WORLD_SIZE / all_numseq) * acc 266 | 267 | model_engine.backward(loss) 268 | 269 | if model_engine.is_gradient_accumulation_boundary(): 270 | # Set LR 271 | lr_this_step = args.lr * lr_scheduler(step) 272 | for param_group in optimizer.param_groups: 273 | param_group['lr'] = lr_this_step 274 | 275 | model_engine.step() 276 | 277 | # Logging 278 | if RANK == 0: 279 | wandb.log({ 280 | "train/loss": loss.item() * (all_numseq / (cur_numseq * WORLD_SIZE)), 281 | "train/acc": acc.item() * (all_numseq / (cur_numseq * WORLD_SIZE)), 282 | "train/lr": lr_this_step 283 | }, step=step) 284 | progress_bar.update() # type: ignore 285 | 286 | ############ Save Checkpoint 287 | # Save model with lean state dict 288 | # https://deepspeed.readthedocs.io/en/latest/model-checkpointing.html 289 | if (epoch + 1 == args.epochs) or (args.save_every and ((epoch + 1) % args.save_every == 0)): 290 | if RANK == 0: 291 | save_path = os.path.join(args.save_path, f"ep_{epoch}") 292 | 293 | model_engine.module.save_pretrained(save_path, 294 | state_dict=state_dict_to_cpu(model_engine.module.state_dict())) # type: ignore 295 | 296 | # Also save tokenizer from base model 297 | save_tokenizer(args, save_path) 298 | 299 | # Write metadata 300 | save_openchat_metadata(args, epoch, save_path) 301 | 302 | # Upload to hub 303 | hub_upload_model_async( 304 | args.push_to_hub, 305 | args.push_to_hub_delete_local, 306 | save_path, 307 | epoch 308 | ) 309 | 310 | 311 | if __name__ == "__main__": 312 | train() 313 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "ochat" 7 | description = "An efficient framework for training and serving top-tier, open-source conversational LLMs." 8 | readme = "README.md" 9 | requires-python = ">=3.8" 10 | dynamic = ["version"] 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "colorama", 17 | "beautifulsoup4", 18 | "markdownify", 19 | "pylatexenc", 20 | "sympy", 21 | "openai>=1", 22 | "tenacity", 23 | "tiktoken", 24 | "tqdm", 25 | "wandb", 26 | "numba", 27 | "datasets", 28 | "orjson", 29 | "torch", 30 | "packaging", 31 | "ninja", 32 | "flash-attn", 33 | "ray", 34 | "sentencepiece", 35 | "transformers>=4.40.1", 36 | "accelerate", 37 | "protobuf", 38 | "fastapi", 39 | "pydantic", 40 | "shortuuid", 41 | "uvicorn", 42 | "vllm>=0.4.0", 43 | "pytest" 44 | ] 45 | 46 | [project.urls] 47 | "Homepage" = "https://github.com/imoneoi/openchat" 48 | "Bug Tracker" = "https://github.com/imoneoi/openchat/issues" 49 | 50 | [tool.setuptools.packages.find] 51 | exclude = ["assets*", "ochat/experimental*"] 52 | 53 | [tool.wheel] 54 | exclude = ["assets*", "ochat/experimental*"] 55 | 56 | [tool.setuptools_scm] -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | cpu: CPU Tests 4 | gpu: GPU Tests 5 | --------------------------------------------------------------------------------