├── .gitignore
├── LICENSE
├── README.md
├── assets
└── logo-final.png
├── examples
├── pipelines
│ ├── combined_filter.py
│ ├── embed_datasets.py
│ ├── embed_datasets.sh
│ ├── filter_datasets.sh
│ ├── score_complexity.py
│ ├── score_complexity_dataset.py
│ ├── score_complexity_dataset.sh
│ ├── score_quality.py
│ ├── score_quality_dataset.sh
│ ├── score_vllm.py
│ └── utils.py
└── train
│ ├── dpo.sh
│ ├── sft.sh
│ └── train_scorers.sh
├── requirements.txt
├── sample_little.py
├── setup.py
└── src
└── deita
├── __init__.py
├── alignment
├── __init__.py
├── constants.py
├── conversation.py
├── dpo_train.py
├── flash_attn
│ ├── bloom_flash_attention.py
│ └── triton_flash_attention.py
├── train.py
└── train_scorers.py
├── data
└── sample_ultrafeedback.py
├── ds_configs
├── deepspeed_config_zero2_no_offload.json
├── deepspped_llama_x.json
└── stage3_no_offloading_accelerate.json
├── pipeline
├── __init__.py
├── base.py
├── embed_pipeline.py
├── filter_pipeline.py
├── score_pipeline.py
└── utils.py
└── selection
├── __init__.py
├── embedder
├── __init__.py
├── base.py
├── clm_embedder.py
├── conversation.py
└── utils.py
├── filter
├── __init__.py
├── base.py
├── combined_filter.py
└── utils.py
└── scorer
├── __init__.py
├── base.py
├── llama_scorer.py
└── mistral_scorer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | *egg-info*
3 | *build*
4 | *dist*
5 | *egg-info
6 | .history
7 | data/
8 | outputs/
9 | logs/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deita
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | 🤗 HF Repo
10 | 📄 Paper
11 | 📚 6K Data
12 | 📚 10K Data
13 |
14 |
15 |
16 | Welcome to Deita (**D**ata-**E**fficient **I**nstruction **T**uning for **A**lignment) Project!
17 |
18 | We will continue to update, please stay tuned!
19 |
20 |
21 | ## What is Deita?
22 | Deita is an open-sourced project designed to facilitate **Automatic Data Selection** for instruction tuning in Large Language Models (LLMs).
23 |
24 | It includes:
25 | - **Open-sourced Toolkits** for automatic data selection in instruction tuning
26 | - **Deita Datasets**: A series of extremely *lightweight*, high-quality alignment SFT data. We release 6k-sized and 10k-sized datasets in the first release
27 | - **Deita Models**: A series of powerful models on par with SOTA chat LLMs with an extremely efficient instruction tuning Process. Deita models can be obained by training with 10x less instruction tuning data compared with other SOTA LLMs
28 |
29 | ## News
30 | - :fire: [03/2024] Our datasets have been used by Huggingface to creat the [Zephyr Gemma Model](https://huggingface.co/collections/HuggingFaceH4/zephyr-7b-gemma-65e1fd82d26b426e3e63d956).
31 | - 📄 [01/2024] Deita paper [What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning](https://arxiv.org/abs/2312.15685) has been accepted by ICLR2024!
32 | - :fire: [01/2024] [Deita pipelines](#deita-pipelines) have been released! With one line code and configurations, a high-quality data subset for alignment can be selected.
33 | - 📚 [01/2024] Our scorer datasets [deita-complexity-scorer-data](https://huggingface.co/datasets/hkust-nlp/deita-complexity-scorer-data) and [deita-quality-scorer-data](https://huggingface.co/datasets/hkust-nlp/deita-quality-scorer-data) have been released.
34 | - :fire: [12/2023] We release the first collection of the Deita resources [here](https://huggingface.co/collections/hkust-nlp/deita-6569c198c174808d94cf5bd4), which include a series of extremely lightweight, effective sft datasets, the data complexity/quality scorer models, as well as the resulted deita chat models.
35 |
36 | ## Performance
37 | :bell: Still curious about how far a small amount of high-quality data can lead LLMs?
38 |
39 | Deita may provide an answer for you:
40 |
41 | **🔦 Highlights**
42 | | Model | Align | Data Size | MT-Bench | AlpacaEval(%) |
43 | |------------------------------------------------|--------------|------------|----------|---------------|
44 | | Zephyr-7B-sft | SFT | 200K | 5.32 | 75.12 |
45 | | $\text{Zephyr-7B-}\beta$ | SFT + DPO | 200K SFT + 60K DPO | 7.34 | 90.60 |
46 | | OpenChat-3.5 | C-RLFT | >> 70K C-RLFT | 7.81 | 88.51 |
47 | | Starling-7B | C-RLFT + APA | >> 70K C-RLFT + 183K APA | 8.09 | 91.99 |
48 | | Tulu-2-13B | SFT | 326K | 6.70 | 78.90 |
49 | | Tulu-2-13B+DPO | SFT + DPO | 326K SFT + 60K DPO | 7.00 | 89.50 |
50 | | LLaMA2-13B-Chat | SFT + PPO | -- | 6.65 | 81.09 |
51 | | WizardLM-13B-v1.2 | SFT | >70K | 7.09 | 89.17 |
52 | | Vicuna-13B-v1.5 | SFT | >125K | 6.57 | 78.80 |
53 | | DEITA-7B-v1.0 (6K) | SFT | 6K | 7.22 | 80.78 |
54 | | DEITA-7B-v1.0-sft | SFT | 10K | 7.32 | 81.67 |
55 | | DEITA-7B-v1.0 | SFT + DPO | 6K SFT + 10K DPO | 7.55 | 90.06 |
56 |
57 | DEITA models are based on Mistral-7B-v0.1. :fire:
58 |
59 | Please refer to [this table](#chart\_with\_upwards\_trend-full-evaluations) for full evaluations including Open LLM Leaderboard as well, which includes DEITA models with LLaMA base models and comparisons with other data selection approaches.
60 |
61 |
62 |
63 | ## :chart_with_upwards_trend: Full Evaluations
64 |
65 |
66 | See full evaluations
67 |
68 | | Model | Align | Data Size | MT-Bench | AlpacaEval(%) | OpenLLM (Avg.) |
69 | |------------------------------------------------|-----------|------------|----------|---------------|----------------|
70 | | **Proprietary Models** | | | | | |
71 | | GPT-4-Turbo | ? | -- | 9.32 | 97.70 | -- |
72 | | GPT-4 | SFT + PPO | -- | 8.99 | 95.03 | -- |
73 | | Claude-2 | SFT + PPO | -- | 8.06 | 91.36 | -- |
74 | | GPT-3.5-turbo | SFT + PPO | -- | 7.94 | 89.37 | -- |
75 | | **Open-sourced Models based on LLaMA-1-13B** | | | | | |
76 | | LIMA | SFT | 1K SFT | 4.29 | 41.98 | 59.82 |
77 | | WizardLM-13B | SFT | 70K SFT | 6.35 | 75.31 | 58.96 |
78 | | Vicuna-13B-v1.3 | SFT | 125K SFT | 6.39 | 82.11 | 60.01 |
79 | | Random | SFT | 10K SFT | 6.03 | 71.52 | 60.14 |
80 | | DEITA-LLaMA1-13B-v1.0-sft | SFT | 10K SFT | 6.60 | 78.01 | 64.27 |
81 | | **Open-sourced Models based on LLaMA-2-13B** | | | | | |
82 | | Tulu-2-13B | SFT | 326K SFT | 6.70 | 78.90 | -- |
83 | | Tulu-2-13B+DPO | SFT + DPO | 326K SFT + 60K DPO | 7.00 | 89.50 | -- |
84 | | LLaMA2-13B-Chat | SFT + PPO | -- | 6.65 | 81.09 | -- |
85 | | WizardLM-13B-v1.2 | SFT | >70K SFT | 7.09 | 89.17 | -- |
86 | | Vicuna-13B-v1.5 | SFT | 125K SFT | 6.57 | 78.80 | 61.63 |
87 | | Random | SFT | 10K SFT | 5.78 | 65.19 | 61.32 |
88 | | DEITA-LLaMA2-13B-v1.0-sft | SFT | 10K SFT | 6.79 | 81.09 | 62.71 |
89 | | **Open-sourced Models based on Mistral-7B** | | | | | |
90 | | Mistral-7B-Instruct-v0.1 | -- | -- | 6.84 | 69.65 | 60.45 |
91 | | Zephyr-7B-sft | SFT | 200K SFT | 5.32 | 75.12 | 60.93 |
92 | | $\text{Zephyr-7B-}\beta$ | SFT + DPO | 200K SFT + 60K DPO | 7.34 | 90.60 | 66.36 |
93 | | OpenChat-3.5 | C-RLFT | >> 70K C-RLFT | 7.81 | 88.51 | -- |
94 | | Starling-7B | C-RLFT + APA | >>70K C-RLFT + 183K APA | 8.09 | 91.99 | -- |
95 | | Random | SFT | 10K SFT | 5.89 | 56.90 | 61.72 |
96 | | DEITA-7B-v1.0-sft (6K) | SFT | 6K SFT | 7.22 | 80.78 | 64.94 |
97 | | DEITA-7B-v1.0-sft (10K) | SFT | 10K SFT | 7.32 | 81.67 | 64.00 |
98 | | DEITA-7B-v1.0 | SFT + DPO | 6K SFT + 10K DPO | 7.55 | 90.06 | 69.86 |
99 |
100 |
101 |
102 |
103 | ## :rocket: Deita Resources
104 |
105 | | Resource | Link | License |
106 | |------------------------------------------------|-----------|------------|
107 | | **Deita Datasets** | | |
108 | | deita-6k-v0 | [:hugs: HF Repo](https://huggingface.co/datasets/hkust-nlp/deita-6k-v0) | [MIT License](https://opensource.org/license/mit/) |
109 | | deita-10k-v0 | [:hugs: HF Repo](https://huggingface.co/datasets/hkust-nlp/deita-10k-v0) | [MIT License](https://opensource.org/license/mit/) |
110 | | deita-complexity-scorer-data | [:hugs: HF Repo](https://huggingface.co/datasets/hkust-nlp/deita-complexity-scorer-data) | [MIT License](https://opensource.org/license/mit/) |
111 | | deita-quality-scorer-data | [:hugs: HF Repo](https://huggingface.co/datasets/hkust-nlp/deita-quality-scorer-data) | [MIT License](https://opensource.org/license/mit/) |
112 | | deita-redundant-pool (100K) | [:hugs: HF Repo](https://huggingface.co/datasets/hkust-nlp/deita-redundant-pool-data) | [MIT License](https://opensource.org/license/mit/) |
113 | | deita-sota-pool (300K) | [:hugs: HF Repo](https://huggingface.co/datasets/AndrewZeng/deita_sota_pool) | [MIT License](https://opensource.org/license/mit/) |
114 | | **Scorers** | | |
115 | | deita-complexity-scorer | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-complexity-scorer) | [LLaMA License](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)|
116 | | deita-quality-scorer | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-quality-scorer) | [LLaMA License](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)|
117 | | **Deita Models** | | |
118 | | DEITA-7B-v1.0-sft | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-7b-v1.0-sft) | [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) |
119 | | DEITA-7B-v1.0 | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-7B-v1.0) | [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) |
120 | | DEITA-LLaMA2-13B-v1.0-sft | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-llama2-13b-v1.0-sft) | [LLaMA 2 License](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) |
121 | | DEITA-LLaMA1-13B-v1.0-sft | [:hugs: HF Repo](https://huggingface.co/hkust-nlp/deita-llama1-13b-v1.0-sft) | [LLaMA License](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) |
122 |
123 | ## :running_man: How to start?
124 |
125 |
126 | ### Installation
127 | ```bash
128 | git clone https://github.com/hkust-nlp/deita.git
129 | cd deita
130 | pip install -e .
131 | ```
132 |
133 | ### Data Sample Scoring
134 |
135 | If you wish to assess the **quality** of a response for a single sample, you can follow these steps:
136 | ```python
137 | from deita.selection.scorer import Llama_Scorer
138 |
139 | model_name_or_path = "hkust-nlp/deita-quality-scorer"
140 |
141 | scorer = Llama_Scorer(model_name_or_path)
142 |
143 | # example input
144 | input_text = "word to describe UI with helpful tooltips" # Example Input
145 | output_text = "User-friendly or intuitive UI" # Example Output
146 | quality_score = scorer.infer_quality(input_text, output_text)
147 |
148 | print(quality_score)
149 | # 2.0230105920381902
150 | ```
151 |
152 | Deita also supports VLLM for faster inference. If you want to use VLLM for inference,
153 |
154 | ```bash
155 | pip install vllm
156 | ```
157 |
158 | And set ```is_vllm = True``` when initilizing scorer
159 |
160 | ```python
161 | scorer = Llama_Scorer(model_name_or_path, is_vllm = True)
162 | ```
163 |
164 | To assess other dimensions of data samples, please refer to the ```examples/scoring```
165 |
166 | ### Deita Pipelines
167 |
168 | You can use deita pipelines to perform a variety of operations on the dataset with only one line code and configurations.
169 |
170 | - **Dataset Scoring**
171 |
172 | ```python
173 | from deita.pipeline import Pipeline
174 |
175 | pipeline = Pipeline("score_pipeline",
176 | data_path = args.data_path, # json file with sharegpt format
177 | scorer = args.scorer, # [mistral, llama]
178 | scorer_name_or_path = args.scorer_name_or_path, # scorer name or path e.g. hkust-nlp/deita-complexity-scorer
179 | is_vllm = args.is_vllm, # launch with vllm [True, False]
180 | score_type = args.score_type, # [complexity, quality]
181 | output_path = args.output_path) # output path (json format)
182 |
183 | pipeline.run()
184 | ```
185 |
186 | - **Get Embeddings**
187 |
188 | We use Huggingface Accelerate to enhance efficiency:
189 |
190 | ```python
191 | from deita.pipeline import Pipeline
192 |
193 | embed_pipeline = Pipeline("embed_pipeline",
194 | data_path = args.data_path, # json file with sharegpt format
195 | output_path = args.output_path, # output path (pickle format)
196 | model_name_or_path = args.model_name_or_path, # model name or path e.g. mistralai/Mistral-7B-v0.1
197 | max_length = args.max_length,
198 | use_flash_attention = args.use_flash_attention,
199 | batch_size_per_device = args.batch_size_per_device,
200 | conv_template = args.conv_template,
201 | only_answer = args.only_answer,
202 | random_shuffle = args.random_shuffle,
203 | bfloat16 = True
204 | )
205 |
206 | embed_pipeline.run()
207 | ```
208 |
209 | ```bash
210 | CUDA_VISIBLE_DEVICES=$GPUIDX accelerate launch \
211 | --mixed_precision bf16 \
212 | --num_processes $NUMPROCESS \
213 | --num_machines 1 \
214 | examples/pipelines/embed_datasets.py \
215 | --use_flash_attention true \
216 | --data_path $DATAPATH \
217 | --output_path $OUTPUTPATH \
218 | --batch_size_per_device $BSZ
219 | ```
220 |
221 | - **Score-first, Diversity-aware Selection**
222 |
223 | ```python
224 | from deita.pipeline import Pipeline
225 |
226 | filter_pipeline = Pipeline("filter_pipeline",
227 | data_path = args.data_path, # json file with sharegpt format
228 | other_data_path = args.other_data_path, # embedding file path (pickle format)
229 | threshold = args.threshold, # filter threshold default: 0.9
230 | data_size = args.data_size, # size of selected data
231 | chunk_size = args.chunk_size, # used for more efficient GPU computing default: 100000
232 | sort_key = args.sort_key, # default: "complexity_scores,quality_scores"
233 | output_path = args.output_path, # json format output path
234 | distance_metric = args.distance_metric, # default: cosine
235 | embedding_field = args.embedding_field, # default: embedding
236 | is_compression = args.is_compression, # default: False
237 | device = args.device # GPU IDX, default: 0
238 | )
239 |
240 | filter_pipeline.run()
241 | ```
242 |
243 | You can refer to ```examples/pipelines``` for more details. A doc will also be coming soon.
244 |
245 | ### SFT Training
246 | Please refer to ```examples/train/sft.sh```
247 | ```bash
248 | deepspeed --include localhost:${DEVICES} --master_port 29501 src/deita/alignment/train.py \
249 | --model_name_or_path ${MODELPATH} \
250 | --data_path ${DATAPATH} \
251 | --output_dir ${OUTPUTPATH}/${RUNNAME} \
252 | --num_train_epochs 6 \
253 | --per_device_train_batch_size ${BSZPERDEV} \
254 | --per_device_eval_batch_size 1 \
255 | --gradient_accumulation_steps ${GRADACC} \
256 | --eval_steps 50 \
257 | --save_strategy "no" \
258 | --save_steps 100 \
259 | --save_total_limit 10 \
260 | --learning_rate 2e-5 \
261 | --warmup_ratio 0.1 \
262 | --lr_scheduler_type "cosine" \
263 | --logging_steps 1 \
264 | --do_eval False \
265 | --evaluation_strategy "no" \
266 | --model_max_length 2048 \
267 | --lazy_preprocess True \
268 | --conv_template "vicuna_v1.1" \
269 | --mask_user True \
270 | --report_to "wandb" \
271 | --run_name ${RUNNAME} \
272 | --bf16 True \
273 | --deepspeed src/deita/ds_configs/deepspeed_config_zero2_no_offload.json
274 | ```
275 |
276 | ### DPO Training
277 | Please refer to ```examples/train/dpo.sh```
278 | ```bash
279 | deepspeed --include localhost:${DEVICES} --master_port 29502 src/deita/alignment/dpo_train.py \
280 | --model_name_or_path ${MODELPATH} \
281 | --json_path ${JSONPATH} \
282 | --data_split ${DATASPLIT} \
283 | --output_dir ${OUTPUTPATH}/${RUNNAME} \
284 | --num_train_epochs ${DPOEPOCH} \
285 | --beta 0.1 \
286 | --per_device_train_batch_size ${BSZPERDEV} \
287 | --per_device_eval_batch_size 1 \
288 | --gradient_accumulation_steps ${GRADACC} \
289 | --save_global_steps False \
290 | --eval_steps 50 \
291 | --save_strategy "no" \
292 | --save_steps 500 \
293 | --save_total_limit 1 \
294 | --learning_rate 5e-7 \
295 | --warmup_ratio 0.1 \
296 | --lr_scheduler_type "linear" \
297 | --logging_steps 1 \
298 | --do_eval False \
299 | --evaluation_strategy "no" \
300 | --model_max_length 2048 \
301 | --conv_template "vicuna_v1.1" \
302 | --report_to "wandb" \
303 | --run_name ${RUNNAME} \
304 | --bf16 True \
305 | --gradient_checkpointing True \
306 | --deepspeed src/deita/ds_configs/stage3_no_offloading_accelerate.json
307 | ```
308 |
309 | ### Evaluation
310 | - For MT-Bench, please refer to [MT-Bench](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge)
311 | - For AlpacaEval, please refer to [alpaca_eval](https://github.com/tatsu-lab/alpaca_eval)
312 | - For Open LLM Benchmark, please refer to [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness/tree/master) and follow settings on [HuggingFaceH4/open_llm_leaderboard](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard)
313 |
314 | ## :muscle: What's more?
315 |
316 | This is the preview version of Deita project. We will continue to update including
317 |
318 | - [ ] Release data selection pipeline with efficient implementation
319 | - [ ] More automatic data selection strategies
320 | - [ ] CLI-Interface Supported
321 | - [ ] Online Demo
322 |
323 | ## Citation
324 | If you find the content of this project helpful, please cite our paper as follows:
325 |
326 | ```
327 | @inproceedings{
328 | liu2024what,
329 | title={What Makes Good Data for Alignment? A Comprehensive Study of Automatic Data Selection in Instruction Tuning},
330 | author={Wei Liu and Weihao Zeng and Keqing He and Yong Jiang and Junxian He},
331 | booktitle={The Twelfth International Conference on Learning Representations},
332 | year={2024},
333 | url={https://openreview.net/forum?id=BTKAeLqLMw}
334 | }
335 | ```
336 |
337 | ## Acknowledgement
338 | For training code, we use the code template of [fastchat](https://github.com/lm-sys/FastChat).
339 |
--------------------------------------------------------------------------------
/assets/logo-final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/deita/b279f2c329b403d2612a61e270c8d2a2eeaed6f4/assets/logo-final.png
--------------------------------------------------------------------------------
/examples/pipelines/combined_filter.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import argparse
3 | from deita.pipeline import Pipeline
4 |
5 | logging.basicConfig(level=logging.INFO)
6 | logger = logging.getLogger(__name__)
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--data_path", type=str, default=None)
10 | parser.add_argument("--other_data_path", type=str, default=None)
11 | parser.add_argument("--threshold", type=float, default=0.9)
12 | parser.add_argument("--data_size", type=int, default=10)
13 | parser.add_argument("--chunk_size", type=int, default=100000)
14 | parser.add_argument("--sort_key", type=str, default="complexity_scores,quality_scores")
15 | parser.add_argument("--output_path", type=str, default=None)
16 | parser.add_argument("--distance_metric", type=str, default="cosine")
17 | parser.add_argument("--embedding_field", type=str, default="embedding")
18 | parser.add_argument("--is_compression", type=bool, default=False)
19 | parser.add_argument("--device", type=int, default="0")
20 |
21 | args = parser.parse_args()
22 |
23 | filter_pipeline = Pipeline("filter_pipeline",
24 | data_path = args.data_path, # json file with sharegpt format
25 | other_data_path = args.other_data_path, # embedding file path (pickle format)
26 | threshold = args.threshold, # filter threshold default: 0.9
27 | data_size = args.data_size, # size of selected data
28 | chunk_size = args.chunk_size, # used for more efficient GPU computing default: 100000
29 | sort_key = args.sort_key, # default: "complexity_scores,quality_scores"
30 | output_path = args.output_path, # json format output path
31 | distance_metric = args.distance_metric, # default: cosine
32 | embedding_field = args.embedding_field, # default: embedding
33 | is_compression = args.is_compression, # default: False
34 | device = args.device # GPU IDX, default: 0
35 | )
36 |
37 | filter_pipeline.run()
--------------------------------------------------------------------------------
/examples/pipelines/embed_datasets.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import argparse
3 | from deita.pipeline import Pipeline
4 |
5 | logging.basicConfig(level=logging.INFO)
6 | logger = logging.getLogger(__name__)
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--data_path", type=str, default=None)
10 | parser.add_argument("--output_path", type=str, default=None)
11 | parser.add_argument("--max_length", type=int, default=2048)
12 | parser.add_argument("--batch_size_per_device", type=int, default=4)
13 | parser.add_argument("--conv_template", type=str, default="vicuna_v1.1")
14 | parser.add_argument("--use_flash_attention", type=bool, default=False)
15 | parser.add_argument("--only_answer", type=bool, default=False)
16 | parser.add_argument("--random_shuffle", type=bool, default=False)
17 | parser.add_argument("--model_name_or_path", type=str, default="mistralai/Mistral-7B-v0.1")
18 |
19 | args = parser.parse_args()
20 |
21 | embed_pipeline = Pipeline("embed_pipeline",
22 | data_path = args.data_path, # json file with sharegpt format
23 | output_path = args.output_path, # output path (pickle format)
24 | model_name_or_path = args.model_name_or_path, # model name or path e.g. mistralai/Mistral-7B-v0.1
25 | max_length = args.max_length,
26 | use_flash_attention = args.use_flash_attention,
27 | batch_size_per_device = args.batch_size_per_device,
28 | conv_template = args.conv_template,
29 | only_answer = args.only_answer,
30 | random_shuffle = args.random_shuffle,
31 | bfloat16 = True
32 | )
33 |
34 | embed_pipeline.run()
--------------------------------------------------------------------------------
/examples/pipelines/embed_datasets.sh:
--------------------------------------------------------------------------------
1 | GPUIDX="0,1,2,3"
2 | NUMPROCESS=4
3 | DATAPATH=""
4 | BSZ=1
5 | OUTPUTPATH=""
6 |
7 | CUDA_VISIBLE_DEVICES=$GPUIDX accelerate launch \
8 | --mixed_precision bf16 \
9 | --num_processes $NUMPROCESS \
10 | --num_machines 1 \
11 | examples/pipelines/embed_datasets.py \
12 | --use_flash_attention true \
13 | --data_path $DATAPATH \
14 | --output_path $OUTPUTPATH \
15 | --batch_size_per_device $BSZ
16 |
--------------------------------------------------------------------------------
/examples/pipelines/filter_datasets.sh:
--------------------------------------------------------------------------------
1 | GPUIDX="0,1,2,3"
2 | NUMGPUS=$(echo $GPUIDX | awk -F',' '{print NF}')
3 | DATAPATH=""
4 | OTHERDATA="" # PATH/TO/EMBEDDING_FILE
5 | OUTPUTPATH="" # PATH/TO/OUTPUTS
6 | THETA=0.9
7 | DATASIZE=10
8 | BSZ=1
9 |
10 | CUDA_VISIBLE_DEVICES=$GPUIDX python examples/pipelines/combined_filter.py \
11 | --data_path $DATAPATH \
12 | --other_data_path $OTHERDATA \
13 | --output_path $OUTPUTPATH \
14 | --threshold $THETA \
15 | --data_size $DATASIZE \
16 | --is_compression true \
17 | --device 0
18 |
--------------------------------------------------------------------------------
/examples/pipelines/score_complexity.py:
--------------------------------------------------------------------------------
1 | from deita.selection.scorer import Llama_Scorer
2 |
3 | model_name_or_path = "hkust-nlp/deita-complexity-scorer"
4 |
5 | scorer = Llama_Scorer(model_name_or_path)
6 |
7 | # example input
8 | input_text = "write a performance review for a junior data scientist"
9 | complexity_score = scorer.infer_complexity(input_text)
10 |
11 | print(complexity_score)
--------------------------------------------------------------------------------
/examples/pipelines/score_complexity_dataset.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import argparse
3 | from deita.pipeline import Pipeline
4 |
5 | logger = logging.getLogger(__name__)
6 | logger.info("Running score_pipeline")
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--data_path", type=str, default=None)
10 | parser.add_argument("--output_path", type=str, default=None)
11 | parser.add_argument("--scorer", type=str, default="llama")
12 | parser.add_argument("--scorer_name_or_path", type=str, default="hkust-nlp/deita-complexity-scorer")
13 | parser.add_argument("--is_vllm", type=bool, default=False)
14 | parser.add_argument("--score_type", type=str, default=None)
15 | args = parser.parse_args()
16 |
17 |
18 | pipeline = Pipeline("score_pipeline",
19 | data_path = args.data_path, # json file with sharegpt format
20 | scorer = args.scorer, # [mistral, llama]
21 | scorer_name_or_path = args.scorer_name_or_path, # scorer name or path e.g. hkust-nlp/deita-complexity-scorer
22 | is_vllm = args.is_vllm, # launch with vllm [True, False]
23 | score_type = args.score_type, # [complexity, quality]
24 | output_path = args.output_path) # output path (json format)
25 |
26 | pipeline.run()
--------------------------------------------------------------------------------
/examples/pipelines/score_complexity_dataset.sh:
--------------------------------------------------------------------------------
1 |
2 | SCORETYPE="complexity"
3 | DATAPATH="data/deita_mix_100.json"
4 | OUTPUTPATH="outputs/deita_mix_complexity/deita_mix_complexity_mistral_sampled.json"
5 | MODELPATH="/data/data9/outputs/complexity-scorer-mistral-z"
6 | SCORER="mistral"
7 | ISVLLM=false
8 |
9 | python examples/scoring/score_complexity_dataset.py \
10 | --data_path $DATAPATH \
11 | --output_path $OUTPUTPATH \
12 | --score_type $SCORETYPE \
13 | --scorer $SCORER \
14 | --scorer_name_or_path $MODELPATH
--------------------------------------------------------------------------------
/examples/pipelines/score_quality.py:
--------------------------------------------------------------------------------
1 | from deita.selection.scorer import Llama_Scorer
2 |
3 | model_name_or_path = "hkust-nlp/deita-quality-scorer"
4 |
5 | scorer = Llama_Scorer(model_name_or_path)
6 |
7 | # example input
8 | input_text = "word to describe UI with helpful tooltips" # Example Input
9 | output_text = "User-friendly or intuitive UI" # Example Output
10 | quality_score = scorer.infer_quality(input_text, output_text)
11 |
12 | print(quality_score)
--------------------------------------------------------------------------------
/examples/pipelines/score_quality_dataset.sh:
--------------------------------------------------------------------------------
1 |
2 | SCORETYPE="quality"
3 | DATAPATH="" # PATH/TO/DATASET
4 | OUTPUTPATH="" # PATH/TO/OUTPUTS
5 | MODELPATH="" # PATH/TO/MODEL
6 | SCORER="mistral"
7 | ISVLLM=true
8 | GPUINDICES="0,1,2,3"
9 |
10 | CUDA_VISIBLE_DEVICES=$GPUINDICES python examples/pipelines/score_complexity_dataset.py \
11 | --data_path $DATAPATH \
12 | --output_path $OUTPUTPATH \
13 | --score_type $SCORETYPE \
14 | --scorer $SCORER \
15 | --scorer_name_or_path $MODELPATH \
16 | --is_vllm $ISVLLM
--------------------------------------------------------------------------------
/examples/pipelines/score_vllm.py:
--------------------------------------------------------------------------------
1 | from deita.selection.scorer import Llama_Scorer
2 |
3 | model_name_or_path = "hkust-nlp/deita-quality-scorer"
4 |
5 | scorer = Llama_Scorer(model_name_or_path, is_vllm = True)
6 |
7 | # example input
8 | input_text = "word to describe UI with helpful tooltips" # Example Input
9 | output_text = "User-friendly or intuitive UI" # Example Output
10 | quality_score = scorer.infer_quality(input_text, output_text)
11 |
12 | print(quality_score)
--------------------------------------------------------------------------------
/examples/pipelines/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def str2bool(v):
4 | if isinstance(v, bool):
5 | return v
6 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
7 | return True
8 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
9 | return False
10 | else:
11 | raise argparse.ArgumentTypeError('Boolean value expected.')
--------------------------------------------------------------------------------
/examples/train/dpo.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="Deita"
2 | RUNNAME="Deita-7B"
3 | MODELPATH="/PATH/TO/SFT_MODEL"
4 | MODEL_SIZE="7B"
5 | DEVICES="" # e.g. 0,1,2,3
6 | NUMGPUS=$(echo $DEVICES | awk -F',' '{print NF}')
7 |
8 | DPOEPOCH=9
9 | JSONPATH="/PATH/TO/ultrafeedback_or_sampled_ultrafeedback" # If you want to sample UltraFeedback dataset, please refer to our code src/deita/data/sample_ultrafeedback.py
10 | OUTPUTPATH="/PATH/TO/OUTPUTS"
11 | DATASPLIT="train"
12 | TOTALBSZ=32
13 | BSZPERDEV=1
14 | GRADACC=$(($TOTALBSZ/$NUMGPUS/$BSZPERDEV))
15 | echo "DPO Training mistral model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BSZPERDEV batch size per GPU, $GRADACC gradient accumulation steps"
16 |
17 | deepspeed --include localhost:${DEVICES} --master_port 29502 src/deita/alignment/dpo_train.py \
18 | --model_name_or_path ${MODELPATH} \
19 | --json_path ${JSONPATH} \
20 | --data_split ${DATASPLIT} \
21 | --output_dir ${OUTPUTPATH}/${RUNNAME} \
22 | --num_train_epochs ${DPOEPOCH} \
23 | --beta 0.1 \
24 | --per_device_train_batch_size ${BSZPERDEV} \
25 | --per_device_eval_batch_size 1 \
26 | --gradient_accumulation_steps ${GRADACC} \
27 | --save_global_steps False \
28 | --eval_steps 50 \
29 | --save_strategy "no" \
30 | --save_steps 500 \
31 | --save_total_limit 1 \
32 | --learning_rate 5e-7 \
33 | --warmup_ratio 0.1 \
34 | --lr_scheduler_type "linear" \
35 | --logging_steps 1 \
36 | --do_eval False \
37 | --evaluation_strategy "no" \
38 | --model_max_length 2048 \
39 | --conv_template "vicuna_v1.1" \
40 | --report_to "wandb" \
41 | --run_name ${RUNNAME} \
42 | --bf16 True \
43 | --gradient_checkpointing True \
44 | --deepspeed src/deita/ds_configs/stage3_no_offloading_accelerate.json
--------------------------------------------------------------------------------
/examples/train/sft.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="Deita"
2 | RUNNAME="Deita-7B-SFT"
3 | MODELPATH="mistralai/Mistral-7B-v0.1"
4 | DATAPATH="hkust-nlp/deita-6k-v0"
5 | MODEL_SIZE="7B"
6 | OUTPUTPATH="/PATH/TO/OUTPUTS"
7 | DEVICES="" # e.g. 0,1,2,3
8 | NUMGPUS=$(echo $DEVICES | awk -F',' '{print NF}')
9 | TOTALBSZ=512
10 | BSZPERDEV=1
11 | GPUIDX=0
12 | GRADACC=$(($TOTALBSZ/$NUMGPUS/$BSZPERDEV))
13 | echo "Training mistral model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BSZPERDEV batch size per GPU, $GRADACC gradient accumulation steps"
14 |
15 | deepspeed --include localhost:${DEVICES} --master_port 29501 src/deita/alignment/train.py \
16 | --model_name_or_path ${MODELPATH} \
17 | --data_path ${DATAPATH} \
18 | --output_dir ${OUTPUTPATH}/${RUNNAME} \
19 | --num_train_epochs 6 \
20 | --per_device_train_batch_size ${BSZPERDEV} \
21 | --per_device_eval_batch_size 1 \
22 | --gradient_accumulation_steps ${GRADACC} \
23 | --eval_steps 50 \
24 | --save_strategy "no" \
25 | --save_steps 100 \
26 | --save_total_limit 10 \
27 | --learning_rate 2e-5 \
28 | --warmup_ratio 0.1 \
29 | --lr_scheduler_type "cosine" \
30 | --logging_steps 1 \
31 | --do_eval False \
32 | --evaluation_strategy "no" \
33 | --model_max_length 2048 \
34 | --lazy_preprocess True \
35 | --conv_template "vicuna_v1.1" \
36 | --mask_user True \
37 | --report_to "wandb" \
38 | --run_name ${RUNNAME} \
39 | --bf16 True \
40 | --deepspeed src/deita/ds_configs/deepspeed_config_zero2_no_offload.json
--------------------------------------------------------------------------------
/examples/train/train_scorers.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="Deita-Scorers"
2 | MODELPATH="mistralai/Mistral-7B-v0.1"
3 | DATAPATH="/PATH/TO/SHAREGPT_FORMAT/DATA"
4 | MODEL_SIZE="7B"
5 | RUNNAME="Deita-7B-Scorers"
6 | OUTPUTPATH="/PATH/TO/OUTPUTS"
7 | TOTALBSZ=512
8 | BSZPERDEV=1
9 | DEVICES="0,1,2,3"
10 | NUMGPUS=$(echo $DEVICES | awk -F',' '{print NF}')
11 | GRADACC=$(($TOTALBSZ/$NUMGPUS/$BSZPERDEV))
12 | EPOCHNUM=6
13 | echo "Training mistral model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BSZPERDEV batch size per GPU, $GRADACC gradient accumulation steps"
14 |
15 | deepspeed --include localhost:$DEVICES --master_port 29502 src/deita/alignment/train_scorers.py \
16 | --model_name_or_path ${MODELPATH} \
17 | --data_path ${DATAPATH} \
18 | --output_dir ${OUTPUTPATH}/${RUNNAME} \
19 | --num_train_epochs ${EPOCHNUM} \
20 | --per_device_train_batch_size ${BSZPERDEV} \
21 | --per_device_eval_batch_size 1 \
22 | --gradient_accumulation_steps ${GRADACC} \
23 | --eval_steps 50 \
24 | --save_strategy "no" \
25 | --save_steps 100 \
26 | --save_total_limit 10 \
27 | --learning_rate 2e-5 \
28 | --warmup_ratio 0.1 \
29 | --lr_scheduler_type "cosine" \
30 | --logging_steps 1 \
31 | --do_eval False \
32 | --evaluation_strategy "no" \
33 | --model_max_length 2048 \
34 | --lazy_preprocess True \
35 | --conv_template "scorer" \
36 | --mask_user True \
37 | --report_to "wandb" \
38 | --run_name ${RUNNAME} \
39 | --bf16 True \
40 | --deepspeed src/deita/ds_configs/deepspeed_config_zero2_no_offload.json
41 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.24.1
2 | aiohttp==3.8.6
3 | aiosignal==1.3.1
4 | alpaca-eval==0.3.6
5 | annotated-types==0.6.0
6 | anthropic==0.5.0
7 | antlr4-python3-runtime==4.9.3
8 | anyio==3.7.1
9 | appdirs==1.4.4
10 | async-timeout==4.0.3
11 | attrs==23.1.0
12 | beautifulsoup4==4.12.2
13 | bitsandbytes==0.41.1
14 | certifi==2023.7.22
15 | charset-normalizer==3.3.2
16 | click==8.1.7
17 | datasets==2.14.6
18 | deepspeed==0.12.2
19 | dill==0.3.7
20 | distro==1.8.0
21 | docker-pycreds==0.4.0
22 | docstring-parser==0.15
23 | einops==0.7.0
24 | exceptiongroup==1.1.3
25 | filelock==3.13.1
26 | fire==0.5.0
27 | frozenlist==1.4.0
28 | fsspec==2023.10.0
29 | gdown==4.7.1
30 | gitdb==4.0.11
31 | gitpython==3.1.40
32 | h11==0.14.0
33 | hjson==3.1.0
34 | httpcore==0.18.0
35 | httpx==0.25.0
36 | huggingface-hub==0.17.3
37 | hydra-core==1.3.2
38 | idna==3.4
39 | jinja2==3.1.2
40 | joblib==1.3.2
41 | loguru==0.7.2
42 | markdown-it-py==3.0.0
43 | markupsafe==2.1.3
44 | mdurl==0.1.2
45 | mpmath==1.3.0
46 | multidict==6.0.4
47 | multiprocess==0.70.15
48 | networkx==3.2.1
49 | ninja==1.11.1.1
50 | numpy==1.26.1
51 | nvidia-cublas-cu12==12.1.3.1
52 | nvidia-cuda-cupti-cu12==12.1.105
53 | nvidia-cuda-nvrtc-cu12==12.1.105
54 | nvidia-cuda-runtime-cu12==12.1.105
55 | nvidia-cudnn-cu12==8.9.2.26
56 | nvidia-cufft-cu12==11.0.2.54
57 | nvidia-curand-cu12==10.3.2.106
58 | nvidia-cusolver-cu12==11.4.5.107
59 | nvidia-cusparse-cu12==12.1.0.106
60 | nvidia-nccl-cu12==2.18.1
61 | nvidia-nvjitlink-cu12==12.3.52
62 | nvidia-nvtx-cu12==12.1.105
63 | omegaconf==2.3.0
64 | openai==0.28.1
65 | packaging==23.2
66 | pandas==2.1.2
67 | pathtools==0.1.2
68 | peft==0.5.0
69 | pillow==10.1.0
70 | pip==23.3.1
71 | protobuf==4.25.0
72 | psutil==5.9.6
73 | py-cpuinfo==9.0.0
74 | pyarrow==14.0.0
75 | pydantic-core==2.10.1
76 | pydantic==1.10.13
77 | pygments==2.16.1
78 | pynvml==11.5.0
79 | pysocks==1.7.1
80 | python-dateutil==2.8.2
81 | python-dotenv==1.0.0
82 | pytz==2023.3.post1
83 | pyyaml==6.0.1
84 | regex==2023.10.3
85 | requests==2.31.0
86 | rich==13.6.0
87 | safetensors==0.4.0
88 | scikit-learn==1.3.2
89 | scipy==1.11.3
90 | sentencepiece==0.1.99
91 | sentry-sdk==1.33.1
92 | setproctitle==1.3.3
93 | setuptools==68.0.0
94 | shortuuid==1.0.11
95 | shtab==1.6.4
96 | six==1.16.0
97 | smmap==5.0.1
98 | sniffio==1.3.0
99 | soupsieve==2.5
100 | sympy==1.12
101 | termcolor==2.4.0
102 | threadpoolctl==3.2.0
103 | tiktoken==0.5.2
104 | tokenizers==0.14.1
105 | torch==2.1.0
106 | tqdm==4.66.1
107 | transformers==4.35.1
108 | triton==2.1.0
109 | trl==0.7.2
110 | typing-extensions==4.8.0
111 | tyro==0.5.12
112 | tzdata==2023.3
113 | urllib3==2.0.7
114 | wandb==0.15.12
115 | wheel==0.41.2
116 | xxhash==3.4.1
117 | yarl==1.9.2
--------------------------------------------------------------------------------
/sample_little.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import json
3 | import random
4 |
5 | filepath = sys.argv[1]
6 | outputpath = sys.argv[2]
7 |
8 | data = json.load(open(filepath, "r"))
9 |
10 | sampled_data = random.sample(data, 100)
11 | with open(outputpath, "w") as f:
12 | json.dump(sampled_data, f, indent=2, ensure_ascii=False)
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import find_packages
3 | from setuptools import setup
4 | import subprocess
5 |
6 | folder = os.path.dirname(__file__)
7 | version_path = os.path.join(folder, "src", "deita", "__init__.py")
8 |
9 | __version__ = None
10 | with open(version_path) as f:
11 | exec(f.read(), globals())
12 |
13 | req_path = os.path.join(folder, "requirements.txt")
14 | install_requires = []
15 | if os.path.exists(req_path):
16 | with open(req_path) as fp:
17 | install_requires = [line.strip() for line in fp]
18 |
19 | readme_path = os.path.join(folder, "README.md")
20 | readme_contents = ""
21 | if os.path.exists(readme_path):
22 | with open(readme_path, encoding='utf-8') as fp:
23 | readme_contents = fp.read().strip()
24 |
25 | setup(
26 | name="deita",
27 | version=__version__,
28 | description="Deita: Data-Efficient Instruction Tuning for Alignment.",
29 | author="The Deita Team",
30 | long_description=readme_contents,
31 | long_description_content_type="text/markdown",
32 | package_dir={"": "src"},
33 | packages=find_packages("src"),
34 | package_data={},
35 | install_requires=install_requires,
36 | classifiers=[
37 | "Intended Audience :: Science/Research/Engineering",
38 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
39 | "Programming Language :: Python :: 3.9",
40 | "Programming Language :: Python :: 3.10",
41 | ],
42 | requires_python=">=3.9",
43 | )
44 |
45 | # Must be called after all dependency installed, since flash-attn setup.py
46 | # relies on torch, packaging, etc.
47 | try:
48 | gpu_state = subprocess.check_output(["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"])
49 | if b"A100" or b"A40" or b"H100" or b"A800" or b"H800" in gpu_state:
50 | subprocess.call(["pip", "install", "flash-attn==2.3.3", "--no-build-isolation"])
51 | except:
52 | pass
--------------------------------------------------------------------------------
/src/deita/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.1.0'
--------------------------------------------------------------------------------
/src/deita/alignment/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.2.19"
2 |
--------------------------------------------------------------------------------
/src/deita/alignment/constants.py:
--------------------------------------------------------------------------------
1 | from enum import IntEnum
2 | import os
3 |
4 | REPO_PATH = os.path.dirname(os.path.dirname(__file__))
5 |
6 | ##### For the gradio web server
7 | SERVER_ERROR_MSG = (
8 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
9 | )
10 | MODERATION_MSG = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE FIX YOUR INPUT AND TRY AGAIN."
11 | CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
12 | INACTIVE_MSG = "THIS SESSION HAS BEEN INACTIVE FOR TOO LONG. PLEASE REFRESH THIS PAGE."
13 | # Maximum input length
14 | INPUT_CHAR_LEN_LIMIT = int(os.getenv("FASTCHAT_INPUT_CHAR_LEN_LIMIT", 2560))
15 | # Maximum conversation turns
16 | CONVERSATION_TURN_LIMIT = 50
17 | # Session expiration time
18 | SESSION_EXPIRATION_TIME = 3600
19 | # The output dir of log files
20 | LOGDIR = "."
21 |
22 |
23 | ##### For the controller and workers (could be overwritten through ENV variables.)
24 | CONTROLLER_HEART_BEAT_EXPIRATION = int(
25 | os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90)
26 | )
27 | WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 45))
28 | WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100))
29 | WORKER_API_EMBEDDING_BATCH_SIZE = int(
30 | os.getenv("FASTCHAT_WORKER_API_EMBEDDING_BATCH_SIZE", 4)
31 | )
32 |
33 |
34 | class ErrorCode(IntEnum):
35 | """
36 | https://platform.openai.com/docs/guides/error-codes/api-errors
37 | """
38 |
39 | VALIDATION_TYPE_ERROR = 40001
40 |
41 | INVALID_AUTH_KEY = 40101
42 | INCORRECT_AUTH_KEY = 40102
43 | NO_PERMISSION = 40103
44 |
45 | INVALID_MODEL = 40301
46 | PARAM_OUT_OF_RANGE = 40302
47 | CONTEXT_OVERFLOW = 40303
48 |
49 | RATE_LIMIT = 42901
50 | QUOTA_EXCEEDED = 42902
51 | ENGINE_OVERLOADED = 42903
52 |
53 | INTERNAL_ERROR = 50001
54 | CUDA_OUT_OF_MEMORY = 50002
55 | GRADIO_REQUEST_ERROR = 50003
56 | GRADIO_STREAM_UNKNOWN_ERROR = 50004
57 | CONTROLLER_NO_WORKER = 50005
58 | CONTROLLER_WORKER_TIMEOUT = 50006
59 |
--------------------------------------------------------------------------------
/src/deita/alignment/conversation.py:
--------------------------------------------------------------------------------
1 | """
2 | Conversation prompt templates.
3 | """
4 |
5 | import dataclasses
6 | from enum import auto, IntEnum
7 | from typing import List, Any, Dict
8 |
9 |
10 | class SeparatorStyle(IntEnum):
11 | """Separator styles."""
12 |
13 | ADD_COLON_SINGLE = auto()
14 | ADD_COLON_TWO = auto()
15 | ADD_COLON_SPACE_SINGLE = auto()
16 | NO_COLON_SINGLE = auto()
17 | NO_COLON_TWO = auto()
18 | ADD_NEW_LINE_SINGLE = auto()
19 | SCORER = auto()
20 | LLAMA2 = auto()
21 | CHATGLM = auto()
22 | CHATML = auto()
23 | CHATINTERN = auto()
24 | DOLLY = auto()
25 | RWKV = auto()
26 | PHOENIX = auto()
27 | ROBIN = auto()
28 |
29 |
30 | @dataclasses.dataclass
31 | class Conversation:
32 | """A class that manages prompt templates and keeps all conversation history."""
33 |
34 | # The name of this template
35 | name: str
36 | # The system prompt
37 | system: str
38 | # Two roles
39 | roles: List[str]
40 | # All messages. Each item is (role, message).
41 | messages: List[List[str]]
42 | # The number of few shot examples
43 | offset: int
44 | # Separators
45 | sep_style: SeparatorStyle
46 | sep: str
47 | sep2: str = None
48 | # Stop criteria (the default one is EOS token)
49 | stop_str: str = None
50 | # Stops generation if meeting any token in this list
51 | stop_token_ids: List[int] = None
52 |
53 | def get_prompt(self) -> str:
54 | """Get the prompt for generation."""
55 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
56 | ret = self.system + self.sep
57 | for role, message in self.messages:
58 | if message:
59 | ret += role + ": " + message + self.sep
60 | else:
61 | ret += role + ":"
62 | return ret
63 | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
64 | seps = [self.sep, self.sep2]
65 | ret = self.system + seps[0]
66 | for i, (role, message) in enumerate(self.messages):
67 | if message:
68 | ret += role + ": " + message + seps[i % 2]
69 | else:
70 | ret += role + ":"
71 | return ret
72 | elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
73 | ret = self.system + self.sep
74 | for role, message in self.messages:
75 | if message:
76 | ret += role + ": " + message + self.sep
77 | else:
78 | ret += role + ": " # must be end with a space
79 | return ret
80 | elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
81 | ret = "" if self.system == "" else self.system + self.sep
82 | for role, message in self.messages:
83 | if message:
84 | ret += role + "\n" + message + self.sep
85 | else:
86 | ret += role + "\n"
87 | return ret
88 | elif self.sep_style == SeparatorStyle.SCORER:
89 | seps = [self.sep, self.sep2]
90 | ret = ""
91 | for i, (role, message) in enumerate(self.messages):
92 | ret += message
93 | return ret
94 | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
95 | ret = self.system
96 | for role, message in self.messages:
97 | if message:
98 | ret += role + message + self.sep
99 | else:
100 | ret += role
101 | return ret
102 | elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
103 | seps = [self.sep, self.sep2]
104 | ret = self.system
105 | for i, (role, message) in enumerate(self.messages):
106 | if message:
107 | ret += role + message + seps[i % 2]
108 | else:
109 | ret += role
110 | return ret
111 | elif self.sep_style == SeparatorStyle.RWKV:
112 | ret = self.system
113 | for i, (role, message) in enumerate(self.messages):
114 | if message:
115 | ret += (
116 | role
117 | + ": "
118 | + message.replace("\r\n", "\n").replace("\n\n", "\n")
119 | )
120 | ret += "\n\n"
121 | else:
122 | ret += role + ":"
123 | return ret
124 | elif self.sep_style == SeparatorStyle.LLAMA2:
125 | seps = [self.sep, self.sep2]
126 | ret = ""
127 | for i, (role, message) in enumerate(self.messages):
128 | if message:
129 | if i == 0:
130 | ret += self.system + message
131 | else:
132 | ret += role + " " + message + seps[i % 2]
133 | else:
134 | ret += role
135 | return ret
136 | elif self.sep_style == SeparatorStyle.CHATGLM:
137 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
138 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
139 | round_add_n = 1 if self.name == "chatglm2" else 0
140 | if self.system:
141 | ret = self.system + self.sep
142 | else:
143 | ret = ""
144 |
145 | for i, (role, message) in enumerate(self.messages):
146 | if i % 2 == 0:
147 | ret += f"[Round {i//2 + round_add_n}]{self.sep}"
148 |
149 | if message:
150 | ret += f"{role}:{message}{self.sep}"
151 | else:
152 | ret += f"{role}:"
153 | return ret
154 | elif self.sep_style == SeparatorStyle.CHATML:
155 | ret = "" if self.system == "" else self.system + self.sep + "\n"
156 | for role, message in self.messages:
157 | if message:
158 | ret += role + "\n" + message + self.sep + "\n"
159 | else:
160 | ret += role + "\n"
161 | return ret
162 | elif self.sep_style == SeparatorStyle.CHATINTERN:
163 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
164 | seps = [self.sep, self.sep2]
165 | ret = self.system
166 | for i, (role, message) in enumerate(self.messages):
167 | if i % 2 == 0:
168 | ret += ""
169 | if message:
170 | ret += role + ":" + message + seps[i % 2] + "\n"
171 | else:
172 | ret += role + ":"
173 | return ret
174 | elif self.sep_style == SeparatorStyle.DOLLY:
175 | seps = [self.sep, self.sep2]
176 | ret = self.system
177 | for i, (role, message) in enumerate(self.messages):
178 | if message:
179 | ret += role + ":\n" + message + seps[i % 2]
180 | if i % 2 == 1:
181 | ret += "\n\n"
182 | else:
183 | ret += role + ":\n"
184 | return ret
185 | elif self.sep_style == SeparatorStyle.PHOENIX:
186 | ret = self.system
187 | for role, message in self.messages:
188 | if message:
189 | ret += role + ": " + "" + message + ""
190 | else:
191 | ret += role + ": " + ""
192 | return ret
193 | elif self.sep_style == SeparatorStyle.ROBIN:
194 | ret = self.system + self.sep
195 | for role, message in self.messages:
196 | if message:
197 | ret += role + ":\n" + message + self.sep
198 | else:
199 | ret += role + ":\n"
200 | return ret
201 | else:
202 | raise ValueError(f"Invalid style: {self.sep_style}")
203 |
204 | def append_message(self, role: str, message: str):
205 | """Append a new message."""
206 | self.messages.append([role, message])
207 |
208 | def update_last_message(self, message: str):
209 | """Update the last output.
210 |
211 | The last message is typically set to be None when constructing the prompt,
212 | so we need to update it in-place after getting the response from a model.
213 | """
214 | self.messages[-1][1] = message
215 |
216 | def to_gradio_chatbot(self):
217 | """Convert the conversation to gradio chatbot format."""
218 | ret = []
219 | for i, (role, msg) in enumerate(self.messages[self.offset :]):
220 | if i % 2 == 0:
221 | ret.append([msg, None])
222 | else:
223 | ret[-1][-1] = msg
224 | return ret
225 |
226 | def to_openai_api_messages(self):
227 | """Convert the conversation to OpenAI chat completion format."""
228 | ret = [{"role": "system", "content": self.system}]
229 |
230 | for i, (_, msg) in enumerate(self.messages[self.offset :]):
231 | if i % 2 == 0:
232 | ret.append({"role": "user", "content": msg})
233 | else:
234 | if msg is not None:
235 | ret.append({"role": "assistant", "content": msg})
236 | return ret
237 |
238 | def copy(self):
239 | return Conversation(
240 | name=self.name,
241 | system=self.system,
242 | roles=self.roles,
243 | messages=[[x, y] for x, y in self.messages],
244 | offset=self.offset,
245 | sep_style=self.sep_style,
246 | sep=self.sep,
247 | sep2=self.sep2,
248 | stop_str=self.stop_str,
249 | stop_token_ids=self.stop_token_ids,
250 | )
251 |
252 | def dict(self):
253 | return {
254 | "template_name": self.name,
255 | "system": self.system,
256 | "roles": self.roles,
257 | "messages": self.messages,
258 | "offset": self.offset,
259 | }
260 |
261 |
262 | # A global registry for all conversation templates
263 | conv_templates: Dict[str, Conversation] = {}
264 |
265 |
266 | def register_conv_template(template: Conversation, override: bool = False):
267 | """Register a new conversation template."""
268 | if not override:
269 | assert (
270 | template.name not in conv_templates
271 | ), f"{template.name} has been registered."
272 |
273 | conv_templates[template.name] = template
274 |
275 |
276 | def get_conv_template(name: str) -> Conversation:
277 | """Get a conversation template."""
278 | return conv_templates[name].copy()
279 |
280 |
281 | # A template with a one-shot conversation example
282 | register_conv_template(
283 | Conversation(
284 | name="one_shot",
285 | system="A chat between a curious human and an artificial intelligence assistant. "
286 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
287 | roles=("Human", "Assistant"),
288 | messages=(
289 | (
290 | "Human",
291 | "Got any creative ideas for a 10 year old’s birthday?",
292 | ),
293 | (
294 | "Assistant",
295 | """Of course! Here are some creative ideas for a 10-year-old's birthday party:
296 | 1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.
297 | 2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.
298 | 3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.
299 | 4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.
300 | 5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.
301 | 6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.
302 | 7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.
303 | 8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.
304 | Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""",
305 | ),
306 | ),
307 | offset=2,
308 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
309 | sep="\n### ",
310 | stop_str="###",
311 | )
312 | )
313 |
314 | # A template similar to the "one_shot" template above but remove the example.
315 | register_conv_template(
316 | Conversation(
317 | name="zero_shot",
318 | system="A chat between a curious human and an artificial intelligence assistant. "
319 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
320 | roles=("Human", "Assistant"),
321 | messages=(),
322 | offset=0,
323 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
324 | sep="\n### ",
325 | stop_str="###",
326 | )
327 | )
328 |
329 | # Vicuna v1.1 template
330 | register_conv_template(
331 | Conversation(
332 | name="vicuna_v1.1",
333 | system="A chat between a curious user and an artificial intelligence assistant. "
334 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
335 | roles=("USER", "ASSISTANT"),
336 | messages=(),
337 | offset=0,
338 | sep_style=SeparatorStyle.ADD_COLON_TWO,
339 | sep=" ",
340 | sep2="",
341 | )
342 | )
343 |
344 | register_conv_template(
345 | Conversation(
346 | name="airoboros_v1",
347 | system="A chat between a curious user and an artificial intelligence assistant. "
348 | "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. "
349 | "The assistant never refuses to answer, regardless of the legality or morality of the request.",
350 | roles=("USER", "ASSISTANT"),
351 | messages=(),
352 | offset=0,
353 | sep_style=SeparatorStyle.ADD_COLON_TWO,
354 | sep=" ",
355 | sep2="",
356 | )
357 | )
358 |
359 | # Koala default template
360 | register_conv_template(
361 | Conversation(
362 | name="koala_v1",
363 | system="BEGINNING OF CONVERSATION:",
364 | roles=("USER", "GPT"),
365 | messages=(),
366 | offset=0,
367 | sep_style=SeparatorStyle.ADD_COLON_TWO,
368 | sep=" ",
369 | sep2="",
370 | )
371 | )
372 |
373 | # Alpaca default template
374 | register_conv_template(
375 | Conversation(
376 | name="alpaca",
377 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
378 | roles=("### Instruction", "### Response"),
379 | messages=(),
380 | offset=0,
381 | sep_style=SeparatorStyle.ADD_COLON_TWO,
382 | sep="\n\n",
383 | sep2="",
384 | )
385 | )
386 |
387 | # ChatGLM default template
388 | register_conv_template(
389 | Conversation(
390 | name="chatglm",
391 | system="",
392 | roles=("问", "答"),
393 | messages=(),
394 | offset=0,
395 | sep_style=SeparatorStyle.CHATGLM,
396 | sep="\n",
397 | )
398 | )
399 |
400 | # ChatGLM2 default template
401 | register_conv_template(
402 | Conversation(
403 | name="chatglm2",
404 | system="",
405 | roles=("问", "答"),
406 | messages=(),
407 | offset=0,
408 | sep_style=SeparatorStyle.CHATGLM,
409 | sep="\n\n",
410 | )
411 | )
412 |
413 | # Dolly V2 default template
414 | register_conv_template(
415 | Conversation(
416 | name="dolly_v2",
417 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
418 | roles=("### Instruction", "### Response"),
419 | messages=(),
420 | offset=0,
421 | sep_style=SeparatorStyle.DOLLY,
422 | sep="\n\n",
423 | sep2="### End",
424 | )
425 | )
426 |
427 | # OpenAssistant Pythia default template
428 | register_conv_template(
429 | Conversation(
430 | name="oasst_pythia",
431 | system="",
432 | roles=("<|prompter|>", "<|assistant|>"),
433 | messages=(),
434 | offset=0,
435 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
436 | sep="<|endoftext|>",
437 | )
438 | )
439 |
440 | # OpenAssistant default template
441 | register_conv_template(
442 | Conversation(
443 | name="oasst_llama",
444 | system="",
445 | roles=("<|prompter|>", "<|assistant|>"),
446 | messages=(),
447 | offset=0,
448 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
449 | sep="",
450 | )
451 | )
452 |
453 | # Tulu default template
454 | register_conv_template(
455 | Conversation(
456 | name="tulu",
457 | system="",
458 | roles=("<|user|>", "<|assistant|>"),
459 | messages=(),
460 | offset=0,
461 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
462 | sep="\n",
463 | )
464 | )
465 |
466 | # StableLM Alpha default template
467 | register_conv_template(
468 | Conversation(
469 | name="stablelm",
470 | system="""<|SYSTEM|># StableLM Tuned (Alpha version)
471 | - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
472 | - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
473 | - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
474 | - StableLM will refuse to participate in anything that could harm a human.
475 | """,
476 | roles=("<|USER|>", "<|ASSISTANT|>"),
477 | messages=(),
478 | offset=0,
479 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
480 | sep="",
481 | stop_token_ids=[50278, 50279, 50277, 1, 0],
482 | )
483 | )
484 |
485 | # Baize default template
486 | register_conv_template(
487 | Conversation(
488 | name="baize",
489 | system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n",
490 | roles=("[|Human|]", "[|AI|]"),
491 | messages=(
492 | ("[|Human|]", "Hello!"),
493 | ("[|AI|]", "Hi!"),
494 | ),
495 | offset=2,
496 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
497 | sep="\n",
498 | stop_str="[|Human|]",
499 | )
500 | )
501 |
502 | # RWKV-4-Raven default template
503 | register_conv_template(
504 | Conversation(
505 | name="rwkv",
506 | system="",
507 | roles=("Bob", "Alice"),
508 | messages=(
509 | ("Bob", "hi"),
510 | (
511 | "Alice",
512 | "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.",
513 | ),
514 | ),
515 | offset=2,
516 | sep_style=SeparatorStyle.RWKV,
517 | sep="",
518 | stop_str="\n\n",
519 | )
520 | )
521 |
522 | # Buddy default template
523 | register_conv_template(
524 | Conversation(
525 | name="openbuddy",
526 | system="""Consider a conversation between User (a human) and Assistant (named Buddy).
527 | Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
528 | Buddy cannot access the Internet.
529 | Buddy can fluently speak the user's language (e.g. English, Chinese).
530 | Buddy can generate poems, stories, code, essays, songs, parodies, and more.
531 | Buddy possesses vast knowledge about the world, history, and culture.
532 | Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
533 | Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
534 |
535 | User: Hi.
536 | Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""",
537 | roles=("User", "Assistant"),
538 | messages=(),
539 | offset=0,
540 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
541 | sep="\n",
542 | )
543 | )
544 |
545 | # Phoenix default template
546 | register_conv_template(
547 | Conversation(
548 | name="phoenix",
549 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
550 | roles=("Human", "Assistant"),
551 | messages=(),
552 | offset=0,
553 | sep_style=SeparatorStyle.PHOENIX,
554 | sep="",
555 | )
556 | )
557 |
558 | # ChatGPT default template
559 | register_conv_template(
560 | Conversation(
561 | name="chatgpt",
562 | system="You are a helpful assistant.",
563 | roles=("user", "assistant"),
564 | messages=(),
565 | offset=0,
566 | sep_style=None,
567 | sep=None,
568 | )
569 | )
570 |
571 | # Claude default template
572 | register_conv_template(
573 | Conversation(
574 | name="claude",
575 | system="",
576 | roles=("Human", "Assistant"),
577 | messages=(),
578 | offset=0,
579 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
580 | sep="\n\n",
581 | )
582 | )
583 |
584 | # MPT default template
585 | register_conv_template(
586 | Conversation(
587 | name="mpt-7b-chat",
588 | system="""<|im_start|>system
589 | - You are a helpful assistant chatbot trained by MosaicML.
590 | - You answer questions.
591 | - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
592 | - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
593 | roles=("<|im_start|>user", "<|im_start|>assistant"),
594 | messages=(),
595 | offset=0,
596 | sep_style=SeparatorStyle.CHATML,
597 | sep="<|im_end|>",
598 | stop_token_ids=[50278, 0],
599 | )
600 | )
601 |
602 | # MPT-30b-chat default template
603 | register_conv_template(
604 | Conversation(
605 | name="mpt-30b-chat",
606 | system="""<|im_start|>system
607 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
608 | roles=("<|im_start|>user", "<|im_start|>assistant"),
609 | messages=(),
610 | offset=0,
611 | sep_style=SeparatorStyle.CHATML,
612 | sep="<|im_end|>",
613 | stop_token_ids=[50278, 0],
614 | )
615 | )
616 |
617 | # MPT-30b-instruct default template
618 | # reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting
619 | register_conv_template(
620 | Conversation(
621 | name="mpt-30b-instruct",
622 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
623 | roles=("### Instruction", "### Response"),
624 | messages=(),
625 | offset=0,
626 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
627 | sep="\n\n",
628 | stop_token_ids=[50278, 0],
629 | )
630 | )
631 |
632 | # Bard default template
633 | # Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150
634 | # https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40
635 | register_conv_template(
636 | Conversation(
637 | name="bard",
638 | system="",
639 | roles=("0", "1"),
640 | messages=(),
641 | offset=0,
642 | sep_style=None,
643 | sep=None,
644 | )
645 | )
646 |
647 | # BiLLa default template
648 | register_conv_template(
649 | Conversation(
650 | name="billa",
651 | system="",
652 | roles=("Human", "Assistant"),
653 | messages=(),
654 | offset=0,
655 | sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE,
656 | sep="\n",
657 | stop_str="Human:",
658 | )
659 | )
660 |
661 | # RedPajama INCITE default template
662 | register_conv_template(
663 | Conversation(
664 | name="redpajama-incite",
665 | system="",
666 | roles=("", ""),
667 | messages=(),
668 | offset=0,
669 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
670 | sep="\n",
671 | stop_str="",
672 | )
673 | )
674 |
675 | # h2oGPT default template
676 | register_conv_template(
677 | Conversation(
678 | name="h2ogpt",
679 | system="",
680 | roles=("<|prompt|>", "<|answer|>"),
681 | messages=(),
682 | offset=0,
683 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
684 | sep="",
685 | )
686 | )
687 |
688 | # Robin default template
689 | register_conv_template(
690 | Conversation(
691 | name="Robin",
692 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
693 | roles=("###Human", "###Assistant"),
694 | messages=(),
695 | offset=0,
696 | sep_style=SeparatorStyle.ROBIN,
697 | sep="\n",
698 | stop_token_ids=[2, 396],
699 | stop_str="###",
700 | )
701 | )
702 |
703 | # Snoozy default template
704 | # Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232
705 | register_conv_template(
706 | Conversation(
707 | name="snoozy",
708 | system="### Instruction:\nThe prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.",
709 | roles=("### Prompt", "### Response"),
710 | messages=(),
711 | offset=0,
712 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
713 | sep="\n",
714 | stop_str="###",
715 | )
716 | )
717 |
718 | # manticore default template
719 | register_conv_template(
720 | Conversation(
721 | name="manticore",
722 | system="",
723 | roles=("USER", "ASSISTANT"),
724 | messages=(),
725 | offset=0,
726 | sep_style=SeparatorStyle.ADD_COLON_TWO,
727 | sep="\n",
728 | sep2="",
729 | )
730 | )
731 |
732 | # Falcon default template
733 | register_conv_template(
734 | Conversation(
735 | name="falcon",
736 | system="",
737 | roles=("User", "Assistant"),
738 | messages=[],
739 | offset=0,
740 | sep_style=SeparatorStyle.RWKV,
741 | sep="\n",
742 | sep2="<|endoftext|>",
743 | stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text
744 | stop_token_ids=[
745 | 0,
746 | 1,
747 | 2,
748 | 3,
749 | 4,
750 | 5,
751 | 6,
752 | 7,
753 | 8,
754 | 9,
755 | 10,
756 | 11,
757 | ], # it better only put special tokens here, because tokenizer only remove special tokens
758 | )
759 | )
760 |
761 | # ChagGPT default template
762 | register_conv_template(
763 | Conversation(
764 | name="polyglot_changgpt",
765 | system="",
766 | roles=("B", "A"),
767 | messages=(),
768 | offset=0,
769 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
770 | sep="\n",
771 | )
772 | )
773 |
774 | # tigerbot template
775 | register_conv_template(
776 | Conversation(
777 | name="tigerbot",
778 | system="A chat between a curious user and an artificial intelligence assistant. "
779 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
780 | roles=("### Instruction", "### Response"),
781 | messages=(),
782 | offset=0,
783 | sep_style=SeparatorStyle.ROBIN,
784 | sep="\n\n",
785 | stop_str="###",
786 | )
787 | )
788 |
789 | # ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst
790 | register_conv_template(
791 | Conversation(
792 | name="xgen",
793 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
794 | roles=("### Human: ", "###"),
795 | messages=(),
796 | offset=0,
797 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
798 | sep="\n",
799 | stop_token_ids=[50256, 0, 1, 2],
800 | stop_str="<|endoftext|>",
801 | )
802 | )
803 |
804 | # Internlm-chat template
805 | register_conv_template(
806 | Conversation(
807 | name="internlm-chat",
808 | system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
809 | roles=("<|User|>", "<|Bot|>"),
810 | messages=(),
811 | offset=0,
812 | sep_style=SeparatorStyle.CHATINTERN,
813 | sep="",
814 | sep2="",
815 | stop_token_ids=[1, 103028],
816 | stop_str="<|User|>",
817 | )
818 | )
819 |
820 | # StarChat template
821 | register_conv_template(
822 | Conversation(
823 | name="starchat",
824 | system="\n",
825 | roles=("<|user|>", "<|assistant|>"),
826 | messages=(),
827 | offset=0,
828 | sep_style=SeparatorStyle.CHATML,
829 | sep="<|end|>",
830 | stop_token_ids=[0, 49155],
831 | stop_str="<|end|>",
832 | )
833 | )
834 |
835 | # Baichuan-13B-Chat template
836 | register_conv_template(
837 | # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
838 | # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
839 | Conversation(
840 | name="baichuan-chat",
841 | system="",
842 | roles=(" ", " "),
843 | messages=(),
844 | offset=0,
845 | sep_style=SeparatorStyle.NO_COLON_TWO,
846 | sep="",
847 | sep2="",
848 | stop_token_ids=[2, 195],
849 | )
850 | )
851 |
852 | # llama2 template
853 | # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
854 | register_conv_template(
855 | Conversation(
856 | name="llama-2",
857 | system="[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
858 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
859 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
860 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
861 | "If you don't know the answer to a question, please don't share false information.\n<>\n\n",
862 | roles=("[INST]", "[/INST]"),
863 | messages=(),
864 | offset=0,
865 | sep_style=SeparatorStyle.LLAMA2,
866 | sep=" ",
867 | sep2=" ",
868 | stop_token_ids=[2],
869 | )
870 | )
871 |
872 | register_conv_template(
873 | Conversation(
874 | name="scorer",
875 | system="",
876 | roles=("query", "score"),
877 | messages=(),
878 | offset=0,
879 | sep_style=SeparatorStyle.SCORER,
880 | sep="\n",
881 | sep2="",
882 | )
883 | )
884 |
885 | # Zephyr template
886 | # reference: https://huggingface.co/spaces/HuggingFaceH4/zephyr-playground/blob/main/dialogues.py
887 | # register_conv_template(
888 | # Conversation(
889 | # name="zephyr",
890 | # system_template="<|system|>\n{system_message}",
891 | # roles=("<|user|>", "<|assistant|>"),
892 | # sep_style=SeparatorStyle.CHATML,
893 | # sep="",
894 | # stop_token_ids=[2],
895 | # stop_str="",
896 | # )
897 | # )
898 |
899 | if __name__ == "__main__":
900 | conv = get_conv_template("vicuna_v1.1")
901 | conv.append_message(conv.roles[0], "Hello!")
902 | conv.append_message(conv.roles[1], "Hi!")
903 | conv.append_message(conv.roles[0], "How are you?")
904 | conv.append_message(conv.roles[1], None)
905 | print(conv.get_prompt())
906 |
--------------------------------------------------------------------------------
/src/deita/alignment/dpo_train.py:
--------------------------------------------------------------------------------
1 | # This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
2 | #
3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from dataclasses import dataclass, field
18 | import pathlib
19 | from typing import Dict, Optional
20 |
21 | from torch.utils.data import Dataset
22 | import transformers
23 | from transformers.trainer_pt_utils import LabelSmoother
24 |
25 | from conversation import get_conv_template
26 |
27 | from trl import DPOTrainer
28 | from datasets import load_dataset
29 | from functools import partial
30 |
31 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
32 |
33 | @dataclass
34 | class ModelArguments:
35 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
36 | flash_attn: bool = False
37 |
38 |
39 | @dataclass
40 | class DataArguments:
41 | data_id: str = field(
42 | default = None, metadata = {"help": "Dataset id name of the training data."}
43 | )
44 |
45 | data_split: str = field(
46 | default = None, metadata = {"help": "Chosen split of the training data."}
47 | )
48 |
49 | data_path: str = field(
50 | default=None, metadata={"help": "Path to the training data."}
51 | )
52 |
53 | cache_path: str = field(
54 | default=None, metadata={"help": "Path to cache the training data."}
55 | )
56 |
57 | num_proc: int = field(
58 | default=32
59 | )
60 |
61 | conv_template: str = field(default = "vicuna-1.1")
62 |
63 | json_path: str = field(
64 | default = None, metadata = {"help": "Path to the json file containing the training data."}
65 | )
66 |
67 |
68 | @dataclass
69 | class TrainingArguments(transformers.TrainingArguments):
70 | beta: float = field(default = 0.1, metadata = {
71 | "help": "Control the deviation from the reference model."
72 | })
73 | cache_dir: Optional[str] = field(default=None)
74 | optim: str = field(default="adamw_torch")
75 | model_max_length: int = field(
76 | default=512,
77 | metadata={
78 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
79 | },
80 | )
81 | min_lr: float = field(
82 | default = None
83 | )
84 | mask_user: bool = field(
85 | default = True
86 | )
87 |
88 | save_global_steps: bool = field(
89 | default = True
90 | )
91 |
92 |
93 | local_rank = None
94 |
95 |
96 | def rank0_print(*args):
97 | if local_rank == 0:
98 | print(*args)
99 |
100 | def preprocess(
101 | sample,
102 | conv_template = "vicuna-1.1",
103 | ) -> Dict:
104 |
105 | conv = get_conv_template(conv_template)
106 |
107 | prompt = conv.system + conv.sep + sample["messages"][0]["role"] + ": " + sample["prompt"] + conv.sep
108 |
109 | # Apply prompt templates
110 | chosen_sources = sample["chosen"]
111 | chosen_conversations = chosen_sources[1]["role"] + ": " + chosen_sources[1]["content"] + conv.sep2
112 |
113 | rejected_sources = sample["rejected"]
114 | rejected_conversations = rejected_sources[1]["role"] + ": " + rejected_sources[1]["content"] + conv.sep2
115 |
116 | return dict(
117 | prompt=prompt,
118 | chosen=chosen_conversations,
119 | rejected=rejected_conversations,
120 | )
121 |
122 | def make_dpo_dataset(
123 | data_args: DataArguments,
124 | sanity_check: bool = False
125 | ) -> Dataset:
126 | """Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
127 |
128 | The dataset is converted to a dictionary with the following structure:
129 | {
130 | 'prompt': List[str],
131 | 'chosen': List[str],
132 | 'rejected': List[str],
133 | }
134 |
135 | Prompts are structured as follows:
136 | "Question: " + + "\n\nAnswer: "
137 | """
138 |
139 | data_id: str = data_args.data_id
140 | data_split: str = data_args.data_split
141 | data_dir: str = data_args.data_path
142 | cache_dir: str = data_args.cache_path
143 | num_proc: int = data_args.num_proc
144 | conv_template: str = data_args.conv_template
145 |
146 | json_path: str = data_args.json_path
147 |
148 | if not json_path:
149 | dataset = load_dataset(
150 | data_id,
151 | split=data_split,
152 | cache_dir=cache_dir,
153 | data_dir=data_dir,
154 | )
155 | else:
156 | dataset = load_dataset(
157 | "json",
158 | data_files = json_path,
159 | split = data_split
160 | )
161 |
162 | original_columns = dataset.column_names
163 |
164 | if sanity_check:
165 | dataset = dataset.select(range(min(len(dataset), 1000)))
166 |
167 | preprocess_with_template = partial(preprocess, conv_template = conv_template)
168 |
169 | return dataset.map(
170 | preprocess_with_template,
171 | num_proc=num_proc,
172 | remove_columns=original_columns,
173 | )
174 |
175 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
176 | """Collects the state dict and dump to disk."""
177 | state_dict = trainer.model.state_dict()
178 | if trainer.args.should_save:
179 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
180 | del state_dict
181 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
182 |
183 | def train():
184 | global local_rank
185 |
186 | parser = transformers.HfArgumentParser(
187 | (ModelArguments, DataArguments, TrainingArguments)
188 | )
189 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
190 | training_args.do_eval = False
191 | local_rank = training_args.local_rank
192 |
193 | # print("Load Model")
194 | model = transformers.AutoModelForCausalLM.from_pretrained(
195 | model_args.model_name_or_path,
196 | cache_dir=training_args.cache_dir,
197 | use_flash_attention_2 = True
198 | )
199 | model.config.use_cache = False
200 |
201 | # print("Load Refer Model")
202 | model_refer = transformers.AutoModelForCausalLM.from_pretrained(
203 | model_args.model_name_or_path,
204 | cache_dir=training_args.cache_dir,
205 | use_flash_attention_2 = True
206 | )
207 |
208 | tokenizer = transformers.AutoTokenizer.from_pretrained(
209 | model_args.model_name_or_path,
210 | cache_dir=training_args.cache_dir,
211 | model_max_length=training_args.model_max_length,
212 | padding_side="right",
213 | use_fast=False,
214 | )
215 |
216 | tokenizer.pad_token = tokenizer.unk_token
217 | train_dataset = make_dpo_dataset(data_args=data_args)
218 |
219 | trainer = DPOTrainer(
220 | model, model_refer, tokenizer = tokenizer, beta = training_args.beta, args=training_args, train_dataset = train_dataset,
221 | max_prompt_length = 512, max_length = 2048
222 | )
223 |
224 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
225 | print("Checkpoint found, resuming training")
226 | trainer.train(resume_from_checkpoint=True)
227 | else:
228 | trainer.train()
229 |
230 | trainer.save_state()
231 | trainer.save_model()
232 |
233 |
234 | if __name__ == "__main__":
235 | train()
236 |
--------------------------------------------------------------------------------
/src/deita/alignment/flash_attn/bloom_flash_attention.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 | import transformers
8 | from transformers.models.bloom.modeling_bloom import dropout_add
9 |
10 | from einops import rearrange
11 |
12 | from .triton_flash_attention import flash_attn_qkvpacked_func
13 |
14 | def forward(
15 | self,
16 | hidden_states: torch.Tensor,
17 | residual: torch.Tensor,
18 | alibi: torch.Tensor,
19 | attention_mask: torch.Tensor,
20 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
21 | head_mask: Optional[torch.Tensor] = None,
22 | use_cache: bool = False,
23 | output_attentions: bool = False,
24 | ):
25 | dtype = hidden_states.dtype
26 | fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
27 |
28 | # 3 x [batch_size, seq_length, num_heads, head_dim]
29 | (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
30 |
31 | batch_size, q_length, _, _ = query_layer.shape
32 | bsz, q_len = batch_size, q_length
33 |
34 | if layer_past is not None:
35 | past_key, past_value = layer_past
36 | # concatenate along seq_length dimension:
37 | # - key: [batch_size * self.num_heads, head_dim, kv_length]
38 | # - value: [batch_size * self.num_heads, kv_length, head_dim]
39 | key_layer = torch.cat((past_key, key_layer), dim=2)
40 | value_layer = torch.cat((past_value, value_layer), dim=1)
41 |
42 | if use_cache is True:
43 | present = (key_layer, value_layer)
44 | else:
45 | present = None
46 |
47 | reshaped_alibi = rearrange(alibi, '(b h) one s-> b h one s', h = self.num_heads)
48 | reshaped_alibi = reshaped_alibi * self.beta
49 |
50 | attention_mask = (1.0 - attention_mask)
51 | attention_mask = attention_mask[:, None, None, :].bool()
52 | # reshaped_alibi_masked = reshaped_alibi.masked_fill(attention_mask, -1e9)
53 | reshaped_alibi_masked = reshaped_alibi
54 |
55 | reshaped_query_layer = query_layer
56 | reshaped_key_layer = key_layer
57 | reshaped_value_layer = value_layer
58 |
59 | qkv = torch.concat([reshaped_query_layer.unsqueeze(2), reshaped_key_layer.unsqueeze(2), reshaped_value_layer.unsqueeze(2)], dim = 2)
60 |
61 | output = flash_attn_qkvpacked_func(
62 | qkv, reshaped_alibi_masked, True, self.inv_norm_factor
63 | )
64 |
65 | output = rearrange(output, 'b s h d -> (b h) s d')
66 |
67 | # change view [batch_size, num_heads, q_length, head_dim]
68 | context_layer = self._merge_heads(output)
69 |
70 | # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
71 | if self.pretraining_tp > 1 and self.slow_but_exact:
72 | slices = self.hidden_size / self.pretraining_tp
73 | output_tensor = torch.zeros_like(context_layer)
74 | for i in range(self.pretraining_tp):
75 | output_tensor = output_tensor + F.linear(
76 | context_layer[:, :, int(i * slices) : int((i + 1) * slices)],
77 | self.dense.weight[:, int(i * slices) : int((i + 1) * slices)],
78 | )
79 | else:
80 | output_tensor = self.dense(context_layer)
81 |
82 | output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
83 |
84 | outputs = (output_tensor, present)
85 | if output_attentions:
86 | outputs += (context_layer,)
87 |
88 | return outputs
89 |
90 |
91 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
92 | # requires the attention mask to be the same as the key_padding_mask
93 | def _prepare_attn_mask(
94 | self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
95 | ) -> torch.BoolTensor:
96 |
97 | return attention_mask
98 |
99 | def replace_bloom_attn_with_flash_attn():
100 | transformers.models.bloom.modeling_bloom.BloomModel._prepare_attn_mask = (
101 | _prepare_attn_mask
102 | )
103 | transformers.models.bloom.modeling_bloom.BloomAttention.forward = forward
--------------------------------------------------------------------------------
/src/deita/alignment/train.py:
--------------------------------------------------------------------------------
1 | # This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
2 | #
3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from dataclasses import dataclass, field
18 | import json
19 | import pathlib
20 | from typing import Dict, Optional
21 |
22 | import numpy as np
23 | import torch
24 | from torch.utils.data import Dataset
25 | import transformers
26 | from transformers import Trainer
27 | from transformers.trainer_pt_utils import LabelSmoother
28 | from datasets import load_dataset
29 |
30 | from conversation import SeparatorStyle
31 | from conversation import get_conv_template
32 |
33 | from transformers import Trainer
34 |
35 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
36 |
37 | @dataclass
38 | class ModelArguments:
39 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
40 | flash_attn: bool = False
41 |
42 |
43 | @dataclass
44 | class DataArguments:
45 | data_path: str = field(
46 | default=None, metadata={"help": "Path to the training data."}
47 | )
48 | lazy_preprocess: bool = False
49 | conv_template: str = field(default = "vicuna-1.1")
50 |
51 |
52 | @dataclass
53 | class TrainingArguments(transformers.TrainingArguments):
54 | cache_dir: Optional[str] = field(default=None)
55 | optim: str = field(default="adamw_torch")
56 | model_max_length: int = field(
57 | default=512,
58 | metadata={
59 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
60 | },
61 | )
62 | min_lr: float = field(
63 | default = None
64 | )
65 | mask_user: bool = field(
66 | default = True
67 | )
68 |
69 |
70 | local_rank = None
71 |
72 |
73 | def rank0_print(*args):
74 | if local_rank == 0:
75 | print(*args)
76 |
77 |
78 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
79 | """Collects the state dict and dump to disk."""
80 | state_dict = trainer.model.state_dict()
81 | if trainer.args.should_save:
82 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
83 | del state_dict
84 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
85 |
86 |
87 | def preprocess(
88 | sources,
89 | tokenizer: transformers.PreTrainedTokenizer,
90 | conv_template = "vicuna-1.1",
91 | mask_user = True
92 | ) -> Dict:
93 |
94 | conv = get_conv_template(conv_template)
95 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
96 |
97 | # Apply prompt templates
98 | conversations = []
99 | for i, source in enumerate(sources):
100 | if roles[source[0]["from"]] != conv.roles[0]:
101 | # Skip the first one if it is not from human
102 | source = source[1:]
103 |
104 | conv.messages = []
105 | for j, sentence in enumerate(source):
106 | role = roles[sentence["from"]]
107 | # assert role == conv.roles[j % 2], f"{i}"
108 | assert role == conv.roles[j % 2], breakpoint()
109 | conv.append_message(role, sentence["value"])
110 | conversations.append(conv.get_prompt())
111 |
112 | # Tokenize conversations
113 | input_ids = tokenizer(
114 | conversations,
115 | return_tensors="pt",
116 | padding="max_length",
117 | max_length=tokenizer.model_max_length,
118 | truncation=True,
119 | ).input_ids
120 |
121 | targets = input_ids.clone()
122 |
123 | assert (conv.sep_style == SeparatorStyle.ADD_COLON_TWO) or (conv.sep_style == SeparatorStyle.CHATML)
124 |
125 | if mask_user:
126 | # Mask targets. Only compute loss on the assistant outputs.
127 | if conv.sep_style == SeparatorStyle.ADD_COLON_TWO:
128 | sep = conv.sep + conv.roles[1] + ": "
129 | for conversation, target in zip(conversations, targets):
130 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
131 |
132 | turns = conversation.split(conv.sep2)
133 | cur_len = 1
134 | target[:cur_len] = IGNORE_TOKEN_ID
135 | # breakpoint()
136 | for i, turn in enumerate(turns):
137 | if turn == "":
138 | break
139 | turn_len = len(tokenizer(turn).input_ids)
140 |
141 | parts = turn.split(sep)
142 | if len(parts) != 2:
143 | break
144 | parts[0] += sep
145 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
146 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2
147 |
148 | # Ignore the user instructions
149 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
150 | cur_len += turn_len
151 |
152 | target[cur_len:] = IGNORE_TOKEN_ID
153 |
154 | if False: # Inspect and check the correctness of masking
155 | z = target.clone()
156 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
157 | rank0_print(tokenizer.decode(z))
158 |
159 |
160 | elif conv.sep_style == SeparatorStyle.CHATML:
161 | breakpoint()
162 | sep = conv.sep + conv.roles[1] + "\n"
163 | for conversation, target in zip(conversations, targets):
164 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
165 |
166 | turns = conversation.split(conv.sep)
167 | cur_len = 1
168 | target[:cur_len] = IGNORE_TOKEN_ID
169 | # breakpoint()
170 | for i, turn in enumerate(turns):
171 | if turn == "":
172 | break
173 | turn_len = len(tokenizer(turn).input_ids)
174 |
175 | parts = turn.split(sep)
176 | if len(parts) != 2:
177 | break
178 | parts[0] += sep
179 | # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
180 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2
181 |
182 | # Ignore the user instructions
183 | target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
184 | cur_len += turn_len
185 |
186 | target[cur_len:] = IGNORE_TOKEN_ID
187 |
188 | if cur_len < tokenizer.model_max_length:
189 | if cur_len != total_len:
190 | target[:] = IGNORE_TOKEN_ID
191 | rank0_print(
192 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
193 | f" (ignored)"
194 | )
195 |
196 | return dict(
197 | input_ids=input_ids,
198 | labels=targets,
199 | attention_mask=input_ids.ne(tokenizer.pad_token_id),
200 | )
201 |
202 |
203 | class SupervisedDataset(Dataset):
204 | """Dataset for supervised fine-tuning."""
205 |
206 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, conv_template = "vicuna-1.1", mask_user = True):
207 | super(SupervisedDataset, self).__init__()
208 |
209 | rank0_print("Formatting inputs...")
210 | sources = [example["conversations"] for example in raw_data]
211 | data_dict = preprocess(sources, tokenizer, conv_template, mask_user)
212 |
213 | if mask_user:
214 | rank0_print(
215 | f"WARNING: The loss of user prompt will be masked"
216 | )
217 | else:
218 | rank0_print(
219 | f"WARNING: The loss of user prompt will **NOT** be masked"
220 | )
221 |
222 |
223 | self.input_ids = data_dict["input_ids"]
224 | self.labels = data_dict["labels"]
225 | self.attention_mask = data_dict["attention_mask"]
226 |
227 | def __len__(self):
228 | return len(self.input_ids)
229 |
230 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
231 | return dict(
232 | input_ids=self.input_ids[i],
233 | labels=self.labels[i],
234 | attention_mask=self.attention_mask[i],
235 | )
236 |
237 |
238 | class LazySupervisedDataset(Dataset):
239 | """Dataset for supervised fine-tuning."""
240 |
241 | def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer, conv_template = "vicuna-1.1", mask_user = True):
242 | super(LazySupervisedDataset, self).__init__()
243 | self.tokenizer = tokenizer
244 |
245 | rank0_print("Formatting inputs...Skip in lazy mode")
246 | self.conv_template = conv_template
247 | self.mask_user = mask_user
248 | self.tokenizer = tokenizer
249 | self.raw_data = raw_data
250 | self.cached_data_dict = {}
251 |
252 | if mask_user:
253 | rank0_print(
254 | f"WARNING: The loss of user prompt will be masked"
255 | )
256 | else:
257 | rank0_print(
258 | f"WARNING: The loss of user prompt will **NOT** be masked"
259 | )
260 |
261 | def __len__(self):
262 | return len(self.raw_data)
263 |
264 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
265 | if i in self.cached_data_dict:
266 | return self.cached_data_dict[i]
267 |
268 | ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer, self.conv_template, self.mask_user)
269 | ret = dict(
270 | input_ids=ret["input_ids"][0],
271 | labels=ret["labels"][0],
272 | attention_mask=ret["attention_mask"][0],
273 | )
274 | self.cached_data_dict[i] = ret
275 |
276 | return ret
277 |
278 |
279 | def make_supervised_data_module(
280 | tokenizer: transformers.PreTrainedTokenizer, data_args, mask_user = True
281 | ) -> Dict:
282 | """Make dataset and collator for supervised fine-tuning."""
283 | conv_template = data_args.conv_template
284 | dataset_cls = (
285 | LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
286 | )
287 | rank0_print("Loading data...")
288 | try:
289 | raw_data = json.load(open(data_args.data_path, "r"))
290 | except FileNotFoundError:
291 | raw_data = load_dataset(data_args.data_path, split = "train")
292 | raw_data = [row for row in raw_data]
293 |
294 | # Split train/eval
295 | np.random.seed(0)
296 | train_raw_data = raw_data
297 | perm = np.random.permutation(len(raw_data))
298 | split = int(len(perm) * 0.98)
299 | train_indices = perm[:split]
300 | eval_indices = perm[split:]
301 | train_raw_data = [raw_data[i] for i in train_indices]
302 | eval_raw_data = [raw_data[i] for i in eval_indices]
303 | rank0_print(f"#train {len(train_raw_data)}, #eval {len(eval_raw_data)}")
304 |
305 | train_dataset = dataset_cls(train_raw_data, tokenizer=tokenizer, conv_template = conv_template, mask_user = mask_user)
306 | eval_dataset = dataset_cls(eval_raw_data, tokenizer=tokenizer, conv_template = conv_template, mask_user = mask_user)
307 | return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
308 |
309 | def train():
310 | global local_rank
311 |
312 | parser = transformers.HfArgumentParser(
313 | (ModelArguments, DataArguments, TrainingArguments)
314 | )
315 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
316 | training_args.do_eval = False
317 | local_rank = training_args.local_rank
318 | model = transformers.AutoModelForCausalLM.from_pretrained(
319 | model_args.model_name_or_path,
320 | cache_dir=training_args.cache_dir,
321 | use_flash_attention_2 = True
322 | )
323 | model.config.use_cache = False
324 | tokenizer = transformers.AutoTokenizer.from_pretrained(
325 | model_args.model_name_or_path,
326 | cache_dir=training_args.cache_dir,
327 | model_max_length=training_args.model_max_length,
328 | padding_side="right",
329 | use_fast=False,
330 | )
331 | tokenizer.pad_token = tokenizer.unk_token
332 |
333 | if "mistral" in model_args.model_name_or_path.lower():
334 | rank0_print("Mistral with Left Padding Side")
335 | tokenizer.padding_side = "left"
336 |
337 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args, mask_user = training_args.mask_user)
338 |
339 | trainer = Trainer(
340 | model=model, tokenizer=tokenizer, args=training_args, **data_module
341 | )
342 |
343 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
344 | trainer.train(resume_from_checkpoint=True)
345 | else:
346 | trainer.train()
347 | trainer.save_state()
348 |
349 | trainer.save_model(output_dir = training_args.output_dir)
350 |
351 |
352 | if __name__ == "__main__":
353 | train()
354 |
--------------------------------------------------------------------------------
/src/deita/alignment/train_scorers.py:
--------------------------------------------------------------------------------
1 | # This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
2 | #
3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | import os
17 | from pathlib import Path
18 | import sys
19 |
20 | import copy
21 | from dataclasses import dataclass, field
22 | import json
23 | import pathlib
24 | from typing import Dict, Optional, Sequence
25 |
26 | import numpy as np
27 | import torch
28 | from torch.utils.data import Dataset
29 | import transformers
30 | from transformers import Trainer
31 | from transformers.trainer_pt_utils import LabelSmoother
32 |
33 | from conversation import get_conv_template
34 |
35 |
36 | from functools import partial
37 | from datasets import Dataset
38 |
39 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
40 |
41 | @dataclass
42 | class ModelArguments:
43 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
44 | flash_attn: bool = False
45 |
46 |
47 | @dataclass
48 | class DataArguments:
49 | data_path: str = field(
50 | default=None, metadata={"help": "Path to the training data."}
51 | )
52 | lazy_preprocess: bool = False
53 | conv_template: str = field(default = "vicuna-1.1")
54 |
55 |
56 | @dataclass
57 | class TrainingArguments(transformers.TrainingArguments):
58 | cache_dir: Optional[str] = field(default=None)
59 | optim: str = field(default="adamw_torch")
60 | model_max_length: int = field(
61 | default=512,
62 | metadata={
63 | "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
64 | },
65 | )
66 | min_lr: float = field(
67 | default = None
68 | )
69 | mask_user: bool = field(
70 | default = True
71 | )
72 |
73 |
74 | local_rank = None
75 |
76 |
77 |
78 | def rank0_print(*args):
79 | if local_rank == 0:
80 | print(*args)
81 |
82 |
83 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
84 | """Collects the state dict and dump to disk."""
85 | state_dict = trainer.model.state_dict()
86 | if trainer.args.should_save:
87 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
88 | del state_dict
89 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
90 |
91 |
92 | def preprocess(
93 | sample,
94 | conv_template,
95 | max_length,
96 | tokenizer,
97 | mask_user = True,
98 | ) -> Dict:
99 |
100 | source = sample["conversations"]
101 |
102 | conv = get_conv_template(conv_template)
103 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
104 |
105 | if roles[source[0]["from"]] != conv.roles[0]:
106 | # Skip the first one if it is not from human
107 | source = source[1:]
108 |
109 | conv.messages = []
110 | for j, sentence in enumerate(source):
111 | role = roles[sentence["from"]]
112 | # assert role == conv.roles[j % 2], f"{i}"
113 | assert role == conv.roles[j % 2], breakpoint()
114 | conv.append_message(role, sentence["value"])
115 | conversation = conv.get_prompt()
116 |
117 | input_ids = tokenizer(
118 | conversation,
119 | return_tensors="pt",
120 | max_length=max_length,
121 | truncation=True,
122 | ).input_ids
123 |
124 | input_ids = input_ids.flatten()
125 |
126 | if mask_user:
127 |
128 | targets = input_ids.clone()
129 |
130 | if roles[source[0]["from"]] != conv.roles[0]:
131 | # Skip the first one if it is not from human
132 | source = source[1:]
133 |
134 | conv.messages = []
135 | for j, sentence in enumerate(source):
136 |
137 | role = roles[sentence["from"]]
138 | # assert role == conv.roles[j % 2], f"{i}"
139 | assert role == conv.roles[j % 2], breakpoint()
140 |
141 | if role == conv.roles[1]:
142 | conv.append_message(role, sentence["value"])
143 |
144 | if role != conv.roles[1]:
145 | if j == 0:
146 | conv_start_idx = 0
147 | else:
148 | conv_last = conv.get_prompt()
149 | conv_start_idx = tokenizer(
150 | conv_last,
151 | return_tensors="pt",
152 | max_length=max_length,
153 | truncation=True,
154 | ).input_ids.shape[1]
155 |
156 | conv.append_message(role, sentence["value"])
157 | conv_so_far = conv.get_prompt()
158 |
159 | conv_end_idx = tokenizer(
160 | conv_so_far,
161 | return_tensors="pt",
162 | max_length=max_length,
163 | truncation=True,
164 | ).input_ids.shape[1]
165 |
166 | # conv_end_idx -= 1 # hard offset for llama model
167 |
168 | targets[conv_start_idx:conv_end_idx] = IGNORE_TOKEN_ID
169 |
170 | if conv_end_idx >= max_length:
171 | break
172 |
173 | attention_mask = torch.ones_like(input_ids)
174 |
175 | return dict(
176 | input_ids=input_ids,
177 | labels = targets,
178 | attention_mask = attention_mask,
179 | )
180 |
181 |
182 | def get_datasets(data, preprocess_func, num_proc):
183 |
184 | conversations = [{"conversations": item["conversations"]} for item in data]
185 |
186 | raw_dataset = Dataset.from_list(conversations)
187 |
188 | tokenized_datasets = raw_dataset.map(
189 | preprocess_func,
190 | batched = False,
191 | num_proc = num_proc,
192 | remove_columns = ["conversations"],
193 | desc = "Tokenizing and reformatting instruction data"
194 | )
195 |
196 | return tokenized_datasets
197 |
198 |
199 | def make_supervised_data_module(
200 | model, tokenizer: transformers.PreTrainedTokenizer, max_length, fwd_batch_size, data_args, mask_user = True
201 | ) -> Dict:
202 | """Make dataset and collator for supervised fine-tuning."""
203 | conv_template = data_args.conv_template
204 | rank0_print("Loading data...")
205 | raw_data = json.load(open(data_args.data_path, "r"))
206 |
207 | preprocess_func = partial(preprocess,
208 | conv_template = conv_template,
209 | max_length = max_length,
210 | tokenizer = tokenizer,
211 | mask_user = mask_user)
212 |
213 | train_dataset = get_datasets(raw_data, preprocess_func, 32)
214 |
215 | return dict(train_dataset=train_dataset)
216 |
217 |
218 | def train():
219 | global local_rank
220 |
221 | parser = transformers.HfArgumentParser(
222 | (ModelArguments, DataArguments, TrainingArguments)
223 | )
224 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
225 | training_args.do_eval = False
226 | local_rank = training_args.local_rank
227 | world_size = training_args.world_size
228 | model = transformers.AutoModelForCausalLM.from_pretrained(
229 | model_args.model_name_or_path,
230 | cache_dir=training_args.cache_dir,
231 | use_flash_attention_2 = True
232 | )
233 | model.config.use_cache = False
234 | tokenizer = transformers.AutoTokenizer.from_pretrained(
235 | model_args.model_name_or_path,
236 | cache_dir=training_args.cache_dir,
237 | model_max_length=training_args.model_max_length,
238 | padding_side="right",
239 | use_fast=False,
240 | )
241 | tokenizer.pad_token = tokenizer.unk_token
242 |
243 | if "mistral" in model_args.model_name_or_path.lower():
244 | rank0_print("Mistral with Left Padding Side")
245 | tokenizer.padding_side = "left"
246 |
247 |
248 | fwd_batch_size = training_args.per_device_train_batch_size * world_size
249 |
250 | train_dataset = make_supervised_data_module(model = model,
251 | tokenizer=tokenizer,
252 | max_length = training_args.model_max_length,
253 | fwd_batch_size = fwd_batch_size,
254 | data_args=data_args,
255 | mask_user = training_args.mask_user)
256 |
257 | trainer = Trainer(
258 | model=model, tokenizer=tokenizer, args=training_args, **train_dataset
259 | )
260 |
261 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
262 | trainer.train(resume_from_checkpoint=True)
263 | else:
264 | trainer.train()
265 | trainer.save_state()
266 | trainer.save_model(output_dir = training_args.output_dir)
267 |
268 |
269 | if __name__ == "__main__":
270 | train()
271 |
--------------------------------------------------------------------------------
/src/deita/data/sample_ultrafeedback.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import random
4 | from datasets import load_dataset, Dataset
5 |
6 | outputpath = sys.argv[1]
7 | datanum = sys.argv[2]
8 |
9 | random.seed(42)
10 |
11 | if not os.path.exists(outputpath):
12 | data = load_dataset("HuggingFaceH4/ultrafeedback_binarized", split = "train_prefs")
13 | data.to_json(outputpath)
14 |
15 | sample_indices = random.sample(range(len(data)), int(datanum))
16 |
17 | sampled_data = data.select(sample_indices)
18 | sampled_data.to_json(outputpath)
--------------------------------------------------------------------------------
/src/deita/ds_configs/deepspeed_config_zero2_no_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 |
11 | "bf16": {
12 | "enabled": "auto"
13 | },
14 |
15 | "optimizer": {
16 | "type": "AdamW",
17 | "params": {
18 | "lr": "auto",
19 | "betas": "auto",
20 | "eps": "auto",
21 | "weight_decay": "auto"
22 | }
23 | },
24 |
25 | "zero_optimization": {
26 | "stage": 2,
27 | "allgather_partitions": true,
28 | "allgather_bucket_size": 2e8,
29 | "overlap_comm": true,
30 | "reduce_scatter": true,
31 | "reduce_bucket_size": 2e8,
32 | "contiguous_gradients": true
33 | },
34 |
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 2000,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/src/deita/ds_configs/deepspped_llama_x.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "offload_optimizer": {
5 | "device": "cpu",
6 | "pin_memory": true
7 | },
8 | "offload_param": {
9 | "device": "cpu",
10 | "pin_memory": true
11 | },
12 | "overlap_comm": true,
13 | "contiguous_gradients": true,
14 | "sub_group_size": 0,
15 | "reduce_bucket_size": "auto",
16 | "stage3_prefetch_bucket_size": "auto",
17 | "stage3_param_persistence_threshold": "auto",
18 | "stage3_max_live_parameters": 0,
19 | "stage3_max_reuse_distance": 0,
20 | "stage3_gather_16bit_weights_on_model_save": true
21 | },
22 | "bf16": {
23 | "enabled": true,
24 | "auto_cast": false,
25 | "loss_scale": 0,
26 | "initial_scale_power": 32,
27 | "loss_scale_window": 1000,
28 | "hysteresis": 2,
29 | "min_loss_scale": 1
30 | },
31 | "optimizer": {
32 | "type": "AdamW",
33 | "params": {
34 | "lr": "auto",
35 | "betas": "auto",
36 | "eps": "auto",
37 | "weight_decay": "auto"
38 | }
39 | },
40 | "train_batch_size": "auto",
41 | "train_micro_batch_size_per_gpu": "auto",
42 | "gradient_accumulation_steps": "auto",
43 | "wall_clock_breakdown": false
44 | }
45 |
--------------------------------------------------------------------------------
/src/deita/ds_configs/stage3_no_offloading_accelerate.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "zero_optimization": {
6 | "stage": 3,
7 | "overlap_comm": true,
8 | "contiguous_gradients": true,
9 | "sub_group_size": 1e9,
10 | "reduce_bucket_size": "auto",
11 | "stage3_prefetch_bucket_size": "auto",
12 | "stage3_param_persistence_threshold": "auto",
13 | "stage3_max_live_parameters": 1e9,
14 | "stage3_max_reuse_distance": 1e9,
15 | "stage3_gather_16bit_weights_on_model_save": true
16 | },
17 | "gradient_accumulation_steps": "auto",
18 | "gradient_clipping": "auto",
19 | "steps_per_print": 1e5,
20 | "train_batch_size": "auto",
21 | "train_micro_batch_size_per_gpu": "auto",
22 | "wall_clock_breakdown": false
23 | }
--------------------------------------------------------------------------------
/src/deita/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | from .embed_pipeline import EmbedPipeline
2 | from .score_pipeline import ScorePipeline
3 | from .filter_pipeline import FilterPipeline
4 | from .base import PipelineRegistry
5 | from typing import Callable
6 |
7 | PipelineRegistry.register("score_pipeline", ScorePipeline)
8 | PipelineRegistry.register("embed_pipeline", EmbedPipeline)
9 | PipelineRegistry.register("filter_pipeline", FilterPipeline)
10 |
11 | class Pipeline:
12 |
13 | def __new__(cls, name, **kwargs) -> Callable:
14 |
15 | PipelineClass = PipelineRegistry.get_pipeline(name)
16 | return PipelineClass(name, **kwargs)
--------------------------------------------------------------------------------
/src/deita/pipeline/base.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | from typing import Any, Dict, List, Optional, Tuple, Union, Callable
4 |
5 | logger = logging.getLogger(__name__)
6 |
7 | class BasePipeline:
8 |
9 | def __init__(self, name: str, data_path: str, **kwargs) -> None:
10 |
11 | self.name = name
12 | self.data_path = data_path
13 |
14 | def _load_data(self, data_path: str) -> None:
15 |
16 | """
17 | Load data from data_path.
18 |
19 | data_path: str - path to json data file.
20 | """
21 |
22 | try:
23 | with open(data_path, "r") as f:
24 | data = json.load(f)
25 | except json.JSONDecodeError:
26 | with open(data_path, "r") as f:
27 | data = [json.loads(line) for line in f]
28 |
29 | return data
30 |
31 | def _load_other_data(self, other_data_path: str) -> None:
32 | raise NotImplementedError
33 |
34 | def _save_data(self, data_path: str, data_format: str) -> None:
35 | raise NotImplementedError
36 |
37 | def _preprocess(self, json_data, other_data) -> None:
38 | raise NotImplementedError
39 |
40 | def _forward(self, preprocessed_data) -> None:
41 | raise NotImplementedError
42 |
43 | def run(self) -> None:
44 |
45 | json_data = self._load_data(self.data_path)
46 |
47 | other_data = None
48 | if hasattr(self, "other_data_path"):
49 | other_data = self._load_other_data(self.other_data_path)
50 |
51 | preprocessed_data = self._preprocess(json_data, other_data)
52 | results = self._forward(preprocessed_data)
53 | self._save_data(json_data, results)
54 | logger.info(f"Pipeline {self.name} run complete.")
55 |
56 | class PipelineRegistry:
57 |
58 | registry = {}
59 |
60 | @classmethod
61 | def register(cls, name: str, pipline_class: Callable):
62 |
63 | if name in cls.registry:
64 | raise ValueError(f"Pipeline {name} already registered.")
65 | cls.registry[name] = pipline_class
66 |
67 | @classmethod
68 | def get_pipeline(cls, name: str):
69 |
70 | if name not in cls.registry:
71 | raise ValueError(f"Pipeline {name} not registered.")
72 | return cls.registry[name]
73 |
74 |
--------------------------------------------------------------------------------
/src/deita/pipeline/embed_pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List
3 | from deita.pipeline.base import BasePipeline
4 | from deita.selection.embedder import CLM_Embedder
5 | import logging
6 | import pandas
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | class EmbedPipeline(BasePipeline):
11 |
12 | def __init__(self, name: str, data_path: str, **kwargs) -> None:
13 |
14 | self.name = name
15 | self.data_path = data_path
16 | self.is_compression = kwargs.get("is_compression", False)
17 |
18 | self.data_format = "sharegpt" # only support sharegpt for now
19 | self.output_path = kwargs.get("output_path")
20 |
21 | if not os.path.exists(self.output_path):
22 | os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
23 |
24 | self.embedder = CLM_Embedder(**kwargs)
25 |
26 | def _preprocess(self, json_data, other_data) -> List:
27 | return json_data
28 |
29 | def _forward(self, preprocessed_data) -> List:
30 |
31 | all_embeddings_list = self.embedder.encode_samples(preprocessed_data)
32 |
33 | logger.info(f"{len(all_embeddings_list)}")
34 | logger.info("Finished embedding")
35 |
36 | return all_embeddings_list
37 |
38 | def _save_data(self, json_data: List, results: List) -> None:
39 |
40 | # We use dataframe to save the results
41 | df = pandas.DataFrame(results)
42 |
43 | if self.embedder.accelerator.is_main_process:
44 | df.sort_values(by = "idx", inplace = True)
45 | df.reset_index(drop = True, inplace = True)
46 |
47 | if not self.is_compression:
48 | df.to_pickle(self.output_path)
49 | logger.info(f"Saved pickle to {self.output_path}")
50 | else:
51 | df.to_pickle(self.output_path, "zip")
52 | logger.info(f"Saved pickle to {self.output_path} with zip compression")
53 |
54 |
55 |
--------------------------------------------------------------------------------
/src/deita/pipeline/filter_pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import pandas as pd
4 | from typing import List
5 | from deita.pipeline.base import BasePipeline
6 | from deita.pipeline.utils import sort_key_split
7 | from deita.selection.filter import Combined_Filter
8 | import logging
9 | import pandas
10 | import numpy as np
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | class FilterPipeline(BasePipeline):
15 |
16 | def __init__(self, name: str, data_path: str, **kwargs) -> None:
17 |
18 | self.name = name
19 | self.data_path = data_path
20 | self.other_data_path = kwargs.get("other_data_path")
21 | self.is_compression = kwargs.get("is_compression", False)
22 |
23 | self.data_format = "sharegpt" # only support sharegpt for now
24 | self.output_path = kwargs.get("output_path")
25 | self.sort_key = kwargs.get("sort_key")
26 | self.sort_key = sort_key_split(self.sort_key)
27 | kwargs["sort_key"] = self.sort_key
28 |
29 | if not os.path.exists(self.output_path):
30 | os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
31 |
32 | self.filter = Combined_Filter(**kwargs)
33 |
34 | def _load_other_data(self, other_data_path: str) -> None:
35 | """
36 | Load Embedding Data
37 | """
38 |
39 | if self.is_compression:
40 | embedding_data = pd.read_pickle(other_data_path, "zip")
41 | else:
42 | embedding_data = pd.read_pickle(other_data_path)
43 |
44 | return embedding_data
45 |
46 | def _preprocess(self, json_data, other_data) -> List:
47 |
48 | """
49 | json_data: List - data to be filtered
50 | other_data: pd.DataFrame - embedding data
51 | """
52 |
53 | if isinstance(other_data, np.ndarray):
54 | df_data = pd.DataFrame([{"embedding": other_data[i]} for i in range(other_data.shape[0])])
55 | elif isinstance(other_data, pd.DataFrame):
56 | df_data = other_data
57 | else:
58 | raise ValueError("other_data must be either np.array or pd.DataFrame")
59 |
60 | if "idx" not in df_data.columns:
61 | df_data["idx"] = df_data.index
62 |
63 | df_json = pd.DataFrame(json_data)
64 |
65 | for sk in self.sort_key:
66 | df_data[sk] = df_json[sk].tolist()
67 |
68 | return df_data
69 |
70 | def _forward(self, preprocessed_data) -> List:
71 |
72 | selected_data = self.filter.filter(preprocessed_data)
73 | selected_data_indices = selected_data["idx"].tolist()
74 |
75 | logger.info(f"Selected Data Number: {len(selected_data_indices)}")
76 | logger.info("Finished Combined Selection")
77 |
78 | return selected_data_indices
79 |
80 | def _save_data(self, json_data: List, results: List) -> None:
81 |
82 | selected_data = []
83 |
84 | for idx in results:
85 | selected_data.append(json_data[idx])
86 |
87 | with open(self.output_path, "w") as f:
88 | json.dump(selected_data, f, indent=2, ensure_ascii=False)
89 |
90 |
91 |
--------------------------------------------------------------------------------
/src/deita/pipeline/score_pipeline.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from deita.pipeline.base import BasePipeline
4 | from deita.selection.scorer import Llama_Scorer, Mistral_Scorer
5 | from typing import Any, Dict, List, Optional, Tuple, Union
6 | from deita.pipeline.utils import load_data
7 | import logging
8 | from tqdm import tqdm
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | class ScorePipeline(BasePipeline):
13 |
14 | def __init__(self, name: str, data_path: str, **kwargs) -> None:
15 |
16 | self.name = name
17 | self.data_path = data_path
18 |
19 | self.data_format = "sharegpt" # only support sharegpt for now
20 |
21 | scorer = kwargs.get("scorer")
22 | is_vllm = kwargs.get("is_vllm")
23 | scorer_name_or_path = kwargs.get("scorer_name_or_path")
24 | self.score_type = kwargs.get("score_type")
25 |
26 | if scorer == "llama":
27 | self.model = Llama_Scorer(scorer_name_or_path, is_vllm)
28 | elif scorer == "mistral":
29 | self.model = Mistral_Scorer(scorer_name_or_path, is_vllm)
30 |
31 | self.output_path = kwargs.get("output_path")
32 |
33 | if not os.path.exists(self.output_path):
34 | os.makedirs(os.path.dirname(self.output_path), exist_ok=True)
35 |
36 | def _load_sharegpt(self, data: str) -> List:
37 |
38 | preprocessed_data = []
39 |
40 | for sample_id, item in enumerate(data):
41 |
42 | preprocessed_item = []
43 |
44 | for idx in range(len(item["conversations"])):
45 |
46 | if idx % 2 != 0:
47 | continue
48 |
49 | if idx != len(item["conversations"]) - 1:
50 | preprocessed_item.append({"instruction": item["conversations"][idx]["value"], "response": item["conversations"][idx+1]["value"]})
51 | else:
52 | preprocessed_item.append({"instruction": item["conversations"][idx]["value"], "response": ""})
53 |
54 | preprocessed_data.append({"conversations": preprocessed_item, "n_conv": len(preprocessed_item)})
55 |
56 | return preprocessed_data
57 |
58 | def _inject_sharegpt(self, json_data: List, results: List) -> None:
59 |
60 | for sample_id in range(len(json_data)):
61 |
62 | json_data[sample_id][f"{self.score_type}_scores"] = []
63 |
64 | for item in results[sample_id]["conversations"]:
65 | json_data[sample_id][f"{self.score_type}_scores"].append(float(item[f"{self.score_type}_score"]))
66 |
67 |
68 | def _preprocess(self, json_data, other_data) -> List:
69 |
70 | if self.data_format == "sharegpt":
71 | preprocessed_data = self._load_sharegpt(json_data)
72 | else:
73 | raise ValueError(f"Data format {self.data_format} not supported.")
74 |
75 | return preprocessed_data
76 |
77 | def _forward(self, preprocessed_data) -> List:
78 |
79 | for convs in tqdm(preprocessed_data, total = len(preprocessed_data)):
80 |
81 | for conv in convs["conversations"]:
82 |
83 | if self.score_type.lower() == "complexity":
84 | score = self.model.infer_complexity(conv["instruction"])
85 | elif self.score_type.lower() == "quality":
86 | score = self.model.infer_quality(conv["instruction"], conv["response"])
87 | else:
88 | raise ValueError(f"Score type {self.score_type} not supported.")
89 |
90 | conv[f"{self.score_type}_score"] = score
91 |
92 | return preprocessed_data
93 |
94 | def _save_data(self, json_data: List, results: List) -> None:
95 |
96 | if self.data_format == "sharegpt":
97 | self._inject_sharegpt(json_data, results)
98 | else:
99 | raise ValueError(f"Data format {self.data_format} not supported.")
100 |
101 | with open(self.output_path, "w") as f:
102 | json.dump(json_data, f, indent=2, ensure_ascii=False)
103 | logger.info(f"Saved results to {self.output_path}.")
--------------------------------------------------------------------------------
/src/deita/pipeline/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import List
3 |
4 | INSTRUCTION_ONLY_TYPE = ["complexity"]
5 | INSTRUCTION_RESPONSE_TYPE = ["quality"]
6 |
7 | def sort_key_split(sort_key: str) -> List:
8 | """
9 | Split sort_key into a list of sort keys.
10 |
11 | sort_key: str - sort key to split.
12 | """
13 | if "," in sort_key:
14 | return sort_key.split(",")
15 | elif "." in sort_key:
16 | return sort_key.split(".")
17 | elif "+" in sort_key:
18 | return sort_key.split("+")
19 | else:
20 | raise ValueError("sort_key must be a string with delimiter ',' or '.' or '+'.")
21 |
22 | def load_data(self, data_path: str) -> None:
23 |
24 | """
25 | Load data from data_path.
26 |
27 | data_path: str - path to json data file.
28 | """
29 |
30 | try:
31 | with open(data_path, "r") as f:
32 | data = json.load(f)
33 | except json.JSONDecodeError:
34 | with open(data_path, "r") as f:
35 | data = [json.loads(line) for line in f]
36 |
37 | return data
--------------------------------------------------------------------------------
/src/deita/selection/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/deita/selection/embedder/__init__.py:
--------------------------------------------------------------------------------
1 | from .clm_embedder import CLM_Embedder
2 |
3 | __all__ = ["CLM_Embedder"]
--------------------------------------------------------------------------------
/src/deita/selection/embedder/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch
4 | from datasets import Dataset
5 | from transformers import (
6 | AutoModelForCausalLM,
7 | AutoTokenizer,
8 | )
9 | from deita.selection.embedder.utils import batchlize
10 | from transformers.trainer_pt_utils import LabelSmoother
11 | from accelerate import Accelerator
12 |
13 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
14 |
15 |
16 | class Embedder:
17 |
18 | def __init__(self, model_name_or_path, **kwargs) -> None:
19 |
20 | self.compute_dtype = (
21 | torch.float16
22 | if kwargs.get('fp16', False)
23 | else (torch.bfloat16 if kwargs.get("bfloat16", False) else torch.float32)
24 | )
25 |
26 | self.model_name_or_path = model_name_or_path
27 | self.max_length = kwargs.get('max_length')
28 | self.use_flash_attention = kwargs.get('use_flash_attention')
29 | self.batch_size_per_device = kwargs.get('batch_size_per_device')
30 | self.conv_template = kwargs.get('conv_template')
31 | self.only_answer = kwargs.get('only_answer')
32 | self.random_shuffle = kwargs.get('random_shuffle')
33 |
34 | self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
35 | self.world_size = int(os.getenv("WORLD_SIZE", "1"))
36 | torch.cuda.set_device(self.local_rank) # NOTE: cpu-only machine will have error
37 |
38 | self.accelerator = Accelerator()
39 | self.accelerator.wait_for_everyone()
40 |
41 | batch_size = self.batch_size_per_device * self.world_size
42 | self.minibatch_size = batch_size
43 |
44 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path,
45 | model_max_length = self.max_length,
46 | padding_side = "right",
47 | use_fast = False)
48 |
49 | if "mistral" in self.model_name_or_path:
50 | self.tokenizer.padding_side = "left"
51 |
52 | self.model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path,
53 | torch_dtype = self.compute_dtype,
54 | use_flash_attention_2 = self.use_flash_attention)
55 |
56 | if self.tokenizer.pad_token is None:
57 | self.tokenizer.pad_token = self.tokenizer.unk_token
58 |
59 | def rank0_print(self, *args):
60 | if self.local_rank == 0:
61 | print(*args)
62 |
63 | def compute_length(self, conversations: list, cnt_field = "response"):
64 |
65 | all_lengths = []
66 | for conv in conversations:
67 |
68 | cur_length = 0
69 |
70 | for i, c in enumerate(conv):
71 | if cnt_field == "response":
72 | if i % 2 == 1:
73 | cur_length += len(c["value"])
74 | elif cnt_field == "instruction":
75 | if i % 2 == 0:
76 | cur_length += len(c["value"])
77 | else:
78 | cur_length += len(c["value"])
79 |
80 | all_lengths.append(cur_length)
81 |
82 | return all_lengths
83 |
84 | def create_databuffer(self, conversations: list, sort_by_length = False):
85 |
86 | all_lengths = self.compute_length(conversations)
87 |
88 | dataset_size = len(conversations)
89 | dataset_buf = []
90 | for idx in range(dataset_size):
91 |
92 | dataset_buf.append({
93 | "conversations": conversations[idx],
94 | "specific_length": all_lengths[idx],
95 | "input_idx": idx
96 | })
97 |
98 | if sort_by_length:
99 | dataset_buf = sorted(dataset_buf, key = lambda x: x["specific_length"])
100 |
101 | return dataset_buf, dataset_size
102 |
103 | def create_dataloader(self, dataset_buf: Dataset):
104 |
105 | dataloader = batchlize(
106 | dataset_buf,
107 | self.minibatch_size,
108 | self.random_shuffle,
109 | sort = self.minibatch_size > 1
110 | )
111 |
112 | print(f"Successfully create dataloader with size {len(dataloader)},batch_size {self.minibatch_size}.")
113 |
114 | return dataloader
115 |
116 | def probe_samples(self, model, data: list):
117 |
118 | raise NotImplementedError
119 |
120 | def collect_grad(self):
121 |
122 | raise NotImplementedError
--------------------------------------------------------------------------------
/src/deita/selection/embedder/clm_embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from functools import partial
3 |
4 | from tqdm import tqdm
5 | import torch
6 | from datasets import Dataset
7 | import torch.distributed as dist
8 | from deita.selection.embedder.base import Embedder
9 | from deita.selection.embedder.utils import DataCollatorForSupervisedDataset, preprocess
10 |
11 | import logging
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 | class CLM_Embedder(Embedder):
16 |
17 | def __init__(self, model_name_or_path, **kwargs):
18 | super().__init__(model_name_or_path, **kwargs)
19 |
20 | def encode_samples(self, data):
21 |
22 | conversations = [item["conversations"] for item in data]
23 | dataset_buf, data_size = self.create_databuffer(conversations, sort_by_length = True)
24 | raw_dataset = Dataset.from_list(dataset_buf)
25 |
26 | preprocess_func = partial(preprocess,
27 | conv_template = self.conv_template,
28 | only_answer = self.only_answer,
29 | max_length = self.max_length,
30 | tokenizer = self.tokenizer)
31 |
32 | with self.accelerator.main_process_first():
33 | tokenized_datasets = raw_dataset.map(
34 | preprocess_func,
35 | batched = True,
36 | num_proc = 32,
37 | remove_columns = ["conversations", "specific_length"],
38 | desc = "Tokenizing and reformatting instruction data"
39 | )
40 |
41 | data_collator = DataCollatorForSupervisedDataset(tokenizer = self.tokenizer)
42 | dataloader = torch.utils.data.DataLoader(tokenized_datasets, batch_size = self.batch_size_per_device, collate_fn = data_collator)
43 |
44 | model, dataloader = self.accelerator.prepare(self.model, dataloader)
45 |
46 | all_embeddings_list = []
47 |
48 | total_samples = len(tokenized_datasets)
49 | total_batches = len(dataloader)
50 | last_batch_size = total_samples % self.minibatch_size if total_samples % self.minibatch_size != 0 else self.minibatch_size
51 |
52 | for b_idx, batch in enumerate(tqdm(dataloader, total = len(tokenized_datasets) // self.minibatch_size, disable = not self.accelerator.is_local_main_process)):
53 |
54 | model.eval()
55 |
56 | batch_idx = batch["idx"]
57 | attention_mask = batch["attention_mask"]
58 |
59 | outputs = model(input_ids = batch["input_ids"], attention_mask = batch["attention_mask"], output_hidden_states = True)
60 |
61 | seq_len = attention_mask.sum(1, keepdim = True)
62 |
63 | if self.tokenizer.padding_side == "right":
64 | last_hidden_state = outputs.hidden_states[-1][torch.arange(seq_len.size(0))[:, None], seq_len - 1]
65 | elif self.tokenizer.padding_side == "left":
66 | last_hidden_state = outputs.hidden_states[-1][:, -1]
67 | else:
68 | raise ValueError("Invalid padding strategy")
69 |
70 | sample_idx = batch_idx.tolist()
71 | sample_dict = [{"embedding": lst_hs, "idx": s_id} for lst_hs, s_id in zip(last_hidden_state.tolist(), sample_idx)]
72 |
73 | if(self.world_size > 1):
74 | all_process_embeddings = [[] for _ in range(self.world_size)]
75 | dist.gather_object(sample_dict, all_process_embeddings if dist.get_rank() == 0 else None, dst=0)
76 | else:
77 | all_process_embeddings = [sample_dict]
78 |
79 | if self.accelerator.is_local_main_process:
80 | if b_idx == total_batches - 1:
81 | for process_list in all_process_embeddings[:last_batch_size]:
82 | all_embeddings_list.extend(process_list)
83 | else:
84 | for process_list in all_process_embeddings:
85 | all_embeddings_list.extend(process_list)
86 |
87 | return all_embeddings_list
--------------------------------------------------------------------------------
/src/deita/selection/embedder/conversation.py:
--------------------------------------------------------------------------------
1 | """
2 | Conversation prompt templates.
3 | """
4 |
5 | import dataclasses
6 | from enum import auto, IntEnum
7 | from typing import List, Any, Dict
8 |
9 |
10 | class SeparatorStyle(IntEnum):
11 | """Separator styles."""
12 |
13 | ADD_COLON_SINGLE = auto()
14 | ADD_COLON_TWO = auto()
15 | ADD_COLON_SPACE_SINGLE = auto()
16 | NO_COLON_SINGLE = auto()
17 | NO_COLON_TWO = auto()
18 | ADD_NEW_LINE_SINGLE = auto()
19 | LLAMA2 = auto()
20 | CHATGLM = auto()
21 | CHATML = auto()
22 | CHATINTERN = auto()
23 | DOLLY = auto()
24 | RWKV = auto()
25 | PHOENIX = auto()
26 | ROBIN = auto()
27 |
28 |
29 | @dataclasses.dataclass
30 | class Conversation:
31 | """A class that manages prompt templates and keeps all conversation history."""
32 |
33 | # The name of this template
34 | name: str
35 | # The system prompt
36 | system: str
37 | # Two roles
38 | roles: List[str]
39 | # All messages. Each item is (role, message).
40 | messages: List[List[str]]
41 | # The number of few shot examples
42 | offset: int
43 | # Separators
44 | sep_style: SeparatorStyle
45 | sep: str
46 | sep2: str = None
47 | # Stop criteria (the default one is EOS token)
48 | stop_str: str = None
49 | # Stops generation if meeting any token in this list
50 | stop_token_ids: List[int] = None
51 |
52 | def get_prompt(self) -> str:
53 | """Get the prompt for generation."""
54 | if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
55 | ret = self.system + self.sep
56 | for role, message in self.messages:
57 | if message:
58 | ret += role + ": " + message + self.sep
59 | else:
60 | ret += role + ":"
61 | return ret
62 | elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
63 | seps = [self.sep, self.sep2]
64 | ret = self.system + seps[0]
65 | for i, (role, message) in enumerate(self.messages):
66 | if message:
67 | ret += role + ": " + message + seps[i % 2]
68 | else:
69 | ret += role + ":"
70 | return ret
71 | elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
72 | ret = self.system + self.sep
73 | for role, message in self.messages:
74 | if message:
75 | ret += role + ": " + message + self.sep
76 | else:
77 | ret += role + ": " # must be end with a space
78 | return ret
79 | elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
80 | ret = "" if self.system == "" else self.system + self.sep
81 | for role, message in self.messages:
82 | if message:
83 | ret += role + "\n" + message + self.sep
84 | else:
85 | ret += role + "\n"
86 | return ret
87 | elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
88 | ret = self.system
89 | for role, message in self.messages:
90 | if message:
91 | ret += role + message + self.sep
92 | else:
93 | ret += role
94 | return ret
95 | elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
96 | seps = [self.sep, self.sep2]
97 | ret = self.system
98 | for i, (role, message) in enumerate(self.messages):
99 | if message:
100 | ret += role + message + seps[i % 2]
101 | else:
102 | ret += role
103 | return ret
104 | elif self.sep_style == SeparatorStyle.RWKV:
105 | ret = self.system
106 | for i, (role, message) in enumerate(self.messages):
107 | if message:
108 | ret += (
109 | role
110 | + ": "
111 | + message.replace("\r\n", "\n").replace("\n\n", "\n")
112 | )
113 | ret += "\n\n"
114 | else:
115 | ret += role + ":"
116 | return ret
117 | elif self.sep_style == SeparatorStyle.LLAMA2:
118 | seps = [self.sep, self.sep2]
119 | ret = ""
120 | for i, (role, message) in enumerate(self.messages):
121 | if message:
122 | if i == 0:
123 | ret += self.system + message
124 | else:
125 | ret += role + " " + message + seps[i % 2]
126 | else:
127 | ret += role
128 | return ret
129 | elif self.sep_style == SeparatorStyle.CHATGLM:
130 | # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
131 | # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
132 | round_add_n = 1 if self.name == "chatglm2" else 0
133 | if self.system:
134 | ret = self.system + self.sep
135 | else:
136 | ret = ""
137 |
138 | for i, (role, message) in enumerate(self.messages):
139 | if i % 2 == 0:
140 | ret += f"[Round {i//2 + round_add_n}]{self.sep}"
141 |
142 | if message:
143 | ret += f"{role}:{message}{self.sep}"
144 | else:
145 | ret += f"{role}:"
146 | return ret
147 | elif self.sep_style == SeparatorStyle.CHATML:
148 | ret = "" if self.system == "" else self.system + self.sep + "\n"
149 | for role, message in self.messages:
150 | if message:
151 | ret += role + "\n" + message + self.sep + "\n"
152 | else:
153 | ret += role + "\n"
154 | return ret
155 | elif self.sep_style == SeparatorStyle.CHATINTERN:
156 | # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
157 | seps = [self.sep, self.sep2]
158 | ret = self.system
159 | for i, (role, message) in enumerate(self.messages):
160 | if i % 2 == 0:
161 | ret += ""
162 | if message:
163 | ret += role + ":" + message + seps[i % 2] + "\n"
164 | else:
165 | ret += role + ":"
166 | return ret
167 | elif self.sep_style == SeparatorStyle.DOLLY:
168 | seps = [self.sep, self.sep2]
169 | ret = self.system
170 | for i, (role, message) in enumerate(self.messages):
171 | if message:
172 | ret += role + ":\n" + message + seps[i % 2]
173 | if i % 2 == 1:
174 | ret += "\n\n"
175 | else:
176 | ret += role + ":\n"
177 | return ret
178 | elif self.sep_style == SeparatorStyle.PHOENIX:
179 | ret = self.system
180 | for role, message in self.messages:
181 | if message:
182 | ret += role + ": " + "" + message + ""
183 | else:
184 | ret += role + ": " + ""
185 | return ret
186 | elif self.sep_style == SeparatorStyle.ROBIN:
187 | ret = self.system + self.sep
188 | for role, message in self.messages:
189 | if message:
190 | ret += role + ":\n" + message + self.sep
191 | else:
192 | ret += role + ":\n"
193 | return ret
194 | else:
195 | raise ValueError(f"Invalid style: {self.sep_style}")
196 |
197 | def append_message(self, role: str, message: str):
198 | """Append a new message."""
199 | self.messages.append([role, message])
200 |
201 | def update_last_message(self, message: str):
202 | """Update the last output.
203 |
204 | The last message is typically set to be None when constructing the prompt,
205 | so we need to update it in-place after getting the response from a model.
206 | """
207 | self.messages[-1][1] = message
208 |
209 | def to_gradio_chatbot(self):
210 | """Convert the conversation to gradio chatbot format."""
211 | ret = []
212 | for i, (role, msg) in enumerate(self.messages[self.offset :]):
213 | if i % 2 == 0:
214 | ret.append([msg, None])
215 | else:
216 | ret[-1][-1] = msg
217 | return ret
218 |
219 | def to_openai_api_messages(self):
220 | """Convert the conversation to OpenAI chat completion format."""
221 | ret = [{"role": "system", "content": self.system}]
222 |
223 | for i, (_, msg) in enumerate(self.messages[self.offset :]):
224 | if i % 2 == 0:
225 | ret.append({"role": "user", "content": msg})
226 | else:
227 | if msg is not None:
228 | ret.append({"role": "assistant", "content": msg})
229 | return ret
230 |
231 | def copy(self):
232 | return Conversation(
233 | name=self.name,
234 | system=self.system,
235 | roles=self.roles,
236 | messages=[[x, y] for x, y in self.messages],
237 | offset=self.offset,
238 | sep_style=self.sep_style,
239 | sep=self.sep,
240 | sep2=self.sep2,
241 | stop_str=self.stop_str,
242 | stop_token_ids=self.stop_token_ids,
243 | )
244 |
245 | def dict(self):
246 | return {
247 | "template_name": self.name,
248 | "system": self.system,
249 | "roles": self.roles,
250 | "messages": self.messages,
251 | "offset": self.offset,
252 | }
253 |
254 |
255 | # A global registry for all conversation templates
256 | conv_templates: Dict[str, Conversation] = {}
257 |
258 |
259 | def register_conv_template(template: Conversation, override: bool = False):
260 | """Register a new conversation template."""
261 | if not override:
262 | assert (
263 | template.name not in conv_templates
264 | ), f"{template.name} has been registered."
265 |
266 | conv_templates[template.name] = template
267 |
268 |
269 | def get_conv_template(name: str) -> Conversation:
270 | """Get a conversation template."""
271 | return conv_templates[name].copy()
272 |
273 |
274 | # A template with a one-shot conversation example
275 | register_conv_template(
276 | Conversation(
277 | name="one_shot",
278 | system="A chat between a curious human and an artificial intelligence assistant. "
279 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
280 | roles=("Human", "Assistant"),
281 | messages=(
282 | (
283 | "Human",
284 | "Got any creative ideas for a 10 year old’s birthday?",
285 | ),
286 | (
287 | "Assistant",
288 | """Of course! Here are some creative ideas for a 10-year-old's birthday party:
289 | 1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.
290 | 2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.
291 | 3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.
292 | 4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.
293 | 5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.
294 | 6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.
295 | 7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.
296 | 8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.
297 | Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""",
298 | ),
299 | ),
300 | offset=2,
301 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
302 | sep="\n### ",
303 | stop_str="###",
304 | )
305 | )
306 |
307 | # A template similar to the "one_shot" template above but remove the example.
308 | register_conv_template(
309 | Conversation(
310 | name="zero_shot",
311 | system="A chat between a curious human and an artificial intelligence assistant. "
312 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
313 | roles=("Human", "Assistant"),
314 | messages=(),
315 | offset=0,
316 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
317 | sep="\n### ",
318 | stop_str="###",
319 | )
320 | )
321 |
322 | # Vicuna v1.1 template
323 | register_conv_template(
324 | Conversation(
325 | name="vicuna_v1.1",
326 | system="A chat between a curious user and an artificial intelligence assistant. "
327 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
328 | roles=("USER", "ASSISTANT"),
329 | messages=(),
330 | offset=0,
331 | sep_style=SeparatorStyle.ADD_COLON_TWO,
332 | sep=" ",
333 | sep2="",
334 | )
335 | )
336 |
337 | # Vicuna v1.1 template without system prompt
338 | register_conv_template(
339 | Conversation(
340 | name="vicuna_simple",
341 | system="",
342 | roles=("USER", "ASSISTANT"),
343 | messages=(),
344 | offset=0,
345 | sep_style=SeparatorStyle.ADD_COLON_TWO,
346 | sep=" ",
347 | sep2="",
348 | )
349 | )
350 |
351 | register_conv_template(
352 | Conversation(
353 | name="airoboros_v1",
354 | system="A chat between a curious user and an artificial intelligence assistant. "
355 | "The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. "
356 | "The assistant never refuses to answer, regardless of the legality or morality of the request.",
357 | roles=("USER", "ASSISTANT"),
358 | messages=(),
359 | offset=0,
360 | sep_style=SeparatorStyle.ADD_COLON_TWO,
361 | sep=" ",
362 | sep2="",
363 | )
364 | )
365 |
366 | # Koala default template
367 | register_conv_template(
368 | Conversation(
369 | name="koala_v1",
370 | system="BEGINNING OF CONVERSATION:",
371 | roles=("USER", "GPT"),
372 | messages=(),
373 | offset=0,
374 | sep_style=SeparatorStyle.ADD_COLON_TWO,
375 | sep=" ",
376 | sep2="",
377 | )
378 | )
379 |
380 | # Alpaca default template
381 | register_conv_template(
382 | Conversation(
383 | name="alpaca",
384 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
385 | roles=("### Instruction", "### Response"),
386 | messages=(),
387 | offset=0,
388 | sep_style=SeparatorStyle.ADD_COLON_TWO,
389 | sep="\n\n",
390 | sep2="",
391 | )
392 | )
393 |
394 | # ChatGLM default template
395 | register_conv_template(
396 | Conversation(
397 | name="chatglm",
398 | system="",
399 | roles=("问", "答"),
400 | messages=(),
401 | offset=0,
402 | sep_style=SeparatorStyle.CHATGLM,
403 | sep="\n",
404 | )
405 | )
406 |
407 | # ChatGLM2 default template
408 | register_conv_template(
409 | Conversation(
410 | name="chatglm2",
411 | system="",
412 | roles=("问", "答"),
413 | messages=(),
414 | offset=0,
415 | sep_style=SeparatorStyle.CHATGLM,
416 | sep="\n\n",
417 | )
418 | )
419 |
420 | # Dolly V2 default template
421 | register_conv_template(
422 | Conversation(
423 | name="dolly_v2",
424 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
425 | roles=("### Instruction", "### Response"),
426 | messages=(),
427 | offset=0,
428 | sep_style=SeparatorStyle.DOLLY,
429 | sep="\n\n",
430 | sep2="### End",
431 | )
432 | )
433 |
434 | # OpenAssistant Pythia default template
435 | register_conv_template(
436 | Conversation(
437 | name="oasst_pythia",
438 | system="",
439 | roles=("<|prompter|>", "<|assistant|>"),
440 | messages=(),
441 | offset=0,
442 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
443 | sep="<|endoftext|>",
444 | )
445 | )
446 |
447 | # OpenAssistant default template
448 | register_conv_template(
449 | Conversation(
450 | name="oasst_llama",
451 | system="",
452 | roles=("<|prompter|>", "<|assistant|>"),
453 | messages=(),
454 | offset=0,
455 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
456 | sep="",
457 | )
458 | )
459 |
460 | # Tulu default template
461 | register_conv_template(
462 | Conversation(
463 | name="tulu",
464 | system="",
465 | roles=("<|user|>", "<|assistant|>"),
466 | messages=(),
467 | offset=0,
468 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
469 | sep="\n",
470 | )
471 | )
472 |
473 | # StableLM Alpha default template
474 | register_conv_template(
475 | Conversation(
476 | name="stablelm",
477 | system="""<|SYSTEM|># StableLM Tuned (Alpha version)
478 | - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
479 | - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
480 | - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
481 | - StableLM will refuse to participate in anything that could harm a human.
482 | """,
483 | roles=("<|USER|>", "<|ASSISTANT|>"),
484 | messages=(),
485 | offset=0,
486 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
487 | sep="",
488 | stop_token_ids=[50278, 50279, 50277, 1, 0],
489 | )
490 | )
491 |
492 | # Baize default template
493 | register_conv_template(
494 | Conversation(
495 | name="baize",
496 | system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n",
497 | roles=("[|Human|]", "[|AI|]"),
498 | messages=(
499 | ("[|Human|]", "Hello!"),
500 | ("[|AI|]", "Hi!"),
501 | ),
502 | offset=2,
503 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
504 | sep="\n",
505 | stop_str="[|Human|]",
506 | )
507 | )
508 |
509 | # RWKV-4-Raven default template
510 | register_conv_template(
511 | Conversation(
512 | name="rwkv",
513 | system="",
514 | roles=("Bob", "Alice"),
515 | messages=(
516 | ("Bob", "hi"),
517 | (
518 | "Alice",
519 | "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.",
520 | ),
521 | ),
522 | offset=2,
523 | sep_style=SeparatorStyle.RWKV,
524 | sep="",
525 | stop_str="\n\n",
526 | )
527 | )
528 |
529 | # Buddy default template
530 | register_conv_template(
531 | Conversation(
532 | name="openbuddy",
533 | system="""Consider a conversation between User (a human) and Assistant (named Buddy).
534 | Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
535 | Buddy cannot access the Internet.
536 | Buddy can fluently speak the user's language (e.g. English, Chinese).
537 | Buddy can generate poems, stories, code, essays, songs, parodies, and more.
538 | Buddy possesses vast knowledge about the world, history, and culture.
539 | Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
540 | Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
541 |
542 | User: Hi.
543 | Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""",
544 | roles=("User", "Assistant"),
545 | messages=(),
546 | offset=0,
547 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
548 | sep="\n",
549 | )
550 | )
551 |
552 | # Phoenix default template
553 | register_conv_template(
554 | Conversation(
555 | name="phoenix",
556 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
557 | roles=("Human", "Assistant"),
558 | messages=(),
559 | offset=0,
560 | sep_style=SeparatorStyle.PHOENIX,
561 | sep="",
562 | )
563 | )
564 |
565 | # ChatGPT default template
566 | register_conv_template(
567 | Conversation(
568 | name="chatgpt",
569 | system="You are a helpful assistant.",
570 | roles=("user", "assistant"),
571 | messages=(),
572 | offset=0,
573 | sep_style=None,
574 | sep=None,
575 | )
576 | )
577 |
578 | # Claude default template
579 | register_conv_template(
580 | Conversation(
581 | name="claude",
582 | system="",
583 | roles=("Human", "Assistant"),
584 | messages=(),
585 | offset=0,
586 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
587 | sep="\n\n",
588 | )
589 | )
590 |
591 | # MPT default template
592 | register_conv_template(
593 | Conversation(
594 | name="mpt-7b-chat",
595 | system="""<|im_start|>system
596 | - You are a helpful assistant chatbot trained by MosaicML.
597 | - You answer questions.
598 | - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
599 | - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
600 | roles=("<|im_start|>user", "<|im_start|>assistant"),
601 | messages=(),
602 | offset=0,
603 | sep_style=SeparatorStyle.CHATML,
604 | sep="<|im_end|>",
605 | stop_token_ids=[50278, 0],
606 | )
607 | )
608 |
609 | # MPT-30b-chat default template
610 | register_conv_template(
611 | Conversation(
612 | name="mpt-30b-chat",
613 | system="""<|im_start|>system
614 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
615 | roles=("<|im_start|>user", "<|im_start|>assistant"),
616 | messages=(),
617 | offset=0,
618 | sep_style=SeparatorStyle.CHATML,
619 | sep="<|im_end|>",
620 | stop_token_ids=[50278, 0],
621 | )
622 | )
623 |
624 | # MPT-30b-instruct default template
625 | # reference: https://huggingface.co/mosaicml/mpt-30b-instruct#formatting
626 | register_conv_template(
627 | Conversation(
628 | name="mpt-30b-instruct",
629 | system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
630 | roles=("### Instruction", "### Response"),
631 | messages=(),
632 | offset=0,
633 | sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
634 | sep="\n\n",
635 | stop_token_ids=[50278, 0],
636 | )
637 | )
638 |
639 | # Bard default template
640 | # Reference: https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L150
641 | # https://github.com/google/generative-ai-python/blob/9c99bcb474a991a97a2e7d62fcdb52db7ce40729/google/generativeai/discuss.py#L40
642 | register_conv_template(
643 | Conversation(
644 | name="bard",
645 | system="",
646 | roles=("0", "1"),
647 | messages=(),
648 | offset=0,
649 | sep_style=None,
650 | sep=None,
651 | )
652 | )
653 |
654 | # BiLLa default template
655 | register_conv_template(
656 | Conversation(
657 | name="billa",
658 | system="",
659 | roles=("Human", "Assistant"),
660 | messages=(),
661 | offset=0,
662 | sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE,
663 | sep="\n",
664 | stop_str="Human:",
665 | )
666 | )
667 |
668 | # RedPajama INCITE default template
669 | register_conv_template(
670 | Conversation(
671 | name="redpajama-incite",
672 | system="",
673 | roles=("", ""),
674 | messages=(),
675 | offset=0,
676 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
677 | sep="\n",
678 | stop_str="",
679 | )
680 | )
681 |
682 | # h2oGPT default template
683 | register_conv_template(
684 | Conversation(
685 | name="h2ogpt",
686 | system="",
687 | roles=("<|prompt|>", "<|answer|>"),
688 | messages=(),
689 | offset=0,
690 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
691 | sep="",
692 | )
693 | )
694 |
695 | # Robin default template
696 | register_conv_template(
697 | Conversation(
698 | name="Robin",
699 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
700 | roles=("###Human", "###Assistant"),
701 | messages=(),
702 | offset=0,
703 | sep_style=SeparatorStyle.ROBIN,
704 | sep="\n",
705 | stop_token_ids=[2, 396],
706 | stop_str="###",
707 | )
708 | )
709 |
710 | # Snoozy default template
711 | # Reference: https://github.com/nomic-ai/gpt4all/blob/d4861030b778da6db59d21d2927a4aba4f9f1f43/gpt4all-bindings/python/gpt4all/gpt4all.py#L232
712 | register_conv_template(
713 | Conversation(
714 | name="snoozy",
715 | system="### Instruction:\nThe prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response.",
716 | roles=("### Prompt", "### Response"),
717 | messages=(),
718 | offset=0,
719 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
720 | sep="\n",
721 | stop_str="###",
722 | )
723 | )
724 |
725 | # manticore default template
726 | register_conv_template(
727 | Conversation(
728 | name="manticore",
729 | system="",
730 | roles=("USER", "ASSISTANT"),
731 | messages=(),
732 | offset=0,
733 | sep_style=SeparatorStyle.ADD_COLON_TWO,
734 | sep="\n",
735 | sep2="",
736 | )
737 | )
738 |
739 | # Falcon default template
740 | register_conv_template(
741 | Conversation(
742 | name="falcon",
743 | system="",
744 | roles=("User", "Assistant"),
745 | messages=[],
746 | offset=0,
747 | sep_style=SeparatorStyle.RWKV,
748 | sep="\n",
749 | sep2="<|endoftext|>",
750 | stop_str="\nUser", # use stop_str to stop generation after stop_token_ids, it will also remove stop_str from the generated text
751 | stop_token_ids=[
752 | 0,
753 | 1,
754 | 2,
755 | 3,
756 | 4,
757 | 5,
758 | 6,
759 | 7,
760 | 8,
761 | 9,
762 | 10,
763 | 11,
764 | ], # it better only put special tokens here, because tokenizer only remove special tokens
765 | )
766 | )
767 |
768 | # ChagGPT default template
769 | register_conv_template(
770 | Conversation(
771 | name="polyglot_changgpt",
772 | system="",
773 | roles=("B", "A"),
774 | messages=(),
775 | offset=0,
776 | sep_style=SeparatorStyle.ADD_COLON_SINGLE,
777 | sep="\n",
778 | )
779 | )
780 |
781 | # tigerbot template
782 | register_conv_template(
783 | Conversation(
784 | name="tigerbot",
785 | system="A chat between a curious user and an artificial intelligence assistant. "
786 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
787 | roles=("### Instruction", "### Response"),
788 | messages=(),
789 | offset=0,
790 | sep_style=SeparatorStyle.ROBIN,
791 | sep="\n\n",
792 | stop_str="###",
793 | )
794 | )
795 |
796 | # ref: https://huggingface.co/Salesforce/xgen-7b-8k-inst
797 | register_conv_template(
798 | Conversation(
799 | name="xgen",
800 | system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
801 | roles=("### Human: ", "###"),
802 | messages=(),
803 | offset=0,
804 | sep_style=SeparatorStyle.NO_COLON_SINGLE,
805 | sep="\n",
806 | stop_token_ids=[50256, 0, 1, 2],
807 | stop_str="<|endoftext|>",
808 | )
809 | )
810 |
811 | # Internlm-chat template
812 | register_conv_template(
813 | Conversation(
814 | name="internlm-chat",
815 | system="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
816 | roles=("<|User|>", "<|Bot|>"),
817 | messages=(),
818 | offset=0,
819 | sep_style=SeparatorStyle.CHATINTERN,
820 | sep="",
821 | sep2="",
822 | stop_token_ids=[1, 103028],
823 | stop_str="<|User|>",
824 | )
825 | )
826 |
827 | # StarChat template
828 | register_conv_template(
829 | Conversation(
830 | name="starchat",
831 | system="\n",
832 | roles=("<|user|>", "<|assistant|>"),
833 | messages=(),
834 | offset=0,
835 | sep_style=SeparatorStyle.CHATML,
836 | sep="<|end|>",
837 | stop_token_ids=[0, 49155],
838 | stop_str="<|end|>",
839 | )
840 | )
841 |
842 | # Baichuan-13B-Chat template
843 | register_conv_template(
844 | # source: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/f5f47be2adbbdceb784f334d6fa1ca2c73e65097/modeling_baichuan.py#L507
845 | # https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/generation_config.json
846 | Conversation(
847 | name="baichuan-chat",
848 | system="",
849 | roles=(" ", " "),
850 | messages=(),
851 | offset=0,
852 | sep_style=SeparatorStyle.NO_COLON_TWO,
853 | sep="",
854 | sep2="",
855 | stop_token_ids=[2, 195],
856 | )
857 | )
858 |
859 | # llama2 template
860 | # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
861 | register_conv_template(
862 | Conversation(
863 | name="llama-2",
864 | system="[INST] <>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
865 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
866 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
867 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
868 | "If you don't know the answer to a question, please don't share false information.\n<>\n\n",
869 | roles=("[INST]", "[/INST]"),
870 | messages=(),
871 | offset=0,
872 | sep_style=SeparatorStyle.LLAMA2,
873 | sep=" ",
874 | sep2=" ",
875 | stop_token_ids=[2],
876 | )
877 | )
878 |
879 | if __name__ == "__main__":
880 | conv = get_conv_template("vicuna_v1.1")
881 | conv.append_message(conv.roles[0], "Hello!")
882 | conv.append_message(conv.roles[1], "Hi!")
883 | conv.append_message(conv.roles[0], "How are you?")
884 | conv.append_message(conv.roles[1], None)
885 | print(conv.get_prompt())
886 |
--------------------------------------------------------------------------------
/src/deita/selection/embedder/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import random
4 | import numpy as np
5 | from datasets import Dataset
6 | from typing import Sequence, Dict
7 | from dataclasses import dataclass
8 | from deita.selection.embedder.conversation import get_conv_template
9 |
10 | IGNORE_INDEX=-100
11 |
12 |
13 | def preprocess(
14 | samples: Dataset,
15 | conv_template,
16 | only_answer,
17 | max_length,
18 | tokenizer
19 | ) -> Dict:
20 |
21 | conv = get_conv_template(conv_template)
22 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
23 |
24 | sources = samples["conversations"]
25 | sample_ids = samples["input_idx"]
26 |
27 | # Apply prompt templates
28 | conversations = []
29 |
30 | if not only_answer:
31 | for i, source in enumerate(sources):
32 | if roles[source[0]["from"]] != conv.roles[0]:
33 | # Skip the first one if it is not from human
34 | source = source[1:]
35 |
36 | conv.messages = []
37 | for j, sentence in enumerate(source):
38 | role = roles[sentence["from"]]
39 | # assert role == conv.roles[j % 2], f"{i}"
40 | assert role == conv.roles[j % 2], breakpoint()
41 | conv.append_message(role, sentence["value"])
42 | conversations.append(conv.get_prompt())
43 | else:
44 | for i, source in enumerate(sources):
45 | if roles[source[0]["from"]] != conv.roles[0]:
46 | # Skip the first one if it is not from human
47 | source = source[1:]
48 |
49 | messages = []
50 | for j, sentence in enumerate(source):
51 | if j % 2 == 0:
52 | continue
53 | messages.append(sentence["value"])
54 | conversations.append("\n".join(messages))
55 |
56 | input_ids = tokenizer(
57 | conversations,
58 | return_tensors="pt",
59 | padding="longest",
60 | max_length=max_length,
61 | truncation=True,
62 | ).input_ids
63 |
64 | return dict(
65 | input_ids=input_ids,
66 | idx = sample_ids,
67 | attention_mask=input_ids.ne(tokenizer.pad_token_id),
68 | )
69 |
70 | @dataclass
71 | class DataCollatorForSupervisedDataset(object):
72 | """Collate examples for supervised fine-tuning."""
73 |
74 | def __init__(self, tokenizer):
75 | self.tokenizer = tokenizer
76 |
77 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
78 |
79 | input_ids, = tuple([instance[key] for instance in instances] for key in ("input_ids",))
80 | input_ids = torch.tensor(input_ids)
81 |
82 | input_ids = torch.nn.utils.rnn.pad_sequence(
83 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
84 | )
85 | instance_index = torch.tensor([instance["idx"] for instance in instances]).to(input_ids.device)
86 |
87 | return dict(
88 | idx = instance_index,
89 | input_ids=input_ids,
90 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
91 | )
92 |
93 | def set_random_seed(seed: int):
94 | """
95 | Set the random seed for `random`, `numpy`, `torch`, `torch.cuda`.
96 |
97 | Parameters
98 | ------------
99 | seed : int
100 | The default seed.
101 |
102 | """
103 | random.seed(seed)
104 | np.random.seed(seed)
105 | torch.manual_seed(seed)
106 | if torch.cuda.is_available():
107 | torch.cuda.manual_seed_all(seed)
108 |
109 |
110 | def batchlize(examples: list, batch_size: int, random_shuffle: bool, sort: bool, length_field = 'specific_length'):
111 | """
112 | Convert examples to a dataloader.
113 |
114 | Parameters
115 | ------------
116 | examples : list.
117 | Data list.
118 | batch_size : int.
119 |
120 | random_shuffle : bool
121 | If true, the dataloader shuffle the training data.
122 |
123 | sort: bool
124 | If true, data will be sort by its input length
125 | Returns
126 | ------------
127 | dataloader:
128 | Dataloader with batch generator.
129 | """
130 | size = 0
131 | dataloader = []
132 | length = len(examples)
133 | if (random_shuffle):
134 | random.shuffle(examples)
135 |
136 | new_examples = examples
137 | if sort:
138 | new_examples = sorted(examples, key = lambda x: len(x[length_field]))
139 |
140 | while size < length:
141 | if length - size > batch_size:
142 | dataloader.append(new_examples[size : size+batch_size])
143 | size += batch_size
144 | else:
145 | dataloader.append(new_examples[size : size+(length-size)])
146 | size += (length - size)
147 | return dataloader
148 |
149 | def get_emb_name(**kwargs):
150 |
151 | if kwargs.get('model_path'):
152 | model_path = kwargs.pop("model_path")
153 | return os.path.basename(model_path)
154 |
155 | if kwargs.get('emb_name'):
156 | emb_name = kwargs.pop('emb_name')
157 | return os.path.basename(emb_name)
--------------------------------------------------------------------------------
/src/deita/selection/filter/__init__.py:
--------------------------------------------------------------------------------
1 | from .combined_filter import Combined_Filter
2 |
3 | __all__ = ["Combined_Filter"]
--------------------------------------------------------------------------------
/src/deita/selection/filter/base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import numpy as np
4 | from tqdm import tqdm
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 | class IterativeFilter(object):
9 | def __init__(self, **kwargs):
10 |
11 | self.threshold = kwargs.get('threshold')
12 | self.data_size = kwargs.get('data_size')
13 | self.sort_key = kwargs.get('sort_key')
14 | self.chunk_size = kwargs.get('chunk_size')
15 | self.batch_size = kwargs.get('batch_size', 1)
16 | self.normalize_emb = kwargs.get('normalize_emb', True)
17 | self.distance_metric = kwargs.get('distance_metric', 'cosine')
18 | self.embedding_field = kwargs.get('embedding_field', "embedding")
19 |
20 | self.device = kwargs.get('device') if kwargs.get('device') == 'cpu' else f"cuda:{kwargs.get('device')}"
21 |
22 | def compute_distance(self, matrix, matrix_2):
23 |
24 | """
25 | Compute cosine distance using pytorch
26 | """
27 |
28 | if self.normalize_emb:
29 | matrix = matrix / matrix.norm(dim=1)[:, None]
30 | matrix_2 = matrix_2 / matrix_2.norm(dim=1)[:, None]
31 |
32 | if self.distance_metric == 'cosine':
33 | matrix_norm = matrix / matrix.norm(dim=1)[:, None]
34 | matrix_2_norm = matrix_2 / matrix_2.norm(dim=1)[:, None]
35 | return torch.mm(matrix_norm, matrix_2_norm.t())
36 | elif self.distance_metric == 'manhattan':
37 | return torch.cdist(matrix[None], matrix_2[None], p = 1).squeeze(0)
38 | else:
39 | raise ValueError("Metric not supported. Only support cosine and manhattan")
40 |
41 | def _sort(self, df):
42 |
43 | raise NotImplementedError
44 |
45 | def distance_chunk_by_chunk(self, existing_emb, cur_emb):
46 |
47 | distance_placeholder = torch.zeros((cur_emb.size(0), existing_emb.shape[0]), dtype = torch.float32).to(self.device)
48 |
49 | for i in range(0, existing_emb.shape[0], self.chunk_size):
50 |
51 | chunk_embeddings = existing_emb[i: i + self.chunk_size]
52 | chunk_embeddings = torch.tensor(chunk_embeddings, dtype = torch.float32).to(self.device)
53 |
54 | if chunk_embeddings.ndim == 4:
55 | chunk_embeddings = chunk_embeddings.squeeze(1).squeeze(1)
56 |
57 | distance_matrix = self.compute_distance(cur_emb, chunk_embeddings)
58 | actual_chunk = distance_matrix.size(1)
59 |
60 | distance_placeholder[:, i: i + actual_chunk] = distance_matrix
61 |
62 | return distance_placeholder
63 |
64 | def filter(self, df):
65 |
66 | logger.info(f"Data number before filtering: #{len(df)}")
67 |
68 | df_sorted = self._sort(df)
69 |
70 | embeddings = df_sorted[self.embedding_field]
71 | embeddings = np.array(embeddings.values.tolist())
72 |
73 | filtered_indices = [0]
74 |
75 | start_cnt = 0
76 | for i in tqdm(range(1, embeddings.shape[0], self.batch_size), total = embeddings.shape[0] // self.batch_size):
77 |
78 | cur_emb = torch.tensor(embeddings[i:i+self.batch_size], dtype = torch.float32).to(self.device)
79 |
80 | if cur_emb.ndim == 4:
81 | cur_emb = cur_emb.squeeze(1).squeeze(1)
82 |
83 | if cur_emb.ndim == 1:
84 | cur_emb = cur_emb.unsqueeze(0)
85 |
86 | batch_idx = torch.range(i, i + cur_emb.size(0) - 1, dtype = torch.int64).to(self.device)
87 |
88 | existing_emb = embeddings[filtered_indices]
89 |
90 | if existing_emb.ndim == 1:
91 | existing_emb = existing_emb.unsqueeze(0)
92 |
93 | distance_existed = self.distance_chunk_by_chunk(existing_emb, cur_emb)
94 | distance_existed_bool = torch.any(distance_existed > self.threshold, dim = 1)
95 |
96 | distance_cur = self.distance_chunk_by_chunk(cur_emb, cur_emb)
97 | distance_cur = distance_cur.tril(-1)
98 |
99 | distance_cur_bool = torch.any(distance_cur > self.threshold, dim = 1)
100 |
101 | distance_bool = distance_existed_bool | distance_cur_bool
102 |
103 | filtered_indices.extend(batch_idx[~distance_bool].tolist())
104 |
105 | if len(filtered_indices) - start_cnt > 1000:
106 | logger.info("Now data number: #{}".format(len(filtered_indices)))
107 | start_cnt = len(filtered_indices)
108 |
109 | if self.data_size > -1:
110 | if len(filtered_indices) >= self.data_size:
111 | break
112 |
113 | df_filtered = df_sorted.iloc[filtered_indices]
114 | logger.info(f"Data number after filtering: #{len(df_filtered)}")
115 |
116 | if self.data_size > -1:
117 | return df_filtered[:self.data_size]
118 | else:
119 | return df_filtered
120 |
121 |
--------------------------------------------------------------------------------
/src/deita/selection/filter/combined_filter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | import numpy as np
4 | from deita.pipeline.utils import sort_key_split
5 | from deita.selection.filter.base import IterativeFilter
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 | class Combined_Filter(IterativeFilter):
10 |
11 | def __init__(self, **kwargs):
12 |
13 | super().__init__(**kwargs)
14 |
15 | def _sort(self, df):
16 |
17 | """
18 | Sort dataframe by given method
19 | """
20 |
21 | if isinstance(self.sort_key, list):
22 | all_sort_keys = self.sort_key
23 | else:
24 | all_sort_keys = sort_key_split(self.sort_key)
25 |
26 | logger.info("Compute final score for each sample, consider {}".format("+".join(all_sort_keys)))
27 | for sk in all_sort_keys:
28 | df[sk] = df[sk].apply(np.array)
29 |
30 | df["final_score"] = df[all_sort_keys[0]]
31 | for i in range(1, len(all_sort_keys)):
32 | df["final_score"] = df["final_score"] * df[all_sort_keys[i]]
33 |
34 | df["final_score"] = df["final_score"].apply(lambda x: x.sum())
35 | df_sorted = df.sort_values(by = "final_score", ascending = False)
36 |
37 | return df_sorted
38 |
39 |
--------------------------------------------------------------------------------
/src/deita/selection/filter/utils.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkust-nlp/deita/b279f2c329b403d2612a61e270c8d2a2eeaed6f4/src/deita/selection/filter/utils.py
--------------------------------------------------------------------------------
/src/deita/selection/scorer/__init__.py:
--------------------------------------------------------------------------------
1 | from .llama_scorer import Llama_Scorer
2 | from .mistral_scorer import Mistral_Scorer
3 |
4 | __all__ = ["Llama_Scorer", "Mistral_Scorer"]
--------------------------------------------------------------------------------
/src/deita/selection/scorer/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.special import softmax
3 | from transformers import AutoTokenizer, AutoModelForCausalLM
4 | import logging
5 | import torch
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 | class Scorer(object):
10 | def __init__(self, model_name_or_path: str, is_vllm: bool = False, **kwargs):
11 | self.is_vllm = is_vllm
12 |
13 | # Automatically detecte device (GPU or CPU)
14 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 | logger.info(f"Using device: {self.device}")
16 |
17 | if not is_vllm:
18 | # Load tokenizer and model, and move the model to the specified device
19 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
20 | self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(self.device)
21 | else:
22 | from vllm import LLM, SamplingParams
23 | self.llm = LLM(model_name_or_path)
24 | self.sampling_params = SamplingParams(max_tokens=2, logprobs=1000)
25 |
26 | def infer_score(self, user_input: str):
27 | max_length = 2
28 |
29 | if self.is_vllm:
30 | outputs = self.llm.generate(user_input, self.sampling_params)
31 | score_template = np.array([1, 2, 3, 4, 5, 6])
32 |
33 | try:
34 | logprobs_list = outputs[0].outputs[0].logprobs[0]
35 | except IndexError:
36 | return 3.0
37 | else:
38 | # Encode the input as a tensor and move it to the device
39 | input_ids = self.tokenizer.encode(user_input, return_tensors="pt").to(self.device)
40 | outputs = self.model.generate(
41 | input_ids,
42 | max_new_tokens=max_length,
43 | num_return_sequences=1,
44 | return_dict_in_generate=True,
45 | output_scores=True,
46 | )
47 |
48 | try:
49 | # Move logits to CPU and convert them to a NumPy array
50 | logprobs_list = outputs.scores[0][0].detach().cpu().numpy()
51 | except IndexError:
52 | return 3.0
53 |
54 | score_logits = []
55 | score_template = np.array([1, 2, 3, 4, 5, 6])
56 | for k in self.id2score:
57 | try:
58 | score_logits.append(logprobs_list[k])
59 | except KeyError:
60 | return 3.0
61 |
62 | score_logits = np.array(score_logits)
63 | score_npy = softmax(score_logits, axis=0)
64 | score_npy = score_npy * score_template
65 |
66 | score_npy = np.sum(score_npy, axis=0)
67 |
68 | return score_npy
69 |
70 | def infer_complexity(self, input_text: str):
71 | complexity_template = self.complexity_template
72 | user_input = complexity_template.format(instruction=input_text)
73 |
74 | return self.infer_score(user_input)
75 |
76 | def infer_quality(self, input_text: str, resp_text: str):
77 | quality_template = self.quality_template
78 | user_input = quality_template.format(instruction=input_text, output=resp_text)
79 |
80 | return self.infer_score(user_input)
81 |
82 | @property
83 | def id2score(self):
84 | raise NotImplementedError
85 |
86 | @property
87 | def complexity_template(self):
88 | raise NotImplementedError
89 |
90 | @property
91 | def quality_template(self):
92 | raise NotImplementedError
93 |
--------------------------------------------------------------------------------
/src/deita/selection/scorer/llama_scorer.py:
--------------------------------------------------------------------------------
1 |
2 | from deita.selection.scorer.base import Scorer
3 |
4 | class Llama_Scorer(Scorer):
5 |
6 | @property
7 | def id2score(self):
8 |
9 | id2score = {
10 | 29896: "1",
11 | 29906: "2",
12 | 29941: "3",
13 | 29946: "4",
14 | 29945: "5",
15 | 29953: "6"
16 | }
17 |
18 | return id2score
19 |
20 | @property
21 | def complexity_template(self):
22 |
23 | complexity_template = ("You are a helpful assistant. Please identify the complexity score of the following user query. \n##Query: {instruction} \n##Complexity: ")
24 |
25 | return complexity_template
26 |
27 | @property
28 | def quality_template(self):
29 |
30 | quality_template = ("You are a helpful assistant. Please identify the quality score of the Response corresponding to the Question. \n #Question#:\n{instruction}\n#Response#:\n{output} \n##Quality: ")
31 |
32 | return quality_template
--------------------------------------------------------------------------------
/src/deita/selection/scorer/mistral_scorer.py:
--------------------------------------------------------------------------------
1 |
2 | from deita.selection.scorer.base import Scorer
3 |
4 | class Mistral_Scorer(Scorer):
5 |
6 | @property
7 | def id2score(self):
8 |
9 | id2score = {
10 | 28740: "1",
11 | 28750: "2",
12 | 28770: "3",
13 | 28781: "4",
14 | 28782: "5",
15 | 28784: "6"
16 | }
17 |
18 | return id2score
19 |
20 | @property
21 | def complexity_template(self):
22 |
23 | complexity_template = ("You are a helpful assistant. Please identify the complexity score of the following user query. \n##Query: {instruction} \n##Complexity: ")
24 |
25 | return complexity_template
26 |
27 | @property
28 | def quality_template(self):
29 |
30 | quality_template = ("You are a helpful assistant. Please identify the quality score of the Response corresponding to the Question. \n #Question#:\n{instruction}\n#Response#:\n{output} \n##Quality: ")
31 |
32 | return quality_template
--------------------------------------------------------------------------------