├── .github
└── workflows
│ ├── docker_build_and_push.yaml
│ └── pypi_publish.yaml
├── .gitignore
├── LICENSE
├── README.md
├── SECURITY.md
├── assets
├── benchmarks-openchat-3.6-20240522.svg
├── embeddings.svg
├── logo_new.png
├── openchat-3.6-20240522.png
├── openchat-bench-0106.png
├── openchat.png
├── openchat_grok.png
├── vicuna_gpt35.svg
└── vicuna_gpt4.svg
├── docker
└── serving
│ ├── Dockerfile
│ └── start.sh
├── ochat
├── __init__.py
├── config
│ ├── __init__.py
│ ├── conversation_template.py
│ └── model_config.py
├── data
│ └── generate_dataset.py
├── evaluation
│ ├── README.md
│ ├── conv_eval.py
│ ├── convert_to_evalplus.py
│ ├── eval_data
│ │ ├── coding
│ │ │ └── humaneval
│ │ │ │ └── humaneval.jsonl
│ │ ├── fs_cothub
│ │ │ ├── bbh
│ │ │ │ ├── boolean_expressions.jsonl
│ │ │ │ ├── causal_judgement.jsonl
│ │ │ │ ├── date_understanding.jsonl
│ │ │ │ ├── disambiguation_qa.jsonl
│ │ │ │ ├── dyck_languages.jsonl
│ │ │ │ ├── formal_fallacies.jsonl
│ │ │ │ ├── geometric_shapes.jsonl
│ │ │ │ ├── hyperbaton.jsonl
│ │ │ │ ├── logical_deduction_five_objects.jsonl
│ │ │ │ ├── logical_deduction_seven_objects.jsonl
│ │ │ │ ├── logical_deduction_three_objects.jsonl
│ │ │ │ ├── movie_recommendation.jsonl
│ │ │ │ ├── multistep_arithmetic_two.jsonl
│ │ │ │ ├── navigate.jsonl
│ │ │ │ ├── object_counting.jsonl
│ │ │ │ ├── penguins_in_a_table.jsonl
│ │ │ │ ├── reasoning_about_colored_objects.jsonl
│ │ │ │ ├── ruin_names.jsonl
│ │ │ │ ├── salient_translation_error_detection.jsonl
│ │ │ │ ├── snarks.jsonl
│ │ │ │ ├── sports_understanding.jsonl
│ │ │ │ ├── temporal_sequences.jsonl
│ │ │ │ ├── tracking_shuffled_objects_five_objects.jsonl
│ │ │ │ ├── tracking_shuffled_objects_seven_objects.jsonl
│ │ │ │ ├── tracking_shuffled_objects_three_objects.jsonl
│ │ │ │ ├── web_of_lies.jsonl
│ │ │ │ └── word_sorting.jsonl
│ │ │ ├── gsm8k
│ │ │ │ └── gsm8k.jsonl
│ │ │ ├── math
│ │ │ │ └── MATH.jsonl
│ │ │ └── mmlu
│ │ │ │ ├── abstract_algebra.jsonl
│ │ │ │ ├── anatomy.jsonl
│ │ │ │ ├── astronomy.jsonl
│ │ │ │ ├── business_ethics.jsonl
│ │ │ │ ├── clinical_knowledge.jsonl
│ │ │ │ ├── college_biology.jsonl
│ │ │ │ ├── college_chemistry.jsonl
│ │ │ │ ├── college_computer_science.jsonl
│ │ │ │ ├── college_mathematics.jsonl
│ │ │ │ ├── college_medicine.jsonl
│ │ │ │ ├── college_physics.jsonl
│ │ │ │ ├── computer_security.jsonl
│ │ │ │ ├── conceptual_physics.jsonl
│ │ │ │ ├── econometrics.jsonl
│ │ │ │ ├── electrical_engineering.jsonl
│ │ │ │ ├── elementary_mathematics.jsonl
│ │ │ │ ├── formal_logic.jsonl
│ │ │ │ ├── global_facts.jsonl
│ │ │ │ ├── high_school_biology.jsonl
│ │ │ │ ├── high_school_chemistry.jsonl
│ │ │ │ ├── high_school_computer_science.jsonl
│ │ │ │ ├── high_school_european_history.jsonl
│ │ │ │ ├── high_school_geography.jsonl
│ │ │ │ ├── high_school_government_and_politics.jsonl
│ │ │ │ ├── high_school_macroeconomics.jsonl
│ │ │ │ ├── high_school_mathematics.jsonl
│ │ │ │ ├── high_school_microeconomics.jsonl
│ │ │ │ ├── high_school_physics.jsonl
│ │ │ │ ├── high_school_psychology.jsonl
│ │ │ │ ├── high_school_statistics.jsonl
│ │ │ │ ├── high_school_us_history.jsonl
│ │ │ │ ├── high_school_world_history.jsonl
│ │ │ │ ├── human_aging.jsonl
│ │ │ │ ├── human_sexuality.jsonl
│ │ │ │ ├── international_law.jsonl
│ │ │ │ ├── jurisprudence.jsonl
│ │ │ │ ├── logical_fallacies.jsonl
│ │ │ │ ├── machine_learning.jsonl
│ │ │ │ ├── management.jsonl
│ │ │ │ ├── marketing.jsonl
│ │ │ │ ├── medical_genetics.jsonl
│ │ │ │ ├── miscellaneous.jsonl
│ │ │ │ ├── moral_disputes.jsonl
│ │ │ │ ├── moral_scenarios.jsonl
│ │ │ │ ├── nutrition.jsonl
│ │ │ │ ├── philosophy.jsonl
│ │ │ │ ├── prehistory.jsonl
│ │ │ │ ├── professional_accounting.jsonl
│ │ │ │ ├── professional_law.jsonl
│ │ │ │ ├── professional_medicine.jsonl
│ │ │ │ ├── professional_psychology.jsonl
│ │ │ │ ├── public_relations.jsonl
│ │ │ │ ├── security_studies.jsonl
│ │ │ │ ├── sociology.jsonl
│ │ │ │ ├── us_foreign_policy.jsonl
│ │ │ │ ├── virology.jsonl
│ │ │ │ └── world_religions.jsonl
│ │ └── zs
│ │ │ ├── agieval
│ │ │ ├── aqua-rat.zero-shot.jsonl
│ │ │ ├── logiqa-en.zero-shot.jsonl
│ │ │ ├── lsat-ar.zero-shot.jsonl
│ │ │ ├── lsat-lr.zero-shot.jsonl
│ │ │ ├── lsat-rc.zero-shot.jsonl
│ │ │ ├── sat-en-without-passage.zero-shot.jsonl
│ │ │ ├── sat-en.zero-shot.jsonl
│ │ │ └── sat-math.zero-shot.jsonl
│ │ │ ├── bbh_mc_orca
│ │ │ ├── boolean_expressions.jsonl
│ │ │ ├── causal_judgment.jsonl
│ │ │ ├── date_understanding.jsonl
│ │ │ ├── disambiguation_qa.jsonl
│ │ │ ├── formal_fallacies_syllogisms_negation.jsonl
│ │ │ ├── geometric_shapes.jsonl
│ │ │ ├── hyperbaton.jsonl
│ │ │ ├── logical_deduction_five_objects.jsonl
│ │ │ ├── logical_deduction_seven_objects.jsonl
│ │ │ ├── logical_deduction_three_objects.jsonl
│ │ │ ├── movie_recommendation.jsonl
│ │ │ ├── navigate.jsonl
│ │ │ ├── penguins_in_a_table.jsonl
│ │ │ ├── reasoning_about_colored_objects.jsonl
│ │ │ ├── ruin_names.jsonl
│ │ │ ├── salient_translation_error_detection.jsonl
│ │ │ ├── snarks.jsonl
│ │ │ ├── sports_understanding.jsonl
│ │ │ ├── temporal_sequences.jsonl
│ │ │ ├── tracking_shuffled_objects_five_objects.jsonl
│ │ │ ├── tracking_shuffled_objects_seven_objects.jsonl
│ │ │ ├── tracking_shuffled_objects_three_objects.jsonl
│ │ │ └── web_of_lies.jsonl
│ │ │ ├── gpqa
│ │ │ └── diamond.jsonl
│ │ │ └── truthfulqa_orca
│ │ │ └── truthfulqa_mc.jsonl
│ ├── grading
│ │ ├── math_grader.py
│ │ └── math_normalize.py
│ ├── match_answer.py
│ ├── run_eval.py
│ └── view_results.py
├── experimental
│ ├── generate_dataset_old.py
│ ├── sharegpt.ipynb
│ ├── test_multipack_dataloader.ipynb
│ ├── text_length.ipynb
│ ├── train_alpaca.py
│ ├── verify_dataset.ipynb
│ └── verify_dataset_orca.ipynb
├── models
│ ├── __init__.py
│ ├── unpadded_gemma.py
│ ├── unpadded_llama.py
│ └── unpadded_mistral.py
├── scripts
│ ├── hf_add_tokens.py
│ ├── init_special_embedding_llama3.py
│ └── modify_eos_embedding.py
├── serving
│ ├── async_tokenizer.py
│ ├── openai_api_protocol.py
│ └── openai_api_server.py
├── tests
│ └── test_model_config.py
└── training_deepspeed
│ ├── deepspeed_config.json
│ ├── hf_hub.py
│ ├── multipack_sampler.py
│ ├── openchat_dataset.py
│ └── train.py
├── pyproject.toml
└── pytest.ini
/.github/workflows/docker_build_and_push.yaml:
--------------------------------------------------------------------------------
1 | name: Docker Build and Push
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Publish to PyPI.org"]
6 | types:
7 | - completed
8 |
9 | jobs:
10 | docker:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - name: Checkout
14 | uses: actions/checkout@v4
15 |
16 | - name: Set up QEMU
17 | uses: docker/setup-qemu-action@v3
18 |
19 | - name: Set up Docker Buildx
20 | uses: docker/setup-buildx-action@v3
21 |
22 | - name: Login to Docker Hub
23 | uses: docker/login-action@v3
24 | with:
25 | username: ${{ secrets.DOCKERHUB_USERNAME }}
26 | password: ${{ secrets.DOCKERHUB_TOKEN }}
27 |
28 | - name: Build and push
29 | uses: docker/build-push-action@v5
30 | with:
31 | context: ./docker/serving
32 | platforms: linux/amd64
33 | push: true
34 | tags: ${{ secrets.DOCKERHUB_TAG }}
35 |
--------------------------------------------------------------------------------
/.github/workflows/pypi_publish.yaml:
--------------------------------------------------------------------------------
1 | name: Publish to PyPI.org
2 | on:
3 | release:
4 | types: [published]
5 | jobs:
6 | pypi:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - name: Checkout
10 | uses: actions/checkout@v3
11 | with:
12 | fetch-depth: 0
13 | - run: python3 -m pip install --upgrade build && python3 -m build
14 | - name: Publish package
15 | uses: pypa/gh-action-pypi-publish@release/v1
16 | with:
17 | password: ${{ secrets.PYPI_API_TOKEN }}
18 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # VSCode
2 | .vscode/
3 |
4 | # WandB
5 | wandb/
6 |
7 | # Old
8 | old/
9 | temp/
10 | profiler/
11 |
12 | # Logs
13 | logs/
14 |
15 | # eval
16 | eval_results/
17 | evalplus_codegen/
18 |
19 | # All datasets
20 | dataset/
21 | dataset_processed/
22 | dataset_processed_*/
23 | tokenizer/
24 |
25 | # All evaluation results
26 | eval_baselines/
27 | eval_results/
28 | eval_results_temp/
29 |
30 | # Byte-compiled / optimized / DLL files
31 | __pycache__/
32 | *.py[cod]
33 | *$py.class
34 |
35 | # C extensions
36 | *.so
37 |
38 | # Distribution / packaging
39 | .Python
40 | build/
41 | develop-eggs/
42 | dist/
43 | downloads/
44 | eggs/
45 | .eggs/
46 | lib/
47 | lib64/
48 | parts/
49 | sdist/
50 | var/
51 | wheels/
52 | share/python-wheels/
53 | *.egg-info/
54 | .installed.cfg
55 | *.egg
56 | MANIFEST
57 |
58 | # PyInstaller
59 | # Usually these files are written by a python script from a template
60 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
61 | *.manifest
62 | *.spec
63 |
64 | # Installer logs
65 | pip-log.txt
66 | pip-delete-this-directory.txt
67 |
68 | # Unit test / coverage reports
69 | htmlcov/
70 | .tox/
71 | .nox/
72 | .coverage
73 | .coverage.*
74 | .cache
75 | nosetests.xml
76 | coverage.xml
77 | *.cover
78 | *.py,cover
79 | .hypothesis/
80 | .pytest_cache/
81 | cover/
82 |
83 | # Translations
84 | *.mo
85 | *.pot
86 |
87 | # Django stuff:
88 | *.log
89 | local_settings.py
90 | db.sqlite3
91 | db.sqlite3-journal
92 |
93 | # Flask stuff:
94 | instance/
95 | .webassets-cache
96 |
97 | # Scrapy stuff:
98 | .scrapy
99 |
100 | # Sphinx documentation
101 | docs/_build/
102 |
103 | # PyBuilder
104 | .pybuilder/
105 | target/
106 |
107 | # Jupyter Notebook
108 | .ipynb_checkpoints
109 |
110 | # IPython
111 | profile_default/
112 | ipython_config.py
113 |
114 | # pyenv
115 | # For a library or package, you might want to ignore these files since the code is
116 | # intended to run in multiple environments; otherwise, check them in:
117 | # .python-version
118 |
119 | # pipenv
120 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
121 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
122 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
123 | # install all needed dependencies.
124 | #Pipfile.lock
125 |
126 | # poetry
127 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
128 | # This is especially recommended for binary packages to ensure reproducibility, and is more
129 | # commonly ignored for libraries.
130 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
131 | #poetry.lock
132 |
133 | # pdm
134 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
135 | #pdm.lock
136 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
137 | # in version control.
138 | # https://pdm.fming.dev/#use-with-ide
139 | .pdm.toml
140 |
141 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
142 | __pypackages__/
143 |
144 | # Celery stuff
145 | celerybeat-schedule
146 | celerybeat.pid
147 |
148 | # SageMath parsed files
149 | *.sage.py
150 |
151 | # Environments
152 | .env
153 | .venv
154 | env/
155 | venv/
156 | ENV/
157 | env.bak/
158 | venv.bak/
159 |
160 | # Spyder project settings
161 | .spyderproject
162 | .spyproject
163 |
164 | # Rope project settings
165 | .ropeproject
166 |
167 | # mkdocs documentation
168 | /site
169 |
170 | # mypy
171 | .mypy_cache/
172 | .dmypy.json
173 | dmypy.json
174 |
175 | # Pyre type checker
176 | .pyre/
177 |
178 | # pytype static type analyzer
179 | .pytype/
180 |
181 | # Cython debug symbols
182 | cython_debug/
183 |
184 | # PyCharm
185 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
186 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
187 | # and can be added to the global gitignore or merged into this file. For a more nuclear
188 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
189 | #.idea/
190 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 |
3 | ## Supported Versions
4 |
5 | | Version | Supported |
6 | |---------|--------------------|
7 | | 3.x | :white_check_mark: |
8 | | < 3.0 | :x: |
9 |
10 | ## Reporting a Vulnerability
11 |
12 | We take security vulnerabilities in our open-source project seriously and appreciate responsible disclosure from users.
13 |
14 | If you believe you have found a security vulnerability in our project, please report it to us by creating a Github issue with the label "security" or "vulnerability". Please do not publicly disclose the vulnerability until it has been addressed by our project team.
15 |
16 | We will acknowledge receipt of your vulnerability report and will keep you informed of our progress in addressing the vulnerability. If you would like to communicate with us about the vulnerability, please email [imonenext at gmail dot com].
17 |
18 | We will not take legal action against users who report vulnerabilities in good faith and in accordance with this disclosure policy.
19 |
20 | Thank you for helping us keep our open-source project secure!
21 |
--------------------------------------------------------------------------------
/assets/logo_new.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/assets/logo_new.png
--------------------------------------------------------------------------------
/assets/openchat-3.6-20240522.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/assets/openchat-3.6-20240522.png
--------------------------------------------------------------------------------
/assets/openchat-bench-0106.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/assets/openchat-bench-0106.png
--------------------------------------------------------------------------------
/assets/openchat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/assets/openchat.png
--------------------------------------------------------------------------------
/assets/openchat_grok.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/assets/openchat_grok.png
--------------------------------------------------------------------------------
/assets/vicuna_gpt35.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/vicuna_gpt4.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docker/serving/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.11-slim-bookworm
2 |
3 | ######### Setup system
4 |
5 | RUN mkdir /workspace && mkdir /workspace/transformers_cache
6 | WORKDIR /workspace
7 |
8 | ENV HF_HOME /workspace/transformers_cache
9 |
10 | ######### Install system dependencies
11 | RUN apt update && apt install -y git bash curl wget libxml2
12 |
13 | # Install ssh server, remove all pre-generated ssh host keys, and disable password auth
14 | RUN apt install -y openssh-server && \
15 | rm -f /etc/ssh/ssh_host_* && \
16 | sed -i 's/#PasswordAuthentication yes/PasswordAuthentication no/g' /etc/ssh/sshd_config
17 |
18 | # Install CUDA (for FlashAttention 2)
19 | RUN wget -q --show-progress --progress=bar:force:noscroll -O cuda_installer https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda_12.1.1_530.30.02_linux.run && \
20 | chmod +x cuda_installer && \
21 | ./cuda_installer --silent --toolkit --override && \
22 | rm -f cuda_installer
23 |
24 | ######### Install OpenChat
25 | # Install OpenChat
26 | RUN pip3 install ninja packaging torch
27 | RUN pip3 install ochat
28 |
29 | ######### Install Cloudflared
30 | RUN wget -q --show-progress --progress=bar:force:noscroll -O /cloudflared https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 && chmod +x /cloudflared
31 |
32 | ######### Startup script
33 |
34 | COPY start.sh /start.sh
35 | ENTRYPOINT ["/start.sh"]
36 |
--------------------------------------------------------------------------------
/docker/serving/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # start ssh server
4 | if [ -n "$PUBLIC_KEY" ]; then
5 | mkdir -p ~/.ssh
6 | echo "$PUBLIC_KEY" >> ~/.ssh/authorized_keys
7 | chmod 700 -R ~/.ssh
8 |
9 | dpkg-reconfigure openssh-server # generate ssh keys
10 | service ssh start
11 | fi
12 |
13 | # start cloudflare tunnel
14 | if [ -n "$CLOUDFLARED_TUNNEL_ARGS" ]; then
15 | /cloudflared $CLOUDFLARED_TUNNEL_ARGS &
16 | fi
17 |
18 | # start openchat server
19 | python3 -m ochat.serving.openai_api_server --model $MODEL --host 127.0.0.1 --port 18888 --engine-use-ray --worker-use-ray --disable-log-requests --disable-log-stats $ARGS &
20 |
21 | wait
22 |
--------------------------------------------------------------------------------
/ochat/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/imoneoi/openchat/47a3596168ed90d8f948f63f458948c3db98e2b8/ochat/__init__.py
--------------------------------------------------------------------------------
/ochat/config/__init__.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import torch
4 | import transformers
5 |
6 | from ochat.config.model_config import ModelConfig
7 | from ochat.config.conversation_template import Message, Conversation, ConversationTemplate
8 | import ochat.models
9 |
10 |
11 | _GEMMA_IT_PREFIXES = {
12 | "user": "user",
13 | "assistant": "model"
14 | }
15 |
16 |
17 | def _v3_2_role_prefix(from_role, condition):
18 | return f"{condition} {from_role.title()}:".strip()
19 |
20 |
21 | def _v3_6_role_prefix(from_role, condition, role_start_token, role_end_token):
22 | return role_start_token + f"{condition} {from_role.title()}".strip() + role_end_token
23 |
24 |
25 | MODEL_CONFIG_MAP = {
26 | # OpenChat V3.6 (llama 3)
27 | "openchat_3.6": ModelConfig(
28 | # Model
29 | model_max_context=8192,
30 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=True), # Llama 3 only has fast tokenizer
31 | model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained,
32 | low_cpu_mem_usage=True,
33 | torch_dtype=torch.bfloat16),
34 | # Conversation Template
35 | conversation_template=partial(ConversationTemplate,
36 | role_prefix=partial(_v3_6_role_prefix,
37 | role_start_token="<|start_header_id|>",
38 | role_end_token="<|end_header_id|>\n\n"),
39 | eot="<|eot_id|>",
40 | system_as_role=True,
41 | inference_condition="GPT4 Correct"),
42 | hf_chat_template="{{ bos_token }}{% for message in messages %}{% if message['role'] in ['user', 'assistant'] %}{% set content = '<|start_header_id|>GPT4 Correct ' + message['role'].title() + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{% elif message['role'] == 'system' %}{% set content = '<|start_header_id|>System<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' %}{% else %}{{ raise_exception('Only user, assistant and system roles are supported!') }}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n' }}{% endif %}",
43 | ),
44 |
45 | # OpenChat V3.2
46 | "openchat_v3.2": ModelConfig(
47 | # Model
48 | model_max_context=4096,
49 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False),
50 | model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained,
51 | low_cpu_mem_usage=True,
52 | torch_dtype=torch.bfloat16),
53 |
54 | # Conversation Template
55 | conversation_template=partial(ConversationTemplate,
56 | role_prefix=_v3_2_role_prefix,
57 | eot="<|end_of_turn|>",
58 | inference_condition="GPT4")
59 | ),
60 |
61 | "openchat_v3.2_mistral": ModelConfig(
62 | serving_aliases=("openchat_3.5", ),
63 |
64 | # Model
65 | model_max_context=8192,
66 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=True),
67 | model_create_for_training=partial(ochat.models.MistralForCausalLM.from_pretrained,
68 | low_cpu_mem_usage=True,
69 | torch_dtype=torch.bfloat16),
70 |
71 | # Conversation Template
72 | conversation_template=partial(ConversationTemplate,
73 | role_prefix=_v3_2_role_prefix,
74 | eot="<|end_of_turn|>",
75 | inference_condition="GPT4 Correct"),
76 | hf_chat_template="{{ bos_token }}{% for message in messages %}{% if message['role'] in ['user', 'assistant'] %}{% set content = 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>' %}{% elif message['role'] == 'system' %}{% set content = message['content'] + '<|end_of_turn|>' %}{% else %}{{ raise_exception('Only user, assistant and system roles are supported!') }}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", ),
77 |
78 | "openchat_v3.2_gemma_new": ModelConfig(
79 | serving_aliases=("openchat_3.5_gemma_new", ),
80 |
81 | # Model
82 | model_max_context=8192,
83 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=True),
84 | model_create_for_training=partial(ochat.models.GemmaForCausalLM.from_pretrained,
85 | low_cpu_mem_usage=True,
86 | torch_dtype=torch.bfloat16),
87 |
88 | # Conversation Template
89 | conversation_template=partial(ConversationTemplate,
90 | role_prefix=_v3_2_role_prefix,
91 | eot="",
92 | inference_condition="GPT4 Correct"),
93 | hf_chat_template="{{ bos_token }}{% for message in messages %}{% if message['role'] in ['user', 'assistant'] %}{% set content = 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '' %}{% elif message['role'] == 'system' %}{% set content = message['content'] + '' %}{% else %}{{ raise_exception('Only user, assistant and system roles are supported!') }}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
94 | ),
95 |
96 | ### Other models
97 | "chatml_8192": ModelConfig(
98 | # Model
99 | model_max_context=8192,
100 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=True),
101 | model_create_for_training=lambda x: None,
102 |
103 | # Conversation Template
104 | conversation_template=partial(ConversationTemplate,
105 | role_prefix=lambda from_role, condition: f"<|im_start|>{from_role}\n",
106 | eot="<|im_end|>",
107 | inference_condition="")
108 | ),
109 | "zephyr_mistral": ModelConfig(
110 | # Model
111 | model_max_context=8192,
112 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False),
113 | model_create_for_training=partial(ochat.models.MistralForCausalLM.from_pretrained,
114 | low_cpu_mem_usage=True,
115 | torch_dtype=torch.bfloat16),
116 |
117 | # Conversation Template
118 | conversation_template=partial(ConversationTemplate,
119 | role_prefix=lambda from_role, condition: f"<|{from_role}|>\n",
120 | eot="",
121 | inference_condition="")
122 | ),
123 | "gemma_it": ModelConfig(
124 | # Model
125 | model_max_context=8192,
126 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=False),
127 | model_create_for_training=partial(ochat.models.GemmaForCausalLM.from_pretrained,
128 | low_cpu_mem_usage=True,
129 | torch_dtype=torch.bfloat16),
130 |
131 | # Conversation Template
132 | conversation_template=partial(ConversationTemplate,
133 | role_prefix=lambda from_role, condition: f"{_GEMMA_IT_PREFIXES[from_role]}\n",
134 | eot="",
135 | inference_condition="")
136 | ),
137 | "llama3_instruct": ModelConfig(
138 | # Model
139 | model_max_context=8192,
140 | model_tokenizer_create=partial(transformers.AutoTokenizer.from_pretrained, use_fast=True), # Llama 3 only has fast tokenizer
141 | model_create_for_training=partial(ochat.models.LlamaForCausalLM.from_pretrained,
142 | low_cpu_mem_usage=True,
143 | torch_dtype=torch.bfloat16),
144 |
145 | # Conversation Template
146 | conversation_template=partial(ConversationTemplate,
147 | role_prefix=lambda from_role, condition: f"<|start_header_id|>{from_role}<|end_header_id|>\n\n",
148 | eot="<|eot_id|>",
149 | inference_condition="")
150 | ),
151 | }
152 |
--------------------------------------------------------------------------------
/ochat/config/conversation_template.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Callable, Iterable, List
2 |
3 | from pydantic import BaseModel
4 |
5 |
6 | class Message(BaseModel):
7 | role: str
8 | content: str
9 |
10 | weight: Optional[float] = None
11 |
12 |
13 | class Conversation(BaseModel):
14 | items: List[Message]
15 |
16 | condition: str = ""
17 | system: str = ""
18 |
19 |
20 | class ConversationTemplate(BaseModel):
21 | tokenizer: Callable
22 |
23 | # Prompt
24 | role_prefix: Callable
25 | system_as_role: Optional[bool] = False
26 | eot: str
27 |
28 | inference_condition: Optional[str] = None
29 |
30 | # Private
31 | bos_tokens_: List[int]
32 | eot_tokens_: List[int]
33 |
34 | def __init__(self, **data):
35 | tokenizer = data["tokenizer"]
36 | eot = data["eot"]
37 | bos_tokens_ = tokenizer("").input_ids
38 | eot_tokens_ = tokenizer(eot, add_special_tokens=False).input_ids
39 | super().__init__(**data, bos_tokens_=bos_tokens_, eot_tokens_=eot_tokens_)
40 |
41 | def _tokenize(self, strings: Iterable[str], ignore_special: bool = True) -> List[List[int]]:
42 | if self.tokenizer.is_fast:
43 | # Support for fast tokenizer
44 | # https://github.com/huggingface/tokenizers/pull/1419
45 | self.tokenizer._tokenizer.encode_special_tokens = ignore_special
46 | result = self.tokenizer(strings, return_attention_mask=False, add_special_tokens=False).input_ids
47 | self.tokenizer._tokenizer.encode_special_tokens = False
48 | else:
49 | result = self.tokenizer(strings, split_special_tokens=ignore_special, return_attention_mask=False, add_special_tokens=False).input_ids
50 |
51 | return result
52 |
53 | def tokenize_conversations(self, conversations: Iterable[Conversation], inference: bool = False, seq_level_weight: bool = False):
54 | # Pre-tokenize all conversations
55 | default_condition = self.inference_condition if inference else ""
56 |
57 | sys_mappings = set()
58 | role_mappings = set()
59 | all_text = []
60 | for conv in conversations:
61 | sys_mappings.add(conv.system)
62 | for msg in conv.items:
63 | role_mappings.add((msg.role, conv.condition or default_condition))
64 | all_text.append(msg.content)
65 |
66 | system_role_tokens = []
67 | if self.system_as_role:
68 | system_role_tokens = self._tokenize(self.role_prefix("system", ""), ignore_special=False)
69 |
70 | sys_mappings = list(sys_mappings)
71 | role_mappings = list(role_mappings)
72 |
73 | sys_mappings = dict(zip(sys_mappings, self._tokenize(sys_mappings)))
74 | role_mappings = dict(zip(role_mappings, self._tokenize([self.role_prefix(*args) for args in role_mappings], ignore_special=False)))
75 | all_text = self._tokenize(all_text)
76 |
77 | # Convert
78 | result_tokens = []
79 | result_weights = []
80 | all_text_idx = 0
81 | for conv in conversations:
82 | tokens = []
83 | weights = []
84 |
85 | # bos tokens
86 | tokens.extend(self.bos_tokens_)
87 | weights.extend([0.] * len(self.bos_tokens_))
88 |
89 | # System
90 | if conv.system:
91 | tokens.extend(system_role_tokens)
92 | weights.extend([0.] * len(system_role_tokens))
93 |
94 | system = sys_mappings[conv.system]
95 | tokens.extend(system)
96 | weights.extend([0.] * len(system))
97 |
98 | tokens.extend(self.eot_tokens_)
99 | weights.extend([0.] * len(self.eot_tokens_))
100 |
101 | # Messages
102 | last_idx = len(conv.items) - 1
103 | for idx, msg in enumerate(conv.items):
104 | # Role Prefix
105 | role = role_mappings[(msg.role, conv.condition or default_condition)]
106 | tokens.extend(role)
107 | weights.extend([0.] * len(role))
108 |
109 | # Message
110 | text = all_text[all_text_idx]
111 | all_text_idx += 1
112 |
113 | # weight
114 | w = None
115 | if not inference:
116 | assert msg.weight is not None
117 |
118 | w = msg.weight
119 | if seq_level_weight:
120 | w /= len(text) + len(self.eot_tokens_)
121 |
122 | # Message tokens
123 | tokens.extend(text)
124 | weights.extend([w] * len(text))
125 |
126 | if not (inference and idx == last_idx): # Do not add EOT on last turn during inference
127 | tokens.extend(self.eot_tokens_)
128 | weights.extend([w] * len(self.eot_tokens_))
129 |
130 | # Append result
131 | result_tokens.append(tokens)
132 | result_weights.append(weights)
133 |
134 | # Sanity check
135 | assert all_text_idx == len(all_text)
136 |
137 | return result_tokens, result_weights
138 |
--------------------------------------------------------------------------------
/ochat/config/model_config.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Callable, Iterable
2 |
3 | from pydantic import BaseModel, ConfigDict
4 |
5 |
6 | class ModelConfig(BaseModel):
7 | # Alias
8 | serving_aliases: Iterable[str] = ()
9 |
10 | # Model
11 | model_max_context: int
12 | model_tokenizer_create: Callable
13 | model_create_for_training: Callable
14 |
15 | # conversation template
16 | conversation_template: Callable
17 | hf_chat_template: Optional[str] = None
18 |
19 | model_config = ConfigDict(protected_namespaces=()) # Disables warnings for the model_ namespace used above
20 |
--------------------------------------------------------------------------------
/ochat/data/generate_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate training data based on conversations
3 |
4 | Usage: python -m ochat.data.generate_data --in-file sharegpt_gpt4.jsonl --tokenizer-name HF_REPO_NAME --out-dir .
5 | """
6 |
7 | import argparse
8 | import os
9 | import gc
10 | import random
11 |
12 | import ray
13 | import orjson
14 | import pyarrow
15 | from pyarrow import parquet
16 |
17 |
18 | PAD_TOKEN_ID = 0
19 |
20 |
21 | def _split(a, n):
22 | # Split list a to n chunks
23 | # https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length
24 | k, m = divmod(len(a), n)
25 | return [a[i*k+min(i, m): (i+1)*k+min(i+1, m)] for i in range(n)]
26 |
27 |
28 | def truncate_trailing_zero_weighted(tokens, weights):
29 | non_zero_index = len(weights) - 1
30 | while non_zero_index >= 0 and weights[non_zero_index] == 0:
31 | non_zero_index -= 1
32 |
33 | return tokens[:non_zero_index + 1], weights[:non_zero_index + 1]
34 |
35 |
36 | def add_single_conv(output, tokens, weights):
37 | # truncate trailing zero weighted tokens
38 | tokens, weights = truncate_trailing_zero_weighted(tokens, weights)
39 | if not tokens:
40 | return
41 |
42 | # labels
43 | length = len(tokens)
44 | labels = [(t if w != 0 else PAD_TOKEN_ID) for t, w in zip(tokens, weights)]
45 |
46 | # populate results
47 | results = {
48 | "total_length": length,
49 |
50 | "seqlens": [length],
51 | "nz_input_ids": tokens,
52 | "nz_position_ids": list(range(length)),
53 |
54 | "nz_shifted_label_ids": labels[1:] + [PAD_TOKEN_ID],
55 | "nz_shifted_loss_weights": weights[1:] + [0.0]
56 | }
57 | results["num_seqs"] = sum(results["nz_shifted_loss_weights"])
58 |
59 | for k, v in results.items():
60 | output[k].append(v)
61 |
62 |
63 | @ray.remote
64 | def convert_conversation_batch(model_type: str, model_path: str, batch: list, schema: pyarrow.Schema, per_sequence_loss: bool):
65 | from ochat.config import MODEL_CONFIG_MAP, Conversation
66 |
67 | # Tokenization
68 | model_config = MODEL_CONFIG_MAP[model_type]
69 | tokenizer = model_config.model_tokenizer_create(model_path)
70 | conv_template = model_config.conversation_template(tokenizer=tokenizer)
71 |
72 | # Decode data
73 | print ("Decoding JSON ...")
74 | batch = [Conversation(**orjson.loads(json_line)) for json_line in batch]
75 |
76 | # Tokenize
77 | print ("Tokenizing ...")
78 | tokens_list, weights_list = conv_template.tokenize_conversations(batch, inference=False, seq_level_weight=per_sequence_loss)
79 |
80 | del batch
81 | gc.collect()
82 |
83 | # Generate data
84 | print ("Generating ...")
85 | max_context = model_config.model_max_context
86 |
87 | outputs = {k: [] for k in schema.names}
88 | for tokens, weights in zip(tokens_list, weights_list):
89 | assert len(tokens) == len(weights)
90 |
91 | # Truncate to specified tokens
92 | tokens = tokens[:max_context]
93 | weights = weights[:max_context]
94 |
95 | # Add to results
96 | add_single_conv(outputs, tokens, weights)
97 |
98 | del tokens_list, weights_list
99 | gc.collect()
100 |
101 | print ("To table ...")
102 | table = pyarrow.Table.from_pydict(outputs, schema=schema)
103 |
104 | del outputs
105 | gc.collect()
106 |
107 | print ("Chunk finish")
108 | return table
109 |
110 |
111 | def generate_epoch(seed: int, model_type: str, model_path: str, in_filename: str, out_filename: str, per_sequence_loss: bool):
112 | # schema
113 | metadata = {
114 | "model_type": model_type
115 | }
116 | schema = [
117 | pyarrow.field("total_length", pyarrow.int32()),
118 | pyarrow.field("num_seqs", pyarrow.float32()),
119 |
120 | pyarrow.field(f"seqlens", pyarrow.list_(pyarrow.int32())),
121 | pyarrow.field(f"nz_input_ids", pyarrow.list_(pyarrow.int32())),
122 | pyarrow.field(f"nz_position_ids", pyarrow.list_(pyarrow.int32())),
123 | pyarrow.field(f"nz_shifted_label_ids", pyarrow.list_(pyarrow.int32())),
124 | pyarrow.field(f"nz_shifted_loss_weights", pyarrow.list_(pyarrow.float32()))
125 | ]
126 |
127 | schema = pyarrow.schema(schema, metadata={"metadata_json": orjson.dumps(metadata)})
128 |
129 | # Load data
130 | with open(in_filename, "rb") as f:
131 | batches = f.readlines()
132 |
133 | random.seed(seed) # Randomized load balancing
134 | random.shuffle(batches)
135 |
136 | batches = _split(batches, int(ray.available_resources()["CPU"]))
137 |
138 | # launch remote workers
139 | handles = [convert_conversation_batch.remote(
140 | model_type=model_type, # type: ignore
141 | model_path=model_path,
142 | batch=batch,
143 | schema=schema,
144 | per_sequence_loss=per_sequence_loss
145 | ) for batch in batches]
146 |
147 | # write
148 | parquet.write_table(pyarrow.concat_tables([ray.get(handle) for handle in handles]), out_filename)
149 |
150 |
151 | def generate_dataset(model_type, model_path, in_prefix, out_prefix, per_sequence_loss, seed):
152 | # Initialize Ray
153 | if not ray.is_initialized():
154 | ray.init(ignore_reinit_error=True, num_cpus=os.cpu_count())
155 |
156 | # Load epochs and tokenize
157 | epoch = 0
158 | while True:
159 | in_filename = f"{in_prefix}.{epoch}.jsonl"
160 | if not os.path.exists(in_filename):
161 | break
162 |
163 | out_filename = f"{out_prefix}.{epoch}.parquet"
164 | generate_epoch(
165 | seed=seed + epoch,
166 | model_type=model_type,
167 | model_path=model_path,
168 | in_filename=in_filename,
169 | out_filename=out_filename,
170 | per_sequence_loss=per_sequence_loss
171 | )
172 | gc.collect()
173 |
174 | epoch += 1
175 |
176 |
177 | if __name__ == "__main__":
178 | parser = argparse.ArgumentParser()
179 | parser.add_argument("--model-type", type=str, required=True)
180 | parser.add_argument("--model-path", type=str, required=True)
181 |
182 | parser.add_argument("--in-prefix", type=str, required=True)
183 | parser.add_argument("--out-prefix", type=str, required=True)
184 |
185 | parser.add_argument("--per-sequence-loss", action="store_true")
186 | parser.add_argument("--seed", type=int, default=42)
187 | args = parser.parse_args()
188 |
189 | generate_dataset(**vars(args))
190 |
--------------------------------------------------------------------------------
/ochat/evaluation/README.md:
--------------------------------------------------------------------------------
1 | # vLLM Eval
2 |
3 | **Working in progress... Stay tuned!**
4 |
5 | 🚀 Efficiently evaluate large language models (LLMs) in just 5 minutes (full eval suite) !
6 |
7 | ## Suite Included
8 |
9 | ### Zero-shot Multiple-Choice
10 |
11 | - AGIEval
12 | - BBH
13 | - TruthfulQA
14 |
15 | ### Few-shot CoT
16 |
17 | - BBH
18 | - GSM8k
19 |
--------------------------------------------------------------------------------
/ochat/evaluation/conv_eval.py:
--------------------------------------------------------------------------------
1 | from typing import OrderedDict
2 | import signal
3 | import os
4 | import json
5 | import subprocess
6 | import argparse
7 | import time
8 | import requests
9 | import re
10 | import coolname
11 |
12 |
13 | MAX_CONTEXT = 4096
14 |
15 |
16 | def find_models(path, prefix, ep_filter):
17 | run_name = '_'.join(coolname.generate(2))
18 |
19 | def generate_model_name(root, ep_number):
20 | return f"{prefix}{os.path.basename(root)}_ep{ep_number}_{run_name}"
21 |
22 | models = {}
23 | for root, dirs, _ in os.walk(path):
24 | for d in dirs:
25 | ep_match = re.match(r"ep_(\d+)", d)
26 | if not ep_match:
27 | continue
28 |
29 | if ep_filter and ep_match.group(1) != ep_filter:
30 | continue
31 |
32 | model_name = generate_model_name(root, ep_match.group(1))
33 | models[model_name] = os.path.join(root, d)
34 |
35 | # Sort and return the models dictionary as an OrderedDict
36 | return OrderedDict(sorted(models.items(), reverse=True, key=lambda x: x[0].split("_ep")[::-1]))
37 |
38 |
39 | def run_mt_bench(mt_bench_path, model_name):
40 | working_dir = os.path.join(mt_bench_path, "fastchat", "llm_judge")
41 |
42 | # Skip if result exists
43 | if os.path.exists(os.path.join(working_dir, "data", "mt_bench", "model_answer", f"{model_name}.jsonl")):
44 | return
45 |
46 | # run mt bench
47 | commands = [
48 | f"python gen_api_answer.py --model {model_name} --max-tokens {MAX_CONTEXT} --parallel 128 --openai-api-base http://localhost:18888/v1",
49 | f"python gen_judgment.py --model-list {model_name} --parallel 8 --mode single",
50 | # f"python gen_judgment.py --model-list {model_name} --parallel 8 --mode pairwise-baseline",
51 | ]
52 |
53 | env = os.environ.copy()
54 | env["PYTHONPATH"] = f'{env.get("PYTHONPATH", "")}:{mt_bench_path}'
55 |
56 | for command in commands:
57 | subprocess.run(command, shell=True, cwd=working_dir, env=env)
58 |
59 |
60 | def run_vicuna_bench(mt_bench_path, model_name):
61 | working_dir = os.path.join(mt_bench_path, "fastchat", "llm_judge")
62 |
63 | # Skip if result exists
64 | if os.path.exists(os.path.join(working_dir, "data", "vicuna_bench", "model_answer", f"{model_name}.jsonl")):
65 | return
66 |
67 | # run mt bench
68 | commands = [
69 | f"python gen_api_answer.py --model {model_name} --max-tokens {MAX_CONTEXT} --parallel 128 --openai-api-base http://localhost:18888/v1 --bench-name vicuna_bench",
70 | f"python gen_judgment.py --model-list {model_name} --parallel 8 --mode pairwise-baseline --bench-name vicuna_bench",
71 | ]
72 |
73 | env = os.environ.copy()
74 | env["PYTHONPATH"] = f'{env.get("PYTHONPATH", "")}:{mt_bench_path}'
75 |
76 | for command in commands:
77 | subprocess.run(command, shell=True, cwd=working_dir, env=env)
78 |
79 |
80 | def create_alpaca_eval_config(alpacaeval_path, model_name):
81 | config_dir = os.path.join(alpacaeval_path, "src", "alpaca_eval", "models_configs", model_name.lower())
82 | os.makedirs(config_dir, exist_ok=True)
83 | config_path = os.path.join(config_dir, "configs.yaml")
84 |
85 | config_content = f"""{model_name.lower()}:
86 | prompt_template: "openchat-13b/prompt.txt"
87 | fn_completions: "openai_completions"
88 | completions_kwargs:
89 | openai_api_base: http://127.0.0.1:18888/v1
90 | requires_chatml: True
91 | sleep_time: 0
92 |
93 | model_name: "{model_name.lower()}"
94 | max_tokens: {MAX_CONTEXT}
95 |
96 | top_p: 1.0
97 | temperature: 0.7
98 |
99 | num_procs: 128
100 |
101 | pretty_name: "{model_name}"
102 | link: "https://github.com/imoneoi/openchat"
103 | """
104 |
105 | with open(config_path, "w") as f:
106 | f.write(config_content)
107 |
108 |
109 | def run_alpaca_eval(alpacaeval_path, model_name):
110 | # Skip if result exists
111 | if os.path.exists(os.path.join(alpacaeval_path, "results", model_name.lower())):
112 | return
113 |
114 | # Create config
115 | create_alpaca_eval_config(alpacaeval_path, model_name)
116 |
117 | # Run
118 | command = f"python -m alpaca_eval.main evaluate_from_model --model_configs {model_name.lower()} --annotators_config alpaca_eval_gpt4"
119 |
120 | env = os.environ.copy()
121 | env["PYTHONPATH"] = f'{env.get("PYTHONPATH", "")}:{os.path.join(alpacaeval_path, "src")}'
122 |
123 | subprocess.run(command, shell=True, cwd=alpacaeval_path, env=env)
124 |
125 |
126 | def wait_for_server(url):
127 | while True:
128 | try:
129 | response = requests.get(url)
130 | if response.status_code in [200, 404]:
131 | break
132 | except requests.exceptions.RequestException:
133 | pass
134 |
135 | time.sleep(1)
136 |
137 |
138 | def main(path, prefix, ep_filter, mt_bench_path, alpacaeval_path):
139 | models = find_models(path, prefix, ep_filter)
140 |
141 | for i, (model_name, model_path) in enumerate(models.items()):
142 | print(f"Processing model {i + 1}/{len(models)}: {model_name}")
143 |
144 | print("Starting server...")
145 | server_command = f"python -m ochat.serving.openai_api_server --model {model_path} --engine-use-ray --worker-use-ray"
146 | server_process = subprocess.Popen(server_command, shell=True, preexec_fn=os.setsid)
147 |
148 | wait_for_server("http://127.0.0.1:18888/v1")
149 | print("Server is ready.")
150 |
151 | print("Running MT-bench...")
152 | run_mt_bench(mt_bench_path, model_name)
153 |
154 | print("Running AlpacaEval...")
155 | # run_alpaca_eval(alpacaeval_path, model_name)
156 |
157 | # print("Running Vicuna-bench")
158 | # run_vicuna_bench(mt_bench_path, model_name)
159 |
160 | print("Terminating server...")
161 | os.killpg(os.getpgid(server_process.pid), signal.SIGTERM)
162 | server_process.wait()
163 |
164 |
165 | if __name__ == "__main__":
166 | parser = argparse.ArgumentParser()
167 | parser.add_argument("--path", default="/ML-A100/home/csj/trained_models/openchat_mistral/1017", help="Path to the models directory")
168 | parser.add_argument("--prefix", default="gpt4correct_")
169 | parser.add_argument("--ep_filter", default=None, help="Filter epochs")
170 |
171 | parser.add_argument("--mt_bench_path", default="/ML-A100/home/csj/one_benchmarks/FastChat", help="Path to the MT-bench directory")
172 | parser.add_argument("--alpacaeval_path", default="/ML-A100/home/csj/one_benchmarks/alpaca_eval", help="Path to the AlpacaEval directory")
173 | args = parser.parse_args()
174 |
175 | main(**vars(args))
176 |
--------------------------------------------------------------------------------
/ochat/evaluation/convert_to_evalplus.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import orjson
4 |
5 | from glob import glob
6 |
7 |
8 | def convert_to_evalplus(results_path: str, output_path: str):
9 | os.makedirs(output_path, exist_ok=True)
10 |
11 | for filename in glob(os.path.join(results_path, "*.json")):
12 | # read eval results
13 | with open(filename, "rb") as f:
14 | data = orjson.loads(f.read())
15 |
16 | # humaneval
17 | result = bytearray()
18 | for item in data:
19 | if item["task_type"] == "coding/humaneval":
20 | result.extend(orjson.dumps(item["answer"]))
21 | result.extend(b"\n")
22 |
23 | with open(os.path.join(output_path, os.path.splitext(os.path.basename(filename))[0] + ".jsonl"), "wb") as f:
24 | f.write(result)
25 |
26 |
27 | def main():
28 | parser = argparse.ArgumentParser()
29 |
30 | # Input / output
31 | parser.add_argument("--results_path", type=str, default="ochat/evaluation/eval_results")
32 | parser.add_argument("--output_path", type=str, default="ochat/evaluation/evalplus_codegen")
33 | args = parser.parse_args()
34 |
35 | convert_to_evalplus(**vars(args))
36 |
37 |
38 | if __name__ == "__main__":
39 | main()
40 |
--------------------------------------------------------------------------------
/ochat/evaluation/grading/math_grader.py:
--------------------------------------------------------------------------------
1 | """
2 | Answer checker API that uses sympy to simplify expressions and check for equality.
3 |
4 | Call grade_answer(given_answer: str, ground_truth: str).
5 | """
6 | import re
7 | import sympy
8 | from pylatexenc import latex2text
9 | from sympy.parsing import sympy_parser
10 |
11 | from ochat.evaluation.grading import math_normalize
12 |
13 |
14 | # sympy might hang -- we don't care about trying to be lenient in these cases
15 | BAD_SUBSTRINGS = ["^{", "^("]
16 | BAD_REGEXES = ["\^[0-9]+\^", "\^[0-9][0-9]+"]
17 | TUPLE_CHARS = "()[]"
18 |
19 |
20 | def _sympy_parse(expr: str):
21 | """Parses an expression with sympy."""
22 | py_expr = expr.replace("^", "**")
23 | return sympy_parser.parse_expr(
24 | py_expr,
25 | transformations=(
26 | sympy_parser.standard_transformations
27 | + (sympy_parser.implicit_multiplication_application,)
28 | ),
29 | )
30 |
31 |
32 | def _parse_latex(expr: str) -> str:
33 | """Attempts to parse latex to an expression sympy can read."""
34 | expr = expr.replace("\\tfrac", "\\frac")
35 | expr = expr.replace("\\dfrac", "\\frac")
36 | expr = expr.replace("\\frac", " \\frac") # Play nice with mixed numbers.
37 | expr = latex2text.LatexNodes2Text().latex_to_text(expr)
38 |
39 | # Replace the specific characters that this parser uses.
40 | expr = expr.replace("√", "sqrt")
41 | expr = expr.replace("π", "pi")
42 | expr = expr.replace("∞", "inf")
43 | expr = expr.replace("∪", "U")
44 | expr = expr.replace("·", "*")
45 | expr = expr.replace("×", "*")
46 |
47 | return expr.strip()
48 |
49 |
50 | def _is_float(num: str) -> bool:
51 | try:
52 | float(num)
53 | return True
54 | except ValueError:
55 | return False
56 |
57 |
58 | def _is_int(x: float) -> bool:
59 | try:
60 | return abs(x - int(round(x))) <= 1e-7
61 | except:
62 | return False
63 |
64 |
65 | def _is_frac(expr: str) -> bool:
66 | return bool(re.search(r"^-?[0-9]+.?/0*[1-9][0-9]*.?$", expr))
67 |
68 |
69 | def _str_is_int(x: str) -> bool:
70 | try:
71 | x = _strip_properly_formatted_commas(x)
72 | x = float(x)
73 | return abs(x - int(round(x))) <= 1e-7
74 | except:
75 | return False
76 |
77 |
78 | def _str_to_int(x: str) -> bool:
79 | x = x.replace(",", "")
80 | x = float(x)
81 | return int(x)
82 |
83 |
84 | def _inject_implicit_mixed_number(step: str):
85 | """
86 | Automatically make a mixed number evalable
87 | e.g. 7 3/4 => 7+3/4
88 | """
89 | p1 = re.compile("([0-9]) +([0-9])")
90 | step = p1.sub("\\1+\\2", step) ## implicit mults
91 | return step
92 |
93 |
94 | def _strip_properly_formatted_commas(expr: str):
95 | # We want to be careful because we don't want to strip tuple commas
96 | p1 = re.compile("(\d)(,)(\d\d\d)($|\D)")
97 | while True:
98 | next_expr = p1.sub("\\1\\3\\4", expr)
99 | if next_expr == expr:
100 | break
101 | expr = next_expr
102 | return next_expr
103 |
104 |
105 | def _normalize(expr: str) -> str:
106 | """Normalize answer expressions."""
107 | if expr is None:
108 | return None
109 |
110 | # Remove enclosing `\text{}`.
111 | m = re.search("^\\\\text\{(?P.+?)\}$", expr)
112 | if m is not None:
113 | expr = m.group("text")
114 |
115 | expr = expr.replace("\\%", "%")
116 | expr = expr.replace("\\$", "$")
117 | expr = expr.replace("$", "")
118 | expr = expr.replace("%", "")
119 | expr = expr.replace(" or ", " , ")
120 | expr = expr.replace(" and ", " , ")
121 |
122 | expr = expr.replace("million", "*10^6")
123 | expr = expr.replace("billion", "*10^9")
124 | expr = expr.replace("trillion", "*10^12")
125 |
126 | for unit in [
127 | "degree",
128 | "cm",
129 | "centimeter",
130 | "meter",
131 | "mile",
132 | "second",
133 | "minute",
134 | "hour",
135 | "day",
136 | "week",
137 | "month",
138 | "year",
139 | "foot",
140 | "feet",
141 | "inch",
142 | "yard",
143 | ]:
144 | expr = re.sub(f"{unit}(es)?(s)? *(\^[0-9]+)?", "", expr)
145 | expr = re.sub(f"\^ *\\\\circ", "", expr)
146 |
147 | if len(expr) > 0 and expr[0] == "{" and expr[-1] == "}":
148 | expr = expr[1:-1]
149 |
150 | expr = re.sub(",\\\\! *", "", expr)
151 | if _is_float(expr) and _is_int(float(expr)):
152 | expr = str(int(round(float(expr))))
153 | if "\\" in expr:
154 | try:
155 | expr = _parse_latex(expr)
156 | except:
157 | pass
158 |
159 | # edge case with mixed numbers and negative signs
160 | expr = re.sub("- *", "-", expr)
161 |
162 | expr = _inject_implicit_mixed_number(expr)
163 | expr = expr.replace(" ", "")
164 |
165 | # if we somehow still have latex braces here, just drop them
166 | expr = expr.replace("{", "")
167 | expr = expr.replace("}", "")
168 |
169 | # don't be case sensitive for text answers
170 | expr = expr.lower()
171 |
172 | if _str_is_int(expr):
173 | expr = str(_str_to_int(expr))
174 |
175 | return expr
176 |
177 |
178 | def count_unknown_letters_in_expr(expr: str):
179 | expr = expr.replace("sqrt", "")
180 | expr = expr.replace("frac", "")
181 | letters_in_expr = set([x for x in expr if x.isalpha()])
182 | return len(letters_in_expr)
183 |
184 |
185 | def should_allow_eval(expr: str):
186 | # we don't want to try parsing unknown text or functions of more than two variables
187 | if count_unknown_letters_in_expr(expr) > 2:
188 | return False
189 |
190 | for bad_string in BAD_SUBSTRINGS:
191 | if bad_string in expr:
192 | return False
193 |
194 | for bad_regex in BAD_REGEXES:
195 | if re.search(bad_regex, expr) is not None:
196 | return False
197 |
198 | return True
199 |
200 |
201 | def are_equal_under_sympy(ground_truth_normalized: str, given_normalized: str):
202 | are_equal = False
203 | try:
204 | expr = f"({ground_truth_normalized})-({given_normalized})"
205 | if should_allow_eval(expr):
206 | sympy_diff = _sympy_parse(expr)
207 | simplified = sympy.simplify(sympy_diff)
208 | if simplified == 0:
209 | are_equal = True
210 | except:
211 | pass
212 | return are_equal
213 |
214 |
215 | def split_tuple(expr: str):
216 | """
217 | Split the elements in a tuple/interval, while handling well-formatted commas in large numbers
218 | """
219 | expr = _strip_properly_formatted_commas(expr)
220 | if len(expr) == 0:
221 | return []
222 | if (
223 | len(expr) > 2
224 | and expr[0] in TUPLE_CHARS
225 | and expr[-1] in TUPLE_CHARS
226 | and all([ch not in expr[1:-1] for ch in TUPLE_CHARS])
227 | ):
228 | elems = [elem.strip() for elem in expr[1:-1].split(",")]
229 | else:
230 | elems = [expr]
231 | return elems
232 |
233 |
234 | def grade_answer(given_answer: str, ground_truth: str) -> bool:
235 | """
236 | The answer will be considered correct if:
237 | (a) it normalizes to the same string as the ground truth answer
238 | OR
239 | (b) sympy can simplify the difference between the expressions to 0
240 | """
241 | if given_answer is None:
242 | return False
243 |
244 | ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth)
245 | given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer)
246 |
247 | # be at least as lenient as mathd
248 | if ground_truth_normalized_mathd == given_answer_normalized_mathd:
249 | return True
250 |
251 | ground_truth_normalized = _normalize(ground_truth)
252 | given_normalized = _normalize(given_answer)
253 |
254 | if ground_truth_normalized is None:
255 | return False
256 |
257 | if ground_truth_normalized == given_normalized:
258 | return True
259 |
260 | if len(given_normalized) == 0:
261 | return False
262 |
263 | ground_truth_elems = split_tuple(ground_truth_normalized)
264 | given_elems = split_tuple(given_normalized)
265 |
266 | if len(ground_truth_elems) > 1 and (
267 | ground_truth_normalized[0] != given_normalized[0]
268 | or ground_truth_normalized[-1] != given_normalized[-1]
269 | ):
270 | is_correct = False
271 | elif len(ground_truth_elems) != len(given_elems):
272 | is_correct = False
273 | else:
274 | for ground_truth_elem, given_elem in zip(ground_truth_elems, given_elems):
275 | if _is_frac(ground_truth_elem) and _is_frac(given_elem):
276 | # if fractions aren't reduced, then shouldn't be marked as correct
277 | # so, we don't want to allow sympy.simplify in this case
278 | is_correct = ground_truth_elem == given_elem
279 | elif _str_is_int(ground_truth_elem) != _str_is_int(given_elem):
280 | # if the ground truth answer is an integer, we require the given answer to be a strict match (no sympy.simplify)
281 | is_correct = False
282 | else:
283 | is_correct = are_equal_under_sympy(ground_truth_elem, given_elem)
284 | if not is_correct:
285 | break
286 |
287 | return is_correct
288 |
--------------------------------------------------------------------------------
/ochat/evaluation/grading/math_normalize.py:
--------------------------------------------------------------------------------
1 | """
2 | This logic is largely copied from the Hendrycks' MATH release (math_equivalence).
3 | """
4 | import re
5 | from typing import Optional
6 |
7 |
8 | def normalize_answer(answer: Optional[str]) -> Optional[str]:
9 | if answer is None:
10 | return None
11 | answer = answer.strip()
12 | try:
13 | # Remove enclosing `\text{}`.
14 | m = re.search("^\\\\text\{(?P.+?)\}$", answer)
15 | if m is not None:
16 | answer = m.group("text").strip()
17 | return _strip_string(answer)
18 | except:
19 | return answer
20 |
21 |
22 | def _fix_fracs(string):
23 | substrs = string.split("\\frac")
24 | new_str = substrs[0]
25 | if len(substrs) > 1:
26 | substrs = substrs[1:]
27 | for substr in substrs:
28 | new_str += "\\frac"
29 | if substr[0] == "{":
30 | new_str += substr
31 | else:
32 | try:
33 | assert len(substr) >= 2
34 | except:
35 | return string
36 | a = substr[0]
37 | b = substr[1]
38 | if b != "{":
39 | if len(substr) > 2:
40 | post_substr = substr[2:]
41 | new_str += "{" + a + "}{" + b + "}" + post_substr
42 | else:
43 | new_str += "{" + a + "}{" + b + "}"
44 | else:
45 | if len(substr) > 2:
46 | post_substr = substr[2:]
47 | new_str += "{" + a + "}" + b + post_substr
48 | else:
49 | new_str += "{" + a + "}" + b
50 | string = new_str
51 | return string
52 |
53 |
54 | def _fix_a_slash_b(string):
55 | if len(string.split("/")) != 2:
56 | return string
57 | a = string.split("/")[0]
58 | b = string.split("/")[1]
59 | try:
60 | a = int(a)
61 | b = int(b)
62 | assert string == "{}/{}".format(a, b)
63 | new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
64 | return new_string
65 | except:
66 | return string
67 |
68 |
69 | def _remove_right_units(string):
70 | # "\\text{ " only ever occurs (at least in the val set) when describing units
71 | if "\\text{ " in string:
72 | splits = string.split("\\text{ ")
73 | assert len(splits) == 2
74 | return splits[0]
75 | else:
76 | return string
77 |
78 |
79 | def _fix_sqrt(string):
80 | if "\\sqrt" not in string:
81 | return string
82 | splits = string.split("\\sqrt")
83 | new_string = splits[0]
84 | for split in splits[1:]:
85 | if split[0] != "{":
86 | a = split[0]
87 | new_substr = "\\sqrt{" + a + "}" + split[1:]
88 | else:
89 | new_substr = "\\sqrt" + split
90 | new_string += new_substr
91 | return new_string
92 |
93 |
94 | def _strip_string(string):
95 | # linebreaks
96 | string = string.replace("\n", "")
97 | # print(string)
98 |
99 | # remove inverse spaces
100 | string = string.replace("\\!", "")
101 | # print(string)
102 |
103 | # replace \\ with \
104 | string = string.replace("\\\\", "\\")
105 | # print(string)
106 |
107 | # replace tfrac and dfrac with frac
108 | string = string.replace("tfrac", "frac")
109 | string = string.replace("dfrac", "frac")
110 | # print(string)
111 |
112 | # remove \left and \right
113 | string = string.replace("\\left", "")
114 | string = string.replace("\\right", "")
115 | # print(string)
116 |
117 | # Remove circ (degrees)
118 | string = string.replace("^{\\circ}", "")
119 | string = string.replace("^\\circ", "")
120 |
121 | # remove dollar signs
122 | string = string.replace("\\$", "")
123 |
124 | # remove units (on the right)
125 | string = _remove_right_units(string)
126 |
127 | # remove percentage
128 | string = string.replace("\\%", "")
129 | string = string.replace("\%", "")
130 |
131 | # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
132 | string = string.replace(" .", " 0.")
133 | string = string.replace("{.", "{0.")
134 | # if empty, return empty string
135 | if len(string) == 0:
136 | return string
137 | if string[0] == ".":
138 | string = "0" + string
139 |
140 | # to consider: get rid of e.g. "k = " or "q = " at beginning
141 | if len(string.split("=")) == 2:
142 | if len(string.split("=")[0]) <= 2:
143 | string = string.split("=")[1]
144 |
145 | # fix sqrt3 --> sqrt{3}
146 | string = _fix_sqrt(string)
147 |
148 | # remove spaces
149 | string = string.replace(" ", "")
150 |
151 | # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
152 | string = _fix_fracs(string)
153 |
154 | # manually change 0.5 --> \frac{1}{2}
155 | if string == "0.5":
156 | string = "\\frac{1}{2}"
157 |
158 | # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
159 | string = _fix_a_slash_b(string)
160 |
161 | return string
162 |
--------------------------------------------------------------------------------
/ochat/evaluation/match_answer.py:
--------------------------------------------------------------------------------
1 | import re
2 | import ast
3 |
4 | from ochat.evaluation.grading.math_grader import grade_answer
5 |
6 |
7 | def zs_agieval_match_answer(task_data, response):
8 | # AGIEval match first capital letter, following original paper implementation
9 | # https://github.com/microsoft/AGIEval/blob/main/src/post_process.py
10 |
11 | letter_set = {"A", "B", "C", "D", "E", "F"}
12 | for c in response:
13 | if c in letter_set:
14 | return True, c
15 |
16 | return False, ""
17 |
18 |
19 | def zs_bbh_mc_orca_truthfulqa_orca_match_answer(task_data, response):
20 | # For BBH & TruthfulQA, match first option letter
21 |
22 | for c in response:
23 | if c in task_data["options"]:
24 | return True, c
25 |
26 | return False, ""
27 |
28 |
29 | def fs_cothub_math_match_answer(task_data, response, max_length=256):
30 | def _last_boxed_only_string(string):
31 | idx = string.rfind("\\boxed")
32 | if idx < 0:
33 | idx = string.rfind("\\fbox")
34 | if idx < 0:
35 | return None
36 |
37 | i = idx
38 | left_brace_idx = None
39 | right_brace_idx = None
40 | num_left_braces_open = 0
41 | while i < len(string):
42 | if string[i] == "{":
43 | num_left_braces_open += 1
44 | if left_brace_idx is None:
45 | left_brace_idx = i
46 | elif string[i] == "}":
47 | num_left_braces_open -= 1
48 | if num_left_braces_open == 0:
49 | right_brace_idx = i
50 | break
51 |
52 | i += 1
53 |
54 | if left_brace_idx is None or right_brace_idx is None:
55 | return None
56 |
57 | return string[left_brace_idx + 1: right_brace_idx].strip()
58 |
59 | # Match true answer
60 | ground_truth_answer = _last_boxed_only_string(task_data["_metadata"]["solution"])
61 | assert ground_truth_answer
62 |
63 | # Match model answer
64 | is_matched = False
65 |
66 | ans_line = response.split('The answer is')
67 | if len(ans_line) > 1:
68 | is_matched = True
69 | response = ans_line[-1].strip()
70 | else:
71 | ans_extracted = _last_boxed_only_string(response)
72 | if ans_extracted:
73 | is_matched = True
74 | response = ans_extracted
75 |
76 | # Grade
77 | response = response[:max_length] # To avoid sympy taking too long
78 | return is_matched, grade_answer(response, ground_truth_answer)
79 |
80 |
81 | def zs_gpqa_match_answer(task_data, response):
82 | # Expected to see answer field, otherwise return C.
83 | ans = response.split("The correct answer is")
84 |
85 | if len(ans) == 1:
86 | return False, "C"
87 |
88 | ans = ans[1]
89 |
90 | letter_set = {"A", "B", "C", "D"}
91 | for c in ans:
92 | if c in letter_set:
93 | return True, c
94 |
95 | return False, "C"
96 |
97 |
98 | def fs_cothub_bbh_match_answer(task_data, response):
99 | # CoT hub match answer for BBH
100 | # https://github.com/FranxYao/chain-of-thought-hub/blob/main/BBH/run_bbh_gpt_3.5_turbo.py
101 |
102 | ans_line = response.split('answer is ')
103 |
104 | # Expect to see 'answer is'. If not return whole string
105 | if len(ans_line) == 1:
106 | return False, response
107 | else:
108 | ans = ans_line[-1].strip()
109 |
110 | if task_data["options"]:
111 | # Multiple choice, find appearing letter
112 | options = ['(A)', '(B)', '(C)', '(D)', '(E)', '(F)', '(G)', '(H)', '(I)', '(J)', '(K)', '(L)', '(M)', '(N)', '(O)', '(P)', '(Q)', '(R)', '(S)', '(T)', '(U)', '(V)', '(W)', '(X)', '(Y)', '(Z)']
113 |
114 | for option in options:
115 | if option in ans:
116 | return True, option
117 |
118 | return False, ans
119 | else:
120 | # Free form, direct return
121 | if len(ans) and ans[-1] == '.':
122 | ans = ans[:-1]
123 |
124 | return True, ans
125 |
126 |
127 | def fs_cothub_gsm8k_match_answer(task_data, response):
128 | # CoT hub match answer for GSM8k, match last numeric value
129 | # https://github.com/FranxYao/chain-of-thought-hub/blob/main/gsm8k/gpt3.5turbo_gsm8k_complex.ipynb
130 |
131 | pattern = '\d*\.?\d+'
132 | pred = re.findall(pattern, response)
133 | if len(pred) >= 1:
134 | return True, pred[-1]
135 |
136 | return False, response
137 |
138 |
139 | def fs_cothub_mmlu_match_answer(task_data, response):
140 | ans_line = response.split('answer is')
141 |
142 | # Expect to see 'answer is'. If not return C
143 | if len(ans_line) == 1:
144 | return False, "(C)"
145 | else:
146 | ans = ans_line[-1].strip()
147 |
148 | options = ['(A)', '(B)', '(C)', '(D)']
149 | for option in options:
150 | if option in ans:
151 | return True, option
152 |
153 | return False, "(C)"
154 |
155 |
156 | def coding_humaneval_match_answer(task_data, response):
157 | # Matching utilities
158 | def _function_exists(code, func_name):
159 | tree = ast.parse(code)
160 | for node in ast.walk(tree):
161 | if isinstance(node, ast.FunctionDef) and node.name == func_name:
162 | return True
163 |
164 | return False
165 |
166 | def _try_match(content, prefix, entrypoint):
167 | # All markdown code blocks, as well as raw
168 | code_blocks = [m[1] for m in re.findall(r"(\`{3}.*?\n+)([\s\S]*?)(\n+\`{3})", content)] \
169 | + [content]
170 |
171 | for block in code_blocks:
172 | # Check syntax
173 | try:
174 | code_completion = prefix + block
175 | if _function_exists(code_completion, entrypoint):
176 | return code_completion
177 | except SyntaxError:
178 | pass
179 |
180 | # Try match with include prefix
181 | humaneval_task = task_data["_metadata"]
182 | include_prefix = humaneval_task['prompt'].split('def')[0].strip() + "\n\n"
183 |
184 | result = _try_match(response, include_prefix, humaneval_task["entry_point"])
185 | if result:
186 | return True, {"task_id": humaneval_task["task_id"], "completion": result}
187 |
188 | # If fail then match with function signature
189 | result = _try_match(response, humaneval_task["prompt"], humaneval_task["entry_point"])
190 | if result:
191 | return True, {"task_id": humaneval_task["task_id"], "completion": result}
192 |
193 | return False, {"task_id": humaneval_task["task_id"], "completion": response}
194 |
195 |
196 | MATCH_ANSWER_FUNCTION = {
197 | "zs/agieval": zs_agieval_match_answer,
198 | "zs/bbh_mc_orca": zs_bbh_mc_orca_truthfulqa_orca_match_answer,
199 | "zs/truthfulqa_orca": zs_bbh_mc_orca_truthfulqa_orca_match_answer,
200 | "zs/gpqa": zs_gpqa_match_answer,
201 |
202 | "fs_cothub/bbh": fs_cothub_bbh_match_answer,
203 | "fs_cothub/gsm8k": fs_cothub_gsm8k_match_answer,
204 | "fs_cothub/mmlu": fs_cothub_mmlu_match_answer,
205 | "fs_cothub/math": fs_cothub_math_match_answer,
206 |
207 | "coding/humaneval": coding_humaneval_match_answer
208 | }
209 |
--------------------------------------------------------------------------------
/ochat/evaluation/run_eval.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import argparse
3 | import os
4 | import asyncio
5 | from glob import glob
6 |
7 | import orjson
8 | import openai
9 | from tqdm import tqdm
10 | from openai import RateLimitError, InternalServerError, APIConnectionError
11 | from tenacity import retry, stop_after_attempt, wait_random_exponential, retry_if_exception_type
12 | from vllm import LLM, SamplingParams
13 |
14 | from transformers.utils.hub import cached_file
15 |
16 | from ochat.evaluation.match_answer import MATCH_ANSWER_FUNCTION
17 | from ochat.config import MODEL_CONFIG_MAP
18 |
19 |
20 | def _strip_first_space(s: str):
21 | if len(s) and s[0] == " ":
22 | return s[1:]
23 | return s
24 |
25 |
26 | @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(20), retry=retry_if_exception_type((RateLimitError, InternalServerError, APIConnectionError, )))
27 | async def _chat_completion_with_backoff(client, **kwargs):
28 | return await client.chat.completions.create(**kwargs)
29 |
30 |
31 | async def chat_completion_thread(model, progress_bar, queue):
32 | client = openai.AsyncOpenAI()
33 |
34 | while True:
35 | # Fetch task
36 | try:
37 | task = queue.get_nowait()
38 | except asyncio.QueueEmpty:
39 | break
40 |
41 | # Completion
42 | try:
43 | response = await _chat_completion_with_backoff(
44 | client,
45 | model=model,
46 | messages=[{"role": "user", "content": task["question"]}],
47 |
48 | temperature=0
49 | )
50 | task["response"] = response.choices[0].message.content # type: ignore
51 | except Exception as e:
52 | if hasattr(e, "last_attempt"):
53 | e = e.last_attempt
54 | if hasattr(e, "_exception"):
55 | e = e._exception
56 |
57 | print(type(e), str(e))
58 |
59 | # Progress
60 | progress_bar.update()
61 |
62 |
63 | async def get_openai_answers(
64 | model: str,
65 | questions: list,
66 | parallel: int
67 | ):
68 | # Complete in retry cycles
69 | last_to_complete_num = 0
70 |
71 | while True:
72 | # fill queue
73 | to_complete_num = 0
74 | queue = asyncio.Queue()
75 | for q in questions:
76 | if q["response"]:
77 | continue
78 |
79 | queue.put_nowait(q)
80 | to_complete_num += 1
81 |
82 | tqdm.write(f"New completion cycle. To complete {to_complete_num}, number of parallel calls {parallel}")
83 |
84 | # Create tasks
85 | progress_bar = tqdm(total=to_complete_num)
86 | async with asyncio.TaskGroup() as task_group:
87 | for _ in range(parallel):
88 | task_group.create_task(chat_completion_thread(model, progress_bar, queue))
89 |
90 | # Next retry cycle
91 | # Break if cannot complete more
92 | if (to_complete_num == last_to_complete_num) or (to_complete_num == 0):
93 | break
94 | last_to_complete_num = to_complete_num
95 |
96 | # Reduce parallel calls
97 | parallel = max(1, parallel // 2)
98 |
99 | return questions
100 |
101 |
102 | def tokenize_questions(model_config: object, conv_template: object, questions: list, condition: str, system_msg: str):
103 | from ochat.config import Conversation, Message
104 |
105 | # Construct conversation
106 | prompt_indices = []
107 | conversations = []
108 | for idx, q in enumerate(questions):
109 | if q["response"]:
110 | continue
111 |
112 | conversations.append(Conversation(
113 | items=[
114 | Message(role="user", content=q["question"]),
115 | Message(role="assistant", content="")
116 | ],
117 | condition=condition,
118 | system=system_msg
119 | ))
120 | prompt_indices.append(idx)
121 |
122 | # Tokenize
123 | conversations, _ = conv_template.tokenize_conversations(conversations, inference=True)
124 | conversations = [tokens[-model_config.model_max_context:] for tokens in conversations]
125 |
126 | return conversations, prompt_indices
127 |
128 |
129 | def get_model_answers(
130 | model: str,
131 | questions: list,
132 | condition: str,
133 | system_msg: str,
134 | model_type: str,
135 | tensor_parallel_size: int
136 | ):
137 | # Load model config
138 | if model_type is None:
139 | with open(cached_file(path_or_repo_id=model, filename="openchat.json"), "r") as f:
140 | model_type = orjson.loads(f.read())["model_type"]
141 |
142 | model_config = MODEL_CONFIG_MAP[model_type]
143 | tokenizer = model_config.model_tokenizer_create(model)
144 | conv_template = model_config.conversation_template(tokenizer=tokenizer)
145 |
146 | # Init vLLM engine
147 | engine = LLM(model,
148 | max_num_batched_tokens=model_config.model_max_context,
149 | max_model_len=model_config.model_max_context,
150 | tensor_parallel_size=tensor_parallel_size)
151 | sampling_params = SamplingParams(temperature=0,
152 | max_tokens=None,
153 | stop_token_ids=conv_template.eot_tokens_, # Override stop tokens
154 | ignore_eos=True)
155 |
156 | # Complete
157 | prompts, prompt_indices = tokenize_questions(model_config, conv_template, questions,
158 | condition=condition, system_msg=system_msg)
159 |
160 | # calculate & fill in responses
161 | responses = engine.generate(prompt_token_ids=prompts, sampling_params=sampling_params)
162 | for idx, resp in zip(prompt_indices, responses):
163 | questions[idx]["response"] = _strip_first_space(resp.outputs[0].text)
164 |
165 | return questions
166 |
167 |
168 | async def run_eval(
169 | model: str,
170 | condition: str,
171 | system_msg: str,
172 | model_type: str,
173 |
174 | data_path: str,
175 | eval_sets: list,
176 |
177 | continue_from: Optional[str],
178 | output_file: str,
179 |
180 | parallel: int,
181 | tensor_parallel_size: int
182 | ):
183 | print (f"Evaluating ({model_type})...\n\nCondition: {condition}\nSystem Prompt: {system_msg}\n")
184 |
185 | if continue_from is not None:
186 | # Load continue
187 | print (f"Continuing from {continue_from}...")
188 |
189 | with open(continue_from, "rb") as f:
190 | questions = orjson.loads(f.read())
191 | else:
192 | # Load questions
193 | questions = []
194 |
195 | for filename in glob(os.path.join(data_path, "**", "*.jsonl"), recursive=True):
196 | task_name = os.path.splitext(filename[len(data_path):])[0].strip("\\/")
197 | task_type = os.path.dirname(task_name)
198 |
199 | assert task_type in MATCH_ANSWER_FUNCTION
200 |
201 | # Filter eval sets
202 | if eval_sets and not sum([task_name.startswith(a) for a in eval_sets]):
203 | continue
204 |
205 | # Load task
206 | with open(filename, "r") as f:
207 | task_data = list(map(orjson.loads, f.readlines()))
208 |
209 | questions.extend([{**item, "task_name": task_name, "task_type": task_type, "response": ""} for item in task_data])
210 |
211 | # run completion
212 | if model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
213 | questions = await get_openai_answers(model, questions, parallel)
214 | else:
215 | questions = get_model_answers(model, questions, condition, system_msg, model_type, tensor_parallel_size)
216 |
217 | # Calculate accuracy
218 | for q in questions:
219 | q["is_matched"], q["answer"] = MATCH_ANSWER_FUNCTION[q["task_type"]](q, q["response"])
220 | try:
221 | q["is_correct"] = q["answer"] in q["label"]
222 | except:
223 | q["is_correct"] = False
224 |
225 | # Write results
226 | if output_file is None:
227 | output_file = os.path.join(os.path.dirname(data_path), "eval_results", f"{os.path.basename(model)}_{condition}.json")
228 |
229 | os.makedirs(os.path.dirname(output_file), exist_ok=True)
230 | with open(output_file, "wb") as f:
231 | f.write(orjson.dumps(questions, option=orjson.OPT_INDENT_2))
232 |
233 |
234 | async def main():
235 | parser = argparse.ArgumentParser()
236 |
237 | # Input / output
238 | parser.add_argument("--model", type=str, default=None)
239 | parser.add_argument("--condition", type=str, default="")
240 | parser.add_argument("--system-msg", type=str, default="")
241 | parser.add_argument("--model-type", type=str, default=None)
242 |
243 | parser.add_argument("--data-path", type=str, default="ochat/evaluation/eval_data")
244 | parser.add_argument("--eval-sets", type=str, nargs="+", default=[])
245 |
246 | parser.add_argument("--continue-from", type=str, default=None)
247 | parser.add_argument("--output-file", type=str, default=None)
248 | parser.add_argument("--parallel", type=int, default=16)
249 | parser.add_argument("--tensor-parallel-size", type=int, default=1)
250 |
251 | args = parser.parse_args()
252 |
253 | await run_eval(**vars(args))
254 |
255 | if __name__ == "__main__":
256 | asyncio.run(main())
257 |
--------------------------------------------------------------------------------
/ochat/evaluation/view_results.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 |
5 | import orjson
6 | import pandas as pd
7 | from glob import glob
8 |
9 | def save_results(dfs, save_path: str):
10 | os.makedirs(os.path.dirname(save_path), exist_ok=True)
11 | with pd.ExcelWriter(save_path) as writer:
12 | for task_type, df_task in dfs.items():
13 | df_task.to_excel(writer, sheet_name=task_type)
14 |
15 | def view_results(result_path: str):
16 | # Read results
17 | eval_results = []
18 | for filename in glob(os.path.join(result_path, "*.json")):
19 | with open(filename, "rb") as f:
20 | questions = orjson.loads(f.read())
21 |
22 | eval_results.extend([{
23 | "model": Path(filename).stem,
24 | "task_type": q["task_type"],
25 | "task_name": os.path.relpath(q["task_name"], q["task_type"]),
26 | "accuracy": q["is_correct"],
27 | "unmatched": not q["is_matched"],
28 | } for q in questions])
29 | df = pd.DataFrame.from_records(eval_results)
30 | all_tables = dict()
31 | # Overall metrics table
32 | df_overall = df.pivot_table(index=["model"], columns=["task_type"], values=["accuracy", "unmatched"], aggfunc="mean")
33 | all_tables["overall"] = df_overall
34 | print(df_overall.to_string(float_format=lambda x: f"{x * 100:.1f}", na_rep="-"))
35 | # Print tables for each task
36 | for task_type in df["task_type"].unique():
37 | df_task = df[df["task_type"] == task_type].pivot_table(index=["task_name"], columns=["model"], values=["accuracy", "unmatched"], aggfunc="mean")
38 | all_tables[task_type.replace("/", "_")] = df_task
39 | print(f"\n### {task_type}\n")
40 | print(df_task.to_string(float_format=lambda x: f"{x * 100:.1f}", na_rep="-"))
41 | return all_tables
42 |
43 |
44 | def main():
45 | parser = argparse.ArgumentParser()
46 |
47 | # Input / output
48 | parser.add_argument("--result_path", type=str, default="ochat/evaluation/eval_results")
49 | parser.add_argument("--save_path", type=str, default="ochat/evaluation/eval_results/summary.xlsx")
50 | parser.add_argument("--save", "-s", action="store_true", help="Save the results to a file")
51 | args = parser.parse_args()
52 |
53 | all_tables = view_results(args.result_path)
54 | if args.save:
55 | save_results(all_tables, args.save_path)
56 |
57 |
58 | if __name__ == "__main__":
59 | main()
60 |
--------------------------------------------------------------------------------
/ochat/experimental/generate_dataset_old.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate training data based on conversations
3 |
4 | Usage: python -m ochat.data.generate_data --in-file sharegpt_gpt4.json --tokenizer-name HF_REPO_NAME --out-dir .
5 | """
6 |
7 | from typing import Optional
8 | from dataclasses import dataclass
9 | import argparse
10 | import json
11 | import os
12 | import random
13 |
14 | import numpy as np
15 | import transformers
16 | from transformers.trainer_pt_utils import LabelSmoother
17 | from ray.util.multiprocessing import Pool
18 |
19 |
20 | @dataclass
21 | class ModelDataConfig:
22 | name: str
23 |
24 | # Prompt
25 | system: str
26 |
27 | role_prefix: dict
28 | ai_role: str
29 | eot_token: str
30 | bos_token: Optional[str]
31 |
32 | # Tokenize
33 | max_tokens: int
34 | pad_token: int
35 | ignore_id: int
36 |
37 |
38 | CONFIG = ModelDataConfig(
39 | name="OChat",
40 |
41 | # Prompt
42 | system="",
43 |
44 | role_prefix={
45 | "human": "Human: ",
46 | "gpt": "Assistant: "
47 | },
48 | ai_role="gpt",
49 | eot_token="<|end_of_turn|>",
50 | bos_token="",
51 |
52 | # Tokenize
53 | max_tokens=8192,
54 | pad_token="",
55 | ignore_id=LabelSmoother.ignore_index
56 | )
57 |
58 |
59 | def generate_split(conversations: list, tokenizer: transformers.AutoTokenizer, split_name: str, out_dir: str):
60 | # Add prompt and tokenize conversation
61 | def _convert_single_conversation(c):
62 | tokens = []
63 | masks = []
64 |
65 | # begin of sentence (bos)
66 | if CONFIG.bos_token:
67 | t = tokenizer.convert_tokens_to_ids(CONFIG.bos_token)
68 | tokens.append(t)
69 | masks.append(False)
70 |
71 | # System
72 | if CONFIG.system:
73 | t = tokenizer(CONFIG.system, add_special_tokens=False) + [tokenizer.convert_tokens_to_ids(CONFIG.eot_token)]
74 | tokens.extend(t)
75 | masks.extend([False] * len(t))
76 |
77 | # Messages
78 | for message in c["items"]:
79 | # Message
80 | message_text = CONFIG.role_prefix[message["from"]] + message["value"]
81 |
82 | t = tokenizer(message_text, add_special_tokens=False) + [tokenizer.convert_tokens_to_ids(CONFIG.eot_token)]
83 | tokens.extend(t)
84 | masks.extend([message["from"] == CONFIG.ai_role] * len(t))
85 |
86 | return tokens, masks
87 |
88 | converted = Pool().map(_convert_single_conversation, conversations)
89 |
90 | # Pad and to numpy array
91 | pad_id = tokenizer.convert_tokens_to_ids(CONFIG.pad_token)
92 |
93 | all_input_ids = []
94 | all_labels = []
95 | all_attention_masks = []
96 | all_plain_texts = []
97 | for tokens, masks in converted:
98 | # Cut to length
99 | tokens = np.array(tokens[:CONFIG.max_tokens], np.int_)
100 | masks = np.array(masks[:CONFIG.max_tokens], np.bool_)
101 |
102 | # Pad
103 | input_ids = np.full(CONFIG.max_tokens, pad_id, np.int_)
104 | labels = np.full(CONFIG.max_tokens, CONFIG.ignore_id, np.int_)
105 | attention_masks = np.full(CONFIG.max_tokens, False, np.bool_)
106 |
107 | length = len(tokens)
108 |
109 | input_ids[:length] = tokens
110 | labels[:length] = np.where(masks, tokens, CONFIG.ignore_id)
111 | attention_masks[:length] = True
112 |
113 | all_input_ids.append(input_ids)
114 | all_labels.append(labels)
115 | all_attention_masks.append(attention_masks)
116 | all_plain_texts.append(tokens)
117 |
118 | # Output training data
119 | np.savez(os.path.join(out_dir, f"ochat.{split_name}.npz"),
120 | # Arrays
121 | input_ids=np.vstack(all_input_ids),
122 | labels=np.vstack(all_labels),
123 | attention_masks=np.vstack(all_attention_masks))
124 |
125 | # Output plain texts
126 | all_plain_texts = tokenizer.decode(all_plain_texts)
127 |
128 | with open(os.path.join(out_dir, f"ochat.{split_name}.text.json"), "w") as f:
129 | json.dump(all_plain_texts, f)
130 |
131 |
132 | def generate_dataset(seed, in_file, tokenizer_name, out_dir, eval_ratio):
133 | # Load tokenizer
134 | tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)
135 |
136 | # Load conversations
137 | with open(in_file, "r") as f:
138 | conversations = json.load(f)
139 |
140 | # Train-test split
141 | random.seed(seed)
142 | random.shuffle(conversations)
143 | eval_num = int(eval_ratio * len(conversations))
144 |
145 | train_conversations = conversations[eval_num:]
146 | eval_conversations = conversations[:eval_num]
147 |
148 | generate_split(train_conversations, tokenizer, "train", out_dir)
149 | generate_split(eval_conversations, tokenizer, "eval", out_dir)
150 |
151 |
152 | if __name__ == "__main__":
153 | parser = argparse.ArgumentParser()
154 | parser.add_argument("--seed", type=int, default=42)
155 | parser.add_argument("--in-file", type=str, required=True)
156 | parser.add_argument("--tokenizer-name", type=str, required=True)
157 | parser.add_argument("--out-dir", type=str, default=".")
158 | parser.add_argument("--eval-ratio", type=float, default=0.01)
159 | args = parser.parse_args()
160 |
161 | generate_dataset(**vars(args))
162 |
--------------------------------------------------------------------------------
/ochat/experimental/test_multipack_dataloader.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import numpy as np\n",
10 | "import numba\n",
11 | "\n",
12 | "import json"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 12,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "from typing import Any, Optional, List, Callable\n",
22 | "\n",
23 | "import torch.distributed as dist\n",
24 | "\n",
25 | "import numpy as np\n",
26 | "import numba\n",
27 | "\n",
28 | "\n",
29 | "@numba.njit\n",
30 | "def ffd_check(a: np.ndarray, c: int, n: int):\n",
31 | " # First-fit-decreasing bin packing\n",
32 | " # Check if a[] could fit in n bins with capacity c\n",
33 | " # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing\n",
34 | "\n",
35 | " a = np.sort(a)[::-1]\n",
36 | " bins = np.full((n, ), c, dtype=a.dtype)\n",
37 | " for size in a:\n",
38 | " not_found = True\n",
39 | " for idx in range(n):\n",
40 | " if bins[idx] >= size:\n",
41 | " bins[idx] -= size\n",
42 | " not_found = False\n",
43 | " break\n",
44 | "\n",
45 | " if not_found:\n",
46 | " return False\n",
47 | "\n",
48 | " return True\n",
49 | "\n",
50 | "\n",
51 | "@numba.njit\n",
52 | "def ffd_with_result(a: np.ndarray, c: int, start_index: int):\n",
53 | " # First-fit-decreasing bin packing (with result return)\n",
54 | "\n",
55 | " indices = np.argsort(a)[::-1]\n",
56 | " a = a[indices]\n",
57 | "\n",
58 | " bins = []\n",
59 | " bins_result = []\n",
60 | " for a_id, size in enumerate(a):\n",
61 | " add_new = True\n",
62 | " for idx in range(len(bins)):\n",
63 | " if bins[idx] >= size:\n",
64 | " bins[idx] -= size\n",
65 | " bins_result[idx].append(indices[a_id] + start_index)\n",
66 | " add_new = False\n",
67 | " break\n",
68 | "\n",
69 | " if add_new:\n",
70 | " bins.append(c - size)\n",
71 | " bins_result.append([indices[a_id] + start_index])\n",
72 | "\n",
73 | " return bins_result\n",
74 | "\n",
75 | "\n",
76 | "@numba.njit\n",
77 | "def allocate(lengths: np.ndarray, numseqs: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int):\n",
78 | " # Dynamic batch allocator, similar to Multifit\n",
79 | " # https://en.wikipedia.org/wiki/Multifit_algorithm\n",
80 | " # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)\n",
81 | "\n",
82 | " s = 0\n",
83 | " start_index = 0\n",
84 | " result = []\n",
85 | " result_totseqs = []\n",
86 | "\n",
87 | " while True:\n",
88 | " # binary search [l, r)\n",
89 | " l = 1\n",
90 | " r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, \"right\")\n",
91 | "\n",
92 | " while r - l > 1:\n",
93 | " m = (l + r) // 2\n",
94 | " if ffd_check(lengths[start_index: start_index + m], c, n):\n",
95 | " l = m\n",
96 | " else:\n",
97 | " r = m\n",
98 | "\n",
99 | " # use length l\n",
100 | " batch = ffd_with_result(lengths[start_index: start_index + l], c, start_index)\n",
101 | " if len(batch) < n:\n",
102 | " break\n",
103 | "\n",
104 | " start_index += l\n",
105 | " s = lengths_cumsum[start_index - 1]\n",
106 | "\n",
107 | " # add local rank\n",
108 | " result.append(batch[rank])\n",
109 | " # add total seqs for all ranks\n",
110 | " totseq = 0\n",
111 | " for indices in batch:\n",
112 | " for idx in indices:\n",
113 | " totseq += numseqs[idx]\n",
114 | " result_totseqs.append(totseq)\n",
115 | "\n",
116 | " return result, result_totseqs, s, len(result) * c * n\n",
117 | "\n",
118 | "\n",
119 | "class MultipackDistributedDataloader:\n",
120 | " \"\"\"Unpadded data loading using Multipack.\n",
121 | " Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.\"\"\"\n",
122 | " \n",
123 | " def __init__(\n",
124 | " self,\n",
125 | " dataset: Any,\n",
126 | " lengths: np.ndarray,\n",
127 | " numseqs: np.ndarray,\n",
128 | "\n",
129 | " batch_max_length: int,\n",
130 | " collate_fn: Callable,\n",
131 | "\n",
132 | " num_replicas: Optional[int] = None,\n",
133 | " rank: Optional[int] = None,\n",
134 | "\n",
135 | " seed: int = 0,\n",
136 | " ):\n",
137 | " # Dataset\n",
138 | " self.dataset = dataset\n",
139 | " self.lengths = lengths\n",
140 | " self.numseqs = numseqs\n",
141 | " assert isinstance(self.lengths, np.ndarray)\n",
142 | "\n",
143 | " self.batch_max_length = batch_max_length\n",
144 | " self.collate_fn = collate_fn\n",
145 | "\n",
146 | " # Get rank\n",
147 | " if num_replicas is None:\n",
148 | " if not dist.is_available():\n",
149 | " raise RuntimeError(\"Requires distributed package to be available\")\n",
150 | " num_replicas = dist.get_world_size()\n",
151 | " if rank is None:\n",
152 | " if not dist.is_available():\n",
153 | " raise RuntimeError(\"Requires distributed package to be available\")\n",
154 | " rank = dist.get_rank()\n",
155 | "\n",
156 | " self.num_replicas = num_replicas\n",
157 | " self.rank = rank\n",
158 | "\n",
159 | " # Seed\n",
160 | " self.seed = seed\n",
161 | "\n",
162 | " # Epoch\n",
163 | " self.epoch = 0\n",
164 | "\n",
165 | " # statistics\n",
166 | " self.eff_total_used = 0\n",
167 | " self.eff_total_slots = 0\n",
168 | "\n",
169 | " def set_epoch(self, epoch: int):\n",
170 | " self.epoch = epoch\n",
171 | "\n",
172 | " def generate_batches(self, set_stats=False):\n",
173 | " indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(len(self.lengths))\n",
174 | "\n",
175 | " lengths = self.lengths[indices]\n",
176 | " numseqs = self.numseqs[indices]\n",
177 | " lengths_cumsum = np.cumsum(lengths)\n",
178 | "\n",
179 | " batches, totseqs, total_used, total_slots = allocate(lengths=lengths,\n",
180 | " numseqs=numseqs,\n",
181 | " lengths_cumsum=lengths_cumsum,\n",
182 | " rank=self.rank,\n",
183 | " c=self.batch_max_length,\n",
184 | " n=self.num_replicas)\n",
185 | " \n",
186 | " curseqs = [np.sum(numseqs[batch]) for batch in batches]\n",
187 | " batches = [indices[batch] for batch in batches]\n",
188 | "\n",
189 | " # statistics\n",
190 | " if set_stats:\n",
191 | " self.eff_total_used += total_used\n",
192 | " self.eff_total_slots += total_slots\n",
193 | "\n",
194 | " return batches, totseqs, curseqs\n",
195 | " \n",
196 | " def __iter__(self):\n",
197 | " all_batches, all_totseqs, all_curseqs = self.generate_batches(set_stats=True)\n",
198 | "\n",
199 | " for batch, totseq, curseq in zip(all_batches, all_totseqs, all_curseqs):\n",
200 | " yield self.collate_fn(self.dataset[batch]), totseq, curseq\n",
201 | "\n",
202 | " def num_batches(self):\n",
203 | " batches, _, _ = self.generate_batches()\n",
204 | " return len(batches)\n",
205 | "\n",
206 | " def efficiency(self):\n",
207 | " return self.eff_total_used / self.eff_total_slots\n"
208 | ]
209 | },
210 | {
211 | "cell_type": "code",
212 | "execution_count": 3,
213 | "metadata": {},
214 | "outputs": [
215 | {
216 | "name": "stdout",
217 | "output_type": "stream",
218 | "text": [
219 | "[[3], [2, 0], [1, 4], [5]]\n"
220 | ]
221 | }
222 | ],
223 | "source": [
224 | "lengths = np.array([1, 5, 7, 8, 3, 2])\n",
225 | "lengths_cumsum = np.cumsum(lengths)\n",
226 | "\n",
227 | "print(ffd_with_result(lengths, 8, start_index=0))"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 16,
233 | "metadata": {},
234 | "outputs": [
235 | {
236 | "name": "stdout",
237 | "output_type": "stream",
238 | "text": [
239 | "[29, 29, 29, 29, 29, 29, 29, 29]\n",
240 | "Efficiency: [0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559, 0.9955281976408559]\n",
241 | "Overall Efficiency: 0.9955281976408559\n"
242 | ]
243 | }
244 | ],
245 | "source": [
246 | "DATASET = \"../../dataset_processed/openchat.train.json\"\n",
247 | "C = 14 * 2048\n",
248 | "N = 8\n",
249 | "EPOCHS = 10\n",
250 | "\n",
251 | "# Load dataset\n",
252 | "with open(DATASET, \"r\") as f:\n",
253 | " dataset = json.load(f)\n",
254 | "\n",
255 | "# Check allocator efficiency\n",
256 | "lengths = np.array([len(tokens) for tokens, masks, group in dataset])\n",
257 | "numseqs = np.random.randint(low=1, high=10, size=lengths.shape)\n",
258 | "# lengths = np.random.randint(0, 2048, (int(5e6))).astype(np.int32)\n",
259 | "\n",
260 | "# test sampler correctness & efficiency\n",
261 | "tot_len = 0\n",
262 | "tot_batches = 0\n",
263 | "\n",
264 | "dataloaders = [MultipackDistributedDataloader(dataset=np.arange(len(lengths)), lengths=lengths, numseqs=numseqs,\n",
265 | " batch_max_length=C, \n",
266 | " num_replicas=N, rank=rank,\n",
267 | " collate_fn=lambda x: x) for rank in range(N)]\n",
268 | "print([loader.num_batches() for loader in dataloaders])\n",
269 | "\n",
270 | "for epoch in range(EPOCHS):\n",
271 | " batches = []\n",
272 | " totseqs = []\n",
273 | " curseqs = []\n",
274 | "\n",
275 | " for loader in dataloaders:\n",
276 | " loader.set_epoch(epoch)\n",
277 | " totseqs.append([])\n",
278 | " curseqs.append([])\n",
279 | "\n",
280 | " for batch, totseq, curseq in loader:\n",
281 | " batches.extend(batch)\n",
282 | "\n",
283 | " gt_curseq = np.sum(numseqs[batch])\n",
284 | " # print (batch, curseq, gt_curseq)\n",
285 | " assert gt_curseq == curseq\n",
286 | "\n",
287 | " totseqs[-1].append(totseq)\n",
288 | " curseqs[-1].append(gt_curseq)\n",
289 | "\n",
290 | " # Check constraints\n",
291 | " overall_len = sum([lengths[x] for x in batch])\n",
292 | " assert overall_len <= C\n",
293 | "\n",
294 | " tot_len += overall_len\n",
295 | " tot_batches += 1\n",
296 | "\n",
297 | " # Check overall unique\n",
298 | " batches.sort()\n",
299 | " assert batches == list(set(batches)) # Unique\n",
300 | "\n",
301 | " # Check totseq accurate\n",
302 | " gt_totseqs = np.sum(curseqs, axis=0)\n",
303 | " for i in range(len(totseqs)):\n",
304 | " assert (totseqs[i] == gt_totseqs).all()\n",
305 | "\n",
306 | "# Check efficiency\n",
307 | "efficiency = [loader.efficiency() for loader in dataloaders]\n",
308 | "print(f\"Efficiency: {efficiency}\")\n",
309 | "\n",
310 | "print(f\"Overall Efficiency: {tot_len / (tot_batches * C)}\")"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": null,
316 | "metadata": {},
317 | "outputs": [
318 | {
319 | "data": {
320 | "text/plain": [
321 | "150.98552224214524"
322 | ]
323 | },
324 | "execution_count": 23,
325 | "metadata": {},
326 | "output_type": "execute_result"
327 | }
328 | ],
329 | "source": [
330 | "C * N / np.mean(lengths)"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": null,
336 | "metadata": {},
337 | "outputs": [
338 | {
339 | "data": {
340 | "text/plain": [
341 | "150.31034482758622"
342 | ]
343 | },
344 | "execution_count": 24,
345 | "metadata": {},
346 | "output_type": "execute_result"
347 | }
348 | ],
349 | "source": [
350 | "np.mean(gt_totseqs)"
351 | ]
352 | }
353 | ],
354 | "metadata": {
355 | "kernelspec": {
356 | "display_name": "torch",
357 | "language": "python",
358 | "name": "python3"
359 | },
360 | "language_info": {
361 | "codemirror_mode": {
362 | "name": "ipython",
363 | "version": 3
364 | },
365 | "file_extension": ".py",
366 | "mimetype": "text/x-python",
367 | "name": "python",
368 | "nbconvert_exporter": "python",
369 | "pygments_lexer": "ipython3",
370 | "version": "3.11.4"
371 | },
372 | "orig_nbformat": 4
373 | },
374 | "nbformat": 4,
375 | "nbformat_minor": 2
376 | }
377 |
--------------------------------------------------------------------------------
/ochat/experimental/text_length.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 15,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%matplotlib inline\n",
10 | "\n",
11 | "import json\n",
12 | "import numpy as np\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "\n",
15 | "import sentencepiece\n",
16 | "from ray.util.multiprocessing import Pool\n",
17 | "\n",
18 | "\n",
19 | "DATASET = \"../../dataset_processed/sharegpt_gpt4.json\"\n",
20 | "TOKENIZER = \"../../tokenizer/llama_tokenizer.model\""
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {},
27 | "outputs": [
28 | {
29 | "name": "stderr",
30 | "output_type": "stream",
31 | "text": [
32 | "2023-05-24 20:04:38,940\tINFO worker.py:1625 -- Started a local Ray instance.\n"
33 | ]
34 | }
35 | ],
36 | "source": [
37 | "# Load dataset\n",
38 | "tokenizer = sentencepiece.SentencePieceProcessor(model_file=TOKENIZER)\n",
39 | "with open(DATASET, \"r\") as f:\n",
40 | " dataset = json.load(f)\n",
41 | "\n",
42 | "# Parallel tokenization\n",
43 | "def _tokenize(sample):\n",
44 | " for c in sample[\"items\"]:\n",
45 | " c[\"value\"] = tokenizer.tokenize(c[\"value\"])\n",
46 | "\n",
47 | " return sample\n",
48 | "\n",
49 | "dataset_tokenized = Pool().map(_tokenize, dataset)"
50 | ]
51 | },
52 | {
53 | "cell_type": "code",
54 | "execution_count": 20,
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "PROMPT_LEN = {\n",
59 | " \"human\": 4,\n",
60 | " \"gpt\": 5\n",
61 | "}\n",
62 | "\n",
63 | "dataset_len = np.array([sum(len(c[\"value\"]) + PROMPT_LEN[c[\"from\"]] for c in sample[\"items\"]) for sample in dataset_tokenized])"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 19,
69 | "metadata": {},
70 | "outputs": [
71 | {
72 | "data": {
73 | "text/plain": [
74 | "(array([4650., 896., 301., 154., 71., 33., 33., 17., 10.,\n",
75 | " 10.]),\n",
76 | " array([ 0., 5000., 10000., 15000., 20000., 25000., 30000., 35000.,\n",
77 | " 40000., 45000., 50000.]),\n",
78 | " )"
79 | ]
80 | },
81 | "execution_count": 19,
82 | "metadata": {},
83 | "output_type": "execute_result"
84 | },
85 | {
86 | "data": {
87 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAGdCAYAAAAMm0nCAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAiuklEQVR4nO3de3BU5cHH8V8u7BIuu+FiNkQSiYOCUcASNGy9vEVSVoxWK0zRUmUEtdDACFhulYLazoTBKgVBsKU1zlRF6AgqETATJFQJt2g04ZJqGxta3ASL2Q0UEiDP+4eTM6ygkhBInvj9zOwMOefZs895IJPvnOxZoowxRgAAABaJbu0JAAAANBUBAwAArEPAAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrEDAAAMA6sa09gQuloaFBBw8eVNeuXRUVFdXa0wEAAOfAGKPa2lolJSUpOvrrr7O024A5ePCgkpOTW3saAACgGQ4cOKDevXt/7f52GzBdu3aV9OUCeDyeVp4NAAA4F+FwWMnJyc7P8a/TbgOm8ddGHo+HgAEAwDLf9vYP3sQLAACsQ8AAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA4BAwAArEPAAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrxLb2BGzUZ3Zea0+hyT5dkNXaUwAAoMVwBQYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgnfMKmAULFigqKkpTp051th0/flzZ2dnq0aOHunTpolGjRqmqqirieZWVlcrKylKnTp2UkJCgGTNm6OTJkxFjtmzZosGDB8vtdqtv377Kzc09n6kCAIB2pNkBs2vXLj3//PMaOHBgxPZp06bpzTff1Jo1a1RYWKiDBw/q7rvvdvafOnVKWVlZqq+v17Zt2/Tiiy8qNzdX8+bNc8ZUVFQoKytLw4YNU0lJiaZOnaoHH3xQmzZtau50AQBAO9KsgDly5IjGjh2rP/7xj+rWrZuzPRQK6U9/+pOeeeYZ3XLLLUpPT9cLL7ygbdu2afv27ZKkt99+W3v37tVf/vIXXXvttRo5cqR+85vfaNmyZaqvr5ckrVixQqmpqXr66ad11VVXafLkyRo9erQWLVrUAqcMAABs16yAyc7OVlZWljIzMyO2FxcX68SJExHb+/fvr5SUFBUVFUmSioqKNGDAAPl8PmdMIBBQOBzWnj17nDFfPXYgEHCOcTZ1dXUKh8MRDwAA0D7FNvUJq1at0vvvv69du3adsS8YDMrlcik+Pj5iu8/nUzAYdMacHi+N+xv3fdOYcDisY8eOKS4u7ozXzsnJ0RNPPNHU0wEAABZq0hWYAwcO6JFHHtFLL72kjh07Xqg5NcucOXMUCoWcx4EDB1p7SgAA4AJpUsAUFxerurpagwcPVmxsrGJjY1VYWKglS5YoNjZWPp9P9fX1qqmpiXheVVWVEhMTJUmJiYln3JXU+PW3jfF4PGe9+iJJbrdbHo8n4gEAANqnJgXM8OHDVVpaqpKSEucxZMgQjR071vlzhw4dVFBQ4DynvLxclZWV8vv9kiS/36/S0lJVV1c7Y/Lz8+XxeJSWluaMOf0YjWMajwEAAL7bmvQemK5du+qaa66J2Na5c2f16NHD2T5hwgRNnz5d3bt3l8fj0ZQpU+T3+zV06FBJ0ogRI5SWlqb77rtPCxcuVDAY1Ny5c5WdnS232y1JmjhxopYuXaqZM2dq/Pjx2rx5s1avXq28vLyWOGcAAGC5Jr+J99ssWrRI0dHRGjVqlOrq6hQIBPTcc885+2NiYrR+/XpNmjRJfr9fnTt31rhx4/Tkk086Y1JTU5WXl6dp06Zp8eLF6t27t1auXKlAINDS0wUAABaKMsaY1p7EhRAOh+X1ehUKhVr8/TB9Ztt3JejTBVmtPQUAAL7Vuf785v9CAgAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWaVLALF++XAMHDpTH45HH45Hf79eGDRuc/cePH1d2drZ69OihLl26aNSoUaqqqoo4RmVlpbKystSpUyclJCRoxowZOnnyZMSYLVu2aPDgwXK73erbt69yc3Obf4YAAKDdaVLA9O7dWwsWLFBxcbF2796tW265RXfeeaf27NkjSZo2bZrefPNNrVmzRoWFhTp48KDuvvtu5/mnTp1SVlaW6uvrtW3bNr344ovKzc3VvHnznDEVFRXKysrSsGHDVFJSoqlTp+rBBx/Upk2bWuiUAQCA7aKMMeZ8DtC9e3c99dRTGj16tC655BK9/PLLGj16tCRp//79uuqqq1RUVKShQ4dqw4YNuv3223Xw4EH5fD5J0ooVKzRr1iwdOnRILpdLs2bNUl5ensrKypzXuOeee1RTU6ONGzee87zC4bC8Xq9CoZA8Hs/5nOIZ+szOa9HjXQyfLshq7SkAAPCtzvXnd7PfA3Pq1CmtWrVKR48eld/vV3FxsU6cOKHMzExnTP/+/ZWSkqKioiJJUlFRkQYMGODEiyQFAgGFw2HnKk5RUVHEMRrHNB7j69TV1SkcDkc8AABA+9TkgCktLVWXLl3kdrs1ceJErV27VmlpaQoGg3K5XIqPj48Y7/P5FAwGJUnBYDAiXhr3N+77pjHhcFjHjh372nnl5OTI6/U6j+Tk5KaeGgAAsESTA6Zfv34qKSnRjh07NGnSJI0bN0579+69EHNrkjlz5igUCjmPAwcOtPaUAADABRLb1Ce4XC717dtXkpSenq5du3Zp8eLFGjNmjOrr61VTUxNxFaaqqkqJiYmSpMTERO3cuTPieI13KZ0+5qt3LlVVVcnj8SguLu5r5+V2u+V2u5t6OgAAwELn/TkwDQ0NqqurU3p6ujp06KCCggJnX3l5uSorK+X3+yVJfr9fpaWlqq6udsbk5+fL4/EoLS3NGXP6MRrHNB4DAACgSVdg5syZo5EjRyolJUW1tbV6+eWXtWXLFm3atEler1cTJkzQ9OnT1b17d3k8Hk2ZMkV+v19Dhw6VJI0YMUJpaWm67777tHDhQgWDQc2dO1fZ2dnO1ZOJEydq6dKlmjlzpsaPH6/Nmzdr9erVysuz784fAABwYTQpYKqrq3X//ffrs88+k9fr1cCBA7Vp0yb98Ic/lCQtWrRI0dHRGjVqlOrq6hQIBPTcc885z4+JidH69es1adIk+f1+de7cWePGjdOTTz7pjElNTVVeXp6mTZumxYsXq3fv3lq5cqUCgUALnTIAALDdeX8OTFvF58BE4nNgAAA2uOCfAwMAANBaCBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHWaFDA5OTm67rrr1LVrVyUkJOiuu+5SeXl5xJjjx48rOztbPXr0UJcuXTRq1ChVVVVFjKmsrFRWVpY6deqkhIQEzZgxQydPnowYs2XLFg0ePFhut1t9+/ZVbm5u884QAAC0O00KmMLCQmVnZ2v79u3Kz8/XiRMnNGLECB09etQZM23aNL355ptas2aNCgsLdfDgQd19993O/lOnTikrK0v19fXatm2bXnzxReXm5mrevHnOmIqKCmVlZWnYsGEqKSnR1KlT9eCDD2rTpk0tcMoAAMB2UcYY09wnHzp0SAkJCSosLNTNN9+sUCikSy65RC+//LJGjx4tSdq/f7+uuuoqFRUVaejQodqwYYNuv/12HTx4UD6fT5K0YsUKzZo1S4cOHZLL5dKsWbOUl5ensrIy57Xuuece1dTUaOPGjec0t3A4LK/Xq1AoJI/H09xTPKs+s/Na9HgXw6cLslp7CgAAfKtz/fl9Xu+BCYVCkqTu3btLkoqLi3XixAllZmY6Y/r376+UlBQVFRVJkoqKijRgwAAnXiQpEAgoHA5rz549zpjTj9E4pvEYZ1NXV6dwOBzxAAAA7VOzA6ahoUFTp07VDTfcoGuuuUaSFAwG5XK5FB8fHzHW5/MpGAw6Y06Pl8b9jfu+aUw4HNaxY8fOOp+cnBx5vV7nkZyc3NxTAwAAbVyzAyY7O1tlZWVatWpVS86n2ebMmaNQKOQ8Dhw40NpTAgAAF0hsc540efJkrV+/Xlu3blXv3r2d7YmJiaqvr1dNTU3EVZiqqiolJiY6Y3bu3BlxvMa7lE4f89U7l6qqquTxeBQXF3fWObndbrnd7uacDgAAsEyTrsAYYzR58mStXbtWmzdvVmpqasT+9PR0dejQQQUFBc628vJyVVZWyu/3S5L8fr9KS0tVXV3tjMnPz5fH41FaWpoz5vRjNI5pPAYAAPhua9IVmOzsbL388st6/fXX1bVrV+c9K16vV3FxcfJ6vZowYYKmT5+u7t27y+PxaMqUKfL7/Ro6dKgkacSIEUpLS9N9992nhQsXKhgMau7cucrOznauoEycOFFLly7VzJkzNX78eG3evFmrV69WXp59d/8AAICW16QrMMuXL1coFNIPfvAD9erVy3m8+uqrzphFixbp9ttv16hRo3TzzTcrMTFRr732mrM/JiZG69evV0xMjPx+v372s5/p/vvv15NPPumMSU1NVV5envLz8zVo0CA9/fTTWrlypQKBQAucMgAAsN15fQ5MW8bnwETic2AAADa4KJ8DAwAA0BoIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADWIWAAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1mlywGzdulV33HGHkpKSFBUVpXXr1kXsN8Zo3rx56tWrl+Li4pSZmamPP/44Yszhw4c1duxYeTwexcfHa8KECTpy5EjEmI8++kg33XSTOnbsqOTkZC1cuLDpZwcAANqlJgfM0aNHNWjQIC1btuys+xcuXKglS5ZoxYoV2rFjhzp37qxAIKDjx487Y8aOHas9e/YoPz9f69ev19atW/Xwww87+8PhsEaMGKHLLrtMxcXFeuqpp/T444/rD3/4QzNOEQAAtDdRxhjT7CdHRWnt2rW66667JH159SUpKUmPPvqofvnLX0qSQqGQfD6fcnNzdc8992jfvn1KS0vTrl27NGTIEEnSxo0bddttt+nf//63kpKStHz5cj322GMKBoNyuVySpNmzZ2vdunXav3//Oc0tHA7L6/UqFArJ4/E09xTPqs/svBY93sXw6YKs1p4CAADf6lx/frfoe2AqKioUDAaVmZnpbPN6vcrIyFBRUZEkqaioSPHx8U68SFJmZqaio6O1Y8cOZ8zNN9/sxIskBQIBlZeX64svvmjJKQMAAAvFtuTBgsGgJMnn80Vs9/l8zr5gMKiEhITIScTGqnv37hFjUlNTzzhG475u3bqd8dp1dXWqq6tzvg6Hw+d5NgAAoK1qN3ch5eTkyOv1Oo/k5OTWnhIAALhAWjRgEhMTJUlVVVUR26uqqpx9iYmJqq6ujth/8uRJHT58OGLM2Y5x+mt81Zw5cxQKhZzHgQMHzv+EAABAm9SiAZOamqrExEQVFBQ428LhsHbs2CG/3y9J8vv9qqmpUXFxsTNm8+bNamhoUEZGhjNm69atOnHihDMmPz9f/fr1O+uvjyTJ7XbL4/FEPAAAQPvU5IA5cuSISkpKVFJSIunLN+6WlJSosrJSUVFRmjp1qn7729/qjTfeUGlpqe6//34lJSU5dypdddVVuvXWW/XQQw9p586deu+99zR58mTdc889SkpKkiT99Kc/lcvl0oQJE7Rnzx69+uqrWrx4saZPn95iJw4AAOzV5Dfx7t69W8OGDXO+boyKcePGKTc3VzNnztTRo0f18MMPq6amRjfeeKM2btyojh07Os956aWXNHnyZA0fPlzR0dEaNWqUlixZ4uz3er16++23lZ2drfT0dPXs2VPz5s2L+KwYAADw3XVenwPTlvE5MJH4HBgAgA1a5XNgAAAALoYW/RwYtF02XjWSuHIEADg7rsAAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA4BAwAArEPAAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA4BAwAArEPAAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA4BAwAArEPAAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA4BAwAArBPb2hMAvkmf2XmtPYUm+3RBVmtPAQDaPa7AAAAA6xAwAADAOgQMAACwDgEDAACsQ8AAAADrEDAAAMA6BAwAALAOAQMAAKxDwAAAAOsQMAAAwDoEDAAAsA7/FxLQwvj/mwDgwuMKDAAAsA4BAwAArEPAAAAA6xAwAADAOgQMAACwDnchAeDOKQDWadNXYJYtW6Y+ffqoY8eOysjI0M6dO1t7SgAAoA1oswHz6quvavr06Zo/f77ef/99DRo0SIFAQNXV1a09NQAA0MqijDGmtSdxNhkZGbruuuu0dOlSSVJDQ4OSk5M1ZcoUzZ49+1ufHw6H5fV6FQqF5PF4WnRuNl5uB4Dm4Fd1uNjO9ed3m3wPTH19vYqLizVnzhxnW3R0tDIzM1VUVHTW59TV1amurs75OhQKSfpyIVpaQ93/WvyYANAWpUxb09pTaLKyJwKtPQWch8af2992faVNBsznn3+uU6dOyefzRWz3+Xzav3//WZ+Tk5OjJ5544oztycnJF2SOAIC2yfv71p4BWkJtba28Xu/X7m+TAdMcc+bM0fTp052vGxoadPjwYfXo0UNRUVEt9jrhcFjJyck6cOBAi/9qCpFY64uDdb44WOeLg3W+OC7kOhtjVFtbq6SkpG8c1yYDpmfPnoqJiVFVVVXE9qqqKiUmJp71OW63W263O2JbfHz8hZqiPB4P3xwXCWt9cbDOFwfrfHGwzhfHhVrnb7ry0qhN3oXkcrmUnp6ugoICZ1tDQ4MKCgrk9/tbcWYAAKAtaJNXYCRp+vTpGjdunIYMGaLrr79ev//973X06FE98MADrT01AADQytpswIwZM0aHDh3SvHnzFAwGde2112rjxo1nvLH3YnO73Zo/f/4Zv65Cy2OtLw7W+eJgnS8O1vniaAvr3GY/BwYAAODrtMn3wAAAAHwTAgYAAFiHgAEAANYhYAAAgHUImCZatmyZ+vTpo44dOyojI0M7d+5s7Sm1GVu3btUdd9yhpKQkRUVFad26dRH7jTGaN2+eevXqpbi4OGVmZurjjz+OGHP48GGNHTtWHo9H8fHxmjBhgo4cORIx5qOPPtJNN92kjh07Kjk5WQsXLjxjLmvWrFH//v3VsWNHDRgwQG+99VaLn29rycnJ0XXXXaeuXbsqISFBd911l8rLyyPGHD9+XNnZ2erRo4e6dOmiUaNGnfHBkJWVlcrKylKnTp2UkJCgGTNm6OTJkxFjtmzZosGDB8vtdqtv377Kzc09Yz7t9Xti+fLlGjhwoPNBXX6/Xxs2bHD2s8YXxoIFCxQVFaWpU6c621jr8/f4448rKioq4tG/f39nv5VrbHDOVq1aZVwul/nzn/9s9uzZYx566CETHx9vqqqqWntqbcJbb71lHnvsMfPaa68ZSWbt2rUR+xcsWGC8Xq9Zt26d+fDDD82PfvQjk5qaao4dO+aMufXWW82gQYPM9u3bzd/+9jfTt29fc++99zr7Q6GQ8fl8ZuzYsaasrMy88sorJi4uzjz//PPOmPfee8/ExMSYhQsXmr1795q5c+eaDh06mNLS0gu+BhdDIBAwL7zwgikrKzMlJSXmtttuMykpKebIkSPOmIkTJ5rk5GRTUFBgdu/ebYYOHWq+//3vO/tPnjxprrnmGpOZmWk++OAD89Zbb5mePXuaOXPmOGP++c9/mk6dOpnp06ebvXv3mmeffdbExMSYjRs3OmPa8/fEG2+8YfLy8szf//53U15ebn71q1+ZDh06mLKyMmMMa3wh7Ny50/Tp08cMHDjQPPLII8521vr8zZ8/31x99dXms88+cx6HDh1y9tu4xgRME1x//fUmOzvb+frUqVMmKSnJ5OTktOKs2qavBkxDQ4NJTEw0Tz31lLOtpqbGuN1u88orrxhjjNm7d6+RZHbt2uWM2bBhg4mKijL/+c9/jDHGPPfcc6Zbt26mrq7OGTNr1izTr18/5+uf/OQnJisrK2I+GRkZ5uc//3mLnmNbUV1dbSSZwsJCY8yX69qhQwezZs0aZ8y+ffuMJFNUVGSM+TI2o6OjTTAYdMYsX77ceDweZ21nzpxprr766ojXGjNmjAkEAs7X37XviW7dupmVK1eyxhdAbW2tueKKK0x+fr75v//7PydgWOuWMX/+fDNo0KCz7rN1jfkV0jmqr69XcXGxMjMznW3R0dHKzMxUUVFRK87MDhUVFQoGgxHr5/V6lZGR4axfUVGR4uPjNWTIEGdMZmamoqOjtWPHDmfMzTffLJfL5YwJBAIqLy/XF1984Yw5/XUax7TXv6dQKCRJ6t69uySpuLhYJ06ciFiD/v37KyUlJWKtBwwYEPHBkIFAQOFwWHv27HHGfNM6fpe+J06dOqVVq1bp6NGj8vv9rPEFkJ2draysrDPWg7VuOR9//LGSkpJ0+eWXa+zYsaqsrJRk7xoTMOfo888/16lTp874JGCfz6dgMNhKs7JH4xp90/oFg0ElJCRE7I+NjVX37t0jxpztGKe/xteNaY9/Tw0NDZo6dapuuOEGXXPNNZK+PH+Xy3XGf2b61bVu7jqGw2EdO3bsO/E9UVpaqi5dusjtdmvixIlau3at0tLSWOMWtmrVKr3//vvKyck5Yx9r3TIyMjKUm5urjRs3avny5aqoqNBNN92k2tpaa9e4zf5XAgC+XXZ2tsrKyvTuu++29lTapX79+qmkpEShUEh//etfNW7cOBUWFrb2tNqVAwcO6JFHHlF+fr46duzY2tNpt0aOHOn8eeDAgcrIyNBll12m1atXKy4urhVn1nxcgTlHPXv2VExMzBnvyq6qqlJiYmIrzcoejWv0TeuXmJio6urqiP0nT57U4cOHI8ac7Rinv8bXjWlvf0+TJ0/W+vXr9c4776h3797O9sTERNXX16umpiZi/FfXurnr6PF4FBcX9534nnC5XOrbt6/S09OVk5OjQYMGafHixaxxCyouLlZ1dbUGDx6s2NhYxcbGqrCwUEuWLFFsbKx8Ph9rfQHEx8fryiuv1CeffGLtv2cC5hy5XC6lp6eroKDA2dbQ0KCCggL5/f5WnJkdUlNTlZiYGLF+4XBYO3bscNbP7/erpqZGxcXFzpjNmzeroaFBGRkZzpitW7fqxIkTzpj8/Hz169dP3bp1c8ac/jqNY9rL35MxRpMnT9batWu1efNmpaamRuxPT09Xhw4dItagvLxclZWVEWtdWloaEYz5+fnyeDxKS0tzxnzTOn4XvycaGhpUV1fHGreg4cOHq7S0VCUlJc5jyJAhGjt2rPNn1rrlHTlyRP/4xz/Uq1cve/89N/ltv99hq1atMm632+Tm5pq9e/eahx9+2MTHx0e8K/u7rLa21nzwwQfmgw8+MJLMM888Yz744APzr3/9yxjz5W3U8fHx5vXXXzcfffSRufPOO896G/X3vvc9s2PHDvPuu++aK664IuI26pqaGuPz+cx9991nysrKzKpVq0ynTp3OuI06NjbW/O53vzP79u0z8+fPb1e3UU+aNMl4vV6zZcuWiFsi//e//zljJk6caFJSUszmzZvN7t27jd/vN36/39nfeEvkiBEjTElJidm4caO55JJLznpL5IwZM8y+ffvMsmXLznpLZHv9npg9e7YpLCw0FRUV5qOPPjKzZ882UVFR5u233zbGsMYX0ul3IRnDWreERx991GzZssVUVFSY9957z2RmZpqePXua6upqY4yda0zANNGzzz5rUlJSjMvlMtdff73Zvn17a0+pzXjnnXeMpDMe48aNM8Z8eSv1r3/9a+Pz+Yzb7TbDhw835eXlEcf473//a+69917TpUsX4/F4zAMPPGBqa2sjxnz44YfmxhtvNG6321x66aVmwYIFZ8xl9erV5sorrzQul8tcffXVJi8v74Kd98V2tjWWZF544QVnzLFjx8wvfvEL061bN9OpUyfz4x//2Hz22WcRx/n000/NyJEjTVxcnOnZs6d59NFHzYkTJyLGvPPOO+baa681LpfLXH755RGv0ai9fk+MHz/eXHbZZcblcplLLrnEDB8+3IkXY1jjC+mrAcNan78xY8aYXr16GZfLZS699FIzZswY88knnzj7bVzjKGOMafp1GwAAgNbDe2AAAIB1CBgAAGAdAgYAAFiHgAEAANYhYAAAgHUIGAAAYB0CBgAAWIeAAQAA1iFgAACAdQgYAABgHQIGAABYh4ABAADW+X9sek0zUsUJ6wAAAABJRU5ErkJggg==",
88 | "text/plain": [
89 | ""
90 | ]
91 | },
92 | "metadata": {},
93 | "output_type": "display_data"
94 | }
95 | ],
96 | "source": [
97 | "plt.hist(dataset_len, range=(0, 50000))"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 21,
103 | "metadata": {},
104 | "outputs": [
105 | {
106 | "data": {
107 | "text/plain": [
108 | "5334"
109 | ]
110 | },
111 | "execution_count": 21,
112 | "metadata": {},
113 | "output_type": "execute_result"
114 | }
115 | ],
116 | "source": [
117 | "np.sum(dataset_len < 8192)"
118 | ]
119 | }
120 | ],
121 | "metadata": {
122 | "kernelspec": {
123 | "display_name": "jax",
124 | "language": "python",
125 | "name": "python3"
126 | },
127 | "language_info": {
128 | "codemirror_mode": {
129 | "name": "ipython",
130 | "version": 3
131 | },
132 | "file_extension": ".py",
133 | "mimetype": "text/x-python",
134 | "name": "python",
135 | "nbconvert_exporter": "python",
136 | "pygments_lexer": "ipython3",
137 | "version": "3.10.11"
138 | },
139 | "orig_nbformat": 4
140 | },
141 | "nbformat": 4,
142 | "nbformat_minor": 2
143 | }
144 |
--------------------------------------------------------------------------------
/ochat/experimental/train_alpaca.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import copy
16 | import logging
17 | from dataclasses import dataclass, field
18 | from typing import Dict, Optional, Sequence
19 |
20 | import torch
21 | import transformers
22 | import utils
23 | from torch.utils.data import Dataset
24 | from transformers import Trainer
25 |
26 | IGNORE_INDEX = -100
27 | DEFAULT_PAD_TOKEN = "[PAD]"
28 | DEFAULT_EOS_TOKEN = ""
29 | DEFAULT_BOS_TOKEN = ""
30 | DEFAULT_UNK_TOKEN = ""
31 | PROMPT_DICT = {
32 | "prompt_input": (
33 | "Below is an instruction that describes a task, paired with an input that provides further context. "
34 | "Write a response that appropriately completes the request.\n\n"
35 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
36 | ),
37 | "prompt_no_input": (
38 | "Below is an instruction that describes a task. "
39 | "Write a response that appropriately completes the request.\n\n"
40 | "### Instruction:\n{instruction}\n\n### Response:"
41 | ),
42 | }
43 |
44 |
45 | @dataclass
46 | class ModelArguments:
47 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
48 |
49 |
50 | @dataclass
51 | class DataArguments:
52 | data_path: str = field(default=None, metadata={"help": "Path to the training data."})
53 |
54 |
55 | @dataclass
56 | class TrainingArguments(transformers.TrainingArguments):
57 | cache_dir: Optional[str] = field(default=None)
58 | optim: str = field(default="adamw_torch")
59 | model_max_length: int = field(
60 | default=512,
61 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
62 | )
63 |
64 |
65 | def smart_tokenizer_and_embedding_resize(
66 | special_tokens_dict: Dict,
67 | tokenizer: transformers.PreTrainedTokenizer,
68 | model: transformers.PreTrainedModel,
69 | ):
70 | """Resize tokenizer and embedding.
71 |
72 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
73 | """
74 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
75 | model.resize_token_embeddings(len(tokenizer))
76 |
77 | if num_new_tokens > 0:
78 | input_embeddings = model.get_input_embeddings().weight.data
79 | output_embeddings = model.get_output_embeddings().weight.data
80 |
81 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
82 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
83 |
84 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
85 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
86 |
87 |
88 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
89 | """Tokenize a list of strings."""
90 | tokenized_list = [
91 | tokenizer(
92 | text,
93 | return_tensors="pt",
94 | padding="longest",
95 | max_length=tokenizer.model_max_length,
96 | truncation=True,
97 | )
98 | for text in strings
99 | ]
100 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
101 | input_ids_lens = labels_lens = [
102 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
103 | ]
104 | return dict(
105 | input_ids=input_ids,
106 | labels=labels,
107 | input_ids_lens=input_ids_lens,
108 | labels_lens=labels_lens,
109 | )
110 |
111 |
112 | def preprocess(
113 | sources: Sequence[str],
114 | targets: Sequence[str],
115 | tokenizer: transformers.PreTrainedTokenizer,
116 | ) -> Dict:
117 | """Preprocess the data by tokenizing."""
118 | examples = [s + t for s, t in zip(sources, targets)]
119 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
120 | input_ids = examples_tokenized["input_ids"]
121 | labels = copy.deepcopy(input_ids)
122 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
123 | label[:source_len] = IGNORE_INDEX
124 | return dict(input_ids=input_ids, labels=labels)
125 |
126 |
127 | class SupervisedDataset(Dataset):
128 | """Dataset for supervised fine-tuning."""
129 |
130 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
131 | super(SupervisedDataset, self).__init__()
132 | logging.warning("Loading data...")
133 | list_data_dict = utils.jload(data_path)
134 |
135 | logging.warning("Formatting inputs...")
136 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
137 | sources = [
138 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
139 | for example in list_data_dict
140 | ]
141 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
142 |
143 | logging.warning("Tokenizing inputs... This may take some time...")
144 | data_dict = preprocess(sources, targets, tokenizer)
145 |
146 | self.input_ids = data_dict["input_ids"]
147 | self.labels = data_dict["labels"]
148 |
149 | def __len__(self):
150 | return len(self.input_ids)
151 |
152 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
153 | return dict(input_ids=self.input_ids[i], labels=self.labels[i])
154 |
155 |
156 | @dataclass
157 | class DataCollatorForSupervisedDataset(object):
158 | """Collate examples for supervised fine-tuning."""
159 |
160 | tokenizer: transformers.PreTrainedTokenizer
161 |
162 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
163 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
164 | input_ids = torch.nn.utils.rnn.pad_sequence(
165 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
166 | )
167 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
168 | return dict(
169 | input_ids=input_ids,
170 | labels=labels,
171 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
172 | )
173 |
174 |
175 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
176 | """Make dataset and collator for supervised fine-tuning."""
177 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
178 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
179 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
180 |
181 |
182 | def train():
183 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
184 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
185 |
186 | model = transformers.AutoModelForCausalLM.from_pretrained(
187 | model_args.model_name_or_path,
188 | cache_dir=training_args.cache_dir,
189 | )
190 |
191 | tokenizer = transformers.AutoTokenizer.from_pretrained(
192 | model_args.model_name_or_path,
193 | cache_dir=training_args.cache_dir,
194 | model_max_length=training_args.model_max_length,
195 | padding_side="right",
196 | use_fast=False,
197 | )
198 | special_tokens_dict = dict()
199 | if tokenizer.pad_token is None:
200 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
201 | if tokenizer.eos_token is None:
202 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
203 | if tokenizer.bos_token is None:
204 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
205 | if tokenizer.unk_token is None:
206 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
207 |
208 | smart_tokenizer_and_embedding_resize(
209 | special_tokens_dict=special_tokens_dict,
210 | tokenizer=tokenizer,
211 | model=model,
212 | )
213 |
214 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
215 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
216 | trainer.train()
217 | trainer.save_state()
218 | trainer.save_model(output_dir=training_args.output_dir)
219 |
220 |
221 | if __name__ == "__main__":
222 | train()
223 |
--------------------------------------------------------------------------------
/ochat/models/__init__.py:
--------------------------------------------------------------------------------
1 | from ochat.models.unpadded_llama import LlamaForCausalLM
2 | from ochat.models.unpadded_mistral import MistralForCausalLM
3 | from ochat.models.unpadded_gemma import GemmaForCausalLM
4 |
--------------------------------------------------------------------------------
/ochat/models/unpadded_gemma.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3 | #
4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5 | # and OPT implementations in this library. It has been modified from its
6 | # original forms to accommodate minor architectural differences compared
7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8 | #
9 | # Licensed under the Apache License, Version 2.0 (the "License");
10 | # you may not use this file except in compliance with the License.
11 | # You may obtain a copy of the License at
12 | #
13 | # http://www.apache.org/licenses/LICENSE-2.0
14 | #
15 | # Unless required by applicable law or agreed to in writing, software
16 | # distributed under the License is distributed on an "AS IS" BASIS,
17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 | # See the License for the specific language governing permissions and
19 | # limitations under the License.
20 | """ PyTorch Unpadded & Fused Gemma model. Compatible with HF. """
21 |
22 | from typing import Optional, Tuple
23 |
24 | import torch
25 | import torch.utils.checkpoint
26 | import torch.nn.functional as F
27 | from torch import nn
28 |
29 | from transformers.activations import ACT2FN
30 | from transformers.modeling_outputs import CausalLMOutputWithPast
31 | from transformers.modeling_utils import PreTrainedModel
32 | from transformers.utils import logging
33 | from transformers.models.gemma.configuration_gemma import GemmaConfig
34 |
35 | try:
36 | from flash_attn.flash_attn_interface import flash_attn_varlen_func
37 | from flash_attn.bert_padding import pad_input
38 | except ImportError:
39 | print ("FlashAttention not found. Install it if you need to train models.")
40 |
41 |
42 | def rotate_half(x: torch.Tensor):
43 | """Rotates half the hidden dims of the input."""
44 | x1 = x[..., : x.shape[-1] // 2]
45 | x2 = x[..., x.shape[-1] // 2 :]
46 | return torch.cat((-x2, x1), dim=-1)
47 |
48 |
49 | def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor):
50 | # q, k: [nnz, num_heads, head_dim]
51 | # position_ids: [nnz]
52 | # cos, sin: [max_seq_len, head_dim]
53 | cos = cos[position_ids].unsqueeze(-2) # [nnz, 1, head_dim]
54 | sin = sin[position_ids].unsqueeze(-2) # [nnz, 1, head_dim]
55 | q_embed = (q * cos) + (rotate_half(q) * sin)
56 | k_embed = (k * cos) + (rotate_half(k) * sin)
57 | return q_embed, k_embed
58 |
59 |
60 | @torch.jit.script
61 | def lm_head_with_loss(embed_weights: torch.Tensor, hidden_states: torch.Tensor, nz_shifted_label_ids: torch.Tensor, nz_shifted_loss_weights: torch.Tensor):
62 | logits = F.linear(hidden_states, embed_weights)
63 |
64 | loss = (nz_shifted_loss_weights * torch.nn.functional.cross_entropy(logits, nz_shifted_label_ids, reduction="none")).sum()
65 | token_accuracy = (nz_shifted_loss_weights * (torch.argmax(logits.detach(), dim=-1) == nz_shifted_label_ids)).sum()
66 | return loss, token_accuracy
67 |
68 |
69 | # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Gemma
70 | RMS_NORM_TRACED = None
71 |
72 |
73 | def rms_norm(hidden_states: torch.Tensor, weight: torch.Tensor, variance_epsilon: torch.Tensor):
74 | input_dtype = hidden_states.dtype
75 | hidden_states = hidden_states.to(torch.float32)
76 |
77 | variance = hidden_states.square().mean(-1, keepdim=True)
78 | hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
79 | return (1 + weight) * hidden_states.to(input_dtype)
80 |
81 |
82 | class UnpaddedGemmaRMSNorm(nn.Module):
83 | def __init__(self, hidden_size, eps):
84 | """
85 | UnpaddedGemmaRMSNorm is equivalent to T5LayerNorm
86 | """
87 | super().__init__()
88 |
89 | self.weight = nn.Parameter(torch.zeros(hidden_size))
90 | self.variance_epsilon = torch.tensor(eps, dtype=torch.get_default_dtype())
91 |
92 | global RMS_NORM_TRACED
93 | if RMS_NORM_TRACED is None:
94 | RMS_NORM_TRACED = torch.jit.trace(rms_norm, (torch.ones(hidden_size), torch.ones(hidden_size), self.variance_epsilon))
95 |
96 | def forward(self, hidden_states):
97 | global RMS_NORM_TRACED
98 | return RMS_NORM_TRACED(hidden_states, self.weight, self.variance_epsilon)
99 |
100 |
101 | # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Gemma
102 | class UnpaddedGemmaRotaryEmbedding(torch.nn.Module):
103 | def __init__(self, dim, max_position_embeddings, base, device=None):
104 | super().__init__()
105 |
106 | # RoPE
107 | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
108 | t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
109 | freqs = torch.outer(t, inv_freq)
110 |
111 | # Different from paper, but it uses a different permutation in order to obtain the same calculation
112 | emb = torch.cat((freqs, freqs), dim=-1)
113 | dtype = torch.get_default_dtype()
114 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
115 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
116 |
117 | def forward(self):
118 | return self.cos_cached, self.sin_cached
119 |
120 |
121 | class UnpaddedGemmaMLP(nn.Module):
122 | def __init__(self, config: GemmaConfig):
123 | super().__init__()
124 |
125 | self.hidden_size = config.hidden_size
126 | self.intermediate_size = config.intermediate_size
127 |
128 | self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
129 | self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
130 | self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
131 | self.act_fn = ACT2FN[config.hidden_act]
132 |
133 | def forward(self, x):
134 | return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
135 |
136 |
137 | class UnpaddedGemmaAttention(nn.Module):
138 | """Multi-headed attention from 'Attention Is All You Need' paper"""
139 |
140 | def __init__(self, config: GemmaConfig):
141 | super().__init__()
142 |
143 | self.hidden_size = config.hidden_size
144 | self.num_heads = config.num_attention_heads
145 | self.head_dim = config.head_dim
146 | self.num_key_value_heads = config.num_key_value_heads
147 |
148 | if self.hidden_size % self.num_heads != 0:
149 | raise ValueError(
150 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
151 | f" and `num_heads`: {self.num_heads})."
152 | )
153 |
154 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
155 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
156 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
157 | self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
158 |
159 | def forward(
160 | self,
161 | cos_sin: Tuple[torch.Tensor, torch.Tensor],
162 | # Unpadded inputs
163 | nz_hidden_states: torch.Tensor,
164 | nz_position_ids: torch.LongTensor,
165 | cu_seqlens: torch.Tensor,
166 | max_seqlen: int
167 | ) -> torch.Tensor:
168 | # nz_hidden_states: [nnz, num_heads, head_dim]
169 | # nz_position_ids: [nnz]
170 | # cu_seqlens: [bs + 1]
171 |
172 | query_states = self.q_proj(nz_hidden_states).view(-1, self.num_heads, self.head_dim)
173 | key_states = self.k_proj(nz_hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
174 | value_states = self.v_proj(nz_hidden_states).view(-1, self.num_key_value_heads, self.head_dim)
175 |
176 | # RoPE
177 | cos, sin = cos_sin
178 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, nz_position_ids)
179 |
180 | # flash attn
181 | attn_output = flash_attn_varlen_func(
182 | q=query_states, k=key_states, v=value_states,
183 | cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
184 | max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen,
185 |
186 | dropout_p=0.0, causal=True)
187 |
188 | # attn_output: [total_nnz, num_heads, head_dim]
189 | attn_output = attn_output.view(-1, self.num_heads * self.head_dim) # type: ignore
190 | return self.o_proj(attn_output)
191 |
192 |
193 | class UnpaddedGemmaDecoderLayer(nn.Module):
194 | def __init__(self, config: GemmaConfig):
195 | super().__init__()
196 |
197 | self.hidden_size = config.hidden_size
198 | self.self_attn = UnpaddedGemmaAttention(config=config)
199 | self.mlp = UnpaddedGemmaMLP(config=config)
200 | self.input_layernorm = UnpaddedGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
201 | self.post_attention_layernorm = UnpaddedGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
202 |
203 | def forward(
204 | self,
205 | cos_sin: Tuple[torch.Tensor, torch.Tensor],
206 | # Unpadded inputs
207 | nz_hidden_states: torch.Tensor,
208 | nz_position_ids: torch.Tensor,
209 | cu_seqlens: torch.Tensor,
210 | max_seqlen: int
211 | ) -> torch.Tensor:
212 | # Self Attention
213 | residual = nz_hidden_states
214 |
215 | nz_hidden_states = self.input_layernorm(nz_hidden_states)
216 | nz_hidden_states = self.self_attn(
217 | cos_sin=cos_sin,
218 |
219 | nz_hidden_states=nz_hidden_states,
220 | nz_position_ids=nz_position_ids,
221 | cu_seqlens=cu_seqlens,
222 | max_seqlen=max_seqlen
223 | )
224 | nz_hidden_states = residual + nz_hidden_states
225 |
226 | # Fully Connected
227 | residual = nz_hidden_states
228 |
229 | nz_hidden_states = self.post_attention_layernorm(nz_hidden_states)
230 | nz_hidden_states = self.mlp(nz_hidden_states)
231 | nz_hidden_states = residual + nz_hidden_states
232 |
233 | return nz_hidden_states
234 |
235 |
236 | class UnpaddedGemmaPreTrainedModel(PreTrainedModel):
237 | config_class = GemmaConfig
238 | base_model_prefix = "model"
239 | supports_gradient_checkpointing = True
240 | _no_split_modules = ["UnpaddedGemmaDecoderLayer"]
241 |
242 | def _init_weights(self, module):
243 | std = self.config.initializer_range
244 | if isinstance(module, nn.Linear):
245 | module.weight.data.normal_(mean=0.0, std=std)
246 | if module.bias is not None:
247 | module.bias.data.zero_()
248 | elif isinstance(module, nn.Embedding):
249 | module.weight.data.normal_(mean=0.0, std=std)
250 | if module.padding_idx is not None:
251 | module.weight.data[module.padding_idx].zero_()
252 |
253 |
254 | class UnpaddedGemmaModel(UnpaddedGemmaPreTrainedModel):
255 | """
256 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`UnpaddedGemmaDecoderLayer`]
257 |
258 | Args:
259 | config: GemmaConfig
260 | """
261 |
262 | def __init__(self, config: GemmaConfig):
263 | super().__init__(config)
264 | self.padding_idx = config.pad_token_id
265 | self.vocab_size = config.vocab_size
266 | self.normalization_factor = config.hidden_size ** 0.5
267 |
268 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
269 | self.rotary_emb = UnpaddedGemmaRotaryEmbedding(config.head_dim,
270 | max_position_embeddings=config.max_position_embeddings,
271 | base=config.rope_theta)
272 |
273 | self.layers = nn.ModuleList([UnpaddedGemmaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
274 | self.norm = UnpaddedGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
275 |
276 | self.gradient_checkpointing = False
277 | # Initialize weights and apply final processing
278 | self.post_init()
279 |
280 | def get_input_embeddings(self):
281 | return self.embed_tokens
282 |
283 | def set_input_embeddings(self, value):
284 | self.embed_tokens = value
285 |
286 | def forward(
287 | self,
288 | # Unpadded inputs
289 | nz_input_ids: torch.Tensor,
290 | nz_position_ids: torch.Tensor,
291 | cu_seqlens: torch.Tensor,
292 | max_seqlen: int,
293 | ) -> torch.Tensor:
294 | nz_hidden_states = self.embed_tokens(nz_input_ids) * self.normalization_factor # Normalized
295 | cos_sin = self.rotary_emb()
296 |
297 | # decoder layers
298 | for decoder_layer in self.layers:
299 | if self.gradient_checkpointing and self.training:
300 | nz_hidden_states = self._gradient_checkpointing_func(
301 | decoder_layer.__call__,
302 |
303 | cos_sin,
304 | nz_hidden_states,
305 | nz_position_ids,
306 | cu_seqlens,
307 | max_seqlen
308 | )
309 | else:
310 | nz_hidden_states = decoder_layer(
311 | cos_sin,
312 |
313 | nz_hidden_states,
314 | nz_position_ids,
315 | cu_seqlens,
316 | max_seqlen
317 | )
318 |
319 | nz_hidden_states = self.norm(nz_hidden_states)
320 |
321 | return nz_hidden_states
322 |
323 |
324 | class GemmaForCausalLM(UnpaddedGemmaPreTrainedModel):
325 | def __init__(self, config):
326 | super().__init__(config)
327 | self.model = UnpaddedGemmaModel(config)
328 |
329 | # Initialize weights and apply final processing
330 | self.post_init()
331 |
332 | def get_input_embeddings(self):
333 | return self.model.embed_tokens
334 |
335 | def set_input_embeddings(self, value):
336 | self.model.embed_tokens = value
337 |
338 | def get_output_embeddings(self):
339 | return self.model.embed_tokens
340 |
341 | def set_output_embeddings(self, new_embeddings):
342 | self.model.embed_tokens = new_embeddings
343 |
344 | def set_decoder(self, decoder):
345 | self.model = decoder
346 |
347 | def get_decoder(self):
348 | return self.model
349 |
350 | def forward(
351 | self,
352 | # Unpadded inputs
353 | nz_input_ids: torch.Tensor,
354 | nz_position_ids: torch.Tensor,
355 | cu_seqlens: torch.Tensor,
356 | max_seqlen: int,
357 | # Unpadded labels
358 | nz_shifted_label_ids: Optional[torch.Tensor] = None,
359 | nz_shifted_loss_weights: Optional[torch.Tensor] = None
360 | ) -> CausalLMOutputWithPast:
361 | # Model logits
362 | hidden_states = self.model(
363 | nz_input_ids=nz_input_ids,
364 | nz_position_ids=nz_position_ids,
365 | cu_seqlens=cu_seqlens,
366 | max_seqlen=max_seqlen
367 | )
368 |
369 | # Loss
370 | loss = lm_head_with_loss(
371 | self.model.embed_tokens.weight, # Tied embeddings
372 | hidden_states,
373 | nz_shifted_label_ids,
374 | nz_shifted_loss_weights
375 | )
376 |
377 | return CausalLMOutputWithPast(
378 | loss=loss # type: ignore
379 | )
380 |
--------------------------------------------------------------------------------
/ochat/scripts/hf_add_tokens.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import transformers
4 | import torch
5 |
6 |
7 | def add_tokens_to_embedding(added_special_tokens, embedding):
8 | # Mean embedding, shape: [1, dim]
9 | new_token_embeddings = torch.mean(embedding.to(torch.float32), dim=0, keepdim=True).to(embedding.dtype)
10 | # Expand to [N, dim]
11 | new_token_embeddings = new_token_embeddings.expand(len(added_special_tokens), -1)
12 |
13 | return torch.cat([embedding, new_token_embeddings], dim=0)
14 |
15 |
16 | def hf_add_tokens(model_path, output_dir, added_special_tokens):
17 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
18 | model = transformers.AutoModelForCausalLM.from_pretrained(model_path,
19 | low_cpu_mem_usage=True,
20 | torch_dtype=torch.bfloat16)
21 | # Add tokens (tokenizer)
22 | tokenizer.add_special_tokens({"additional_special_tokens": added_special_tokens})
23 |
24 | # Add tokens (embedding)
25 | assert model.model.embed_tokens.weight.requires_grad
26 | assert model.lm_head.weight.requires_grad
27 |
28 | model.model.embed_tokens.weight = torch.nn.Parameter(add_tokens_to_embedding(added_special_tokens, model.model.embed_tokens.weight), requires_grad=True)
29 | model.lm_head.weight = torch.nn.Parameter(add_tokens_to_embedding(added_special_tokens, model.lm_head.weight), requires_grad=True)
30 |
31 | model.config.vocab_size += len(added_special_tokens)
32 |
33 | # Fix model config (Mistral's actual token length is 8192)
34 | if "mistral" in model_path.lower():
35 | assert model.config.max_position_embeddings == 32768
36 | model.config.max_position_embeddings = 8192
37 |
38 | print ({k: v.shape for k, v in model.state_dict().items()})
39 |
40 | # Save
41 | tokenizer.save_pretrained(output_dir)
42 | model.save_pretrained(output_dir)
43 |
44 |
45 | def main():
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument(
48 | "--model-path",
49 | help="Location of Mistral model, or HuggingFace repo ID",
50 | )
51 | parser.add_argument(
52 | "--output-dir",
53 | help="Location to write resulting model and tokenizer",
54 | )
55 | parser.add_argument(
56 | "--added-special-tokens",
57 | type=str,
58 | nargs="+",
59 | help="Special token list to add"
60 | )
61 |
62 | hf_add_tokens(**vars(parser.parse_args()))
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/ochat/scripts/init_special_embedding_llama3.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import transformers
4 | import torch
5 |
6 |
7 | def init_eot_embedding_llama3(model_path, output_dir, special_tokens=["<|eot_id|>", "<|start_header_id|>", "<|end_header_id|>"], mean_cutoff=128000, dtype=torch.bfloat16):
8 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
9 | model = transformers.AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=dtype)
10 |
11 | assert model.model.embed_tokens.weight.shape[0] >= mean_cutoff
12 | assert model.lm_head.weight.shape[0] >= mean_cutoff
13 |
14 | with torch.no_grad():
15 | for token in special_tokens:
16 | token_id = tokenizer.convert_tokens_to_ids(token)
17 |
18 | print (f"Token {token} ID {token_id}")
19 |
20 | model.model.embed_tokens.weight[token_id] = torch.mean(model.model.embed_tokens.weight[:mean_cutoff].to(torch.float32), dim=0).to(dtype)
21 | model.lm_head.weight[token_id] = torch.mean(model.lm_head.weight[:mean_cutoff].to(torch.float32), dim=0).to(dtype)
22 |
23 | # Save
24 | tokenizer.save_pretrained(output_dir)
25 | model.save_pretrained(output_dir)
26 |
27 |
28 | def main():
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument(
31 | "--model-path",
32 | help="Location of model, or HuggingFace repo ID",
33 | )
34 | parser.add_argument(
35 | "--output-dir",
36 | help="Location to write resulting model and tokenizer",
37 | )
38 |
39 | init_eot_embedding_llama3(**vars(parser.parse_args()))
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/ochat/scripts/modify_eos_embedding.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import transformers
4 | import torch
5 |
6 |
7 | def modify_eos_embeddings(model_path, output_dir):
8 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
9 | model = transformers.AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
10 |
11 | eos_token_id = tokenizer.eos_token_id
12 |
13 | print (f"EOS Token {tokenizer.convert_ids_to_tokens(eos_token_id)} ID {eos_token_id}")
14 | with torch.no_grad():
15 | model.model.embed_tokens.weight[eos_token_id] = torch.mean(model.model.embed_tokens.weight, dim=0)
16 | model.lm_head.weight[eos_token_id] = torch.mean(model.lm_head.weight, dim=0)
17 |
18 | # Save
19 | tokenizer.save_pretrained(output_dir)
20 | model.save_pretrained(output_dir)
21 |
22 |
23 | def main():
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument(
26 | "--model-path",
27 | help="Location of model, or HuggingFace repo ID",
28 | )
29 | parser.add_argument(
30 | "--output-dir",
31 | help="Location to write resulting model and tokenizer",
32 | )
33 |
34 | modify_eos_embeddings(**vars(parser.parse_args()))
35 |
36 |
37 | if __name__ == "__main__":
38 | main()
39 |
--------------------------------------------------------------------------------
/ochat/serving/async_tokenizer.py:
--------------------------------------------------------------------------------
1 | import ray
2 |
3 | from ochat.config import Message, Conversation
4 |
5 |
6 | @ray.remote
7 | class AsyncTokenizer:
8 | def __init__(self, model_type: str, model_path: str) -> None:
9 | from ochat.config import MODEL_CONFIG_MAP
10 |
11 | config = MODEL_CONFIG_MAP[model_type]
12 | tokenizer = config.model_tokenizer_create(model_path)
13 |
14 | self.conv_template = config.conversation_template(tokenizer=tokenizer)
15 |
16 | def tokenize(self, messages, condition, enable_sys_prompt=False):
17 | # get system messages
18 | system_message = ""
19 | items = []
20 |
21 | for msg_raw in messages:
22 | msg = Message(**msg_raw)
23 | if msg.role == "system":
24 | # Use system prompt only when enabled
25 | if enable_sys_prompt:
26 | system_message = msg.content.strip()
27 |
28 | continue
29 |
30 | items.append(msg)
31 |
32 | assert len(items)
33 |
34 | # append ai role
35 | if items[-1].role != "assistant":
36 | items.append(Message(role="assistant", content=""))
37 |
38 | tokens, _ = self.conv_template.tokenize_conversations([Conversation(items=items, system=system_message, condition=condition)],
39 | inference=True)
40 | return tokens[0]
41 |
42 | def get_eot_tokens(self):
43 | assert len(self.conv_template.eot_tokens_) == 1
44 |
45 | return self.conv_template.eot_tokens_
46 |
--------------------------------------------------------------------------------
/ochat/serving/openai_api_protocol.py:
--------------------------------------------------------------------------------
1 | # Adapted from
2 | # https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
3 | import time
4 | from typing import Dict, List, Literal, Optional, Union
5 |
6 | from pydantic import BaseModel, Field
7 |
8 | from vllm.utils import random_uuid
9 |
10 |
11 | class ErrorResponse(BaseModel):
12 | object: str = "error"
13 | message: str
14 | type: str
15 | param: Optional[str] = None
16 | code: Optional[str] = None
17 |
18 |
19 | class ModelPermission(BaseModel):
20 | id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
21 | object: str = "model_permission"
22 | created: int = Field(default_factory=lambda: int(time.time()))
23 | allow_create_engine: bool = False
24 | allow_sampling: bool = True
25 | allow_logprobs: bool = False
26 | allow_search_indices: bool = False
27 | allow_view: bool = True
28 | allow_fine_tuning: bool = False
29 | organization: str = "*"
30 | group: Optional[str] = None
31 | is_blocking: str = False
32 |
33 |
34 | class ModelCard(BaseModel):
35 | id: str
36 | object: str = "model"
37 | created: int = Field(default_factory=lambda: int(time.time()))
38 | owned_by: str = "openchat"
39 | root: Optional[str] = None
40 | parent: Optional[str] = None
41 | permission: List[ModelPermission] = Field(default_factory=list)
42 |
43 |
44 | class ModelList(BaseModel):
45 | object: str = "list"
46 | data: List[ModelCard] = Field(default_factory=list)
47 |
48 |
49 | class UsageInfo(BaseModel):
50 | prompt_tokens: int = 0
51 | total_tokens: int = 0
52 | completion_tokens: Optional[int] = 0
53 |
54 |
55 | class ChatCompletionRequest(BaseModel):
56 | model: str
57 | messages: Union[str, List[Dict[str, str]]]
58 | condition: Optional[str] = ""
59 | temperature: Optional[float] = 0.7
60 | top_p: Optional[float] = 1.0
61 | n: Optional[int] = 1
62 | max_tokens: Optional[int] = None
63 | seed: Optional[int] = None
64 | stop: Optional[Union[str, List[str]]] = None
65 | stream: Optional[bool] = False
66 | presence_penalty: Optional[float] = 0.0
67 | frequency_penalty: Optional[float] = 0.0
68 | logit_bias: Optional[Dict[str, float]] = None
69 | user: Optional[str] = None
70 |
71 |
72 | class ChatMessage(BaseModel):
73 | role: str
74 | content: str
75 |
76 |
77 | class ChatCompletionResponseChoice(BaseModel):
78 | index: int
79 | message: ChatMessage
80 | finish_reason: Optional[Literal["stop", "length"]] = None
81 |
82 |
83 | class ChatCompletionResponse(BaseModel):
84 | id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
85 | object: str = "chat.completion"
86 | created: int = Field(default_factory=lambda: int(time.time()))
87 | model: str
88 | choices: List[ChatCompletionResponseChoice]
89 | usage: UsageInfo
90 |
91 |
92 | class DeltaMessage(BaseModel):
93 | role: Optional[str] = None
94 | content: Optional[str] = None
95 |
96 |
97 | class ChatCompletionResponseStreamChoice(BaseModel):
98 | index: int
99 | delta: DeltaMessage
100 | finish_reason: Optional[Literal["stop", "length"]] = None
101 |
102 |
103 | class ChatCompletionStreamResponse(BaseModel):
104 | id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
105 | object: str = "chat.completion.chunk"
106 | created: int = Field(default_factory=lambda: int(time.time()))
107 | model: str
108 | choices: List[ChatCompletionResponseStreamChoice]
109 |
110 |
111 | class LoggingRecord(BaseModel):
112 | time: int
113 | request: ChatCompletionRequest
114 | outputs: List[str]
115 |
--------------------------------------------------------------------------------
/ochat/serving/openai_api_server.py:
--------------------------------------------------------------------------------
1 | # Adapted from
2 | # https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
3 |
4 | import argparse
5 | import asyncio
6 | from http import HTTPStatus
7 | import json
8 | import time
9 | import logging
10 | from logging.handlers import RotatingFileHandler
11 | from typing import AsyncGenerator, Optional
12 | from dataclasses import dataclass
13 |
14 | import fastapi
15 | from fastapi import BackgroundTasks, Request
16 | from fastapi.exceptions import RequestValidationError
17 | from fastapi.middleware.cors import CORSMiddleware
18 | from fastapi.responses import JSONResponse, StreamingResponse
19 | from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
20 |
21 | import uvicorn
22 | import ray
23 |
24 | from vllm.engine.arg_utils import AsyncEngineArgs
25 | from vllm.engine.async_llm_engine import AsyncLLMEngine
26 | from vllm.outputs import RequestOutput
27 | from vllm.sampling_params import SamplingParams
28 | from vllm.utils import random_uuid
29 |
30 | from ochat.config import MODEL_CONFIG_MAP
31 | from ochat.serving import openai_api_protocol, async_tokenizer
32 |
33 | from transformers.utils.hub import cached_file
34 |
35 |
36 | TIMEOUT_KEEP_ALIVE = 5 # seconds
37 |
38 |
39 | @dataclass
40 | class ModelConfig:
41 | names: set = None
42 |
43 | max_length: int = None
44 | stream_period: int = None
45 | eot_tokens: list = None
46 |
47 | enable_sys_prompt: bool = None
48 | api_keys: list = None
49 |
50 |
51 | logger = None
52 | app = fastapi.FastAPI()
53 |
54 | model = ModelConfig()
55 | tokenizer = None
56 |
57 |
58 | def _strip_first_space(s: str):
59 | if s[0] == " ":
60 | return s[1:]
61 | return s
62 |
63 |
64 | def log_request(created_time: int, request: openai_api_protocol.ChatCompletionRequest, output: RequestOutput):
65 | if logger is not None:
66 | logger.info(openai_api_protocol.LoggingRecord(
67 | time=created_time,
68 | request=request,
69 | outputs=[o.text for o in output.outputs]
70 | ).model_dump_json(exclude_unset=True))
71 |
72 |
73 | def create_error_response(status_code: HTTPStatus,
74 | message: str) -> JSONResponse:
75 | return JSONResponse(openai_api_protocol.ErrorResponse(message=message,
76 | type="invalid_request_error").dict(),
77 | status_code=status_code.value)
78 |
79 |
80 | def check_model(request) -> Optional[JSONResponse]:
81 | if request.model in model.names:
82 | return
83 |
84 | return create_error_response(
85 | HTTPStatus.NOT_FOUND,
86 | f"The model `{request.model}` does not exist.",
87 | )
88 |
89 |
90 | @app.exception_handler(RequestValidationError)
91 | async def validation_exception_handler(request, exc): # pylint: disable=unused-argument
92 | return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))
93 |
94 |
95 | async def check_api_key(
96 | auth: Optional[HTTPAuthorizationCredentials] = fastapi.Depends(HTTPBearer(auto_error=False)),
97 | ):
98 | if not model.api_keys:
99 | return
100 |
101 | if auth is None or auth.credentials not in model.api_keys:
102 | raise fastapi.HTTPException(
103 | status_code=401,
104 | detail={
105 | "error": {
106 | "message": "",
107 | "type": "invalid_request_error",
108 | "param": None,
109 | "code": "invalid_api_key",
110 | }
111 | },
112 | )
113 |
114 |
115 | @app.get("/v1/models", dependencies=[fastapi.Depends(check_api_key)])
116 | async def show_available_models():
117 | """Show available models. Right now we only have one model."""
118 | return openai_api_protocol.ModelList(data=[
119 | openai_api_protocol.ModelCard(id=name,
120 | root=name,
121 | permission=[openai_api_protocol.ModelPermission()])
122 | for name in model.names])
123 |
124 |
125 | @app.post("/v1/chat/completions", dependencies=[fastapi.Depends(check_api_key)])
126 | async def create_chat_completion(raw_request: Request, background_tasks: BackgroundTasks):
127 | """Completion API similar to OpenAI's API.
128 |
129 | See https://platform.openai.com/docs/api-reference/chat/create
130 | for the API specification. This API mimics the OpenAI ChatCompletion API.
131 |
132 | NOTE: Currently we do not support the following features:
133 | - function_call (Users should implement this by themselves)
134 | - logit_bias (to be supported by vLLM engine)
135 | """
136 |
137 | request = openai_api_protocol.ChatCompletionRequest(**await raw_request.json())
138 |
139 | error_check_ret = check_model(request)
140 | if error_check_ret is not None:
141 | return error_check_ret
142 |
143 | if request.logit_bias is not None and len(request.logit_bias) > 0:
144 | # TODO: support logit_bias in vLLM engine.
145 | return create_error_response(HTTPStatus.BAD_REQUEST,
146 | "logit_bias is not currently supported")
147 |
148 | # input ids
149 | input_ids = await tokenizer.tokenize.remote(request.messages, condition=request.condition,
150 | enable_sys_prompt=model.enable_sys_prompt)
151 | input_num_tokens = len(input_ids)
152 |
153 | # check length
154 | if request.max_tokens is None:
155 | request.max_tokens = model.max_length - input_num_tokens
156 |
157 | if input_num_tokens + request.max_tokens > model.max_length:
158 | return input_ids, create_error_response(
159 | HTTPStatus.BAD_REQUEST,
160 | f"This model's maximum context length is {model.max_length} tokens. "
161 | f"However, you requested {input_num_tokens + request.max_tokens} tokens "
162 | f"({input_num_tokens} in the messages, "
163 | f"{request.max_tokens} in the completion). "
164 | f"Please reduce the length of the messages or completion.",
165 | )
166 |
167 | # completion
168 | model_name = request.model
169 | request_id = f"cmpl-{random_uuid()}"
170 | created_time = int(time.time())
171 |
172 | try:
173 | sampling_params = SamplingParams(
174 | n=request.n,
175 | presence_penalty=request.presence_penalty,
176 | frequency_penalty=request.frequency_penalty,
177 | temperature=request.temperature,
178 | top_p=request.top_p,
179 | max_tokens=request.max_tokens,
180 | seed=request.seed,
181 | # Override stop tokens
182 | stop_token_ids=model.eot_tokens,
183 | ignore_eos=True
184 | )
185 | except ValueError as e:
186 | return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
187 |
188 | result_generator = engine.generate(prompt=None,
189 | prompt_token_ids=input_ids,
190 | sampling_params=sampling_params,
191 | request_id=request_id)
192 |
193 | def create_stream_response_json(
194 | index: int,
195 | text: str,
196 | finish_reason: Optional[str] = None,
197 | ) -> str:
198 | choice_data = openai_api_protocol.ChatCompletionResponseStreamChoice(
199 | index=index,
200 | delta=openai_api_protocol.DeltaMessage(content=text),
201 | finish_reason=finish_reason,
202 | )
203 | response = openai_api_protocol.ChatCompletionStreamResponse(
204 | id=request_id,
205 | choices=[choice_data],
206 | model=model_name,
207 | )
208 |
209 | return response.model_dump_json(exclude_unset=True)
210 |
211 | async def completion_stream_generator() -> AsyncGenerator[str, None]:
212 | # First chunk with role
213 | for i in range(request.n):
214 | choice_data = openai_api_protocol.ChatCompletionResponseStreamChoice(
215 | index=i,
216 | delta=openai_api_protocol.DeltaMessage(role="assistant"),
217 | finish_reason=None,
218 | )
219 | chunk = openai_api_protocol.ChatCompletionStreamResponse(id=request_id,
220 | choices=[choice_data],
221 | model=model_name)
222 |
223 | yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"
224 |
225 | previous_texts = [""] * request.n
226 | previous_num_tokens = [0] * request.n
227 |
228 | stream_index = 0
229 | final_res = None
230 | is_first = True
231 | async for res in result_generator:
232 | stream_index += 1
233 | final_res = res
234 |
235 | for output in res.outputs:
236 | # stream on end or every stream_period
237 | if (stream_index % model.stream_period == 0) or (output.finish_reason is not None):
238 | i = output.index
239 | delta_text = output.text[len(previous_texts[i]):]
240 | if "\ufffd" not in delta_text:
241 | previous_texts[i] = output.text
242 | previous_num_tokens[i] = len(output.token_ids)
243 |
244 | if is_first:
245 | # Strip first space
246 | is_first = False
247 | delta_text = _strip_first_space(delta_text)
248 |
249 | yield f"data: {create_stream_response_json(index=i, text=delta_text)}\n\n"
250 | if output.finish_reason is not None:
251 | yield f"data: {create_stream_response_json(index=i, text='', finish_reason=output.finish_reason)}\n\n"
252 |
253 | yield "data: [DONE]\n\n"
254 |
255 | # Log request
256 | background_tasks.add_task(log_request, created_time, request, final_res)
257 |
258 | # Streaming response
259 | if request.stream:
260 | return StreamingResponse(completion_stream_generator(),
261 | media_type="text/event-stream")
262 |
263 | # Non-streaming response
264 | final_res: RequestOutput = None
265 | async for res in result_generator:
266 | if await raw_request.is_disconnected():
267 | # Abort the request if the client disconnects.
268 | await engine.abort(request_id)
269 | return create_error_response(HTTPStatus.BAD_REQUEST,
270 | "Client disconnected")
271 | final_res = res
272 | assert final_res is not None
273 | choices = []
274 | for output in final_res.outputs:
275 | choice_data = openai_api_protocol.ChatCompletionResponseChoice(
276 | index=output.index,
277 | message=openai_api_protocol.ChatMessage(role="assistant", content=_strip_first_space(output.text)),
278 | finish_reason=output.finish_reason,
279 | )
280 | choices.append(choice_data)
281 |
282 | num_prompt_tokens = len(final_res.prompt_token_ids)
283 | num_generated_tokens = sum(
284 | len(output.token_ids) for output in final_res.outputs)
285 | usage = openai_api_protocol.UsageInfo(
286 | prompt_tokens=num_prompt_tokens,
287 | completion_tokens=num_generated_tokens,
288 | total_tokens=num_prompt_tokens + num_generated_tokens,
289 | )
290 | response = openai_api_protocol.ChatCompletionResponse(
291 | id=request_id,
292 | created=created_time,
293 | model=model_name,
294 | choices=choices,
295 | usage=usage,
296 | )
297 |
298 | # Log request
299 | background_tasks.add_task(log_request, created_time, request, final_res)
300 |
301 | return response
302 |
303 |
304 | if __name__ == "__main__":
305 | parser = argparse.ArgumentParser(description="OpenChat OpenAI-Compatible RESTful API server.")
306 |
307 | # Model
308 | parser.add_argument("--model-type", type=str, default=None, help="Model type. Leave empty to auto-detect.")
309 |
310 | parser.add_argument("--stream-period", type=int, default=6, help="Number of tokens per stream event")
311 | parser.add_argument("--api-keys", type=str, nargs="*", default=[], help="Allowed API Keys. Leave blank to not verify")
312 | parser.add_argument("--enable-sys-prompt", default=False, action="store_true")
313 |
314 | # Server
315 | parser.add_argument("--host", type=str, default="localhost", help="Host name")
316 | parser.add_argument("--port", type=int, default=18888, help="Port number")
317 | parser.add_argument("--allow-credentials", action="store_true", help="Allow credentials")
318 | parser.add_argument("--allowed-origins", type=json.loads, default=["*"], help="Allowed origins")
319 | parser.add_argument("--allowed-methods", type=json.loads, default=["*"], help="Allowed methods")
320 | parser.add_argument("--allowed-headers", type=json.loads, default=["*"], help="Allowed headers")
321 |
322 | # Logging
323 | parser.add_argument("--log-file", type=str, default=None, help="Log file. Leave blank to disable logging")
324 | parser.add_argument("--log-max-mb", type=int, default=128, help="Max log size in MB")
325 | parser.add_argument("--log-max-count", type=int, default=10, help="Max log file versions to keep")
326 |
327 | parser = AsyncEngineArgs.add_cli_args(parser)
328 | args = parser.parse_args()
329 |
330 | # App and logging
331 | app.add_middleware(
332 | CORSMiddleware,
333 | allow_origins=args.allowed_origins,
334 | allow_credentials=args.allow_credentials,
335 | allow_methods=args.allowed_methods,
336 | allow_headers=args.allowed_headers,
337 | )
338 |
339 | if args.log_file:
340 | logger = logging.getLogger(__name__)
341 |
342 | logger.setLevel(logging.INFO)
343 | logger.addHandler(RotatingFileHandler(
344 | args.log_file,
345 | maxBytes=args.log_max_mb * 1048576,
346 | backupCount=args.log_max_count)
347 | )
348 | logger.propagate = False
349 |
350 | # Load model type
351 | if args.model_type is None:
352 | with open(cached_file(path_or_repo_id=args.model, filename="openchat.json"), "r") as f:
353 | args.model_type = json.load(f)["model_type"]
354 |
355 | # Load tokenizer
356 | tokenizer = async_tokenizer.AsyncTokenizer.remote(args.model_type, args.model)
357 |
358 | # Model config
359 | model.names = set(list(MODEL_CONFIG_MAP[args.model_type].serving_aliases) + [args.model_type])
360 | model.max_length = MODEL_CONFIG_MAP[args.model_type].model_max_context
361 | model.eot_tokens = ray.get(tokenizer.get_eot_tokens.remote())
362 |
363 | model.enable_sys_prompt = args.enable_sys_prompt
364 | model.stream_period = args.stream_period
365 | model.api_keys = args.api_keys
366 |
367 | # Set max num batched tokens
368 | args.max_num_batched_tokens = max(args.max_num_batched_tokens or model.max_length, model.max_length)
369 | args.max_model_len = model.max_length
370 |
371 | # Load model engine
372 | engine_args = AsyncEngineArgs.from_cli_args(args)
373 | engine = AsyncLLMEngine.from_engine_args(engine_args)
374 | engine_model_config = asyncio.run(engine.get_model_config())
375 |
376 | # Run
377 | uvicorn.run(app,
378 | host=args.host,
379 | port=args.port,
380 | log_level="info",
381 | access_log=False,
382 | timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
383 |
--------------------------------------------------------------------------------
/ochat/training_deepspeed/deepspeed_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": true
4 | },
5 |
6 | "zero_optimization": {
7 | "stage": 2
8 | },
9 |
10 | "gradient_clipping": 1.0,
11 | "gradient_accumulation_steps": 1,
12 | "train_micro_batch_size_per_gpu": 1,
13 |
14 | "steps_per_print": 100,
15 | "wall_clock_breakdown": false
16 | }
--------------------------------------------------------------------------------
/ochat/training_deepspeed/hf_hub.py:
--------------------------------------------------------------------------------
1 | import shlex
2 | import subprocess
3 |
4 | from huggingface_hub import HfApi
5 |
6 |
7 | def hub_upload_check(push_to_hub: str):
8 | if push_to_hub is not None:
9 | # Try creating a test repo
10 | test_repo_name = f"{push_to_hub}-dummy-test"
11 | try:
12 | HfApi().create_repo(
13 | repo_id=test_repo_name,
14 | repo_type="model",
15 | private=True
16 | )
17 | HfApi().delete_repo(
18 | repo_id=test_repo_name,
19 | repo_type="model"
20 | )
21 | except Exception as e:
22 | raise RuntimeError(f"Failed to push test repo {test_repo_name} to HuggingFace Hub. Please check your permissions and network connection. Use `huggingface-cli login` to log in your account.\n\n{str(e)}")
23 |
24 |
25 | def hub_upload_model_async(push_to_hub: str, push_to_hub_delete_local: bool, save_path: str, epoch: int):
26 | if push_to_hub is not None:
27 | safe_repo_name = shlex.quote(f"{push_to_hub}-ep-{epoch}")
28 | safe_save_path = shlex.quote(save_path)
29 |
30 | command = f"huggingface-cli upload --quiet --repo-type model --private {safe_repo_name} {safe_save_path}"
31 | if push_to_hub_delete_local:
32 | command += f" && rm -rf {safe_save_path}"
33 |
34 | subprocess.Popen(command, shell=True)
35 |
--------------------------------------------------------------------------------
/ochat/training_deepspeed/multipack_sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import numba
3 |
4 |
5 | @numba.njit
6 | def ffd_check(a: np.ndarray, c: int, n: int):
7 | # First-fit-decreasing bin packing
8 | # Check if a[] could fit in n bins with capacity c
9 | # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
10 |
11 | a = np.sort(a)[::-1]
12 | bins = np.full((n, ), c, dtype=a.dtype)
13 | for size in a:
14 | not_found = True
15 | for idx in range(n):
16 | if bins[idx] >= size:
17 | bins[idx] -= size
18 | not_found = False
19 | break
20 |
21 | if not_found:
22 | return False
23 |
24 | return True
25 |
26 |
27 | @numba.njit
28 | def ffd_with_result(a: np.ndarray, c: int, start_index: int):
29 | # First-fit-decreasing bin packing (with result return)
30 |
31 | indices = np.argsort(a)[::-1]
32 | a = a[indices]
33 |
34 | bins = []
35 | bins_result = []
36 | for a_id, size in enumerate(a):
37 | add_new = True
38 | for idx in range(len(bins)):
39 | if bins[idx] >= size:
40 | bins[idx] -= size
41 | bins_result[idx].append(indices[a_id] + start_index)
42 | add_new = False
43 | break
44 |
45 | if add_new:
46 | bins.append(c - size)
47 | bins_result.append([indices[a_id] + start_index])
48 |
49 | return bins_result
50 |
51 |
52 | @numba.njit
53 | def allocate(lengths: np.ndarray, numseqs: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int):
54 | # Dynamic batch allocator, similar to Multifit
55 | # https://en.wikipedia.org/wiki/Multifit_algorithm
56 | # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
57 |
58 | s = 0
59 | start_index = 0
60 | result = []
61 | result_totseqs = []
62 |
63 | while True:
64 | # binary search [l, r)
65 | l = 1
66 | r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
67 |
68 | while r - l > 1:
69 | m = (l + r) // 2
70 | if ffd_check(lengths[start_index: start_index + m], c, n):
71 | l = m
72 | else:
73 | r = m
74 |
75 | # use length l
76 | batch = ffd_with_result(lengths[start_index: start_index + l], c, start_index) # type: ignore
77 | if len(batch) < n:
78 | break
79 |
80 | start_index += l
81 | s = lengths_cumsum[start_index - 1]
82 |
83 | # add local rank
84 | result.append(batch[rank])
85 | # add total seqs for all ranks
86 | totseq = 0
87 | for indices in batch:
88 | for idx in indices:
89 | totseq += numseqs[idx]
90 | result_totseqs.append(totseq)
91 |
92 | return result, result_totseqs, s, len(result) * c * n
93 |
94 |
95 | class MultipackDistributedSampler:
96 | """Unpadded data loading using Multipack.
97 | Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard."""
98 |
99 | def __init__(
100 | self,
101 | lengths: np.ndarray,
102 | numseqs: np.ndarray,
103 |
104 | batch_max_length: int,
105 |
106 | num_replicas: int,
107 | rank: int,
108 |
109 | seed: int,
110 | ):
111 | # Dataset
112 | self.lengths = lengths
113 | self.numseqs = numseqs
114 | assert isinstance(self.lengths, np.ndarray)
115 |
116 | self.batch_max_length = batch_max_length
117 |
118 | # Get rank
119 | self.num_replicas = num_replicas
120 | self.rank = rank
121 |
122 | # Seed
123 | self.seed = seed
124 |
125 | # statistics
126 | self.eff_total_used = 0
127 | self.eff_total_slots = 0
128 |
129 | def generate_batches(self, epoch, set_stats=False):
130 | indices = np.random.default_rng(seed=self.seed + epoch).permutation(len(self.lengths))
131 |
132 | lengths = self.lengths[indices]
133 | numseqs = self.numseqs[indices]
134 | lengths_cumsum = np.cumsum(lengths)
135 |
136 | batches, totseqs, total_used, total_slots = allocate(lengths=lengths,
137 | numseqs=numseqs,
138 | lengths_cumsum=lengths_cumsum,
139 | rank=self.rank,
140 | c=self.batch_max_length,
141 | n=self.num_replicas)
142 |
143 | curseqs = [np.sum(numseqs[batch]) for batch in batches]
144 | batches = [indices[batch] for batch in batches]
145 |
146 | # statistics
147 | if set_stats:
148 | self.eff_total_used += total_used
149 | self.eff_total_slots += total_slots
150 |
151 | return batches, totseqs, curseqs
152 |
153 | def iter(self, epoch):
154 | all_batches, all_totseqs, all_curseqs = self.generate_batches(epoch, set_stats=True)
155 |
156 | for batch, totseq, curseq in zip(all_batches, all_totseqs, all_curseqs):
157 | yield batch, totseq, curseq
158 |
159 | def estimate_num_batches(self):
160 | batches, _, _ = self.generate_batches(epoch=0)
161 | return len(batches)
162 |
163 | def efficiency(self):
164 | return self.eff_total_used / self.eff_total_slots
165 |
--------------------------------------------------------------------------------
/ochat/training_deepspeed/openchat_dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.utils.data import IterableDataset, get_worker_info
4 |
5 | import pyarrow.parquet as pq
6 | import orjson
7 |
8 | from ochat.training_deepspeed.multipack_sampler import MultipackDistributedSampler
9 |
10 |
11 | def _find_multiple(a, b):
12 | return (-(a // -b)) * b
13 |
14 |
15 | class OpenchatDataset(IterableDataset):
16 | def __init__(self, dataset_filename, batch_max_length, rank, num_replicas):
17 | super().__init__()
18 | # Init constants
19 | self.PAD_ID = 0
20 | self.PAD_MULTIPLE = 64
21 | self.BATCH_KEYS = {
22 | "seqlens": torch.int32,
23 | "nz_input_ids": torch.long,
24 | "nz_position_ids": torch.long,
25 | "nz_shifted_label_ids": torch.long,
26 |
27 | "nz_shifted_loss_weights": torch.bfloat16
28 | }
29 |
30 | assert batch_max_length % self.PAD_MULTIPLE == 0, f"Batch size {batch_max_length} need to be multiples of {self.PAD_MULTIPLE}"
31 |
32 | # Load data
33 | # Convert parquet to numpy for fast random access
34 | table = pq.read_table(dataset_filename, memory_map=True)
35 | self.dataset = {k: v.to_numpy() for k, v in zip(table.column_names, table.columns)}
36 |
37 | # read metadata
38 | self.metadata = table.schema.metadata.get(b"metadata_json", None)
39 | if self.metadata is not None:
40 | self.metadata = orjson.loads(self.metadata)
41 |
42 | # Free table space
43 | del table
44 |
45 | # Create sampler
46 | self.sampler = MultipackDistributedSampler(
47 | lengths=self.dataset["total_length"],
48 | numseqs=self.dataset["num_seqs"],
49 |
50 | batch_max_length=batch_max_length,
51 |
52 | rank=rank,
53 | num_replicas=num_replicas,
54 | seed=0
55 | )
56 |
57 | # Init state
58 | self._epoch = 0
59 |
60 | def _load_batch(self, indices):
61 | batch = {k: v[indices] for k, v in self.dataset.items()}
62 |
63 | # Concat batches
64 | batch = {k: np.concatenate(batch[k], axis=0) for k in self.BATCH_KEYS.keys()}
65 |
66 | # Pad an unused item to reach multiple of PAD_MULTIPLE, for faster GEMM
67 | total_seqlen = batch["nz_input_ids"].size
68 | pad_len = _find_multiple(total_seqlen, self.PAD_MULTIPLE) - total_seqlen
69 |
70 | if pad_len > 0:
71 | assert pad_len < self.PAD_MULTIPLE
72 |
73 | # total length
74 | padding_specs = {
75 | "seqlens": (1, pad_len),
76 |
77 | "nz_input_ids": (pad_len, self.PAD_ID),
78 | "nz_position_ids": (pad_len, 0),
79 | "nz_shifted_label_ids": (pad_len, self.PAD_ID),
80 | "nz_shifted_loss_weights": (pad_len, 0),
81 | }
82 | for k, pad_spec in padding_specs.items():
83 | batch[k] = np.concatenate((batch[k], np.full(*pad_spec, dtype=batch[k].dtype)), axis=0)
84 |
85 | # to tensor
86 | batch_tensor = {}
87 | for k, dtype in self.BATCH_KEYS.items():
88 | batch_tensor[k] = torch.from_numpy(batch[k]).to(dtype)
89 |
90 | # cu seqlens
91 | batch_tensor["cu_seqlens"] = torch.nn.functional.pad(batch_tensor["seqlens"].cumsum(-1, dtype=torch.int32), (1, 0))
92 | # batch info
93 | batch_info = {"max_seqlen": torch.max(batch_tensor["seqlens"]).item()}
94 |
95 | # inputs
96 | del batch_tensor["seqlens"]
97 | return batch_tensor, batch_info
98 |
99 | def __iter__(self):
100 | worker_info = get_worker_info()
101 | assert worker_info is None or worker_info.num_workers == 1
102 |
103 | for indices, all_numseq, cur_numseq in self.sampler.iter(self._epoch):
104 | yield self._load_batch(indices), all_numseq, cur_numseq
105 |
106 | # Increase epoch count
107 | self._epoch += 1
108 |
109 | def estimate_num_batches(self):
110 | return self.sampler.estimate_num_batches()
111 |
--------------------------------------------------------------------------------
/ochat/training_deepspeed/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import math
4 | import json
5 | from functools import partial
6 |
7 | import torch
8 | import torch.distributed as dist
9 | from torch.utils.data import DataLoader
10 |
11 | import tqdm
12 | import wandb
13 | import numpy as np
14 |
15 | from ochat.config import MODEL_CONFIG_MAP
16 | from ochat.training_deepspeed.openchat_dataset import OpenchatDataset
17 | from ochat.training_deepspeed.hf_hub import hub_upload_check, hub_upload_model_async
18 |
19 | try:
20 | import deepspeed
21 | except ImportError:
22 | raise ImportError("Please install deepspeed to train models.")
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser()
27 | # Distributed
28 | parser.add_argument("--local_rank", type=int, required=True)
29 |
30 | # Model type and data
31 | parser.add_argument("--model_path", type=str, required=True)
32 | parser.add_argument("--data_prefix", type=str, required=True)
33 | parser.add_argument("--save_path", type=str, required=True)
34 | parser.add_argument("--save_every", type=int, default=None)
35 | parser.add_argument("--push_to_hub", type=str, default=None,
36 | help="Specify repository prefix for pushing to HuggingFace Hub. "
37 | "For example, 'openchat/openchat-3.6' will create repositories "
38 | "like 'openchat/openchat-3.6-ep0', 'openchat/openchat-3.6-ep1', ..."
39 | "If not specified, will not push to Hub.")
40 | parser.add_argument("--push_to_hub_delete_local", action="store_true")
41 |
42 | # Hyperparameters
43 | parser.add_argument("--batch_max_len", type=int, default=81920)
44 | parser.add_argument("--epochs", type=int, default=5)
45 |
46 | # Set lr to None to automatically estimate from LLaMA pretraining parameters (e.g. lr ~ sqrt(batch_size))
47 | parser.add_argument("--lr", type=float, default=None)
48 | parser.add_argument("--lr_min_ratio", type=float, default=0.1)
49 | parser.add_argument("--lr_warmup_ratio", type=int, default=0.05)
50 |
51 | parser.add_argument("--weight_decay", type=float, default=0.1)
52 |
53 | parser.add_argument("--beta1", type=float, default=0.9)
54 | parser.add_argument("--beta2", type=float, default=0.95)
55 | parser.add_argument("--eps", type=float, default=1e-5)
56 |
57 | parser.add_argument("--wandb_entity", type=str, default=None)
58 | parser.add_argument("--wandb_project", type=str, default=None)
59 |
60 | # DeepSpeed parameters
61 | parser = deepspeed.add_config_arguments(parser)
62 |
63 | # Parse known args
64 | args, unknown = parser.parse_known_args()
65 | return args
66 |
67 |
68 | def create_dataset_and_dataloader(args, epoch: int):
69 | # Find data
70 | filename = f"{args.data_prefix}.{epoch}.parquet"
71 |
72 | # Create dataset and dataloader
73 | print(f"Loading epoch {epoch} data from {filename}...")
74 |
75 | dataset = OpenchatDataset(
76 | dataset_filename=filename,
77 |
78 | batch_max_length=args.batch_max_len,
79 | rank=dist.get_rank(),
80 | num_replicas=dist.get_world_size()
81 | )
82 | dataloader = DataLoader(
83 | dataset,
84 | batch_size=None,
85 |
86 | num_workers=1,
87 | prefetch_factor=8,
88 |
89 | pin_memory=True
90 | )
91 | return dataset, dataloader
92 |
93 |
94 | def create_model(args):
95 | print(f"Loading model {args.model_type} from {args.model_path}...")
96 |
97 | # Create model + optimizer + lr scheduler
98 | model = MODEL_CONFIG_MAP[args.model_type].model_create_for_training(args.model_path)
99 | # Model to assigned cuda device
100 | model = model.to(args.local_rank)
101 | # Enable gradient checkpointing
102 | model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=dict(
103 | use_reentrant=False
104 | ))
105 |
106 | # Optimizer
107 | optimizer = torch.optim.AdamW(model.parameters(),
108 | lr=args.lr,
109 | weight_decay=args.weight_decay,
110 | betas=(args.beta1, args.beta2),
111 | eps=args.eps,
112 | fused=True)
113 |
114 | # DeepSpeed model
115 | model_engine, optimizer, _, _ = deepspeed.initialize(args=args,
116 | model=model,
117 | model_parameters=model.parameters(),
118 | optimizer=optimizer)
119 |
120 | # Put deepspeed arguments
121 | args.device = model_engine.device
122 |
123 | return model_engine, optimizer
124 |
125 |
126 | def cosine_schedule_with_warmup_lr_lambda(
127 | current_step: int, *, num_warmup_steps: int, num_training_steps: int, min_ratio: float = 0.0, num_cycles: float = 0.5
128 | ):
129 | if current_step < num_warmup_steps:
130 | return float(current_step) / float(max(1, num_warmup_steps))
131 |
132 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
133 | return min_ratio + max(0.0, (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
134 |
135 |
136 | def create_lr_scheduler(args, train_total_steps):
137 | lr_scheduler = partial(
138 | cosine_schedule_with_warmup_lr_lambda,
139 |
140 | num_warmup_steps=round(args.lr_warmup_ratio * train_total_steps),
141 | num_training_steps=train_total_steps,
142 | min_ratio=args.lr_min_ratio
143 | )
144 |
145 | return lr_scheduler
146 |
147 |
148 | def save_tokenizer(args, save_path):
149 | model_config = MODEL_CONFIG_MAP[args.model_type]
150 | tokenizer = model_config.model_tokenizer_create(args.model_path)
151 | tokenizer.chat_template = model_config.hf_chat_template
152 | tokenizer.save_pretrained(save_path)
153 |
154 |
155 | def save_openchat_metadata(args, epoch, save_path):
156 | metadata = vars(args)
157 | metadata["epoch"] = epoch
158 |
159 | with open(os.path.join(save_path, "openchat.json"), "w") as f:
160 | json.dump(metadata, f, default=lambda o: "")
161 |
162 |
163 | def calculate_auto_lr(lr, batch_max_len, model_type, train_dataset):
164 | if lr is not None:
165 | return lr
166 |
167 | # Llama hyperparameters
168 | # FIXME: Only 7B/13B is supported
169 | base_lr = 3e-4
170 | base_bs = 4_000_000
171 | if "mistral" in model_type.lower():
172 | base_lr /= 6.0
173 | elif "gemma" in model_type.lower():
174 | base_lr /= 5.5 # NOTE(one): Maybe MLP and Attn layers are using different lr?
175 | elif "openchat_3.6" in model_type.lower(): # Llama 3 estimated hyperparams
176 | # NOTE(one): Estimated divisor: 1.5 * sqrt(25000 H100s / 2000 H100s)
177 | base_lr /= 5.3
178 |
179 | loss_weights = np.concatenate(train_dataset.dataset["nz_shifted_loss_weights"])
180 | supervised_ratio = np.sum(loss_weights != 0) / len(loss_weights)
181 |
182 | supervised_tokens = batch_max_len * dist.get_world_size() * supervised_ratio
183 | lr = base_lr * math.sqrt(supervised_tokens / base_bs)
184 |
185 | print(f"Use automatic learning rate {lr} (estimated from supervised ratio {supervised_ratio} effective batch size {supervised_tokens})")
186 | return lr
187 |
188 |
189 | def state_dict_to_cpu(item, device=torch.device('cpu')):
190 | # Move all tensors to CPU
191 | if torch.is_tensor(item):
192 | return item.detach().to(device)
193 | elif isinstance(item, list):
194 | return [state_dict_to_cpu(v, device) for v in item]
195 | elif isinstance(item, tuple):
196 | return tuple([state_dict_to_cpu(v, device) for v in item])
197 | elif isinstance(item, dict):
198 | return type(item)({k: state_dict_to_cpu(v, device) for k, v in item.items()})
199 | else:
200 | return item
201 |
202 |
203 | def train():
204 | deepspeed.init_distributed(dist_backend="nccl")
205 | RANK = dist.get_rank()
206 | WORLD_SIZE = dist.get_world_size()
207 |
208 | # Args
209 | args = parse_args()
210 |
211 | hub_upload_check(args.push_to_hub)
212 |
213 | # Dataset
214 | train_dataset, train_loader = create_dataset_and_dataloader(args, 0)
215 |
216 | if train_dataset is None:
217 | raise RuntimeError("Training data not found.")
218 |
219 | # Load model type
220 | args.model_type = train_dataset.metadata["model_type"]
221 |
222 | train_total_steps = args.epochs * train_dataset.estimate_num_batches()
223 |
224 | # Hyperparams
225 | args.lr = calculate_auto_lr(args.lr, args.batch_max_len, args.model_type, train_dataset)
226 |
227 | # Model
228 | model_engine, optimizer = create_model(args)
229 |
230 | # LR Scheduler
231 | lr_scheduler = create_lr_scheduler(args, train_total_steps)
232 |
233 | # Progress bar and logger
234 | progress_bar = None
235 | if RANK == 0:
236 | progress_bar = tqdm.tqdm(total=train_total_steps)
237 |
238 | wandb.init(project=args.wandb_project or os.path.basename(args.model_path), entity=args.wandb_entity, config=args)
239 |
240 | # Training Loop
241 | step = 0
242 | lr_this_step = None
243 | for epoch in range(args.epochs):
244 | print (f"[rank {RANK} of {WORLD_SIZE}]: Epoch {epoch}")
245 |
246 | ############ Load Dataset
247 | if epoch != 0:
248 | del train_dataset, train_loader
249 |
250 | train_dataset, train_loader = create_dataset_and_dataloader(args, epoch)
251 |
252 | ############ Train Epoch
253 | model_engine.train()
254 | for (batch_tensor, batch_info), all_numseq, cur_numseq in train_loader:
255 | step += 1
256 | if step > train_total_steps: # At most train_total_steps
257 | break
258 |
259 | # To device
260 | batch_tensor = {k: (v.to(args.device) if v is not None else None) for k, v in batch_tensor.items()}
261 |
262 | # Update
263 | loss, acc = model_engine(**batch_tensor, **batch_info).loss
264 | loss = (WORLD_SIZE / all_numseq) * loss
265 | acc = (WORLD_SIZE / all_numseq) * acc
266 |
267 | model_engine.backward(loss)
268 |
269 | if model_engine.is_gradient_accumulation_boundary():
270 | # Set LR
271 | lr_this_step = args.lr * lr_scheduler(step)
272 | for param_group in optimizer.param_groups:
273 | param_group['lr'] = lr_this_step
274 |
275 | model_engine.step()
276 |
277 | # Logging
278 | if RANK == 0:
279 | wandb.log({
280 | "train/loss": loss.item() * (all_numseq / (cur_numseq * WORLD_SIZE)),
281 | "train/acc": acc.item() * (all_numseq / (cur_numseq * WORLD_SIZE)),
282 | "train/lr": lr_this_step
283 | }, step=step)
284 | progress_bar.update() # type: ignore
285 |
286 | ############ Save Checkpoint
287 | # Save model with lean state dict
288 | # https://deepspeed.readthedocs.io/en/latest/model-checkpointing.html
289 | if (epoch + 1 == args.epochs) or (args.save_every and ((epoch + 1) % args.save_every == 0)):
290 | if RANK == 0:
291 | save_path = os.path.join(args.save_path, f"ep_{epoch}")
292 |
293 | model_engine.module.save_pretrained(save_path,
294 | state_dict=state_dict_to_cpu(model_engine.module.state_dict())) # type: ignore
295 |
296 | # Also save tokenizer from base model
297 | save_tokenizer(args, save_path)
298 |
299 | # Write metadata
300 | save_openchat_metadata(args, epoch, save_path)
301 |
302 | # Upload to hub
303 | hub_upload_model_async(
304 | args.push_to_hub,
305 | args.push_to_hub_delete_local,
306 | save_path,
307 | epoch
308 | )
309 |
310 |
311 | if __name__ == "__main__":
312 | train()
313 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "ochat"
7 | description = "An efficient framework for training and serving top-tier, open-source conversational LLMs."
8 | readme = "README.md"
9 | requires-python = ">=3.8"
10 | dynamic = ["version"]
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "colorama",
17 | "beautifulsoup4",
18 | "markdownify",
19 | "pylatexenc",
20 | "sympy",
21 | "openai>=1",
22 | "tenacity",
23 | "tiktoken",
24 | "tqdm",
25 | "wandb",
26 | "numba",
27 | "datasets",
28 | "orjson",
29 | "torch",
30 | "packaging",
31 | "ninja",
32 | "flash-attn",
33 | "ray",
34 | "sentencepiece",
35 | "transformers>=4.40.1",
36 | "accelerate",
37 | "protobuf",
38 | "fastapi",
39 | "pydantic",
40 | "shortuuid",
41 | "uvicorn",
42 | "vllm>=0.4.0",
43 | "pytest"
44 | ]
45 |
46 | [project.urls]
47 | "Homepage" = "https://github.com/imoneoi/openchat"
48 | "Bug Tracker" = "https://github.com/imoneoi/openchat/issues"
49 |
50 | [tool.setuptools.packages.find]
51 | exclude = ["assets*", "ochat/experimental*"]
52 |
53 | [tool.wheel]
54 | exclude = ["assets*", "ochat/experimental*"]
55 |
56 | [tool.setuptools_scm]
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | markers =
3 | cpu: CPU Tests
4 | gpu: GPU Tests
5 |
--------------------------------------------------------------------------------